rig_neo4j/lib.rs
1//! A Rig vector store for Neo4j.
2//!
3//! This crate is a companion crate to the [rig-core crate](https://github.com/0xPlaygrounds/rig).
4//! It provides a vector store implementation that uses Neo4j as the underlying datastore.
5//!
6//! See the [README](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j) for more information.
7//!
8//! ## Prerequisites
9//!
10//! ### GenAI Plugin
11//! The GenAI plugin is enabled by default in Neo4j Aura.
12//!
13//! The plugin needs to be installed on self-managed instances. This is done by moving the neo4j-genai.jar
14//! file from /products to /plugins in the Neo4j home directory, or, if you are using Docker, by starting
15//! the Docker container with the extra parameter `--env NEO4J_PLUGINS='["genai"]'`.
16//!
17//! For more information, see [Operations Manual → Configure plugins](https://neo4j.com/docs/upgrade-migration-guide/current/version-5/migration/install-and-configure/#_plugins).
18//!
19//! ### Pre-existing Vector Index
20//!
21//! The [Neo4jVectorStoreIndex](Neo4jVectorIndex) struct is designed to work with a pre-existing
22//! Neo4j vector index. You can create the index using the Neo4j browser, a raw Cypher query, or the
23//! [Neo4jClient::create_vector_index] method.
24//! See the [Neo4j documentation](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/)
25//! for more information.
26//!
27//! The index name must be unique among both indexes and constraints.
28//! ❗A newly created index is not immediately available but is created in the background.
29//!
30//! ```cypher
31//! CREATE VECTOR INDEX moviePlots
32//! FOR (m:Movie)
33//! ON m.embedding
34//! OPTIONS {indexConfig: {
35//! `vector.dimensions`: 1536,
36//! `vector.similarity_function`: 'cosine'
37//! }}
38//! ```
39//!
40//! ## Simple example:
41//! More examples can be found in the [/examples](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j/examples) folder.
42//! ```
43//! use rig_neo4j::{vector_index::*, Neo4jClient};
44//! use neo4rs::ConfigBuilder;
45//! use rig::{providers::openai::*, vector_store::VectorStoreIndex};
46//! use serde::Deserialize;
47//! use std::env;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//! let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
52//! let openai_client = Client::new(&openai_api_key);
53//! let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
54//!
55//!
56//! const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
57//! const NEO4J_DB: &str = "recommendations";
58//! const NEO4J_USERNAME: &str = "recommendations";
59//! const NEO4J_PASSWORD: &str = "recommendations";
60//!
61//! let client = Neo4jClient::from_config(
62//! ConfigBuilder::default()
63//! .uri(NEO4J_URI)
64//! .db(NEO4J_DB)
65//! .user(NEO4J_USERNAME)
66//! .password(NEO4J_PASSWORD)
67//! .build()
68//! .unwrap(),
69//! )
70//! .await
71//! .unwrap();
72//!
73//! let index = client.get_index(
74//! model,
75//! "moviePlotsEmbedding",
76//! SearchParams::default()
77//! ).await.unwrap();
78//!
79//! #[derive(Debug, Deserialize)]
80//! struct Movie {
81//! title: String,
82//! plot: String,
83//! }
84//! let results = index.top_n::<Movie>("Batman", 3).await.unwrap();
85//! println!("{:#?}", results);
86//! }
87//! ```
88pub mod vector_index;
89use std::str::FromStr;
90
91use futures::TryStreamExt;
92use neo4rs::*;
93use rig::{embeddings::EmbeddingModel, vector_store::VectorStoreError};
94use serde::Deserialize;
95use vector_index::{IndexConfig, Neo4jVectorIndex, SearchParams, VectorSimilarityFunction};
96
97pub struct Neo4jClient {
98 pub graph: Graph,
99}
100
101fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
102 VectorStoreError::DatastoreError(Box::new(e))
103}
104
105pub trait ToBoltType {
106 fn to_bolt_type(&self) -> BoltType;
107}
108
109impl<T> ToBoltType for T
110where
111 T: serde::Serialize,
112{
113 fn to_bolt_type(&self) -> BoltType {
114 match serde_json::to_value(self) {
115 Ok(json_value) => match json_value {
116 serde_json::Value::Null => BoltType::Null(BoltNull),
117 serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
118 serde_json::Value::Number(num) => {
119 if let Some(i) = num.as_i64() {
120 BoltType::Integer(BoltInteger::new(i))
121 } else if let Some(f) = num.as_f64() {
122 BoltType::Float(BoltFloat::new(f))
123 } else {
124 println!("Couldn't map to BoltType, will ignore.");
125 BoltType::Null(BoltNull) // Handle unexpected number type
126 }
127 }
128 serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
129 serde_json::Value::Array(arr) => BoltType::List(
130 arr.iter()
131 .map(|v| v.to_bolt_type())
132 .collect::<Vec<BoltType>>()
133 .into(),
134 ),
135 serde_json::Value::Object(obj) => {
136 let mut bolt_map = BoltMap::new();
137 for (k, v) in obj {
138 bolt_map.put(BoltString::new(&k), v.to_bolt_type());
139 }
140 BoltType::Map(bolt_map)
141 }
142 },
143 Err(_) => {
144 println!("Couldn't serialize to JSON, will ignore.");
145 BoltType::Null(BoltNull) // Handle serialization error
146 }
147 }
148 }
149}
150
151impl Neo4jClient {
152 const GET_INDEX_QUERY: &'static str = "
153 SHOW VECTOR INDEXES
154 YIELD name, properties, options
155 WHERE name=$index_name
156 RETURN name, properties, options
157 ";
158
159 const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
160
161 pub fn new(graph: Graph) -> Self {
162 Self { graph }
163 }
164
165 pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
166 tracing::info!("Connecting to Neo4j DB at {} ...", uri);
167 let graph = Graph::new(uri, user, password)
168 .await
169 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170 tracing::info!("Connected to Neo4j");
171 Ok(Self { graph })
172 }
173
174 pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
175 let graph = Graph::connect(config)
176 .await
177 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
178 Ok(Self { graph })
179 }
180
181 pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
182 graph: &Graph,
183 query: Query,
184 ) -> Result<Vec<T>, VectorStoreError> {
185 graph
186 .execute(query)
187 .await
188 .map_err(neo4j_to_rig_error)?
189 .into_stream_as::<T>()
190 .try_collect::<Vec<T>>()
191 .await
192 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
193 }
194
195 /// Returns a `Neo4jVectorIndex` that mirrors an existing Neo4j Vector Index.
196 ///
197 /// An index (of type "vector") of the same name as `index_name` must already exist for the Neo4j database.
198 /// See the Neo4j [documentation (Create vector index)](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/) for more information on creating indexes.
199 ///
200 /// ❗IMPORTANT: The index must be created with the same embedding model that will be used to query the index.
201 pub async fn get_index<M: EmbeddingModel>(
202 &self,
203 model: M,
204 index_name: &str,
205 search_params: SearchParams,
206 ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
207 #[derive(Deserialize)]
208 struct IndexInfo {
209 name: String,
210 properties: Vec<String>,
211 options: IndexOptions,
212 }
213
214 #[derive(Deserialize)]
215 #[serde(rename_all = "camelCase")]
216 struct IndexOptions {
217 _index_provider: String,
218 index_config: IndexConfigDetails,
219 }
220
221 #[derive(Deserialize)]
222 struct IndexConfigDetails {
223 #[serde(rename = "vector.dimensions")]
224 vector_dimensions: i64,
225 #[serde(rename = "vector.similarity_function")]
226 vector_similarity_function: String,
227 }
228
229 let index_info = Self::execute_and_collect::<IndexInfo>(
230 &self.graph,
231 neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
232 )
233 .await?;
234
235 let index_config = if let Some(index) = index_info.first() {
236 if index.options.index_config.vector_dimensions != model.ndims() as i64 {
237 tracing::warn!(
238 "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
239 index.options.index_config.vector_dimensions,
240 model.ndims()
241 );
242 }
243 IndexConfig::new(index.name.clone())
244 .embedding_property(index.properties.first().unwrap())
245 .similarity_function(VectorSimilarityFunction::from_str(
246 &index.options.index_config.vector_similarity_function,
247 )?)
248 } else {
249 let indexes = Self::execute_and_collect::<String>(
250 &self.graph,
251 neo4rs::query(Self::SHOW_INDEXES_QUERY),
252 )
253 .await?;
254 return Err(VectorStoreError::DatastoreError(Box::new(
255 std::io::Error::new(
256 std::io::ErrorKind::NotFound,
257 format!(
258 "Index `{index_name}` not found in database. Available indexes: {indexes:?}"
259 ),
260 ),
261 )));
262 };
263 Ok(Neo4jVectorIndex::new(
264 self.graph.clone(),
265 model,
266 index_config,
267 search_params,
268 ))
269 }
270
271 /// Calls the `CREATE VECTOR INDEX` Neo4j query and waits for the index to be created.
272 /// A newly created index is not immediately fully available but is created (i.e. data is indexed) in the background.
273 ///
274 /// ❗ If there is already an index targeting the same node label and property, the new index creation will fail.
275 ///
276 /// ### Arguments
277 /// * `index_name` - The name of the index to create.
278 /// * `node_label` - The label of the nodes to which the index will be applied. For example, if your nodes have
279 /// the label `:Movie`, pass "Movie" as the `node_label` parameter.
280 /// * `embedding_prop_name` (optional) - The name of the property that contains the embedding vectors. Defaults to "embedding".
281 ///
282 pub async fn create_vector_index(
283 &self,
284 index_config: IndexConfig,
285 node_label: &str,
286 model: &impl EmbeddingModel,
287 ) -> Result<(), VectorStoreError> {
288 // Create a vector index on our vector store
289 tracing::info!("Creating vector index {} ...", index_config.index_name);
290
291 let create_vector_index_query = format!(
292 "
293 CREATE VECTOR INDEX $index_name IF NOT EXISTS
294 FOR (m:{})
295 ON m.{}
296 OPTIONS {{
297 indexConfig: {{
298 `vector.dimensions`: $dimensions,
299 `vector.similarity_function`: $similarity_function
300 }}
301 }}",
302 node_label, index_config.embedding_property
303 );
304
305 self.graph
306 .run(
307 neo4rs::query(&create_vector_index_query)
308 .param("index_name", index_config.index_name.clone())
309 .param(
310 "similarity_function",
311 index_config.similarity_function.clone().to_bolt_type(),
312 )
313 .param("dimensions", model.ndims() as i64),
314 )
315 .await
316 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
317
318 // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
319 let index_exists = self
320 .graph
321 .run(
322 neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
323 .param("index_name", index_config.index_name.clone()),
324 )
325 .await;
326
327 if index_exists.is_err() {
328 tracing::warn!(
329 "Index with name `{}` is not ready or could not be created.",
330 index_config.index_name.clone()
331 );
332 }
333
334 tracing::info!(
335 "Index created successfully with name: {}",
336 index_config.index_name
337 );
338 Ok(())
339 }
340}