1pub mod vector_index;
89use std::str::FromStr;
90
91use futures::TryStreamExt;
92use neo4rs::*;
93use bep::{embeddings::EmbeddingModel, vector_store::VectorStoreError};
94use serde::Deserialize;
95use vector_index::{IndexConfig, Neo4jVectorIndex, SearchParams, VectorSimilarityFunction};
96
97pub struct Neo4jClient {
98 pub graph: Graph,
99}
100
101fn neo4j_to_bep_error(e: neo4rs::Error) -> VectorStoreError {
102 VectorStoreError::DatastoreError(Box::new(e))
103}
104
105pub trait ToBoltType {
106 fn to_bolt_type(&self) -> BoltType;
107}
108
109impl<T> ToBoltType for T
110where
111 T: serde::Serialize,
112{
113 fn to_bolt_type(&self) -> BoltType {
114 match serde_json::to_value(self) {
115 Ok(json_value) => match json_value {
116 serde_json::Value::Null => BoltType::Null(BoltNull),
117 serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
118 serde_json::Value::Number(num) => {
119 if let Some(i) = num.as_i64() {
120 BoltType::Integer(BoltInteger::new(i))
121 } else if let Some(f) = num.as_f64() {
122 BoltType::Float(BoltFloat::new(f))
123 } else {
124 println!("Couldn't map to BoltType, will ignore.");
125 BoltType::Null(BoltNull) }
127 }
128 serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
129 serde_json::Value::Array(arr) => BoltType::List(
130 arr.iter()
131 .map(|v| v.to_bolt_type())
132 .collect::<Vec<BoltType>>()
133 .into(),
134 ),
135 serde_json::Value::Object(obj) => {
136 let mut bolt_map = BoltMap::new();
137 for (k, v) in obj {
138 bolt_map.put(BoltString::new(&k), v.to_bolt_type());
139 }
140 BoltType::Map(bolt_map)
141 }
142 },
143 Err(_) => {
144 println!("Couldn't serialize to JSON, will ignore.");
145 BoltType::Null(BoltNull) }
147 }
148 }
149}
150
151impl Neo4jClient {
152 const GET_INDEX_QUERY: &'static str = "
153 SHOW VECTOR INDEXES
154 YIELD name, properties, options
155 WHERE name=$index_name
156 RETURN name, properties, options
157 ";
158
159 const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
160
161 pub fn new(graph: Graph) -> Self {
162 Self { graph }
163 }
164
165 pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
166 tracing::info!("Connecting to Neo4j DB at {} ...", uri);
167 let graph = Graph::new(uri, user, password)
168 .await
169 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170 tracing::info!("Connected to Neo4j");
171 Ok(Self { graph })
172 }
173
174 pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
175 let graph = Graph::connect(config)
176 .await
177 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
178 Ok(Self { graph })
179 }
180
181 pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
182 graph: &Graph,
183 query: Query,
184 ) -> Result<Vec<T>, VectorStoreError> {
185 graph
186 .execute(query)
187 .await
188 .map_err(neo4j_to_bep_error)?
189 .into_stream_as::<T>()
190 .try_collect::<Vec<T>>()
191 .await
192 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
193 }
194
195 pub async fn get_index<M: EmbeddingModel>(
202 &self,
203 model: M,
204 index_name: &str,
205 search_params: SearchParams,
206 ) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
207 #[derive(Deserialize)]
208 struct IndexInfo {
209 name: String,
210 properties: Vec<String>,
211 options: IndexOptions,
212 }
213
214 #[derive(Deserialize)]
215 #[serde(rename_all = "camelCase")]
216 struct IndexOptions {
217 _index_provider: String,
218 index_config: IndexConfigDetails,
219 }
220
221 #[derive(Deserialize)]
222 struct IndexConfigDetails {
223 #[serde(rename = "vector.dimensions")]
224 vector_dimensions: i64,
225 #[serde(rename = "vector.similarity_function")]
226 vector_similarity_function: String,
227 }
228
229 let index_info = Self::execute_and_collect::<IndexInfo>(
230 &self.graph,
231 neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
232 )
233 .await?;
234
235 let index_config = if let Some(index) = index_info.first() {
236 if index.options.index_config.vector_dimensions != model.ndims() as i64 {
237 tracing::warn!(
238 "The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
239 index.options.index_config.vector_dimensions,
240 model.ndims()
241 );
242 }
243 IndexConfig::new(index.name.clone())
244 .embedding_property(index.properties.first().unwrap())
245 .similarity_function(VectorSimilarityFunction::from_str(
246 &index.options.index_config.vector_similarity_function,
247 )?)
248 } else {
249 let indexes = Self::execute_and_collect::<String>(
250 &self.graph,
251 neo4rs::query(Self::SHOW_INDEXES_QUERY),
252 )
253 .await?;
254 return Err(VectorStoreError::DatastoreError(Box::new(
255 std::io::Error::new(
256 std::io::ErrorKind::NotFound,
257 format!(
258 "Index `{}` not found in database. Available indexes: {:?}",
259 index_name, indexes
260 ),
261 ),
262 )));
263 };
264 Ok(Neo4jVectorIndex::new(
265 self.graph.clone(),
266 model,
267 index_config,
268 search_params,
269 ))
270 }
271
272 pub async fn create_vector_index(
284 &self,
285 index_config: IndexConfig,
286 node_label: &str,
287 model: &impl EmbeddingModel,
288 ) -> Result<(), VectorStoreError> {
289 tracing::info!("Creating vector index {} ...", index_config.index_name);
291
292 let create_vector_index_query = format!(
293 "
294 CREATE VECTOR INDEX $index_name IF NOT EXISTS
295 FOR (m:{})
296 ON m.{}
297 OPTIONS {{
298 indexConfig: {{
299 `vector.dimensions`: $dimensions,
300 `vector.similarity_function`: $similarity_function
301 }}
302 }}",
303 node_label, index_config.embedding_property
304 );
305
306 self.graph
307 .run(
308 neo4rs::query(&create_vector_index_query)
309 .param("index_name", index_config.index_name.clone())
310 .param(
311 "similarity_function",
312 index_config.similarity_function.clone().to_bolt_type(),
313 )
314 .param("dimensions", model.ndims() as i64),
315 )
316 .await
317 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
318
319 let index_exists = self
321 .graph
322 .run(
323 neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
324 .param("index_name", index_config.index_name.clone()),
325 )
326 .await;
327
328 if index_exists.is_err() {
329 tracing::warn!(
330 "Index with name `{}` is not ready or could not be created.",
331 index_config.index_name.clone()
332 );
333 }
334
335 tracing::info!(
336 "Index created successfully with name: {}",
337 index_config.index_name
338 );
339 Ok(())
340 }
341}
342
343#[allow(dead_code)]
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use neo4rs::ConfigBuilder;
348 use bep::{
349 providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
350 vector_store::VectorStoreIndex,
351 };
352 use serde::Deserialize;
353 use std::env;
354
355 const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687";
356 const NEO4J_DB: &str = "recommendations";
357 const NEO4J_USERNAME: &str = "recommendations";
358 const NEO4J_PASSWORD: &str = "recommendations";
359
360 #[derive(Debug, Deserialize)]
361 struct Movie {
362 title: String,
363 plot: String,
364 }
365
366 #[tokio::test]
367 async fn test_connect() {
368 let result = Neo4jClient::from_config(
369 ConfigBuilder::default()
370 .uri(NEO4J_URI)
371 .db(NEO4J_DB)
372 .user(NEO4J_USERNAME)
373 .password(NEO4J_PASSWORD)
374 .build()
375 .unwrap(),
376 )
377 .await;
378 assert!(result.is_ok());
379 }
380
381 #[tokio::test]
382 async fn test_vector_search_no_display() {
383 let results = vector_search().await.unwrap();
384 assert!(results.len() > 0);
385 }
386
387 async fn vector_search() -> Result<Vec<(f64, String, Movie)>, VectorStoreError> {
388 let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
389 let openai_client = Client::new(&openai_api_key);
390 let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
391
392 let client = Neo4jClient::from_config(
393 ConfigBuilder::default()
394 .uri(NEO4J_URI)
395 .db(NEO4J_DB)
396 .user(NEO4J_USERNAME)
397 .password(NEO4J_PASSWORD)
398 .build()
399 .unwrap(),
400 )
401 .await
402 .unwrap();
403
404 let index = client
405 .get_index(model, "moviePlotsEmbedding", SearchParams::default())
406 .await?;
407 Ok(index.top_n::<Movie>("Batman", 3).await?)
408 }
409}