Skip to main content

rig_neo4j/
lib.rs

1//! A Rig vector store for Neo4j.
2//!
3//! This crate is a companion crate to the [rig-core crate](https://github.com/0xPlaygrounds/rig).
4//! It provides a vector store implementation that uses Neo4j as the underlying datastore.
5//!
6//! See the [README](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j) for more information.
7//!
8//! ## Prerequisites
9//!
10//! ### GenAI Plugin
11//! The GenAI plugin is enabled by default in Neo4j Aura.
12//!
13//! The plugin needs to be installed on self-managed instances. This is done by moving the neo4j-genai.jar
14//! file from /products to /plugins in the Neo4j home directory, or, if you are using Docker, by starting
15//! the Docker container with the extra parameter `--env NEO4J_PLUGINS='["genai"]'`.
16//!
17//! For more information, see [Operations Manual → Configure plugins](https://neo4j.com/docs/upgrade-migration-guide/current/version-5/migration/install-and-configure/#_plugins).
18//!
19//! ### Pre-existing Vector Index
20//!
21//! The [Neo4jVectorStoreIndex](Neo4jVectorIndex) struct is designed to work with a pre-existing
22//! Neo4j vector index. You can create the index using the Neo4j browser, a raw Cypher query, or the
23//! [Neo4jClient::create_vector_index] method.
24//! See the [Neo4j documentation](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/)
25//! for more information.
26//!
27//! The index name must be unique among both indexes and constraints.
28//! ❗A newly created index is not immediately available but is created in the background.
29//!
30//! ```cypher
31//! CREATE VECTOR INDEX moviePlots
32//!     FOR (m:Movie)
33//!     ON m.embedding
34//!     OPTIONS {indexConfig: {
35//!         `vector.dimensions`: 1536,
36//!         `vector.similarity_function`: 'cosine'
37//!     }}
38//! ```
39//!
40//! ## Simple example:
41//! More examples can be found in the [/examples](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j/examples) folder.
42//! ```
43//! use rig_neo4j::{vector_index::*, Neo4jClient};
44//! use neo4rs::ConfigBuilder;
45//! use rig::{providers::openai::*, vector_store::VectorStoreIndex};
46//! use serde::Deserialize;
47//! use std::env;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//!     let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
52//!     let openai_client = Client::new(&openai_api_key);
53//!     let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
54//!
55//!
56//!     const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
57//!     const NEO4J_DB: &str = "recommendations";
58//!     const NEO4J_USERNAME: &str = "recommendations";
59//!     const NEO4J_PASSWORD: &str = "recommendations";
60//!
61//!     let client = Neo4jClient::from_config(
62//!         ConfigBuilder::default()
63//!             .uri(NEO4J_URI)
64//!             .db(NEO4J_DB)
65//!             .user(NEO4J_USERNAME)
66//!             .password(NEO4J_PASSWORD)
67//!             .build()
68//!             .unwrap(),
69//!     )
70//!    .await
71//!    .unwrap();
72//!
73//!     let index = client.get_index(
74//!         model,
75//!         "moviePlotsEmbedding"
76//!     ).await.unwrap();
77//!
78//!     #[derive(Debug, Deserialize)]
79//!     struct Movie {
80//!         title: String,
81//!         plot: String,
82//!     }
83//!     let results = index.top_n::<Movie>("Batman", 3).await.unwrap();
84//!     println!("{:#?}", results);
85//! }
86//! ```
87pub mod vector_index;
88use std::str::FromStr;
89
90use futures::TryStreamExt;
91use neo4rs::*;
92use rig::{
93    embeddings::EmbeddingModel,
94    vector_store::{VectorStoreError, request::SearchFilter},
95};
96use serde::{Deserialize, Serialize};
97use vector_index::{IndexConfig, Neo4jVectorIndex, VectorSimilarityFunction};
98
99pub struct Neo4jClient {
100    pub graph: Graph,
101}
102
103fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
104    VectorStoreError::DatastoreError(Box::new(e))
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct Neo4jSearchFilter(String);
109
110impl SearchFilter for Neo4jSearchFilter {
111    type Value = serde_json::Value;
112
113    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
114        Self(format!("n.{} = {}", key.as_ref(), serialize_cypher(value)))
115    }
116
117    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
118        Self(format!("n.{} > {}", key.as_ref(), serialize_cypher(value)))
119    }
120
121    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
122        Self(format!("n.{} < {}", key.as_ref(), serialize_cypher(value)))
123    }
124
125    fn and(self, rhs: Self) -> Self {
126        Self(format!("({}) AND ({})", self.0, rhs.0))
127    }
128
129    fn or(self, rhs: Self) -> Self {
130        Self(format!("({}) OR ({})", self.0, rhs.0))
131    }
132}
133
134impl Neo4jSearchFilter {
135    pub fn render(self) -> String {
136        format!("WHERE {}", self.0)
137    }
138
139    #[allow(clippy::should_implement_trait)]
140    pub fn not(self) -> Self {
141        Self(format!("NOT ({})", self.0))
142    }
143
144    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
145        Self(format!("n.{key} >= {}", serialize_cypher(value)))
146    }
147
148    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
149        Self(format!("n.{key} <= {}", serialize_cypher(value)))
150    }
151
152    pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153        Self(format!(
154            "n.{key} IN {}",
155            serialize_cypher(serde_json::Value::Array(values))
156        ))
157    }
158
159    // String matching
160
161    /// Tests whether the value at `key` contains the pattern
162    pub fn contains<S>(key: String, pattern: S) -> Self
163    where
164        S: AsRef<str>,
165    {
166        Self(format!(
167            "n.{key} CONTAINS {}",
168            serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
169        ))
170    }
171
172    /// Tests whether the value at `key` starts with the pattern
173    pub fn starts_with<S>(key: String, pattern: S) -> Self
174    where
175        S: AsRef<str>,
176    {
177        Self(format!(
178            "n.{key} STARTS WITH {}",
179            serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
180        ))
181    }
182
183    /// Tests whether the value at `key` ends with the pattern
184    pub fn ends_with<S>(key: String, pattern: S) -> Self
185    where
186        S: AsRef<str>,
187    {
188        Self(format!(
189            "n.{key} ENDS WITH {}",
190            serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
191        ))
192    }
193
194    pub fn matches<S>(key: String, pattern: S) -> Self
195    where
196        S: AsRef<str>,
197    {
198        Self(format!(
199            "n.{key} =~ {}",
200            serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
201        ))
202    }
203}
204
205fn serialize_cypher(value: serde_json::Value) -> String {
206    use serde_json::Value::*;
207    match value {
208        Null => "null".into(),
209        Bool(b) => b.to_string(),
210        Number(n) => n.to_string(),
211        String(s) => format!("'{}'", s.replace('\'', "\\'")),
212        Array(arr) => {
213            format!(
214                "[{}]",
215                arr.into_iter()
216                    .map(serialize_cypher)
217                    .collect::<Vec<std::string::String>>()
218                    .join(", ")
219            )
220        }
221        Object(obj) => {
222            format!(
223                "{{{}}}",
224                obj.into_iter()
225                    .map(|(k, v)| format!("{k}: {}", serialize_cypher(v)))
226                    .collect::<Vec<std::string::String>>()
227                    .join(", ")
228            )
229        }
230    }
231}
232
233pub trait ToBoltType {
234    fn to_bolt_type(&self) -> BoltType;
235}
236
237impl<T> ToBoltType for T
238where
239    T: serde::Serialize,
240{
241    fn to_bolt_type(&self) -> BoltType {
242        match serde_json::to_value(self) {
243            Ok(json_value) => match json_value {
244                serde_json::Value::Null => BoltType::Null(BoltNull),
245                serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
246                serde_json::Value::Number(num) => {
247                    if let Some(i) = num.as_i64() {
248                        BoltType::Integer(BoltInteger::new(i))
249                    } else if let Some(f) = num.as_f64() {
250                        BoltType::Float(BoltFloat::new(f))
251                    } else {
252                        println!("Couldn't map to BoltType, will ignore.");
253                        BoltType::Null(BoltNull) // Handle unexpected number type
254                    }
255                }
256                serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
257                serde_json::Value::Array(arr) => BoltType::List(
258                    arr.iter()
259                        .map(|v| v.to_bolt_type())
260                        .collect::<Vec<BoltType>>()
261                        .into(),
262                ),
263                serde_json::Value::Object(obj) => {
264                    let mut bolt_map = BoltMap::new();
265                    for (k, v) in obj {
266                        bolt_map.put(BoltString::new(&k), v.to_bolt_type());
267                    }
268                    BoltType::Map(bolt_map)
269                }
270            },
271            Err(_) => {
272                println!("Couldn't serialize to JSON, will ignore.");
273                BoltType::Null(BoltNull) // Handle serialization error
274            }
275        }
276    }
277}
278
279impl Neo4jClient {
280    const GET_INDEX_QUERY: &'static str = "
281    SHOW VECTOR INDEXES
282    YIELD name, properties, options
283    WHERE name=$index_name
284    RETURN name, properties, options
285    ";
286
287    const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
288
289    pub fn new(graph: Graph) -> Self {
290        Self { graph }
291    }
292
293    pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
294        tracing::info!("Connecting to Neo4j DB at {} ...", uri);
295        let graph = Graph::new(uri, user, password)
296            .await
297            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
298        tracing::info!("Connected to Neo4j");
299        Ok(Self { graph })
300    }
301
302    pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
303        let graph = Graph::connect(config)
304            .await
305            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
306        Ok(Self { graph })
307    }
308
309    pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
310        graph: &Graph,
311        query: Query,
312    ) -> Result<Vec<T>, VectorStoreError> {
313        graph
314            .execute(query)
315            .await
316            .map_err(neo4j_to_rig_error)?
317            .into_stream_as::<T>()
318            .try_collect::<Vec<T>>()
319            .await
320            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
321    }
322
323    /// Returns a `Neo4jVectorIndex` that mirrors an existing Neo4j Vector Index.
324    ///
325    /// An index (of type "vector") of the same name as `index_name` must already exist for the Neo4j database.
326    /// See the Neo4j [documentation (Create vector index)](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/) for more information on creating indexes.
327    ///
328    /// ❗IMPORTANT: The index must be created with the same embedding model that will be used to query the index.
329    pub async fn get_index<M: EmbeddingModel>(
330        &self,
331        model: M,
332        index_name: &str,
333    ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
334        #[derive(Deserialize)]
335        struct IndexInfo {
336            name: String,
337            properties: Vec<String>,
338            options: IndexOptions,
339        }
340
341        #[derive(Deserialize)]
342        #[serde(rename_all = "camelCase")]
343        struct IndexOptions {
344            _index_provider: String,
345            index_config: IndexConfigDetails,
346        }
347
348        #[derive(Deserialize)]
349        struct IndexConfigDetails {
350            #[serde(rename = "vector.dimensions")]
351            vector_dimensions: i64,
352            #[serde(rename = "vector.similarity_function")]
353            vector_similarity_function: String,
354        }
355
356        let index_info = Self::execute_and_collect::<IndexInfo>(
357            &self.graph,
358            neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
359        )
360        .await?;
361
362        let index_config = if let Some(index) = index_info.first() {
363            if index.options.index_config.vector_dimensions != model.ndims() as i64 {
364                tracing::warn!(
365                    "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
366                    index.options.index_config.vector_dimensions,
367                    model.ndims()
368                );
369            }
370            IndexConfig::new(index.name.clone())
371                .embedding_property(index.properties.first().unwrap())
372                .similarity_function(VectorSimilarityFunction::from_str(
373                    &index.options.index_config.vector_similarity_function,
374                )?)
375        } else {
376            let indexes = Self::execute_and_collect::<String>(
377                &self.graph,
378                neo4rs::query(Self::SHOW_INDEXES_QUERY),
379            )
380            .await?;
381            return Err(VectorStoreError::DatastoreError(Box::new(
382                std::io::Error::new(
383                    std::io::ErrorKind::NotFound,
384                    format!(
385                        "Index `{index_name}` not found in database. Available indexes: {indexes:?}"
386                    ),
387                ),
388            )));
389        };
390        Ok(Neo4jVectorIndex::new(
391            self.graph.clone(),
392            model,
393            index_config,
394        ))
395    }
396
397    /// Calls the `CREATE VECTOR INDEX` Neo4j query and waits for the index to be created.
398    /// A newly created index is not immediately fully available but is created (i.e. data is indexed) in the background.
399    ///
400    /// ❗ If there is already an index targeting the same node label and property, the new index creation will fail.
401    ///
402    /// ### Arguments
403    /// * `index_name` - The name of the index to create.
404    /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
405    ///   the label `:Movie`, pass "Movie" as the `node_label` parameter.
406    /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
407    ///
408    pub async fn create_vector_index(
409        &self,
410        index_config: IndexConfig,
411        node_label: &str,
412        model: &impl EmbeddingModel,
413    ) -> Result<(), VectorStoreError> {
414        // Create a vector index on our vector store
415        tracing::info!("Creating vector index {} ...", index_config.index_name);
416
417        let create_vector_index_query = format!(
418            "
419            CREATE VECTOR INDEX $index_name IF NOT EXISTS
420            FOR (m:{})
421            ON m.{}
422            OPTIONS {{
423                indexConfig: {{
424                    `vector.dimensions`: $dimensions,
425                    `vector.similarity_function`: $similarity_function
426                }}
427            }}",
428            node_label, index_config.embedding_property
429        );
430
431        self.graph
432            .run(
433                neo4rs::query(&create_vector_index_query)
434                    .param("index_name", index_config.index_name.clone())
435                    .param(
436                        "similarity_function",
437                        index_config.similarity_function.clone().to_bolt_type(),
438                    )
439                    .param("dimensions", model.ndims() as i64),
440            )
441            .await
442            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
443
444        // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
445        let index_exists = self
446            .graph
447            .run(
448                neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
449                    .param("index_name", index_config.index_name.clone()),
450            )
451            .await;
452
453        if index_exists.is_err() {
454            tracing::warn!(
455                "Index with name `{}` is not ready or could not be created.",
456                index_config.index_name.clone()
457            );
458        }
459
460        tracing::info!(
461            "Index created successfully with name: {}",
462            index_config.index_name
463        );
464        Ok(())
465    }
466}