1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::Result;
5use rayon::prelude::*;
6use regex::Regex;
7
8use crate::embed::{self, EmbedProvider};
9
10#[derive(Debug, Clone)]
12pub struct SearchResult {
13 pub symbol_id: String,
14 pub name: String,
15 pub kind: String,
16 pub file: String,
17 pub score: f32,
18 pub bm25_score: f32,
19 pub vector_score: f32,
20 pub docstring: Option<String>,
21}
22
23const K1: f32 = 1.2;
25const B: f32 = 0.75;
26
27pub struct BM25Index {
29 docs: Vec<(String, String)>,
31 inverted: HashMap<String, Vec<(usize, f32)>>,
33 avg_doc_len: f32,
34}
35
36impl BM25Index {
37 pub fn build(docs: Vec<(String, String)>) -> Self {
39 let n = docs.len();
40 let mut inverted: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
41 let mut total_len = 0usize;
42
43 for (i, (_id, text)) in docs.iter().enumerate() {
44 let tokens = tokenize(text);
45 total_len += tokens.len();
46
47 let mut tf_map: HashMap<&str, f32> = HashMap::new();
48 for t in &tokens {
49 *tf_map.entry(t.as_str()).or_default() += 1.0;
50 }
51
52 for (term, tf) in tf_map {
53 inverted.entry(term.to_string()).or_default().push((i, tf));
54 }
55 }
56
57 let avg_doc_len = if n > 0 {
58 total_len as f32 / n as f32
59 } else {
60 1.0
61 };
62
63 Self {
64 docs,
65 inverted,
66 avg_doc_len,
67 }
68 }
69
70 pub fn search(&self, query: &str, limit: usize) -> Vec<(usize, f32)> {
72 let query_tokens = tokenize(query);
73 let n = self.docs.len() as f32;
74 let mut scores = vec![0.0f32; self.docs.len()];
75
76 for token in &query_tokens {
77 if let Some(postings) = self.inverted.get(token.as_str()) {
78 let df = postings.len() as f32;
79 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
80
81 for &(doc_idx, tf) in postings {
82 let doc_len = tokenize(&self.docs[doc_idx].1).len() as f32;
83 let tf_norm =
84 (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * doc_len / self.avg_doc_len));
85 scores[doc_idx] += idf * tf_norm;
86 }
87 }
88 }
89
90 let mut results: Vec<(usize, f32)> = scores
91 .into_iter()
92 .enumerate()
93 .filter(|(_, s)| *s > 0.0)
94 .collect();
95 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
96 results.truncate(limit);
97 results
98 }
99
100 pub fn doc_id(&self, idx: usize) -> &str {
101 &self.docs[idx].0
102 }
103
104 pub fn doc_text(&self, idx: usize) -> &str {
105 &self.docs[idx].1
106 }
107}
108
109pub struct RawScores {
111 pub bm25: HashMap<String, f32>,
113 pub vector: HashMap<String, f32>,
115}
116
117pub fn compute_raw_scores(
124 query: &str,
125 bm25_index: &BM25Index,
126 embedder: &dyn EmbedProvider,
127 symbol_embeddings: &[(String, Vec<f32>)],
128 oversample: usize,
129 hnsw_index_path: Option<&Path>,
130 embeddings_path: Option<&Path>,
131) -> Result<RawScores> {
132 let bm25_results = bm25_index.search(query, oversample);
133 let bm25_max = bm25_results
134 .first()
135 .map(|(_, s)| *s)
136 .unwrap_or(1.0)
137 .max(0.001);
138
139 let mut bm25_map: HashMap<String, f32> = HashMap::new();
140 for (idx, score) in &bm25_results {
141 let id = bm25_index.doc_id(*idx).to_string();
142 bm25_map.insert(id, score / bm25_max);
143 }
144
145 let query_embedding = embedder.embed(query)?;
146
147 const HNSW_THRESHOLD: usize = 200_000;
150 let use_hnsw = symbol_embeddings.len() >= HNSW_THRESHOLD;
151 let vec_scores = if use_hnsw {
152 if let (Some(idx_path), Some(emb_path)) = (hnsw_index_path, embeddings_path) {
153 match embed::search_hnsw(idx_path, emb_path, &query_embedding, oversample) {
154 Ok(Some(candidates)) => {
155 let emb_lookup: HashMap<&str, &[f32]> = symbol_embeddings
156 .iter()
157 .map(|(id, v)| (id.as_str(), v.as_slice()))
158 .collect();
159 let mut reranked: Vec<(String, f32)> = candidates
160 .into_iter()
161 .filter_map(|r| {
162 emb_lookup
163 .get(r.id.as_str())
164 .map(|emb| (r.id, embed::cosine_similarity(&query_embedding, emb)))
165 })
166 .collect();
167 reranked.sort_unstable_by(|a, b| {
168 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
169 });
170 reranked.truncate(oversample);
171 reranked
172 }
173 _ => brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample),
174 }
175 } else {
176 brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample)
177 }
178 } else {
179 brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample)
180 };
181
182 let vec_max = vec_scores
183 .first()
184 .map(|(_, s)| *s)
185 .unwrap_or(1.0)
186 .max(0.001);
187
188 let mut vector_map: HashMap<String, f32> = HashMap::new();
189 for (id, score) in &vec_scores {
190 vector_map.insert(id.clone(), score / vec_max);
191 }
192
193 Ok(RawScores {
194 bm25: bm25_map,
195 vector: vector_map,
196 })
197}
198
199fn brute_force_vector_scores(
200 query_embedding: &[f32],
201 symbol_embeddings: &[(String, Vec<f32>)],
202 oversample: usize,
203) -> Vec<(String, f32)> {
204 let mut vec_scores: Vec<(String, f32)> = symbol_embeddings
205 .par_iter()
206 .map(|(id, emb)| {
207 let sim = embed::cosine_similarity(query_embedding, emb);
208 (id.clone(), sim)
209 })
210 .collect();
211 vec_scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212 vec_scores.truncate(oversample);
213 vec_scores
214}
215
216pub fn combine_scores(raw: &RawScores, alpha: f32, limit: usize) -> Vec<SearchResult> {
218 let all_ids: std::collections::HashSet<&String> =
219 raw.bm25.keys().chain(raw.vector.keys()).collect();
220
221 let mut results: Vec<SearchResult> = all_ids
222 .into_iter()
223 .map(|id| {
224 let bm25 = raw.bm25.get(id).copied().unwrap_or(0.0);
225 let vec = raw.vector.get(id).copied().unwrap_or(0.0);
226 let score = (1.0 - alpha) * bm25 + alpha * vec;
227 SearchResult {
228 symbol_id: id.clone(),
229 name: String::new(),
230 kind: String::new(),
231 file: String::new(),
232 score,
233 bm25_score: bm25,
234 vector_score: vec,
235 docstring: None,
236 }
237 })
238 .collect();
239
240 results.sort_by(|a, b| {
241 b.score
242 .partial_cmp(&a.score)
243 .unwrap_or(std::cmp::Ordering::Equal)
244 });
245 results.truncate(limit);
246 results
247}
248
249#[allow(clippy::too_many_arguments)]
251pub fn hybrid_search(
252 query: &str,
253 bm25_index: &BM25Index,
254 embedder: &dyn EmbedProvider,
255 symbol_embeddings: &[(String, Vec<f32>)],
256 limit: usize,
257 alpha: f32, hnsw_index_path: Option<&Path>,
259 embeddings_path: Option<&Path>,
260) -> Result<Vec<SearchResult>> {
261 let raw = compute_raw_scores(
262 query,
263 bm25_index,
264 embedder,
265 symbol_embeddings,
266 limit * 2,
267 hnsw_index_path,
268 embeddings_path,
269 )?;
270 Ok(combine_scores(&raw, alpha, limit))
271}
272
273fn tokenize(text: &str) -> Vec<String> {
275 text.to_lowercase()
276 .split(|c: char| !c.is_alphanumeric() && c != '_')
277 .filter(|s| !s.is_empty() && s.len() > 1)
278 .map(String::from)
279 .collect()
280}
281
282#[derive(Debug, Clone)]
288pub struct GrepMatch {
289 pub file: String,
291 pub line_number: usize,
293 pub line_text: String,
295}
296
297pub fn grep_search(
300 root: &Path,
301 pattern: &str,
302 file_pattern: Option<&str>,
303 limit: usize,
304) -> Result<Vec<GrepMatch>> {
305 let re =
306 Regex::new(pattern).map_err(|e| anyhow::anyhow!("invalid regex '{}': {}", pattern, e))?;
307
308 let glob_pat = file_pattern
309 .map(glob::Pattern::new)
310 .transpose()
311 .map_err(|e| anyhow::anyhow!("invalid file pattern: {}", e))?;
312
313 let mut matches = Vec::new();
314 walk_and_search(root, root, &re, &glob_pat, limit, &mut matches)?;
315 Ok(matches)
316}
317
318const IGNORE_DIRS: &[&str] = &[
320 ".infigraph",
321 ".git",
322 "node_modules",
323 "__pycache__",
324 ".venv",
325 "venv",
326 "target",
327 "build",
328 "dist",
329 ".tox",
330];
331
332fn walk_and_search(
333 base: &Path,
334 dir: &Path,
335 re: &Regex,
336 glob_pat: &Option<glob::Pattern>,
337 limit: usize,
338 matches: &mut Vec<GrepMatch>,
339) -> Result<()> {
340 if matches.len() >= limit {
341 return Ok(());
342 }
343
344 let entries = match std::fs::read_dir(dir) {
345 Ok(e) => e,
346 Err(_) => return Ok(()), };
348
349 for entry in entries {
350 if matches.len() >= limit {
351 return Ok(());
352 }
353 let entry = entry?;
354 let path = entry.path();
355 let name = entry.file_name();
356 let name_str = name.to_string_lossy();
357
358 if path.is_dir() {
359 if !IGNORE_DIRS.contains(&name_str.as_ref()) && !name_str.starts_with('.') {
360 walk_and_search(base, &path, re, glob_pat, limit, matches)?;
361 }
362 } else if path.is_file() {
363 let rel = path
364 .strip_prefix(base)
365 .unwrap_or(&path)
366 .to_string_lossy()
367 .replace('\\', "/");
368
369 if let Some(ref gp) = glob_pat {
371 if !gp.matches(&rel) {
372 continue;
373 }
374 }
375
376 let content = match std::fs::read_to_string(&path) {
378 Ok(c) => c,
379 Err(_) => continue,
380 };
381
382 for (idx, line) in content.lines().enumerate() {
383 if matches.len() >= limit {
384 return Ok(());
385 }
386 if re.is_match(line) {
387 matches.push(GrepMatch {
388 file: rel.clone(),
389 line_number: idx + 1,
390 line_text: line.to_string(),
391 });
392 }
393 }
394 }
395 }
396 Ok(())
397}
398
399pub fn read_lines_from_file(path: &Path, start_line: u32, end_line: u32) -> Result<String> {
402 let content = std::fs::read_to_string(path)
403 .map_err(|e| anyhow::anyhow!("cannot read {}: {}", path.display(), e))?;
404 let lines: Vec<&str> = content.lines().collect();
405 let start = (start_line as usize).saturating_sub(1);
406 let end = (end_line as usize).min(lines.len());
407 if start >= lines.len() {
408 return Ok(String::new());
409 }
410 Ok(lines[start..end].join("\n"))
411}