Skip to main content

rig_neo4j/
vector_index.rs

1//! A vector index for a Neo4j graph DB.
2//!
3//! This module provides a way to perform vector searches on a Neo4j graph DB.
4//! It uses the [Neo4j vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/)
5//! to search for similar nodes based on a query.
6
7use neo4rs::{Graph, Query};
8use rig::{
9    embeddings::{Embedding, EmbeddingModel},
10    vector_store::{
11        VectorStoreError, VectorStoreIndex,
12        request::{SearchFilter, VectorSearchRequest},
13    },
14};
15use serde::{Deserialize, Serialize, de::Error};
16
17use crate::{Neo4jClient, Neo4jSearchFilter};
18
19pub struct Neo4jVectorIndex<M>
20where
21    M: EmbeddingModel,
22{
23    graph: Graph,
24    embedding_model: M,
25    index_config: IndexConfig,
26}
27
28/// The index name must be unique among both indexes and constraints.
29/// A newly created index is not immediately available but is created in the background.
30///
31/// #### Default Values
32/// - `index_name`: "vector_index"
33/// - `embedding_property`: "embedding"
34/// - `similarity_function`: VectorSimilarityFunction::Cosine
35#[derive(Serialize, Deserialize, Clone)]
36pub struct IndexConfig {
37    pub index_name: String,
38    pub embedding_property: String,
39    pub similarity_function: VectorSimilarityFunction,
40}
41
42impl Default for IndexConfig {
43    fn default() -> Self {
44        Self {
45            index_name: "vector_index".to_string(),
46            embedding_property: "embedding".to_string(),
47            similarity_function: VectorSimilarityFunction::Cosine,
48        }
49    }
50}
51
52impl IndexConfig {
53    pub fn new(index_name: impl Into<String>) -> Self {
54        Self {
55            index_name: index_name.into(),
56            embedding_property: "embedding".to_string(),
57            similarity_function: VectorSimilarityFunction::Cosine,
58        }
59    }
60
61    pub fn index_name(mut self, index_name: &str) -> Self {
62        self.index_name = index_name.to_string();
63        self
64    }
65
66    pub fn similarity_function(mut self, similarity_function: VectorSimilarityFunction) -> Self {
67        self.similarity_function = similarity_function;
68        self
69    }
70
71    pub fn embedding_property(mut self, embedding_property: &str) -> Self {
72        self.embedding_property = embedding_property.to_string();
73        self
74    }
75}
76
77/// Cosine is most commonly used, but Euclidean is also supported.
78/// See [Neo4j vector similarity functions](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#similarity-functions)
79/// for more information.
80#[derive(Default, Serialize, Deserialize, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum VectorSimilarityFunction {
83    #[default]
84    Cosine,
85    Euclidean,
86}
87
88use std::str::FromStr;
89
90impl FromStr for VectorSimilarityFunction {
91    type Err = VectorStoreError;
92
93    fn from_str(s: &str) -> Result<Self, VectorStoreError> {
94        match s.to_lowercase().as_str() {
95            "cosine" => Ok(VectorSimilarityFunction::Cosine),
96            "euclidean" => Ok(VectorSimilarityFunction::Euclidean),
97            _ => Err(VectorStoreError::JsonError(serde_json::Error::custom(
98                format!("Invalid similarity function: {s}"),
99            ))),
100        }
101    }
102}
103
104const BASE_VECTOR_SEARCH_QUERY: &str = "
105    CALL db.index.vector.queryNodes($index_name, $num_candidates, $queryVector)
106    YIELD node, score
107";
108
109impl<M> Neo4jVectorIndex<M>
110where
111    M: EmbeddingModel,
112{
113    pub fn new(graph: Graph, embedding_model: M, index_config: IndexConfig) -> Self {
114        Self {
115            graph,
116            embedding_model,
117            index_config,
118        }
119    }
120
121    /// Build a Neo4j query that performs a vector search against an index.
122    /// See [Query vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#query-vector-index) for more information.
123    ///
124    /// Query template:
125    /// ```text
126    /// CALL db.index.vector.queryNodes($index_name, $num_candidates, $queryVector)
127    /// YIELD node, score
128    /// WHERE {where_clause}
129    /// RETURN score, ID(node) as element_id, node {.*, embedding:null } as node
130    /// ```
131    pub fn build_vector_search_query(
132        &self,
133        prompt_embedding: Embedding,
134        return_node: bool,
135        req: &VectorSearchRequest<Neo4jSearchFilter>,
136    ) -> Query {
137        let where_clause = match (req.threshold(), req.filter()) {
138            (Some(thresh), Some(filt)) => Neo4jSearchFilter::gt("distance", thresh.into())
139                .and(filt.clone())
140                .render(),
141            (Some(thresh), _) => Neo4jSearchFilter::gt("distance", thresh.into()).render(),
142            (_, Some(filt)) => filt.clone().render(),
143            _ => String::new(),
144        };
145
146        // Propertiy containing the embedding vectors are excluded from the returned node
147        let query = format!(
148            "\
149            {}\
150            \t{}\n\
151            \tRETURN score, ID(node) as element_id {}
152            ",
153            BASE_VECTOR_SEARCH_QUERY,
154            where_clause,
155            if return_node {
156                format!(
157                    ", node {{.*, {}:null }} as node",
158                    self.index_config.embedding_property
159                )
160            } else {
161                "".to_string()
162            }
163        );
164
165        tracing::debug!("Query before params: {}", query);
166
167        Query::new(query)
168            .param("queryVector", prompt_embedding.vec)
169            .param("num_candidates", req.samples() as i64)
170            .param("index_name", self.index_config.index_name.clone())
171    }
172}
173
174/// Search parameters for a vector search. Neo4j currently only supports post-vector-search filtering.
175pub struct SearchParams {
176    /// Sets the **post-filter** field of the search params. Uses a WHERE clause.
177    /// See [Neo4j WHERE clause](https://neo4j.com/docs/cypher-manual/current/clauses/where/) for more information.
178    post_vector_search_filter: Option<String>,
179}
180
181impl SearchParams {
182    /// Initializes a new `SearchParams` with default values.
183    pub fn new(filter: Option<String>) -> Self {
184        Self {
185            post_vector_search_filter: filter,
186        }
187    }
188
189    pub fn filter(mut self, filter: String) -> Self {
190        self.post_vector_search_filter = Some(filter);
191        self
192    }
193}
194
195impl Default for SearchParams {
196    fn default() -> Self {
197        Self::new(None)
198    }
199}
200
201#[derive(Debug, Deserialize)]
202pub struct RowResultNode<T> {
203    score: f64,
204    element_id: i64,
205    node: T,
206}
207
208#[derive(Debug, Deserialize)]
209struct RowResult {
210    score: f64,
211    element_id: i64,
212}
213
214impl<M> VectorStoreIndex for Neo4jVectorIndex<M>
215where
216    M: EmbeddingModel + std::marker::Sync + Send,
217{
218    type Filter = Neo4jSearchFilter;
219
220    /// Get the top n nodes and scores matching the query.
221    ///
222    /// #### Generic Type Parameters
223    ///
224    /// - `T`: The type used to deserialize the result from the Neo4j query.
225    ///   It must implement the `serde::Deserialize` trait.
226    ///
227    /// #### Returns
228    ///
229    /// Returns a `Result` containing a vector of tuples. Each tuple contains:
230    /// - A `f64` representing the similarity score
231    /// - A `String` representing the node ID
232    /// - A value of type `T` representing the deserialized node data
233    ///
234    async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>(
235        &self,
236        req: VectorSearchRequest<Neo4jSearchFilter>,
237    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
238        let prompt_embedding = self.embedding_model.embed_text(req.query()).await?;
239        let query = self.build_vector_search_query(prompt_embedding, true, &req);
240
241        let rows = Neo4jClient::execute_and_collect::<RowResultNode<T>>(&self.graph, query).await?;
242
243        let results = rows
244            .into_iter()
245            .map(|row| (row.score, row.element_id.to_string(), row.node))
246            .collect::<Vec<_>>();
247
248        Ok(results)
249    }
250
251    /// Get the top n ids and scores matching the query. Runs faster than top_n since it doesn't need to transfer and parse
252    /// the full nodes and embeddings to the client.
253    async fn top_n_ids(
254        &self,
255        req: VectorSearchRequest<Neo4jSearchFilter>,
256    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
257        let prompt_embedding = self.embedding_model.embed_text(req.query()).await?;
258
259        let query = self.build_vector_search_query(prompt_embedding, true, &req);
260
261        let rows = Neo4jClient::execute_and_collect::<RowResult>(&self.graph, query).await?;
262
263        let results = rows
264            .into_iter()
265            .map(|row| (row.score, row.element_id.to_string()))
266            .collect::<Vec<_>>();
267
268        Ok(results)
269    }
270}