rig_neo4j/
lib.rs

1//! A Rig vector store for Neo4j.
2//!
3//! This crate is a companion crate to the [rig-core crate](https://github.com/0xPlaygrounds/rig).
4//! It provides a vector store implementation that uses Neo4j as the underlying datastore.
5//!
6//! See the [README](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j) for more information.
7//!
8//! ## Prerequisites
9//!
10//! ### GenAI Plugin
11//! The GenAI plugin is enabled by default in Neo4j Aura.
12//!
13//! The plugin needs to be installed on self-managed instances. This is done by moving the neo4j-genai.jar
14//! file from /products to /plugins in the Neo4j home directory, or, if you are using Docker, by starting
15//! the Docker container with the extra parameter `--env NEO4J_PLUGINS='["genai"]'`.
16//!
17//! For more information, see [Operations Manual → Configure plugins](https://neo4j.com/docs/upgrade-migration-guide/current/version-5/migration/install-and-configure/#_plugins).
18//!
19//! ### Pre-existing Vector Index
20//!
21//! The [Neo4jVectorStoreIndex](Neo4jVectorIndex) struct is designed to work with a pre-existing
22//! Neo4j vector index. You can create the index using the Neo4j browser, a raw Cypher query, or the
23//! [Neo4jClient::create_vector_index] method.
24//! See the [Neo4j documentation](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/)
25//! for more information.
26//!
27//! The index name must be unique among both indexes and constraints.
28//! ❗A newly created index is not immediately available but is created in the background.
29//!
30//! ```cypher
31//! CREATE VECTOR INDEX moviePlots
32//!     FOR (m:Movie)
33//!     ON m.embedding
34//!     OPTIONS {indexConfig: {
35//!         `vector.dimensions`: 1536,
36//!         `vector.similarity_function`: 'cosine'
37//!     }}
38//! ```
39//!
40//! ## Simple example:
41//! More examples can be found in the [/examples](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j/examples) folder.
42//! ```
43//! use rig_neo4j::{vector_index::*, Neo4jClient};
44//! use neo4rs::ConfigBuilder;
45//! use rig::{providers::openai::*, vector_store::VectorStoreIndex};
46//! use serde::Deserialize;
47//! use std::env;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//!     let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
52//!     let openai_client = Client::new(&openai_api_key);
53//!     let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
54//!
55//!
56//!     const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
57//!     const NEO4J_DB: &str = "recommendations";
58//!     const NEO4J_USERNAME: &str = "recommendations";
59//!     const NEO4J_PASSWORD: &str = "recommendations";
60//!
61//!     let client = Neo4jClient::from_config(
62//!         ConfigBuilder::default()
63//!             .uri(NEO4J_URI)
64//!             .db(NEO4J_DB)
65//!             .user(NEO4J_USERNAME)
66//!             .password(NEO4J_PASSWORD)
67//!             .build()
68//!             .unwrap(),
69//!     )
70//!    .await
71//!    .unwrap();
72//!
73//!     let index = client.get_index(
74//!         model,
75//!         "moviePlotsEmbedding",
76//!         SearchParams::default()
77//!     ).await.unwrap();
78//!
79//!     #[derive(Debug, Deserialize)]
80//!     struct Movie {
81//!         title: String,
82//!         plot: String,
83//!     }
84//!     let results = index.top_n::<Movie>("Batman", 3).await.unwrap();
85//!     println!("{:#?}", results);
86//! }
87//! ```
88pub mod vector_index;
89use std::str::FromStr;
90
91use futures::TryStreamExt;
92use neo4rs::*;
93use rig::{embeddings::EmbeddingModel, vector_store::VectorStoreError};
94use serde::Deserialize;
95use vector_index::{IndexConfig, Neo4jVectorIndex, SearchParams, VectorSimilarityFunction};
96
97pub struct Neo4jClient {
98    pub graph: Graph,
99}
100
101fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
102    VectorStoreError::DatastoreError(Box::new(e))
103}
104
105pub trait ToBoltType {
106    fn to_bolt_type(&self) -> BoltType;
107}
108
109impl<T> ToBoltType for T
110where
111    T: serde::Serialize,
112{
113    fn to_bolt_type(&self) -> BoltType {
114        match serde_json::to_value(self) {
115            Ok(json_value) => match json_value {
116                serde_json::Value::Null => BoltType::Null(BoltNull),
117                serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
118                serde_json::Value::Number(num) => {
119                    if let Some(i) = num.as_i64() {
120                        BoltType::Integer(BoltInteger::new(i))
121                    } else if let Some(f) = num.as_f64() {
122                        BoltType::Float(BoltFloat::new(f))
123                    } else {
124                        println!("Couldn't map to BoltType, will ignore.");
125                        BoltType::Null(BoltNull) // Handle unexpected number type
126                    }
127                }
128                serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
129                serde_json::Value::Array(arr) => BoltType::List(
130                    arr.iter()
131                        .map(|v| v.to_bolt_type())
132                        .collect::<Vec<BoltType>>()
133                        .into(),
134                ),
135                serde_json::Value::Object(obj) => {
136                    let mut bolt_map = BoltMap::new();
137                    for (k, v) in obj {
138                        bolt_map.put(BoltString::new(&k), v.to_bolt_type());
139                    }
140                    BoltType::Map(bolt_map)
141                }
142            },
143            Err(_) => {
144                println!("Couldn't serialize to JSON, will ignore.");
145                BoltType::Null(BoltNull) // Handle serialization error
146            }
147        }
148    }
149}
150
151impl Neo4jClient {
152    const GET_INDEX_QUERY: &'static str = "
153    SHOW VECTOR INDEXES
154    YIELD name, properties, options
155    WHERE name=$index_name
156    RETURN name, properties, options
157    ";
158
159    const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
160
161    pub fn new(graph: Graph) -> Self {
162        Self { graph }
163    }
164
165    pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
166        tracing::info!("Connecting to Neo4j DB at {} ...", uri);
167        let graph = Graph::new(uri, user, password)
168            .await
169            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170        tracing::info!("Connected to Neo4j");
171        Ok(Self { graph })
172    }
173
174    pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
175        let graph = Graph::connect(config)
176            .await
177            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
178        Ok(Self { graph })
179    }
180
181    pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
182        graph: &Graph,
183        query: Query,
184    ) -> Result<Vec<T>, VectorStoreError> {
185        graph
186            .execute(query)
187            .await
188            .map_err(neo4j_to_rig_error)?
189            .into_stream_as::<T>()
190            .try_collect::<Vec<T>>()
191            .await
192            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
193    }
194
195    /// Returns a `Neo4jVectorIndex` that mirrors an existing Neo4j Vector Index.
196    ///
197    /// An index (of type "vector") of the same name as `index_name` must already exist for the Neo4j database.
198    /// See the Neo4j [documentation (Create vector index)](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/) for more information on creating indexes.
199    ///
200    /// ❗IMPORTANT: The index must be created with the same embedding model that will be used to query the index.
201    pub async fn get_index<M: EmbeddingModel>(
202        &self,
203        model: M,
204        index_name: &str,
205        search_params: SearchParams,
206    ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
207        #[derive(Deserialize)]
208        struct IndexInfo {
209            name: String,
210            properties: Vec<String>,
211            options: IndexOptions,
212        }
213
214        #[derive(Deserialize)]
215        #[serde(rename_all = "camelCase")]
216        struct IndexOptions {
217            _index_provider: String,
218            index_config: IndexConfigDetails,
219        }
220
221        #[derive(Deserialize)]
222        struct IndexConfigDetails {
223            #[serde(rename = "vector.dimensions")]
224            vector_dimensions: i64,
225            #[serde(rename = "vector.similarity_function")]
226            vector_similarity_function: String,
227        }
228
229        let index_info = Self::execute_and_collect::<IndexInfo>(
230            &self.graph,
231            neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
232        )
233        .await?;
234
235        let index_config = if let Some(index) = index_info.first() {
236            if index.options.index_config.vector_dimensions != model.ndims() as i64 {
237                tracing::warn!(
238                    "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
239                    index.options.index_config.vector_dimensions,
240                    model.ndims()
241                );
242            }
243            IndexConfig::new(index.name.clone())
244                .embedding_property(index.properties.first().unwrap())
245                .similarity_function(VectorSimilarityFunction::from_str(
246                    &index.options.index_config.vector_similarity_function,
247                )?)
248        } else {
249            let indexes = Self::execute_and_collect::<String>(
250                &self.graph,
251                neo4rs::query(Self::SHOW_INDEXES_QUERY),
252            )
253            .await?;
254            return Err(VectorStoreError::DatastoreError(Box::new(
255                std::io::Error::new(
256                    std::io::ErrorKind::NotFound,
257                    format!(
258                        "Index `{index_name}` not found in database. Available indexes: {indexes:?}"
259                    ),
260                ),
261            )));
262        };
263        Ok(Neo4jVectorIndex::new(
264            self.graph.clone(),
265            model,
266            index_config,
267            search_params,
268        ))
269    }
270
271    /// Calls the `CREATE VECTOR INDEX` Neo4j query and waits for the index to be created.
272    /// A newly created index is not immediately fully available but is created (i.e. data is indexed) in the background.
273    ///
274    /// ❗ If there is already an index targeting the same node label and property, the new index creation will fail.
275    ///
276    /// ### Arguments
277    /// * `index_name` - The name of the index to create.
278    /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
279    ///   the label `:Movie`, pass "Movie" as the `node_label` parameter.
280    /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
281    ///
282    pub async fn create_vector_index(
283        &self,
284        index_config: IndexConfig,
285        node_label: &str,
286        model: &impl EmbeddingModel,
287    ) -> Result<(), VectorStoreError> {
288        // Create a vector index on our vector store
289        tracing::info!("Creating vector index {} ...", index_config.index_name);
290
291        let create_vector_index_query = format!(
292            "
293            CREATE VECTOR INDEX $index_name IF NOT EXISTS
294            FOR (m:{})
295            ON m.{}
296            OPTIONS {{
297                indexConfig: {{
298                    `vector.dimensions`: $dimensions,
299                    `vector.similarity_function`: $similarity_function
300                }}
301            }}",
302            node_label, index_config.embedding_property
303        );
304
305        self.graph
306            .run(
307                neo4rs::query(&create_vector_index_query)
308                    .param("index_name", index_config.index_name.clone())
309                    .param(
310                        "similarity_function",
311                        index_config.similarity_function.clone().to_bolt_type(),
312                    )
313                    .param("dimensions", model.ndims() as i64),
314            )
315            .await
316            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
317
318        // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
319        let index_exists = self
320            .graph
321            .run(
322                neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
323                    .param("index_name", index_config.index_name.clone()),
324            )
325            .await;
326
327        if index_exists.is_err() {
328            tracing::warn!(
329                "Index with name `{}` is not ready or could not be created.",
330                index_config.index_name.clone()
331            );
332        }
333
334        tracing::info!(
335            "Index created successfully with name: {}",
336            index_config.index_name
337        );
338        Ok(())
339    }
340}