pub mod vector_index;
use std::str::FromStr;
use futures::TryStreamExt;
use neo4rs::*;
use rig::{embeddings::EmbeddingModel, vector_store::VectorStoreError};
use serde::Deserialize;
use vector_index::{IndexConfig, Neo4jVectorIndex, SearchParams, VectorSimilarityFunction};
pub struct Neo4jClient {
pub graph: Graph,
}
fn neo4j_to_rig_error(e: neo4rs::Error) -> VectorStoreError {
VectorStoreError::DatastoreError(Box::new(e))
}
pub trait ToBoltType {
fn to_bolt_type(&self) -> BoltType;
}
impl<T> ToBoltType for T
where
T: serde::Serialize,
{
fn to_bolt_type(&self) -> BoltType {
match serde_json::to_value(self) {
Ok(json_value) => match json_value {
serde_json::Value::Null => BoltType::Null(BoltNull),
serde_json::Value::Bool(b) => BoltType::Boolean(BoltBoolean::new(b)),
serde_json::Value::Number(num) => {
if let Some(i) = num.as_i64() {
BoltType::Integer(BoltInteger::new(i))
} else if let Some(f) = num.as_f64() {
BoltType::Float(BoltFloat::new(f))
} else {
println!("Couldn't map to BoltType, will ignore.");
BoltType::Null(BoltNull) }
}
serde_json::Value::String(s) => BoltType::String(BoltString::new(&s)),
serde_json::Value::Array(arr) => BoltType::List(
arr.iter()
.map(|v| v.to_bolt_type())
.collect::<Vec<BoltType>>()
.into(),
),
serde_json::Value::Object(obj) => {
let mut bolt_map = BoltMap::new();
for (k, v) in obj {
bolt_map.put(BoltString::new(&k), v.to_bolt_type());
}
BoltType::Map(bolt_map)
}
},
Err(_) => {
println!("Couldn't serialize to JSON, will ignore.");
BoltType::Null(BoltNull) }
}
}
}
impl Neo4jClient {
const GET_INDEX_QUERY: &'static str = "
SHOW VECTOR INDEXES
YIELD name, properties, options
WHERE name=$index_name
RETURN name, properties, options
";
const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";
pub fn new(graph: Graph) -> Self {
Self { graph }
}
pub async fn connect(uri: &str, user: &str, password: &str) -> Result<Self, VectorStoreError> {
tracing::info!("Connecting to Neo4j DB at {} ...", uri);
let graph = Graph::new(uri, user, password)
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
tracing::info!("Connected to Neo4j");
Ok(Self { graph })
}
pub async fn from_config(config: Config) -> Result<Self, VectorStoreError> {
let graph = Graph::connect(config)
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
Ok(Self { graph })
}
pub async fn execute_and_collect<T: for<'a> Deserialize<'a>>(
graph: &Graph,
query: Query,
) -> Result<Vec<T>, VectorStoreError> {
graph
.execute(query)
.await
.map_err(neo4j_to_rig_error)?
.into_stream_as::<T>()
.try_collect::<Vec<T>>()
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
}
pub async fn get_index<M: EmbeddingModel>(
&self,
model: M,
index_name: &str,
search_params: SearchParams,
) -> Result<Neo4jVectorIndex<M>, VectorStoreError> {
#[derive(Deserialize)]
struct IndexInfo {
name: String,
properties: Vec<String>,
options: IndexOptions,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct IndexOptions {
_index_provider: String,
index_config: IndexConfigDetails,
}
#[derive(Deserialize)]
struct IndexConfigDetails {
#[serde(rename = "vector.dimensions")]
vector_dimensions: i64,
#[serde(rename = "vector.similarity_function")]
vector_similarity_function: String,
}
let index_info = Self::execute_and_collect::<IndexInfo>(
&self.graph,
neo4rs::query(Self::GET_INDEX_QUERY).param("index_name", index_name),
)
.await?;
let index_config = if let Some(index) = index_info.first() {
if index.options.index_config.vector_dimensions != model.ndims() as i64 {
tracing::warn!(
"The embedding vector dimensions of the existing Neo4j DB index ({}) do not match the provided model dimensions ({}). This may affect search performance.",
index.options.index_config.vector_dimensions,
model.ndims()
);
}
IndexConfig::new(index.name.clone())
.embedding_property(index.properties.first().unwrap())
.similarity_function(VectorSimilarityFunction::from_str(
&index.options.index_config.vector_similarity_function,
)?)
} else {
let indexes = Self::execute_and_collect::<String>(
&self.graph,
neo4rs::query(Self::SHOW_INDEXES_QUERY),
)
.await?;
return Err(VectorStoreError::DatastoreError(Box::new(
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!(
"Index `{index_name}` not found in database. Available indexes: {indexes:?}"
),
),
)));
};
Ok(Neo4jVectorIndex::new(
self.graph.clone(),
model,
index_config,
search_params,
))
}
pub async fn create_vector_index(
&self,
index_config: IndexConfig,
node_label: &str,
model: &impl EmbeddingModel,
) -> Result<(), VectorStoreError> {
tracing::info!("Creating vector index {} ...", index_config.index_name);
let create_vector_index_query = format!(
"
CREATE VECTOR INDEX $index_name IF NOT EXISTS
FOR (m:{})
ON m.{}
OPTIONS {{
indexConfig: {{
`vector.dimensions`: $dimensions,
`vector.similarity_function`: $similarity_function
}}
}}",
node_label, index_config.embedding_property
);
self.graph
.run(
neo4rs::query(&create_vector_index_query)
.param("index_name", index_config.index_name.clone())
.param(
"similarity_function",
index_config.similarity_function.clone().to_bolt_type(),
)
.param("dimensions", model.ndims() as i64),
)
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
let index_exists = self
.graph
.run(
neo4rs::query("CALL db.awaitIndex($index_name, 10000)")
.param("index_name", index_config.index_name.clone()),
)
.await;
if index_exists.is_err() {
tracing::warn!(
"Index with name `{}` is not ready or could not be created.",
index_config.index_name.clone()
);
}
tracing::info!(
"Index created successfully with name: {}",
index_config.index_name
);
Ok(())
}
}