1use std::collections::HashMap;
4
5use crate::engine::tokenizer::Tokenizer;
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::index::{DocLengths, TermIndex};
9use crate::types::{AmemResult, EventType};
10
11const BM25_K1: f32 = 1.2;
12const BM25_B: f32 = 0.75;
13
14pub struct TextSearchParams {
16 pub query: String,
18 pub max_results: usize,
20 pub event_types: Vec<EventType>,
22 pub session_ids: Vec<u32>,
24 pub min_score: f32,
26}
27
28pub struct TextMatch {
30 pub node_id: u64,
31 pub score: f32,
32 pub matched_terms: Vec<String>,
34}
35
36pub struct HybridSearchParams {
38 pub query_text: String,
40 pub query_vec: Option<Vec<f32>>,
42 pub max_results: usize,
44 pub event_types: Vec<EventType>,
46 pub text_weight: f32,
48 pub vector_weight: f32,
50 pub rrf_k: u32,
52}
53
54pub struct HybridMatch {
56 pub node_id: u64,
57 pub combined_score: f32,
59 pub text_rank: u32,
61 pub vector_rank: u32,
63 pub text_score: f32,
65 pub vector_similarity: f32,
67}
68
69impl super::query::QueryEngine {
70 pub fn text_search(
73 &self,
74 graph: &MemoryGraph,
75 term_index: Option<&TermIndex>,
76 doc_lengths: Option<&DocLengths>,
77 params: TextSearchParams,
78 ) -> AmemResult<Vec<TextMatch>> {
79 let tokenizer = Tokenizer::new();
80 let query_terms = tokenizer.tokenize(¶ms.query);
81
82 if query_terms.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let type_filter: std::collections::HashSet<EventType> =
88 params.event_types.iter().copied().collect();
89 let session_filter: std::collections::HashSet<u32> =
90 params.session_ids.iter().copied().collect();
91
92 let matches = if let (Some(ti), Some(dl)) = (term_index, doc_lengths) {
93 self.bm25_fast_path(graph, ti, dl, &query_terms, &type_filter, &session_filter)
95 } else {
96 self.bm25_slow_path(
98 graph,
99 &tokenizer,
100 &query_terms,
101 &type_filter,
102 &session_filter,
103 )
104 };
105
106 let mut results: Vec<TextMatch> = matches
107 .into_iter()
108 .filter(|m| m.score >= params.min_score)
109 .collect();
110
111 results.sort_by(|a, b| {
112 b.score
113 .partial_cmp(&a.score)
114 .unwrap_or(std::cmp::Ordering::Equal)
115 });
116 results.truncate(params.max_results);
117
118 Ok(results)
119 }
120
121 fn bm25_fast_path(
122 &self,
123 graph: &MemoryGraph,
124 term_index: &TermIndex,
125 doc_lengths: &DocLengths,
126 query_terms: &[String],
127 type_filter: &std::collections::HashSet<EventType>,
128 session_filter: &std::collections::HashSet<u32>,
129 ) -> Vec<TextMatch> {
130 let n = term_index.doc_count() as f32;
131 let avgdl = term_index.avg_doc_length();
132
133 let mut scores: HashMap<u64, (f32, Vec<String>)> = HashMap::new();
135
136 for term in query_terms {
137 let postings = term_index.get(term);
138 let df = postings.len() as f32;
139 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
140
141 for &(node_id, tf) in postings {
142 if let Some(node) = graph.get_node(node_id) {
144 if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
145 continue;
146 }
147 if !session_filter.is_empty() && !session_filter.contains(&node.session_id) {
148 continue;
149 }
150 }
151
152 let dl = doc_lengths.get(node_id) as f32;
153 let tf_f = tf as f32;
154 let bm25_term = idf * (tf_f * (BM25_K1 + 1.0))
155 / (tf_f + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avgdl.max(1.0)));
156
157 let entry = scores.entry(node_id).or_insert((0.0, Vec::new()));
158 entry.0 += bm25_term;
159 if !entry.1.contains(term) {
160 entry.1.push(term.clone());
161 }
162 }
163 }
164
165 scores
166 .into_iter()
167 .map(|(node_id, (score, matched_terms))| TextMatch {
168 node_id,
169 score,
170 matched_terms,
171 })
172 .collect()
173 }
174
175 fn bm25_slow_path(
176 &self,
177 graph: &MemoryGraph,
178 tokenizer: &Tokenizer,
179 query_terms: &[String],
180 type_filter: &std::collections::HashSet<EventType>,
181 session_filter: &std::collections::HashSet<u32>,
182 ) -> Vec<TextMatch> {
183 let nodes = graph.nodes();
184 if nodes.is_empty() {
185 return Vec::new();
186 }
187
188 let n = nodes.len() as f32;
190 let mut doc_freqs: HashMap<String, usize> = HashMap::new();
191 let mut node_data: Vec<(u64, HashMap<String, u32>, u32)> = Vec::new();
192 let mut total_tokens: u64 = 0;
193
194 for node in nodes {
195 if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
196 continue;
197 }
198 if !session_filter.is_empty() && !session_filter.contains(&node.session_id) {
199 continue;
200 }
201
202 let freqs = tokenizer.term_frequencies(&node.content);
203 let doc_len: u32 = freqs.values().sum();
204 total_tokens += doc_len as u64;
205
206 for term in freqs.keys() {
207 *doc_freqs.entry(term.clone()).or_insert(0) += 1;
208 }
209
210 node_data.push((node.id, freqs, doc_len));
211 }
212
213 let avgdl = if node_data.is_empty() {
214 0.0
215 } else {
216 total_tokens as f32 / node_data.len() as f32
217 };
218
219 let mut results = Vec::new();
220
221 for (node_id, freqs, doc_len) in &node_data {
222 let mut score = 0.0f32;
223 let mut matched = Vec::new();
224
225 for term in query_terms {
226 if let Some(&tf) = freqs.get(term) {
227 let df = *doc_freqs.get(term).unwrap_or(&0) as f32;
228 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
229 let tf_f = tf as f32;
230 let dl = *doc_len as f32;
231 let bm25_term = idf * (tf_f * (BM25_K1 + 1.0))
232 / (tf_f + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avgdl.max(1.0)));
233 score += bm25_term;
234 if !matched.contains(term) {
235 matched.push(term.clone());
236 }
237 }
238 }
239
240 if score > 0.0 {
241 results.push(TextMatch {
242 node_id: *node_id,
243 score,
244 matched_terms: matched,
245 });
246 }
247 }
248
249 results
250 }
251
252 pub fn hybrid_search(
254 &self,
255 graph: &MemoryGraph,
256 term_index: Option<&TermIndex>,
257 doc_lengths: Option<&DocLengths>,
258 params: HybridSearchParams,
259 ) -> AmemResult<Vec<HybridMatch>> {
260 let overfetch = params.max_results * 3;
261
262 let total_weight = params.text_weight + params.vector_weight;
264 let (tw, vw) = if total_weight > 0.0 {
265 (
266 params.text_weight / total_weight,
267 params.vector_weight / total_weight,
268 )
269 } else {
270 (0.5, 0.5)
271 };
272
273 let bm25_results = self.text_search(
275 graph,
276 term_index,
277 doc_lengths,
278 TextSearchParams {
279 query: params.query_text.clone(),
280 max_results: overfetch,
281 event_types: params.event_types.clone(),
282 session_ids: Vec::new(),
283 min_score: 0.0,
284 },
285 )?;
286
287 let mut bm25_map: HashMap<u64, (u32, f32)> = HashMap::new();
289 for (rank, m) in bm25_results.iter().enumerate() {
290 bm25_map.insert(m.node_id, ((rank + 1) as u32, m.score));
291 }
292
293 let mut vec_map: HashMap<u64, (u32, f32)> = HashMap::new();
295 let has_vectors = params.query_vec.is_some()
296 && graph
297 .nodes()
298 .iter()
299 .any(|n| n.feature_vec.iter().any(|&x| x != 0.0));
300
301 if has_vectors {
302 if let Some(ref qvec) = params.query_vec {
303 let type_filter: std::collections::HashSet<EventType> =
304 params.event_types.iter().copied().collect();
305 let mut sim_results: Vec<(u64, f32)> = Vec::new();
306
307 for node in graph.nodes() {
308 if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
309 continue;
310 }
311 if node.feature_vec.iter().all(|&x| x == 0.0) {
312 continue;
313 }
314 let sim = cosine_similarity(qvec, &node.feature_vec);
315 if sim > 0.0 {
316 sim_results.push((node.id, sim));
317 }
318 }
319
320 sim_results
321 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
322 sim_results.truncate(overfetch);
323
324 for (rank, (node_id, sim)) in sim_results.iter().enumerate() {
325 vec_map.insert(*node_id, ((rank + 1) as u32, *sim));
326 }
327 }
328 }
329
330 let mut all_ids: std::collections::HashSet<u64> = std::collections::HashSet::new();
332 all_ids.extend(bm25_map.keys());
333 all_ids.extend(vec_map.keys());
334
335 let max_bm25_rank = (bm25_results.len() + 1) as u32;
336 let max_vec_rank = (vec_map.len() + 1) as u32;
337 let rrf_k = params.rrf_k as f32;
338
339 let mut hybrid_results: Vec<HybridMatch> = all_ids
340 .into_iter()
341 .map(|node_id| {
342 let (text_rank, text_score) = bm25_map
343 .get(&node_id)
344 .copied()
345 .unwrap_or((max_bm25_rank, 0.0));
346 let (vector_rank, vector_similarity) = vec_map
347 .get(&node_id)
348 .copied()
349 .unwrap_or((max_vec_rank, 0.0));
350
351 let rrf_text = tw / (rrf_k + text_rank as f32);
352 let rrf_vec = if has_vectors {
353 vw / (rrf_k + vector_rank as f32)
354 } else {
355 0.0
356 };
357 let combined_score = rrf_text + rrf_vec;
358
359 HybridMatch {
360 node_id,
361 combined_score,
362 text_rank,
363 vector_rank,
364 text_score,
365 vector_similarity,
366 }
367 })
368 .collect();
369
370 hybrid_results.sort_by(|a, b| {
371 b.combined_score
372 .partial_cmp(&a.combined_score)
373 .unwrap_or(std::cmp::Ordering::Equal)
374 });
375 hybrid_results.truncate(params.max_results);
376
377 Ok(hybrid_results)
378 }
379}