1pub mod vector_index;
88use std::str::FromStr;
89
90use futures::TryStreamExt;
91use neo4rs::*;
92use rig::{
93 embeddings::EmbeddingModel,
94 vector_store::{VectorStoreError, request::SearchFilter},
95};
96use serde::{Deserialize, Serialize};
97use vector_index::{IndexConfig, Neo4jVectorIndex, VectorSimilarityFunction};
98
99pub struct Neo4jClient {
100 pub graph: Graph,
101}
102
103fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
104 VectorStoreError::DatastoreError(Box::new(e))
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct Neo4jSearchFilter(String);
109
110impl SearchFilter for Neo4jSearchFilter {
111 type Value = serde_json::Value;
112
113 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
114 Self(format!("n.{} = {}", key.as_ref(), serialize_cypher(value)))
115 }
116
117 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
118 Self(format!("n.{} > {}", key.as_ref(), serialize_cypher(value)))
119 }
120
121 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
122 Self(format!("n.{} < {}", key.as_ref(), serialize_cypher(value)))
123 }
124
125 fn and(self, rhs: Self) -> Self {
126 Self(format!("({}) AND ({})", self.0, rhs.0))
127 }
128
129 fn or(self, rhs: Self) -> Self {
130 Self(format!("({}) OR ({})", self.0, rhs.0))
131 }
132}
133
134impl Neo4jSearchFilter {
135 pub fn render(self) -> String {
136 format!("WHERE {}", self.0)
137 }
138
139 #[allow(clippy::should_implement_trait)]
140 pub fn not(self) -> Self {
141 Self(format!("NOT ({})", self.0))
142 }
143
144 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
145 Self(format!("n.{key} >= {}", serialize_cypher(value)))
146 }
147
148 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
149 Self(format!("n.{key} <= {}", serialize_cypher(value)))
150 }
151
152 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153 Self(format!(
154 "n.{key} IN {}",
155 serialize_cypher(serde_json::Value::Array(values))
156 ))
157 }
158
159 pub fn contains<S>(key: String, pattern: S) -> Self
163 where
164 S: AsRef<str>,
165 {
166 Self(format!(
167 "n.{key} CONTAINS {}",
168 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
169 ))
170 }
171
172 pub fn starts_with<S>(key: String, pattern: S) -> Self
174 where
175 S: AsRef<str>,
176 {
177 Self(format!(
178 "n.{key} STARTS WITH {}",
179 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
180 ))
181 }
182
183 pub fn ends_with<S>(key: String, pattern: S) -> Self
185 where
186 S: AsRef<str>,
187 {
188 Self(format!(
189 "n.{key} ENDS WITH {}",
190 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
191 ))
192 }
193
194 pub fn matches<S>(key: String, pattern: S) -> Self
195 where
196 S: AsRef<str>,
197 {
198 Self(format!(
199 "n.{key} =~ {}",
200 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
201 ))
202 }
203}
204
205fn serialize_cypher(value: serde_json::Value) -> String {
206 use serde_json::Value::*;
207 match value {
208 Null => "null".into(),
209 Bool(b) => b.to_string(),
210 Number(n) => n.to_string(),
211 String(s) => format!("'{}'", s.replace('\'', "\\'")),
212 Array(arr) => {
213 format!(
214 "[{}]",
215 arr.into_iter()
216 .map(serialize_cypher)
217 .collect::<Vec<std::string::String>>()
218 .join(", ")
219 )
220 }
221 Object(obj) => {
222 format!(
223 "{{{}}}",
224 obj.into_iter()
225 .map(|(k, v)| format!("{k}: {}", serialize_cypher(v)))
226 .collect::<Vec<std::string::String>>()
227 .join(", ")
228 )
229 }
230 }
231}
232
233pub trait ToBoltType {
234 fn to_bolt_type(&self) -> BoltType;
235}
236
237impl<T> ToBoltType for T
238where
239 T: serde::Serialize,
240{
241 fn to_bolt_type(&self) -> BoltType {
242 match serde_json::to_value(self) {
243 Ok(json_value) => match json_value {
244 serde_json::Value::Null => BoltType::Null(BoltNull),
245 serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
246 serde_json::Value::Number(num) => {
247 if let Some(i) = num.as_i64() {
248 BoltType::Integer(BoltInteger::new(i))
249 } else if let Some(f) = num.as_f64() {
250 BoltType::Float(BoltFloat::new(f))
251 } else {
252 println!("Couldn't map to BoltType, will ignore.");
253 BoltType::Null(BoltNull) }
255 }
256 serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
257 serde_json::Value::Array(arr) => BoltType::List(
258 arr.iter()
259 .map(|v| v.to_bolt_type())
260 .collect::<Vec<BoltType>>()
261 .into(),
262 ),
263 serde_json::Value::Object(obj) => {
264 let mut bolt_map = BoltMap::new();
265 for (k, v) in obj {
266 bolt_map.put(BoltString::new(&k), v.to_bolt_type());
267 }
268 BoltType::Map(bolt_map)
269 }
270 },
271 Err(_) => {
272 println!("Couldn't serialize to JSON, will ignore.");
273 BoltType::Null(BoltNull) }
275 }
276 }
277}
278
279impl Neo4jClient {
280 const GET_INDEX_QUERY: &'static str = "
281 SHOW VECTOR INDEXES
282 YIELD name, properties, options
283 WHERE name=$index_name
284 RETURN name, properties, options
285 ";
286
287 const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
288
289 pub fn new(graph: Graph) -> Self {
290 Self { graph }
291 }
292
293 pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
294 tracing::info!("Connecting to Neo4j DB at {} ...", uri);
295 let graph = Graph::new(uri, user, password)
296 .await
297 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
298 tracing::info!("Connected to Neo4j");
299 Ok(Self { graph })
300 }
301
302 pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
303 let graph = Graph::connect(config)
304 .await
305 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
306 Ok(Self { graph })
307 }
308
309 pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
310 graph: &Graph,
311 query: Query,
312 ) -> Result<Vec<T>, VectorStoreError> {
313 graph
314 .execute(query)
315 .await
316 .map_err(neo4j_to_rig_error)?
317 .into_stream_as::<T>()
318 .try_collect::<Vec<T>>()
319 .await
320 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
321 }
322
323 pub async fn get_index<M: EmbeddingModel>(
330 &self,
331 model: M,
332 index_name: &str,
333 ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
334 #[derive(Deserialize)]
335 struct IndexInfo {
336 name: String,
337 properties: Vec<String>,
338 options: IndexOptions,
339 }
340
341 #[derive(Deserialize)]
342 #[serde(rename_all = "camelCase")]
343 struct IndexOptions {
344 #[allow(dead_code)]
345 index_provider: Option<String>,
346 index_config: IndexConfigDetails,
347 }
348
349 #[derive(Deserialize)]
350 struct IndexConfigDetails {
351 #[serde(rename = "vector.dimensions")]
352 vector_dimensions: i64,
353 #[serde(rename = "vector.similarity_function")]
354 vector_similarity_function: String,
355 }
356
357 let index_info = Self::execute_and_collect::<IndexInfo>(
358 &self.graph,
359 neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
360 )
361 .await?;
362
363 let index_config = if let Some(index) = index_info.first() {
364 if index.options.index_config.vector_dimensions != model.ndims() as i64 {
365 tracing::warn!(
366 "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
367 index.options.index_config.vector_dimensions,
368 model.ndims()
369 );
370 }
371 let embedding_property = index.properties.first().ok_or_else(|| {
372 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
373 "Neo4j index is missing an embedding property",
374 )))
375 })?;
376 IndexConfig::new(index.name.clone())
377 .embedding_property(embedding_property)
378 .similarity_function(VectorSimilarityFunction::from_str(
379 &index.options.index_config.vector_similarity_function,
380 )?)
381 } else {
382 let indexes = Self::execute_and_collect::<String>(
383 &self.graph,
384 neo4rs::query(Self::SHOW_INDEXES_QUERY),
385 )
386 .await?;
387 return Err(VectorStoreError::DatastoreError(Box::new(
388 std::io::Error::new(
389 std::io::ErrorKind::NotFound,
390 format!(
391 "Index `{index_name}` not found in database. Available indexes: {indexes:?}"
392 ),
393 ),
394 )));
395 };
396 Ok(Neo4jVectorIndex::new(
397 self.graph.clone(),
398 model,
399 index_config,
400 ))
401 }
402
403 pub async fn create_vector_index(
415 &self,
416 index_config: IndexConfig,
417 node_label: &str,
418 model: &impl EmbeddingModel,
419 ) -> Result<(), VectorStoreError> {
420 tracing::info!("Creating vector index {} ...", index_config.index_name);
422
423 let create_vector_index_query = format!(
424 "
425 CREATE VECTOR INDEX $index_name IF NOT EXISTS
426 FOR (m:{})
427 ON m.{}
428 OPTIONS {{
429 indexConfig: {{
430 `vector.dimensions`: $dimensions,
431 `vector.similarity_function`: $similarity_function
432 }}
433 }}",
434 node_label, index_config.embedding_property
435 );
436
437 self.graph
438 .run(
439 neo4rs::query(&create_vector_index_query)
440 .param("index_name", index_config.index_name.clone())
441 .param(
442 "similarity_function",
443 index_config.similarity_function.clone().to_bolt_type(),
444 )
445 .param("dimensions", model.ndims() as i64),
446 )
447 .await
448 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
449
450 let index_exists = self
452 .graph
453 .run(
454 neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
455 .param("index_name", index_config.index_name.clone()),
456 )
457 .await;
458
459 if index_exists.is_err() {
460 tracing::warn!(
461 "Index with name `{}` is not ready or could not be created.",
462 index_config.index_name.clone()
463 );
464 }
465
466 tracing::info!(
467 "Index created successfully with name: {}",
468 index_config.index_name
469 );
470 Ok(())
471 }
472}