1#[cfg(any(feature = "text-index", feature = "hybrid-search"))]
4use grafeo_common::types::NodeId;
5#[cfg(feature = "vector-index")]
6use grafeo_common::types::Value;
7#[cfg(any(feature = "text-index", feature = "hybrid-search"))]
8use grafeo_common::utils::error::Error;
9#[cfg(any(
10 feature = "vector-index",
11 feature = "text-index",
12 feature = "hybrid-search"
13))]
14use grafeo_common::utils::error::Result;
15
16impl super::GrafeoDB {
17 #[cfg(feature = "vector-index")]
25 fn compute_filter_allowlist(
26 &self,
27 label: &str,
28 filters: Option<&std::collections::HashMap<String, Value>>,
29 ) -> Option<std::collections::HashSet<NodeId>> {
30 let filters = filters.filter(|f| !f.is_empty())?;
31
32 let label_nodes: std::collections::HashSet<NodeId> =
34 self.store.nodes_by_label(label).into_iter().collect();
35
36 let mut allowlist = label_nodes;
37
38 for (key, filter_value) in filters {
39 let is_operator_filter = matches!(filter_value, Value::Map(ops) if ops.keys().any(|k| k.as_str().starts_with('$')));
41
42 let matching: std::collections::HashSet<NodeId> = if is_operator_filter {
43 self.store
45 .find_nodes_matching_filter(key, filter_value)
46 .into_iter()
47 .collect()
48 } else {
49 self.store
51 .find_nodes_by_property(key, filter_value)
52 .into_iter()
53 .collect()
54 };
55 allowlist = allowlist.intersection(&matching).copied().collect();
56
57 if allowlist.is_empty() {
59 return Some(allowlist);
60 }
61 }
62
63 Some(allowlist)
64 }
65
66 #[cfg(feature = "vector-index")]
84 pub fn vector_search(
85 &self,
86 label: &str,
87 property: &str,
88 query: &[f32],
89 k: usize,
90 ef: Option<usize>,
91 filters: Option<&std::collections::HashMap<String, Value>>,
92 ) -> Result<Vec<(grafeo_common::types::NodeId, f32)>> {
93 let index = self.store.get_vector_index(label, property).ok_or_else(|| {
94 grafeo_common::utils::error::Error::Internal(format!(
95 "No vector index found for :{label}({property}). Call create_vector_index() first."
96 ))
97 })?;
98
99 let accessor =
100 grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
101
102 let results = match self.compute_filter_allowlist(label, filters) {
103 Some(allowlist) => match ef {
104 Some(ef_val) => {
105 index.search_with_ef_and_filter(query, k, ef_val, &allowlist, &accessor)
106 }
107 None => index.search_with_filter(query, k, &allowlist, &accessor),
108 },
109 None => match ef {
110 Some(ef_val) => index.search_with_ef(query, k, ef_val, &accessor),
111 None => index.search(query, k, &accessor),
112 },
113 };
114
115 Ok(results)
116 }
117
118 #[cfg(feature = "vector-index")]
131 pub fn batch_vector_search(
132 &self,
133 label: &str,
134 property: &str,
135 queries: &[Vec<f32>],
136 k: usize,
137 ef: Option<usize>,
138 filters: Option<&std::collections::HashMap<String, Value>>,
139 ) -> Result<Vec<Vec<(grafeo_common::types::NodeId, f32)>>> {
140 let index = self.store.get_vector_index(label, property).ok_or_else(|| {
141 grafeo_common::utils::error::Error::Internal(format!(
142 "No vector index found for :{label}({property}). Call create_vector_index() first."
143 ))
144 })?;
145
146 let accessor =
147 grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
148
149 let results = match self.compute_filter_allowlist(label, filters) {
150 Some(allowlist) => match ef {
151 Some(ef_val) => {
152 index.batch_search_with_ef_and_filter(queries, k, ef_val, &allowlist, &accessor)
153 }
154 None => index.batch_search_with_filter(queries, k, &allowlist, &accessor),
155 },
156 None => match ef {
157 Some(ef_val) => index.batch_search_with_ef(queries, k, ef_val, &accessor),
158 None => index.batch_search(queries, k, &accessor),
159 },
160 };
161
162 Ok(results)
163 }
164
165 #[cfg(feature = "vector-index")]
188 #[allow(clippy::too_many_arguments)]
189 pub fn mmr_search(
190 &self,
191 label: &str,
192 property: &str,
193 query: &[f32],
194 k: usize,
195 fetch_k: Option<usize>,
196 lambda: Option<f32>,
197 ef: Option<usize>,
198 filters: Option<&std::collections::HashMap<String, Value>>,
199 ) -> Result<Vec<(grafeo_common::types::NodeId, f32)>> {
200 use grafeo_core::index::vector::mmr_select;
201
202 let index = self.store.get_vector_index(label, property).ok_or_else(|| {
203 grafeo_common::utils::error::Error::Internal(format!(
204 "No vector index found for :{label}({property}). Call create_vector_index() first."
205 ))
206 })?;
207
208 let accessor =
209 grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
210
211 let fetch_k = fetch_k.unwrap_or(k.saturating_mul(4).max(k));
212 let lambda = lambda.unwrap_or(0.5);
213
214 let initial_results = match self.compute_filter_allowlist(label, filters) {
216 Some(allowlist) => match ef {
217 Some(ef_val) => {
218 index.search_with_ef_and_filter(query, fetch_k, ef_val, &allowlist, &accessor)
219 }
220 None => index.search_with_filter(query, fetch_k, &allowlist, &accessor),
221 },
222 None => match ef {
223 Some(ef_val) => index.search_with_ef(query, fetch_k, ef_val, &accessor),
224 None => index.search(query, fetch_k, &accessor),
225 },
226 };
227
228 if initial_results.is_empty() {
229 return Ok(Vec::new());
230 }
231
232 use grafeo_core::index::vector::VectorAccessor;
234 let candidates: Vec<(grafeo_common::types::NodeId, f32, std::sync::Arc<[f32]>)> =
235 initial_results
236 .into_iter()
237 .filter_map(|(id, dist)| accessor.get_vector(id).map(|vec| (id, dist, vec)))
238 .collect();
239
240 let candidate_refs: Vec<(grafeo_common::types::NodeId, f32, &[f32])> = candidates
242 .iter()
243 .map(|(id, dist, vec)| (*id, *dist, vec.as_ref()))
244 .collect();
245
246 let metric = index.config().metric;
248 Ok(mmr_select(query, &candidate_refs, k, lambda, metric))
249 }
250
251 #[cfg(feature = "text-index")]
260 pub fn text_search(
261 &self,
262 label: &str,
263 property: &str,
264 query: &str,
265 k: usize,
266 ) -> Result<Vec<(NodeId, f64)>> {
267 let index = self.store.get_text_index(label, property).ok_or_else(|| {
268 Error::Internal(format!(
269 "No text index found for :{label}({property}). Call create_text_index() first."
270 ))
271 })?;
272
273 Ok(index.read().search(query, k))
274 }
275
276 #[cfg(feature = "hybrid-search")]
295 #[allow(clippy::too_many_arguments)]
296 pub fn hybrid_search(
297 &self,
298 label: &str,
299 text_property: &str,
300 vector_property: &str,
301 query_text: &str,
302 query_vector: Option<&[f32]>,
303 k: usize,
304 fusion: Option<grafeo_core::index::text::FusionMethod>,
305 ) -> Result<Vec<(NodeId, f64)>> {
306 use grafeo_core::index::text::fuse_results;
307
308 let fusion_method = fusion.unwrap_or_default();
309 let mut sources: Vec<Vec<(NodeId, f64)>> = Vec::new();
310
311 if let Some(text_index) = self.store.get_text_index(label, text_property) {
313 let text_results = text_index.read().search(query_text, k * 2);
314 if !text_results.is_empty() {
315 sources.push(text_results);
316 }
317 }
318
319 if let Some(query_vec) = query_vector
321 && let Some(vector_index) = self.store.get_vector_index(label, vector_property)
322 {
323 let accessor = grafeo_core::index::vector::PropertyVectorAccessor::new(
324 &self.store,
325 vector_property,
326 );
327 let vector_results = vector_index.search(query_vec, k * 2, &accessor);
328 if !vector_results.is_empty() {
329 sources.push(
330 vector_results
331 .into_iter()
332 .map(|(id, dist)| (id, f64::from(dist)))
333 .collect(),
334 );
335 }
336 }
337
338 if sources.is_empty() {
339 return Ok(Vec::new());
340 }
341
342 Ok(fuse_results(&sources, &fusion_method, k))
343 }
344}