Skip to main content

grafeo_engine/database/
search.rs

1//! Vector, text, and hybrid search operations for GrafeoDB.
2
3#[cfg(any(feature = "text-index", feature = "hybrid-search"))]
4use grafeo_common::types::NodeId;
5#[cfg(feature = "vector-index")]
6use grafeo_common::types::Value;
7#[cfg(any(feature = "text-index", feature = "hybrid-search"))]
8use grafeo_common::utils::error::Error;
9#[cfg(any(
10    feature = "vector-index",
11    feature = "text-index",
12    feature = "hybrid-search"
13))]
14use grafeo_common::utils::error::Result;
15
16impl super::GrafeoDB {
17    /// Computes a node allowlist from property filters.
18    ///
19    /// Supports equality filters (scalar values) and operator filters (Map values
20    /// with `$`-prefixed keys like `$gt`, `$lt`, `$in`, `$contains`).
21    ///
22    /// Returns `None` if filters is `None` or empty (meaning no filtering),
23    /// or `Some(set)` with the intersection (possibly empty).
24    #[cfg(feature = "vector-index")]
25    fn compute_filter_allowlist(
26        &self,
27        label: &str,
28        filters: Option<&std::collections::HashMap<String, Value>>,
29    ) -> Option<std::collections::HashSet<NodeId>> {
30        let filters = filters.filter(|f| !f.is_empty())?;
31
32        // Start with all nodes for this label
33        let label_nodes: std::collections::HashSet<NodeId> =
34            self.store.nodes_by_label(label).into_iter().collect();
35
36        let mut allowlist = label_nodes;
37
38        for (key, filter_value) in filters {
39            // Check if this is an operator filter (Map with $-prefixed keys)
40            let is_operator_filter = matches!(filter_value, Value::Map(ops) if ops.keys().any(|k| k.as_str().starts_with('$')));
41
42            let matching: std::collections::HashSet<NodeId> = if is_operator_filter {
43                // Operator filter: must scan nodes and check each
44                self.store
45                    .find_nodes_matching_filter(key, filter_value)
46                    .into_iter()
47                    .collect()
48            } else {
49                // Equality filter: use indexed lookup when available
50                self.store
51                    .find_nodes_by_property(key, filter_value)
52                    .into_iter()
53                    .collect()
54            };
55            allowlist = allowlist.intersection(&matching).copied().collect();
56
57            // Short-circuit: empty intersection means no results possible
58            if allowlist.is_empty() {
59                return Some(allowlist);
60            }
61        }
62
63        Some(allowlist)
64    }
65
66    /// Searches for the k nearest neighbors of a query vector.
67    ///
68    /// Uses the HNSW index created by [`create_vector_index`](Self::create_vector_index).
69    ///
70    /// # Arguments
71    ///
72    /// * `label` - Node label that was indexed
73    /// * `property` - Property that was indexed
74    /// * `query` - Query vector (slice of floats)
75    /// * `k` - Number of nearest neighbors to return
76    /// * `ef` - Search beam width (higher = better recall, slower). Uses index default if `None`.
77    /// * `filters` - Optional property equality filters. Only nodes matching all
78    ///   `(key, value)` pairs will appear in results.
79    ///
80    /// # Returns
81    ///
82    /// Vector of `(NodeId, distance)` pairs sorted by distance ascending.
83    #[cfg(feature = "vector-index")]
84    pub fn vector_search(
85        &self,
86        label: &str,
87        property: &str,
88        query: &[f32],
89        k: usize,
90        ef: Option<usize>,
91        filters: Option<&std::collections::HashMap<String, Value>>,
92    ) -> Result<Vec<(grafeo_common::types::NodeId, f32)>> {
93        let index = self.store.get_vector_index(label, property).ok_or_else(|| {
94            grafeo_common::utils::error::Error::Internal(format!(
95                "No vector index found for :{label}({property}). Call create_vector_index() first."
96            ))
97        })?;
98
99        let accessor =
100            grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
101
102        let results = match self.compute_filter_allowlist(label, filters) {
103            Some(allowlist) => match ef {
104                Some(ef_val) => {
105                    index.search_with_ef_and_filter(query, k, ef_val, &allowlist, &accessor)
106                }
107                None => index.search_with_filter(query, k, &allowlist, &accessor),
108            },
109            None => match ef {
110                Some(ef_val) => index.search_with_ef(query, k, ef_val, &accessor),
111                None => index.search(query, k, &accessor),
112            },
113        };
114
115        Ok(results)
116    }
117
118    /// Searches for nearest neighbors for multiple query vectors in parallel.
119    ///
120    /// Uses rayon parallel iteration under the hood for multi-core throughput.
121    ///
122    /// # Arguments
123    ///
124    /// * `label` - Node label that was indexed
125    /// * `property` - Property that was indexed
126    /// * `queries` - Batch of query vectors
127    /// * `k` - Number of nearest neighbors per query
128    /// * `ef` - Search beam width (uses index default if `None`)
129    /// * `filters` - Optional property equality filters
130    #[cfg(feature = "vector-index")]
131    pub fn batch_vector_search(
132        &self,
133        label: &str,
134        property: &str,
135        queries: &[Vec<f32>],
136        k: usize,
137        ef: Option<usize>,
138        filters: Option<&std::collections::HashMap<String, Value>>,
139    ) -> Result<Vec<Vec<(grafeo_common::types::NodeId, f32)>>> {
140        let index = self.store.get_vector_index(label, property).ok_or_else(|| {
141            grafeo_common::utils::error::Error::Internal(format!(
142                "No vector index found for :{label}({property}). Call create_vector_index() first."
143            ))
144        })?;
145
146        let accessor =
147            grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
148
149        let results = match self.compute_filter_allowlist(label, filters) {
150            Some(allowlist) => match ef {
151                Some(ef_val) => {
152                    index.batch_search_with_ef_and_filter(queries, k, ef_val, &allowlist, &accessor)
153                }
154                None => index.batch_search_with_filter(queries, k, &allowlist, &accessor),
155            },
156            None => match ef {
157                Some(ef_val) => index.batch_search_with_ef(queries, k, ef_val, &accessor),
158                None => index.batch_search(queries, k, &accessor),
159            },
160        };
161
162        Ok(results)
163    }
164
165    /// Searches for diverse nearest neighbors using Maximal Marginal Relevance (MMR).
166    ///
167    /// MMR balances relevance (similarity to query) with diversity (dissimilarity
168    /// among selected results). This is the algorithm used by LangChain's
169    /// `mmr_traversal_search()` for RAG applications.
170    ///
171    /// # Arguments
172    ///
173    /// * `label` - Node label that was indexed
174    /// * `property` - Property that was indexed
175    /// * `query` - Query vector
176    /// * `k` - Number of diverse results to return
177    /// * `fetch_k` - Number of initial candidates from HNSW (default: `4 * k`)
178    /// * `lambda` - Relevance vs. diversity in \[0, 1\] (default: 0.5).
179    ///   1.0 = pure relevance, 0.0 = pure diversity.
180    /// * `ef` - HNSW search beam width (uses index default if `None`)
181    /// * `filters` - Optional property equality filters
182    ///
183    /// # Returns
184    ///
185    /// `(NodeId, distance)` pairs in MMR selection order. The f32 is the original
186    /// distance from the query, matching [`vector_search`](Self::vector_search).
187    #[cfg(feature = "vector-index")]
188    #[allow(clippy::too_many_arguments)]
189    pub fn mmr_search(
190        &self,
191        label: &str,
192        property: &str,
193        query: &[f32],
194        k: usize,
195        fetch_k: Option<usize>,
196        lambda: Option<f32>,
197        ef: Option<usize>,
198        filters: Option<&std::collections::HashMap<String, Value>>,
199    ) -> Result<Vec<(grafeo_common::types::NodeId, f32)>> {
200        use grafeo_core::index::vector::mmr_select;
201
202        let index = self.store.get_vector_index(label, property).ok_or_else(|| {
203            grafeo_common::utils::error::Error::Internal(format!(
204                "No vector index found for :{label}({property}). Call create_vector_index() first."
205            ))
206        })?;
207
208        let accessor =
209            grafeo_core::index::vector::PropertyVectorAccessor::new(&self.store, property);
210
211        let fetch_k = fetch_k.unwrap_or(k.saturating_mul(4).max(k));
212        let lambda = lambda.unwrap_or(0.5);
213
214        // Step 1: Fetch candidates from HNSW (with optional filter)
215        let initial_results = match self.compute_filter_allowlist(label, filters) {
216            Some(allowlist) => match ef {
217                Some(ef_val) => {
218                    index.search_with_ef_and_filter(query, fetch_k, ef_val, &allowlist, &accessor)
219                }
220                None => index.search_with_filter(query, fetch_k, &allowlist, &accessor),
221            },
222            None => match ef {
223                Some(ef_val) => index.search_with_ef(query, fetch_k, ef_val, &accessor),
224                None => index.search(query, fetch_k, &accessor),
225            },
226        };
227
228        if initial_results.is_empty() {
229            return Ok(Vec::new());
230        }
231
232        // Step 2: Retrieve stored vectors for MMR pairwise comparison
233        use grafeo_core::index::vector::VectorAccessor;
234        let candidates: Vec<(grafeo_common::types::NodeId, f32, std::sync::Arc<[f32]>)> =
235            initial_results
236                .into_iter()
237                .filter_map(|(id, dist)| accessor.get_vector(id).map(|vec| (id, dist, vec)))
238                .collect();
239
240        // Step 3: Build slice-based candidates for mmr_select
241        let candidate_refs: Vec<(grafeo_common::types::NodeId, f32, &[f32])> = candidates
242            .iter()
243            .map(|(id, dist, vec)| (*id, *dist, vec.as_ref()))
244            .collect();
245
246        // Step 4: Run MMR selection
247        let metric = index.config().metric;
248        Ok(mmr_select(query, &candidate_refs, k, lambda, metric))
249    }
250
251    /// Searches a text index using BM25 scoring.
252    ///
253    /// Returns up to `k` results as `(NodeId, score)` pairs sorted by
254    /// descending relevance score.
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if no text index exists for this label+property.
259    #[cfg(feature = "text-index")]
260    pub fn text_search(
261        &self,
262        label: &str,
263        property: &str,
264        query: &str,
265        k: usize,
266    ) -> Result<Vec<(NodeId, f64)>> {
267        let index = self.store.get_text_index(label, property).ok_or_else(|| {
268            Error::Internal(format!(
269                "No text index found for :{label}({property}). Call create_text_index() first."
270            ))
271        })?;
272
273        Ok(index.read().search(query, k))
274    }
275
276    /// Performs hybrid search combining text (BM25) and vector similarity.
277    ///
278    /// Runs both text search and vector search, then fuses results using
279    /// the specified method (default: Reciprocal Rank Fusion).
280    ///
281    /// # Arguments
282    ///
283    /// * `label` - Node label to search within
284    /// * `text_property` - Property indexed for text search
285    /// * `vector_property` - Property indexed for vector search
286    /// * `query_text` - Text query for BM25 search
287    /// * `query_vector` - Vector query for similarity search (optional)
288    /// * `k` - Number of results to return
289    /// * `fusion` - Score fusion method (default: RRF with k=60)
290    ///
291    /// # Errors
292    ///
293    /// Returns an error if the required indexes don't exist.
294    #[cfg(feature = "hybrid-search")]
295    #[allow(clippy::too_many_arguments)]
296    pub fn hybrid_search(
297        &self,
298        label: &str,
299        text_property: &str,
300        vector_property: &str,
301        query_text: &str,
302        query_vector: Option<&[f32]>,
303        k: usize,
304        fusion: Option<grafeo_core::index::text::FusionMethod>,
305    ) -> Result<Vec<(NodeId, f64)>> {
306        use grafeo_core::index::text::fuse_results;
307
308        let fusion_method = fusion.unwrap_or_default();
309        let mut sources: Vec<Vec<(NodeId, f64)>> = Vec::new();
310
311        // Text search
312        if let Some(text_index) = self.store.get_text_index(label, text_property) {
313            let text_results = text_index.read().search(query_text, k * 2);
314            if !text_results.is_empty() {
315                sources.push(text_results);
316            }
317        }
318
319        // Vector search (if query vector provided)
320        if let Some(query_vec) = query_vector
321            && let Some(vector_index) = self.store.get_vector_index(label, vector_property)
322        {
323            let accessor = grafeo_core::index::vector::PropertyVectorAccessor::new(
324                &self.store,
325                vector_property,
326            );
327            let vector_results = vector_index.search(query_vec, k * 2, &accessor);
328            if !vector_results.is_empty() {
329                sources.push(
330                    vector_results
331                        .into_iter()
332                        .map(|(id, dist)| (id, f64::from(dist)))
333                        .collect(),
334                );
335            }
336        }
337
338        if sources.is_empty() {
339            return Ok(Vec::new());
340        }
341
342        Ok(fuse_results(&sources, &fusion_method, k))
343    }
344}