Skip to main content

agentic_memory/engine/
text_search.rs

1//! BM25 text search and hybrid search (queries 8-9).
2
3use std::collections::HashMap;
4
5use crate::engine::tokenizer::Tokenizer;
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::index::{DocLengths, TermIndex};
9use crate::types::{AmemResult, EventType};
10
11const BM25_K1: f32 = 1.2;
12const BM25_B: f32 = 0.75;
13
14/// Parameters for BM25 text search.
15pub struct TextSearchParams {
16    /// The search query string (will be tokenized).
17    pub query: String,
18    /// Maximum number of results.
19    pub max_results: usize,
20    /// Filter by event type(s). Empty = all types.
21    pub event_types: Vec<EventType>,
22    /// Filter by session ID(s). Empty = all sessions.
23    pub session_ids: Vec<u32>,
24    /// Minimum BM25 score to include (default: 0.0).
25    pub min_score: f32,
26}
27
28/// A single BM25 text search match.
29pub struct TextMatch {
30    pub node_id: u64,
31    pub score: f32,
32    /// Which query terms matched in this node's content.
33    pub matched_terms: Vec<String>,
34}
35
36/// Parameters for hybrid BM25 + vector search.
37pub struct HybridSearchParams {
38    /// Text query for BM25 component.
39    pub query_text: String,
40    /// Feature vector for similarity component. If None, runs BM25-only.
41    pub query_vec: Option<Vec<f32>>,
42    /// Maximum number of final results.
43    pub max_results: usize,
44    /// Filter by event type(s). Empty = all types.
45    pub event_types: Vec<EventType>,
46    /// Weight for BM25 component (0.0 to 1.0, default: 0.5).
47    pub text_weight: f32,
48    /// Weight for vector component (0.0 to 1.0, default: 0.5).
49    pub vector_weight: f32,
50    /// RRF constant k (default: 60).
51    pub rrf_k: u32,
52}
53
54/// A single hybrid search match.
55pub struct HybridMatch {
56    pub node_id: u64,
57    /// Combined RRF score.
58    pub combined_score: f32,
59    /// BM25 rank (1-based, 0 if not in BM25 results).
60    pub text_rank: u32,
61    /// Similarity rank (1-based, 0 if not in similarity results).
62    pub vector_rank: u32,
63    /// Raw BM25 score.
64    pub text_score: f32,
65    /// Raw cosine similarity.
66    pub vector_similarity: f32,
67}
68
69impl super::query::QueryEngine {
70    /// BM25 text search over node contents.
71    /// Uses TermIndex if available, falls back to full scan.
72    pub fn text_search(
73        &self,
74        graph: &MemoryGraph,
75        term_index: Option<&TermIndex>,
76        doc_lengths: Option<&DocLengths>,
77        params: TextSearchParams,
78    ) -> AmemResult<Vec<TextMatch>> {
79        let tokenizer = Tokenizer::new();
80        let query_terms = tokenizer.tokenize(&params.query);
81
82        if query_terms.is_empty() {
83            return Ok(Vec::new());
84        }
85
86        // Build type/session filter sets
87        let type_filter: std::collections::HashSet<EventType> =
88            params.event_types.iter().copied().collect();
89        let session_filter: std::collections::HashSet<u32> =
90            params.session_ids.iter().copied().collect();
91
92        let matches = if let (Some(ti), Some(dl)) = (term_index, doc_lengths) {
93            // Fast path: use pre-built indexes
94            self.bm25_fast_path(graph, ti, dl, &query_terms, &type_filter, &session_filter)
95        } else {
96            // Slow path: full scan
97            self.bm25_slow_path(
98                graph,
99                &tokenizer,
100                &query_terms,
101                &type_filter,
102                &session_filter,
103            )
104        };
105
106        let mut results: Vec<TextMatch> = matches
107            .into_iter()
108            .filter(|m| m.score >= params.min_score)
109            .collect();
110
111        results.sort_by(|a, b| {
112            b.score
113                .partial_cmp(&a.score)
114                .unwrap_or(std::cmp::Ordering::Equal)
115        });
116        results.truncate(params.max_results);
117
118        Ok(results)
119    }
120
121    fn bm25_fast_path(
122        &self,
123        graph: &MemoryGraph,
124        term_index: &TermIndex,
125        doc_lengths: &DocLengths,
126        query_terms: &[String],
127        type_filter: &std::collections::HashSet<EventType>,
128        session_filter: &std::collections::HashSet<u32>,
129    ) -> Vec<TextMatch> {
130        let n = term_index.doc_count() as f32;
131        let avgdl = term_index.avg_doc_length();
132
133        // Collect all candidate node IDs from posting lists
134        let mut scores: HashMap<u64, (f32, Vec<String>)> = HashMap::new();
135
136        for term in query_terms {
137            let postings = term_index.get(term);
138            let df = postings.len() as f32;
139            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
140
141            for &(node_id, tf) in postings {
142                // Apply filters
143                if let Some(node) = graph.get_node(node_id) {
144                    if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
145                        continue;
146                    }
147                    if !session_filter.is_empty() && !session_filter.contains(&node.session_id) {
148                        continue;
149                    }
150                }
151
152                let dl = doc_lengths.get(node_id) as f32;
153                let tf_f = tf as f32;
154                let bm25_term = idf * (tf_f * (BM25_K1 + 1.0))
155                    / (tf_f + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avgdl.max(1.0)));
156
157                let entry = scores.entry(node_id).or_insert((0.0, Vec::new()));
158                entry.0 += bm25_term;
159                if !entry.1.contains(term) {
160                    entry.1.push(term.clone());
161                }
162            }
163        }
164
165        scores
166            .into_iter()
167            .map(|(node_id, (score, matched_terms))| TextMatch {
168                node_id,
169                score,
170                matched_terms,
171            })
172            .collect()
173    }
174
175    fn bm25_slow_path(
176        &self,
177        graph: &MemoryGraph,
178        tokenizer: &Tokenizer,
179        query_terms: &[String],
180        type_filter: &std::collections::HashSet<EventType>,
181        session_filter: &std::collections::HashSet<u32>,
182    ) -> Vec<TextMatch> {
183        let nodes = graph.nodes();
184        if nodes.is_empty() {
185            return Vec::new();
186        }
187
188        // Build temporary term data
189        let n = nodes.len() as f32;
190        let mut doc_freqs: HashMap<String, usize> = HashMap::new();
191        let mut node_data: Vec<(u64, HashMap<String, u32>, u32)> = Vec::new();
192        let mut total_tokens: u64 = 0;
193
194        for node in nodes {
195            if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
196                continue;
197            }
198            if !session_filter.is_empty() && !session_filter.contains(&node.session_id) {
199                continue;
200            }
201
202            let freqs = tokenizer.term_frequencies(&node.content);
203            let doc_len: u32 = freqs.values().sum();
204            total_tokens += doc_len as u64;
205
206            for term in freqs.keys() {
207                *doc_freqs.entry(term.clone()).or_insert(0) += 1;
208            }
209
210            node_data.push((node.id, freqs, doc_len));
211        }
212
213        let avgdl = if node_data.is_empty() {
214            0.0
215        } else {
216            total_tokens as f32 / node_data.len() as f32
217        };
218
219        let mut results = Vec::new();
220
221        for (node_id, freqs, doc_len) in &node_data {
222            let mut score = 0.0f32;
223            let mut matched = Vec::new();
224
225            for term in query_terms {
226                if let Some(&tf) = freqs.get(term) {
227                    let df = *doc_freqs.get(term).unwrap_or(&0) as f32;
228                    let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
229                    let tf_f = tf as f32;
230                    let dl = *doc_len as f32;
231                    let bm25_term = idf * (tf_f * (BM25_K1 + 1.0))
232                        / (tf_f + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avgdl.max(1.0)));
233                    score += bm25_term;
234                    if !matched.contains(term) {
235                        matched.push(term.clone());
236                    }
237                }
238            }
239
240            if score > 0.0 {
241                results.push(TextMatch {
242                    node_id: *node_id,
243                    score,
244                    matched_terms: matched,
245                });
246            }
247        }
248
249        results
250    }
251
252    /// Hybrid BM25 + vector search with Reciprocal Rank Fusion.
253    pub fn hybrid_search(
254        &self,
255        graph: &MemoryGraph,
256        term_index: Option<&TermIndex>,
257        doc_lengths: Option<&DocLengths>,
258        params: HybridSearchParams,
259    ) -> AmemResult<Vec<HybridMatch>> {
260        let overfetch = params.max_results * 3;
261
262        // Normalize weights
263        let total_weight = params.text_weight + params.vector_weight;
264        let (tw, vw) = if total_weight > 0.0 {
265            (
266                params.text_weight / total_weight,
267                params.vector_weight / total_weight,
268            )
269        } else {
270            (0.5, 0.5)
271        };
272
273        // Run BM25 search
274        let bm25_results = self.text_search(
275            graph,
276            term_index,
277            doc_lengths,
278            TextSearchParams {
279                query: params.query_text.clone(),
280                max_results: overfetch,
281                event_types: params.event_types.clone(),
282                session_ids: Vec::new(),
283                min_score: 0.0,
284            },
285        )?;
286
287        // Build BM25 rank map
288        let mut bm25_map: HashMap<u64, (u32, f32)> = HashMap::new();
289        for (rank, m) in bm25_results.iter().enumerate() {
290            bm25_map.insert(m.node_id, ((rank + 1) as u32, m.score));
291        }
292
293        // Run vector search if available
294        let mut vec_map: HashMap<u64, (u32, f32)> = HashMap::new();
295        let has_vectors = params.query_vec.is_some()
296            && graph
297                .nodes()
298                .iter()
299                .any(|n| n.feature_vec.iter().any(|&x| x != 0.0));
300
301        if has_vectors {
302            if let Some(ref qvec) = params.query_vec {
303                let type_filter: std::collections::HashSet<EventType> =
304                    params.event_types.iter().copied().collect();
305                let mut sim_results: Vec<(u64, f32)> = Vec::new();
306
307                for node in graph.nodes() {
308                    if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
309                        continue;
310                    }
311                    if node.feature_vec.iter().all(|&x| x == 0.0) {
312                        continue;
313                    }
314                    let sim = cosine_similarity(qvec, &node.feature_vec);
315                    if sim > 0.0 {
316                        sim_results.push((node.id, sim));
317                    }
318                }
319
320                sim_results
321                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
322                sim_results.truncate(overfetch);
323
324                for (rank, (node_id, sim)) in sim_results.iter().enumerate() {
325                    vec_map.insert(*node_id, ((rank + 1) as u32, *sim));
326                }
327            }
328        }
329
330        // Combine all candidate node IDs
331        let mut all_ids: std::collections::HashSet<u64> = std::collections::HashSet::new();
332        all_ids.extend(bm25_map.keys());
333        all_ids.extend(vec_map.keys());
334
335        let max_bm25_rank = (bm25_results.len() + 1) as u32;
336        let max_vec_rank = (vec_map.len() + 1) as u32;
337        let rrf_k = params.rrf_k as f32;
338
339        let mut hybrid_results: Vec<HybridMatch> = all_ids
340            .into_iter()
341            .map(|node_id| {
342                let (text_rank, text_score) = bm25_map
343                    .get(&node_id)
344                    .copied()
345                    .unwrap_or((max_bm25_rank, 0.0));
346                let (vector_rank, vector_similarity) = vec_map
347                    .get(&node_id)
348                    .copied()
349                    .unwrap_or((max_vec_rank, 0.0));
350
351                let rrf_text = tw / (rrf_k + text_rank as f32);
352                let rrf_vec = if has_vectors {
353                    vw / (rrf_k + vector_rank as f32)
354                } else {
355                    0.0
356                };
357                let combined_score = rrf_text + rrf_vec;
358
359                HybridMatch {
360                    node_id,
361                    combined_score,
362                    text_rank,
363                    vector_rank,
364                    text_score,
365                    vector_similarity,
366                }
367            })
368            .collect();
369
370        hybrid_results.sort_by(|a, b| {
371            b.combined_score
372                .partial_cmp(&a.combined_score)
373                .unwrap_or(std::cmp::Ordering::Equal)
374        });
375        hybrid_results.truncate(params.max_results);
376
377        Ok(hybrid_results)
378    }
379}