use neo4rs::{Graph, Query};
use rig::{
embeddings::{Embedding, EmbeddingModel},
vector_store::{VectorStoreError, VectorStoreIndex},
};
use serde::{Deserialize, Serialize, de::Error};
use crate::Neo4jClient;
pub struct Neo4jVectorIndex<M: EmbeddingModel> {
graph: Graph,
embedding_model: M,
search_params: SearchParams,
index_config: IndexConfig,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct IndexConfig {
pub index_name: String,
pub embedding_property: String,
pub similarity_function: VectorSimilarityFunction,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
index_name: "vector_index".to_string(),
embedding_property: "embedding".to_string(),
similarity_function: VectorSimilarityFunction::Cosine,
}
}
}
impl IndexConfig {
pub fn new(index_name: impl Into<String>) -> Self {
Self {
index_name: index_name.into(),
embedding_property: "embedding".to_string(),
similarity_function: VectorSimilarityFunction::Cosine,
}
}
pub fn index_name(mut self, index_name: &str) -> Self {
self.index_name = index_name.to_string();
self
}
pub fn similarity_function(mut self, similarity_function: VectorSimilarityFunction) -> Self {
self.similarity_function = similarity_function;
self
}
pub fn embedding_property(mut self, embedding_property: &str) -> Self {
self.embedding_property = embedding_property.to_string();
self
}
}
#[derive(Default, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum VectorSimilarityFunction {
#[default]
Cosine,
Euclidean,
}
use std::str::FromStr;
impl FromStr for VectorSimilarityFunction {
type Err = VectorStoreError;
fn from_str(s: &str) -> Result<Self, VectorStoreError> {
match s.to_lowercase().as_str() {
"cosine" => Ok(VectorSimilarityFunction::Cosine),
"euclidean" => Ok(VectorSimilarityFunction::Euclidean),
_ => Err(VectorStoreError::JsonError(serde_json::Error::custom(
format!("Invalid similarity function: {s}"),
))),
}
}
}
const BASE_VECTOR_SEARCH_QUERY: &str = "
CALL db.index.vector.queryNodes($index_name, $num_candidates, $queryVector)
YIELD node, score
";
impl<M: EmbeddingModel> Neo4jVectorIndex<M> {
pub fn new(
graph: Graph,
embedding_model: M,
index_config: IndexConfig,
search_params: SearchParams,
) -> Self {
Self {
graph,
embedding_model,
index_config,
search_params,
}
}
pub fn build_vector_search_query(
&self,
prompt_embedding: Embedding,
return_node: bool,
n: usize,
) -> Query {
let where_clause = match &self.search_params.post_vector_search_filter {
Some(filter) => format!("WHERE {filter}"),
None => "".to_string(),
};
let query = format!(
"\
{}\
\t{}\n\
\tRETURN score, ID(node) as element_id {}
",
BASE_VECTOR_SEARCH_QUERY,
where_clause,
if return_node {
format!(
", node {{.*, {}:null }} as node",
self.index_config.embedding_property
)
} else {
"".to_string()
}
);
tracing::debug!("Query before params: {}", query);
Query::new(query)
.param("queryVector", prompt_embedding.vec)
.param("num_candidates", n as i64)
.param("index_name", self.index_config.index_name.clone())
}
}
pub struct SearchParams {
post_vector_search_filter: Option<String>,
}
impl SearchParams {
pub fn new(filter: Option<String>) -> Self {
Self {
post_vector_search_filter: filter,
}
}
pub fn filter(mut self, filter: String) -> Self {
self.post_vector_search_filter = Some(filter);
self
}
}
impl Default for SearchParams {
fn default() -> Self {
Self::new(None)
}
}
#[derive(Debug, Deserialize)]
pub struct RowResultNode<T> {
score: f64,
element_id: i64,
node: T,
}
#[derive(Debug, Deserialize)]
struct RowResult {
score: f64,
element_id: i64,
}
impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for Neo4jVectorIndex<M> {
async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let prompt_embedding = self.embedding_model.embed_text(query).await?;
let query = self.build_vector_search_query(prompt_embedding, true, n);
let rows = Neo4jClient::execute_and_collect::<RowResultNode<T>>(&self.graph, query).await?;
let results = rows
.into_iter()
.map(|row| (row.score, row.element_id.to_string(), row.node))
.collect::<Vec<_>>();
Ok(results)
}
async fn top_n_ids(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let prompt_embedding = self.embedding_model.embed_text(query).await?;
let query = self.build_vector_search_query(prompt_embedding, false, n);
let rows = Neo4jClient::execute_and_collect::<RowResult>(&self.graph, query).await?;
let results = rows
.into_iter()
.map(|row| (row.score, row.element_id.to_string()))
.collect::<Vec<_>>();
Ok(results)
}
}