1pub mod vector_index;
88use std::str::FromStr;
89
90use futures::TryStreamExt;
91use neo4rs::*;
92use rig::{
93 embeddings::EmbeddingModel,
94 vector_store::{VectorStoreError, request::SearchFilter},
95};
96use serde::{Deserialize, Serialize};
97use vector_index::{IndexConfig, Neo4jVectorIndex, VectorSimilarityFunction};
98
99pub struct Neo4jClient {
100 pub graph: Graph,
101}
102
103fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
104 VectorStoreError::DatastoreError(Box::new(e))
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct Neo4jSearchFilter(String);
109
110impl SearchFilter for Neo4jSearchFilter {
111 type Value = serde_json::Value;
112
113 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
114 Self(format!("n.{} = {}", key.as_ref(), serialize_cypher(value)))
115 }
116
117 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
118 Self(format!("n.{} > {}", key.as_ref(), serialize_cypher(value)))
119 }
120
121 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
122 Self(format!("n.{} < {}", key.as_ref(), serialize_cypher(value)))
123 }
124
125 fn and(self, rhs: Self) -> Self {
126 Self(format!("({}) AND ({})", self.0, rhs.0))
127 }
128
129 fn or(self, rhs: Self) -> Self {
130 Self(format!("({}) OR ({})", self.0, rhs.0))
131 }
132}
133
134impl Neo4jSearchFilter {
135 pub fn render(self) -> String {
136 format!("WHERE {}", self.0)
137 }
138
139 #[allow(clippy::should_implement_trait)]
140 pub fn not(self) -> Self {
141 Self(format!("NOT ({})", self.0))
142 }
143
144 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
145 Self(format!("n.{key} >= {}", serialize_cypher(value)))
146 }
147
148 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
149 Self(format!("n.{key} <= {}", serialize_cypher(value)))
150 }
151
152 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153 Self(format!(
154 "n.{key} IN {}",
155 serialize_cypher(serde_json::Value::Array(values))
156 ))
157 }
158
159 pub fn contains<S>(key: String, pattern: S) -> Self
163 where
164 S: AsRef<str>,
165 {
166 Self(format!(
167 "n.{key} CONTAINS {}",
168 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
169 ))
170 }
171
172 pub fn starts_with<S>(key: String, pattern: S) -> Self
174 where
175 S: AsRef<str>,
176 {
177 Self(format!(
178 "n.{key} STARTS WITH {}",
179 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
180 ))
181 }
182
183 pub fn ends_with<S>(key: String, pattern: S) -> Self
185 where
186 S: AsRef<str>,
187 {
188 Self(format!(
189 "n.{key} ENDS WITH {}",
190 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
191 ))
192 }
193
194 pub fn matches<S>(key: String, pattern: S) -> Self
195 where
196 S: AsRef<str>,
197 {
198 Self(format!(
199 "n.{key} =~ {}",
200 serialize_cypher(serde_json::Value::String(pattern.as_ref().into()))
201 ))
202 }
203}
204
205fn serialize_cypher(value: serde_json::Value) -> String {
206 use serde_json::Value::*;
207 match value {
208 Null => "null".into(),
209 Bool(b) => b.to_string(),
210 Number(n) => n.to_string(),
211 String(s) => format!("'{}'", s.replace('\'', "\\'")),
212 Array(arr) => {
213 format!(
214 "[{}]",
215 arr.into_iter()
216 .map(serialize_cypher)
217 .collect::<Vec<std::string::String>>()
218 .join(", ")
219 )
220 }
221 Object(obj) => {
222 format!(
223 "{{{}}}",
224 obj.into_iter()
225 .map(|(k, v)| format!("{k}: {}", serialize_cypher(v)))
226 .collect::<Vec<std::string::String>>()
227 .join(", ")
228 )
229 }
230 }
231}
232
233pub trait ToBoltType {
234 fn to_bolt_type(&self) -> BoltType;
235}
236
237impl<T> ToBoltType for T
238where
239 T: serde::Serialize,
240{
241 fn to_bolt_type(&self) -> BoltType {
242 match serde_json::to_value(self) {
243 Ok(json_value) => match json_value {
244 serde_json::Value::Null => BoltType::Null(BoltNull),
245 serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
246 serde_json::Value::Number(num) => {
247 if let Some(i) = num.as_i64() {
248 BoltType::Integer(BoltInteger::new(i))
249 } else if let Some(f) = num.as_f64() {
250 BoltType::Float(BoltFloat::new(f))
251 } else {
252 println!("Couldn't map to BoltType, will ignore.");
253 BoltType::Null(BoltNull) }
255 }
256 serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
257 serde_json::Value::Array(arr) => BoltType::List(
258 arr.iter()
259 .map(|v| v.to_bolt_type())
260 .collect::<Vec<BoltType>>()
261 .into(),
262 ),
263 serde_json::Value::Object(obj) => {
264 let mut bolt_map = BoltMap::new();
265 for (k, v) in obj {
266 bolt_map.put(BoltString::new(&k), v.to_bolt_type());
267 }
268 BoltType::Map(bolt_map)
269 }
270 },
271 Err(_) => {
272 println!("Couldn't serialize to JSON, will ignore.");
273 BoltType::Null(BoltNull) }
275 }
276 }
277}
278
279impl Neo4jClient {
280 const GET_INDEX_QUERY: &'static str = "
281 SHOW VECTOR INDEXES
282 YIELD name, properties, options
283 WHERE name=$index_name
284 RETURN name, properties, options
285 ";
286
287 const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
288
289 pub fn new(graph: Graph) -> Self {
290 Self { graph }
291 }
292
293 pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
294 tracing::info!("Connecting to Neo4j DB at {} ...", uri);
295 let graph = Graph::new(uri, user, password)
296 .await
297 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
298 tracing::info!("Connected to Neo4j");
299 Ok(Self { graph })
300 }
301
302 pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
303 let graph = Graph::connect(config)
304 .await
305 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
306 Ok(Self { graph })
307 }
308
309 pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
310 graph: &Graph,
311 query: Query,
312 ) -> Result<Vec<T>, VectorStoreError> {
313 graph
314 .execute(query)
315 .await
316 .map_err(neo4j_to_rig_error)?
317 .into_stream_as::<T>()
318 .try_collect::<Vec<T>>()
319 .await
320 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
321 }
322
323 pub async fn get_index<M: EmbeddingModel>(
330 &self,
331 model: M,
332 index_name: &str,
333 ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
334 #[derive(Deserialize)]
335 struct IndexInfo {
336 name: String,
337 properties: Vec<String>,
338 options: IndexOptions,
339 }
340
341 #[derive(Deserialize)]
342 #[serde(rename_all = "camelCase")]
343 struct IndexOptions {
344 _index_provider: String,
345 index_config: IndexConfigDetails,
346 }
347
348 #[derive(Deserialize)]
349 struct IndexConfigDetails {
350 #[serde(rename = "vector.dimensions")]
351 vector_dimensions: i64,
352 #[serde(rename = "vector.similarity_function")]
353 vector_similarity_function: String,
354 }
355
356 let index_info = Self::execute_and_collect::<IndexInfo>(
357 &self.graph,
358 neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
359 )
360 .await?;
361
362 let index_config = if let Some(index) = index_info.first() {
363 if index.options.index_config.vector_dimensions != model.ndims() as i64 {
364 tracing::warn!(
365 "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
366 index.options.index_config.vector_dimensions,
367 model.ndims()
368 );
369 }
370 IndexConfig::new(index.name.clone())
371 .embedding_property(index.properties.first().unwrap())
372 .similarity_function(VectorSimilarityFunction::from_str(
373 &index.options.index_config.vector_similarity_function,
374 )?)
375 } else {
376 let indexes = Self::execute_and_collect::<String>(
377 &self.graph,
378 neo4rs::query(Self::SHOW_INDEXES_QUERY),
379 )
380 .await?;
381 return Err(VectorStoreError::DatastoreError(Box::new(
382 std::io::Error::new(
383 std::io::ErrorKind::NotFound,
384 format!(
385 "Index `{index_name}` not found in database. Available indexes: {indexes:?}"
386 ),
387 ),
388 )));
389 };
390 Ok(Neo4jVectorIndex::new(
391 self.graph.clone(),
392 model,
393 index_config,
394 ))
395 }
396
397 pub async fn create_vector_index(
409 &self,
410 index_config: IndexConfig,
411 node_label: &str,
412 model: &impl EmbeddingModel,
413 ) -> Result<(), VectorStoreError> {
414 tracing::info!("Creating vector index {} ...", index_config.index_name);
416
417 let create_vector_index_query = format!(
418 "
419 CREATE VECTOR INDEX $index_name IF NOT EXISTS
420 FOR (m:{})
421 ON m.{}
422 OPTIONS {{
423 indexConfig: {{
424 `vector.dimensions`: $dimensions,
425 `vector.similarity_function`: $similarity_function
426 }}
427 }}",
428 node_label, index_config.embedding_property
429 );
430
431 self.graph
432 .run(
433 neo4rs::query(&create_vector_index_query)
434 .param("index_name", index_config.index_name.clone())
435 .param(
436 "similarity_function",
437 index_config.similarity_function.clone().to_bolt_type(),
438 )
439 .param("dimensions", model.ndims() as i64),
440 )
441 .await
442 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
443
444 let index_exists = self
446 .graph
447 .run(
448 neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
449 .param("index_name", index_config.index_name.clone()),
450 )
451 .await;
452
453 if index_exists.is_err() {
454 tracing::warn!(
455 "Index with name `{}` is not ready or could not be created.",
456 index_config.index_name.clone()
457 );
458 }
459
460 tracing::info!(
461 "Index created successfully with name: {}",
462 index_config.index_name
463 );
464 Ok(())
465 }
466}