bep_neo4j/
lib.rs

1//! A Bep vector store for Neo4j.
2//!
3//! This crate is a companion crate to the [bep-core crate](https://github.com/bep-ai/bep-core).
4//! It provides a vector store implementation that uses Neo4j as the underlying datastore.
5//!
6//! See the [README](https://github.com/bepdotai/bep/tree/main/bep-neo4j) for more information.
7//!
8//! ## Prerequisites
9//!
10//! ### GenAI Plugin
11//! The GenAI plugin is enabled by default in Neo4j Aura.
12//!
13//! The plugin needs to be installed on self-managed instances. This is done by moving the neo4j-genai.jar
14//! file from /products to /plugins in the Neo4j home directory, or, if you are using Docker, by starting
15//! the Docker container with the extra parameter `--env NEO4J_PLUGINS='["genai"]'`.
16//!
17//! For more information, see [Operations Manual → Configure plugins](https://neo4j.com/docs/operations-manual/current/plugins/configure/).
18//!
19//! ### Pre-existing Vector Index
20//!
21//! The [Neo4jVectorStoreIndex](Neo4jVectorIndex) struct is designed to work with a pre-existing
22//! Neo4j vector index. You can create the index using the Neo4j browser, a raw Cypher query, or the
23//! [Neo4jClient::create_vector_index] method.
24//! See the [Neo4j documentation](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/)
25//! for more information.
26//!
27//! The index name must be unique among both indexes and constraints.
28//! ❗A newly created index is not immediately available but is created in the background.
29//!
30//! ```cypher
31//! CREATE VECTOR INDEX moviePlots
32//!     FOR (m:Movie)
33//!     ON m.embedding
34//!     OPTIONS {indexConfig: {
35//!         `vector.dimensions`: 1536,
36//!         `vector.similarity_function`: 'cosine'
37//!     }}
38//! ```
39//!
40//! ## Simple example:
41//! More examples can be found in the [/examples](https://github.com/bepdotai/bep/tree/main/bep-neo4j/examples) folder.
42//! ```
43//! use bep_neo4j::{vector_index::*, Neo4jClient};
44//! use neo4rs::ConfigBuilder;
45//! use bep::{providers::openai::*, vector_store::VectorStoreIndex};
46//! use serde::Deserialize;
47//! use std::env;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//!     let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
52//!     let openai_client = Client::new(&openai_api_key);
53//!     let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
54//!
55//!
56//!     const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
57//!     const NEO4J_DB: &str = "recommendations";
58//!     const NEO4J_USERNAME: &str = "recommendations";
59//!     const NEO4J_PASSWORD: &str = "recommendations";
60//!
61//!     let client = Neo4jClient::from_config(
62//!         ConfigBuilder::default()
63//!             .uri(NEO4J_URI)
64//!             .db(NEO4J_DB)
65//!             .user(NEO4J_USERNAME)
66//!             .password(NEO4J_PASSWORD)
67//!             .build()
68//!             .unwrap(),
69//!     )
70//!    .await
71//!    .unwrap();
72//!
73//!     let index = client.get_index(
74//!         model,
75//!         "moviePlotsEmbedding",
76//!         SearchParams::default()
77//!     ).await.unwrap();
78//!
79//!     #[derive(Debug, Deserialize)]
80//!     struct Movie {
81//!         title: String,
82//!         plot: String,
83//!     }
84//!     let results = index.top_n::<Movie>("Batman", 3).await.unwrap();
85//!     println!("{:#?}", results);
86//! }
87//! ```
88pub mod vector_index;
89use std::str::FromStr;
90
91use futures::TryStreamExt;
92use neo4rs::*;
93use bep::{embeddings::EmbeddingModel, vector_store::VectorStoreError};
94use serde::Deserialize;
95use vector_index::{IndexConfig, Neo4jVectorIndex, SearchParams, VectorSimilarityFunction};
96
97pub struct Neo4jClient {
98    pub graph: Graph,
99}
100
101fn neo4j_to_bep_error(e: neo4rs::Error) -> VectorStoreError {
102    VectorStoreError::DatastoreError(Box::new(e))
103}
104
105pub trait ToBoltType {
106    fn to_bolt_type(&self) -> BoltType;
107}
108
109impl<T> ToBoltType for T
110where
111    T: serde::Serialize,
112{
113    fn to_bolt_type(&self) -> BoltType {
114        match serde_json::to_value(self) {
115            Ok(json_value) => match json_value {
116                serde_json::Value::Null => BoltType::Null(BoltNull),
117                serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
118                serde_json::Value::Number(num) => {
119                    if let Some(i) = num.as_i64() {
120                        BoltType::Integer(BoltInteger::new(i))
121                    } else if let Some(f) = num.as_f64() {
122                        BoltType::Float(BoltFloat::new(f))
123                    } else {
124                        println!("Couldn't map to BoltType, will ignore.");
125                        BoltType::Null(BoltNull) // Handle unexpected number type
126                    }
127                }
128                serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
129                serde_json::Value::Array(arr) => BoltType::List(
130                    arr.iter()
131                        .map(|v| v.to_bolt_type())
132                        .collect::<Vec<BoltType>>()
133                        .into(),
134                ),
135                serde_json::Value::Object(obj) => {
136                    let mut bolt_map = BoltMap::new();
137                    for (k, v) in obj {
138                        bolt_map.put(BoltString::new(&k), v.to_bolt_type());
139                    }
140                    BoltType::Map(bolt_map)
141                }
142            },
143            Err(_) => {
144                println!("Couldn't serialize to JSON, will ignore.");
145                BoltType::Null(BoltNull) // Handle serialization error
146            }
147        }
148    }
149}
150
151impl Neo4jClient {
152    const GET_INDEX_QUERY: &'static str = "
153    SHOW VECTOR INDEXES
154    YIELD name, properties, options
155    WHERE name=$index_name
156    RETURN name, properties, options
157    ";
158
159    const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
160
161    pub fn new(graph: Graph) -> Self {
162        Self { graph }
163    }
164
165    pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
166        tracing::info!("Connecting to Neo4j DB at {} ...", uri);
167        let graph = Graph::new(uri, user, password)
168            .await
169            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170        tracing::info!("Connected to Neo4j");
171        Ok(Self { graph })
172    }
173
174    pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
175        let graph = Graph::connect(config)
176            .await
177            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
178        Ok(Self { graph })
179    }
180
181    pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
182        graph: &Graph,
183        query: Query,
184    ) -> Result<Vec<T>, VectorStoreError> {
185        graph
186            .execute(query)
187            .await
188            .map_err(neo4j_to_bep_error)?
189            .into_stream_as::<T>()
190            .try_collect::<Vec<T>>()
191            .await
192            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
193    }
194
195    /// Returns a `Neo4jVectorIndex` that mirrors an existing Neo4j Vector Index.
196    ///
197    /// An index (of type "vector") of the same name as `index_name` must already exist for the Neo4j database.
198    /// See the Neo4j [documentation (Create vector index)](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/) for more information on creating indexes.
199    ///
200    /// ❗IMPORTANT: The index must be created with the same embedding model that will be used to query the index.
201    pub async fn get_index<M: EmbeddingModel>(
202        &self,
203        model: M,
204        index_name: &str,
205        search_params: SearchParams,
206    ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
207        #[derive(Deserialize)]
208        struct IndexInfo {
209            name: String,
210            properties: Vec<String>,
211            options: IndexOptions,
212        }
213
214        #[derive(Deserialize)]
215        #[serde(rename_all = "camelCase")]
216        struct IndexOptions {
217            _index_provider: String,
218            index_config: IndexConfigDetails,
219        }
220
221        #[derive(Deserialize)]
222        struct IndexConfigDetails {
223            #[serde(rename = "vector.dimensions")]
224            vector_dimensions: i64,
225            #[serde(rename = "vector.similarity_function")]
226            vector_similarity_function: String,
227        }
228
229        let index_info = Self::execute_and_collect::<IndexInfo>(
230            &self.graph,
231            neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
232        )
233        .await?;
234
235        let index_config = if let Some(index) = index_info.first() {
236            if index.options.index_config.vector_dimensions != model.ndims() as i64 {
237                tracing::warn!(
238                    "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
239                    index.options.index_config.vector_dimensions,
240                    model.ndims()
241                );
242            }
243            IndexConfig::new(index.name.clone())
244                .embedding_property(index.properties.first().unwrap())
245                .similarity_function(VectorSimilarityFunction::from_str(
246                    &index.options.index_config.vector_similarity_function,
247                )?)
248        } else {
249            let indexes = Self::execute_and_collect::<String>(
250                &self.graph,
251                neo4rs::query(Self::SHOW_INDEXES_QUERY),
252            )
253            .await?;
254            return Err(VectorStoreError::DatastoreError(Box::new(
255                std::io::Error::new(
256                    std::io::ErrorKind::NotFound,
257                    format!(
258                        "Index `{}` not found in database. Available indexes: {:?}",
259                        index_name, indexes
260                    ),
261                ),
262            )));
263        };
264        Ok(Neo4jVectorIndex::new(
265            self.graph.clone(),
266            model,
267            index_config,
268            search_params,
269        ))
270    }
271
272    /// Calls the `CREATE VECTOR INDEX` Neo4j query and waits for the index to be created.
273    /// A newly created index is not immediately fully available but is created (i.e. data is indexed) in the background.
274    ///
275    /// ❗ If there is already an index targetting the same node label and property, the new index creation will fail.
276    ///
277    /// ### Arguments
278    /// * `index_name` - The name of the index to create.
279    /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
280    ///                  the label `:Movie`, pass "Movie" as the `node_label` parameter.
281    /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
282    ///
283    pub async fn create_vector_index(
284        &self,
285        index_config: IndexConfig,
286        node_label: &str,
287        model: &impl EmbeddingModel,
288    ) -> Result<(), VectorStoreError> {
289        // Create a vector index on our vector store
290        tracing::info!("Creating vector index {} ...", index_config.index_name);
291
292        let create_vector_index_query = format!(
293            "
294            CREATE VECTOR INDEX $index_name IF NOT EXISTS
295            FOR (m:{})
296            ON m.{}
297            OPTIONS {{
298                indexConfig: {{
299                    `vector.dimensions`: $dimensions,
300                    `vector.similarity_function`: $similarity_function
301                }}
302            }}",
303            node_label, index_config.embedding_property
304        );
305
306        self.graph
307            .run(
308                neo4rs::query(&create_vector_index_query)
309                    .param("index_name", index_config.index_name.clone())
310                    .param(
311                        "similarity_function",
312                        index_config.similarity_function.clone().to_bolt_type(),
313                    )
314                    .param("dimensions", model.ndims() as i64),
315            )
316            .await
317            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
318
319        // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
320        let index_exists = self
321            .graph
322            .run(
323                neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
324                    .param("index_name", index_config.index_name.clone()),
325            )
326            .await;
327
328        if index_exists.is_err() {
329            tracing::warn!(
330                "Index with name `{}` is not ready or could not be created.",
331                index_config.index_name.clone()
332            );
333        }
334
335        tracing::info!(
336            "Index created successfully with name: {}",
337            index_config.index_name
338        );
339        Ok(())
340    }
341}
342
343#[allow(dead_code)]
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use neo4rs::ConfigBuilder;
348    use bep::{
349        providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
350        vector_store::VectorStoreIndex,
351    };
352    use serde::Deserialize;
353    use std::env;
354
355    const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
356    const NEO4J_DB: &str = "recommendations";
357    const NEO4J_USERNAME: &str = "recommendations";
358    const NEO4J_PASSWORD: &str = "recommendations";
359
360    #[derive(Debug, Deserialize)]
361    struct Movie {
362        title: String,
363        plot: String,
364    }
365
366    #[tokio::test]
367    async fn test_connect() {
368        let result = Neo4jClient::from_config(
369            ConfigBuilder::default()
370                .uri(NEO4J_URI)
371                .db(NEO4J_DB)
372                .user(NEO4J_USERNAME)
373                .password(NEO4J_PASSWORD)
374                .build()
375                .unwrap(),
376        )
377        .await;
378        assert!(result.is_ok());
379    }
380
381    #[tokio::test]
382    async fn test_vector_search_no_display() {
383        let results = vector_search().await.unwrap();
384        assert!(results.len() > 0);
385    }
386
387    async fn vector_search() -> Result<Vec<(f64, String, Movie)>, VectorStoreError> {
388        let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
389        let openai_client = Client::new(&openai_api_key);
390        let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
391
392        let client = Neo4jClient::from_config(
393            ConfigBuilder::default()
394                .uri(NEO4J_URI)
395                .db(NEO4J_DB)
396                .user(NEO4J_USERNAME)
397                .password(NEO4J_PASSWORD)
398                .build()
399                .unwrap(),
400        )
401        .await
402        .unwrap();
403
404        let index = client
405            .get_index(model, "moviePlotsEmbedding", SearchParams::default())
406            .await?;
407        Ok(index.top_n::<Movie>("Batman", 3).await?)
408    }
409}