1use neo4rs::{Graph, Query};
8use rig::{
9 embeddings::{Embedding, EmbeddingModel},
10 vector_store::{
11 VectorStoreError, VectorStoreIndex,
12 request::{SearchFilter, VectorSearchRequest},
13 },
14};
15use serde::{Deserialize, Serialize, de::Error};
16
17use crate::{Neo4jClient, Neo4jSearchFilter};
18
19pub struct Neo4jVectorIndex<M>
20where
21 M: EmbeddingModel,
22{
23 graph: Graph,
24 embedding_model: M,
25 index_config: IndexConfig,
26}
27
28#[derive(Serialize, Deserialize, Clone)]
36pub struct IndexConfig {
37 pub index_name: String,
38 pub embedding_property: String,
39 pub similarity_function: VectorSimilarityFunction,
40}
41
42impl Default for IndexConfig {
43 fn default() -> Self {
44 Self {
45 index_name: "vector_index".to_string(),
46 embedding_property: "embedding".to_string(),
47 similarity_function: VectorSimilarityFunction::Cosine,
48 }
49 }
50}
51
52impl IndexConfig {
53 pub fn new(index_name: impl Into<String>) -> Self {
54 Self {
55 index_name: index_name.into(),
56 embedding_property: "embedding".to_string(),
57 similarity_function: VectorSimilarityFunction::Cosine,
58 }
59 }
60
61 pub fn index_name(mut self, index_name: &str) -> Self {
62 self.index_name = index_name.to_string();
63 self
64 }
65
66 pub fn similarity_function(mut self, similarity_function: VectorSimilarityFunction) -> Self {
67 self.similarity_function = similarity_function;
68 self
69 }
70
71 pub fn embedding_property(mut self, embedding_property: &str) -> Self {
72 self.embedding_property = embedding_property.to_string();
73 self
74 }
75}
76
77#[derive(Default, Serialize, Deserialize, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum VectorSimilarityFunction {
83 #[default]
84 Cosine,
85 Euclidean,
86}
87
88use std::str::FromStr;
89
90impl FromStr for VectorSimilarityFunction {
91 type Err = VectorStoreError;
92
93 fn from_str(s: &str) -> Result<Self, VectorStoreError> {
94 match s.to_lowercase().as_str() {
95 "cosine" => Ok(VectorSimilarityFunction::Cosine),
96 "euclidean" => Ok(VectorSimilarityFunction::Euclidean),
97 _ => Err(VectorStoreError::JsonError(serde_json::Error::custom(
98 format!("Invalid similarity function: {s}"),
99 ))),
100 }
101 }
102}
103
104const BASE_VECTOR_SEARCH_QUERY: &str = "
105 CALL db.index.vector.queryNodes($index_name, $num_candidates, $queryVector)
106 YIELD node, score
107";
108
109impl<M> Neo4jVectorIndex<M>
110where
111 M: EmbeddingModel,
112{
113 pub fn new(graph: Graph, embedding_model: M, index_config: IndexConfig) -> Self {
114 Self {
115 graph,
116 embedding_model,
117 index_config,
118 }
119 }
120
121 pub fn build_vector_search_query(
132 &self,
133 prompt_embedding: Embedding,
134 return_node: bool,
135 req: &VectorSearchRequest<Neo4jSearchFilter>,
136 ) -> Query {
137 let where_clause = match (req.threshold(), req.filter()) {
138 (Some(thresh), Some(filt)) => Neo4jSearchFilter::gt("distance", thresh.into())
139 .and(filt.clone())
140 .render(),
141 (Some(thresh), _) => Neo4jSearchFilter::gt("distance", thresh.into()).render(),
142 (_, Some(filt)) => filt.clone().render(),
143 _ => String::new(),
144 };
145
146 let query = format!(
148 "\
149 {}\
150 \t{}\n\
151 \tRETURN score, ID(node) as element_id {}
152 ",
153 BASE_VECTOR_SEARCH_QUERY,
154 where_clause,
155 if return_node {
156 format!(
157 ", node {{.*, {}:null }} as node",
158 self.index_config.embedding_property
159 )
160 } else {
161 "".to_string()
162 }
163 );
164
165 tracing::debug!("Query before params: {}", query);
166
167 Query::new(query)
168 .param("queryVector", prompt_embedding.vec)
169 .param("num_candidates", req.samples() as i64)
170 .param("index_name", self.index_config.index_name.clone())
171 }
172}
173
174pub struct SearchParams {
176 post_vector_search_filter: Option<String>,
179}
180
181impl SearchParams {
182 pub fn new(filter: Option<String>) -> Self {
184 Self {
185 post_vector_search_filter: filter,
186 }
187 }
188
189 pub fn filter(mut self, filter: String) -> Self {
190 self.post_vector_search_filter = Some(filter);
191 self
192 }
193}
194
195impl Default for SearchParams {
196 fn default() -> Self {
197 Self::new(None)
198 }
199}
200
201#[derive(Debug, Deserialize)]
202pub struct RowResultNode<T> {
203 score: f64,
204 element_id: i64,
205 node: T,
206}
207
208#[derive(Debug, Deserialize)]
209struct RowResult {
210 score: f64,
211 element_id: i64,
212}
213
214impl<M> VectorStoreIndex for Neo4jVectorIndex<M>
215where
216 M: EmbeddingModel + std::marker::Sync + Send,
217{
218 type Filter = Neo4jSearchFilter;
219
220 async fn top_n<T: for<'a> Deserialize<'a> + std::marker::Send>(
235 &self,
236 req: VectorSearchRequest<Neo4jSearchFilter>,
237 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
238 let prompt_embedding = self.embedding_model.embed_text(req.query()).await?;
239 let query = self.build_vector_search_query(prompt_embedding, true, &req);
240
241 let rows = Neo4jClient::execute_and_collect::<RowResultNode<T>>(&self.graph, query).await?;
242
243 let results = rows
244 .into_iter()
245 .map(|row| (row.score, row.element_id.to_string(), row.node))
246 .collect::<Vec<_>>();
247
248 Ok(results)
249 }
250
251 async fn top_n_ids(
254 &self,
255 req: VectorSearchRequest<Neo4jSearchFilter>,
256 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
257 let prompt_embedding = self.embedding_model.embed_text(req.query()).await?;
258
259 let query = self.build_vector_search_query(prompt_embedding, true, &req);
260
261 let rows = Neo4jClient::execute_and_collect::<RowResult>(&self.graph, query).await?;
262
263 let results = rows
264 .into_iter()
265 .map(|row| (row.score, row.element_id.to_string()))
266 .collect::<Vec<_>>();
267
268 Ok(results)
269 }
270}