Skip to main content

infigraph_core/search/
mod.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::Result;
5use rayon::prelude::*;
6use regex::Regex;
7
8use crate::embed::{self, EmbedProvider};
9
10/// A search result with combined score.
11#[derive(Debug, Clone)]
12pub struct SearchResult {
13    pub symbol_id: String,
14    pub name: String,
15    pub kind: String,
16    pub file: String,
17    pub score: f32,
18    pub bm25_score: f32,
19    pub vector_score: f32,
20    pub docstring: Option<String>,
21}
22
23/// BM25 parameters.
24const K1: f32 = 1.2;
25const B: f32 = 0.75;
26
27/// Simple BM25 scorer over symbol text (name + docstring).
28pub struct BM25Index {
29    /// symbol_id -> text
30    docs: Vec<(String, String)>,
31    /// term -> list of (doc_index, term_frequency)
32    inverted: HashMap<String, Vec<(usize, f32)>>,
33    avg_doc_len: f32,
34}
35
36impl BM25Index {
37    /// Build a BM25 index from symbol (id, text) pairs.
38    pub fn build(docs: Vec<(String, String)>) -> Self {
39        let n = docs.len();
40        let mut inverted: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
41        let mut total_len = 0usize;
42
43        for (i, (_id, text)) in docs.iter().enumerate() {
44            let tokens = tokenize(text);
45            total_len += tokens.len();
46
47            let mut tf_map: HashMap<&str, f32> = HashMap::new();
48            for t in &tokens {
49                *tf_map.entry(t.as_str()).or_default() += 1.0;
50            }
51
52            for (term, tf) in tf_map {
53                inverted.entry(term.to_string()).or_default().push((i, tf));
54            }
55        }
56
57        let avg_doc_len = if n > 0 {
58            total_len as f32 / n as f32
59        } else {
60            1.0
61        };
62
63        Self {
64            docs,
65            inverted,
66            avg_doc_len,
67        }
68    }
69
70    /// Score all documents against a query. Returns (doc_index, score) sorted descending.
71    pub fn search(&self, query: &str, limit: usize) -> Vec<(usize, f32)> {
72        let query_tokens = tokenize(query);
73        let n = self.docs.len() as f32;
74        let mut scores = vec![0.0f32; self.docs.len()];
75
76        for token in &query_tokens {
77            if let Some(postings) = self.inverted.get(token.as_str()) {
78                let df = postings.len() as f32;
79                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
80
81                for &(doc_idx, tf) in postings {
82                    let doc_len = tokenize(&self.docs[doc_idx].1).len() as f32;
83                    let tf_norm =
84                        (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * doc_len / self.avg_doc_len));
85                    scores[doc_idx] += idf * tf_norm;
86                }
87            }
88        }
89
90        let mut results: Vec<(usize, f32)> = scores
91            .into_iter()
92            .enumerate()
93            .filter(|(_, s)| *s > 0.0)
94            .collect();
95        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
96        results.truncate(limit);
97        results
98    }
99
100    pub fn doc_id(&self, idx: usize) -> &str {
101        &self.docs[idx].0
102    }
103
104    pub fn doc_text(&self, idx: usize) -> &str {
105        &self.docs[idx].1
106    }
107}
108
109/// Pre-computed BM25 and vector scores before alpha blending.
110pub struct RawScores {
111    /// symbol_id -> normalized BM25 score
112    pub bm25: HashMap<String, f32>,
113    /// symbol_id -> normalized vector score
114    pub vector: HashMap<String, f32>,
115}
116
117/// Compute BM25 and vector scores separately. Call once, then blend with
118/// multiple alpha values via `combine_scores`.
119///
120/// When `hnsw_index_path` and `embeddings_path` are provided and a valid HNSW
121/// index exists on disk, vector scoring uses the index (~1ms) instead of
122/// brute-force scanning all embeddings (~20-30ms).
123pub fn compute_raw_scores(
124    query: &str,
125    bm25_index: &BM25Index,
126    embedder: &dyn EmbedProvider,
127    symbol_embeddings: &[(String, Vec<f32>)],
128    oversample: usize,
129    hnsw_index_path: Option<&Path>,
130    embeddings_path: Option<&Path>,
131) -> Result<RawScores> {
132    let bm25_results = bm25_index.search(query, oversample);
133    let bm25_max = bm25_results
134        .first()
135        .map(|(_, s)| *s)
136        .unwrap_or(1.0)
137        .max(0.001);
138
139    let mut bm25_map: HashMap<String, f32> = HashMap::new();
140    for (idx, score) in &bm25_results {
141        let id = bm25_index.doc_id(*idx).to_string();
142        bm25_map.insert(id, score / bm25_max);
143    }
144
145    let query_embedding = embedder.embed(query)?;
146
147    // HNSW only pays off above ~200K embeddings where brute-force exceeds index
148    // load + search time. Below that, rayon dot-product is faster.
149    const HNSW_THRESHOLD: usize = 200_000;
150    let use_hnsw = symbol_embeddings.len() >= HNSW_THRESHOLD;
151    let vec_scores = if use_hnsw {
152        if let (Some(idx_path), Some(emb_path)) = (hnsw_index_path, embeddings_path) {
153            match embed::search_hnsw(idx_path, emb_path, &query_embedding, oversample) {
154                Ok(Some(candidates)) => {
155                    let emb_lookup: HashMap<&str, &[f32]> = symbol_embeddings
156                        .iter()
157                        .map(|(id, v)| (id.as_str(), v.as_slice()))
158                        .collect();
159                    let mut reranked: Vec<(String, f32)> = candidates
160                        .into_iter()
161                        .filter_map(|r| {
162                            emb_lookup
163                                .get(r.id.as_str())
164                                .map(|emb| (r.id, embed::cosine_similarity(&query_embedding, emb)))
165                        })
166                        .collect();
167                    reranked.sort_unstable_by(|a, b| {
168                        b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
169                    });
170                    reranked.truncate(oversample);
171                    reranked
172                }
173                _ => brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample),
174            }
175        } else {
176            brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample)
177        }
178    } else {
179        brute_force_vector_scores(&query_embedding, symbol_embeddings, oversample)
180    };
181
182    let vec_max = vec_scores
183        .first()
184        .map(|(_, s)| *s)
185        .unwrap_or(1.0)
186        .max(0.001);
187
188    let mut vector_map: HashMap<String, f32> = HashMap::new();
189    for (id, score) in &vec_scores {
190        vector_map.insert(id.clone(), score / vec_max);
191    }
192
193    Ok(RawScores {
194        bm25: bm25_map,
195        vector: vector_map,
196    })
197}
198
199fn brute_force_vector_scores(
200    query_embedding: &[f32],
201    symbol_embeddings: &[(String, Vec<f32>)],
202    oversample: usize,
203) -> Vec<(String, f32)> {
204    let mut vec_scores: Vec<(String, f32)> = symbol_embeddings
205        .par_iter()
206        .map(|(id, emb)| {
207            let sim = embed::cosine_similarity(query_embedding, emb);
208            (id.clone(), sim)
209        })
210        .collect();
211    vec_scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212    vec_scores.truncate(oversample);
213    vec_scores
214}
215
216/// Blend pre-computed raw scores with a given alpha. Returns sorted results.
217pub fn combine_scores(raw: &RawScores, alpha: f32, limit: usize) -> Vec<SearchResult> {
218    let all_ids: std::collections::HashSet<&String> =
219        raw.bm25.keys().chain(raw.vector.keys()).collect();
220
221    let mut results: Vec<SearchResult> = all_ids
222        .into_iter()
223        .map(|id| {
224            let bm25 = raw.bm25.get(id).copied().unwrap_or(0.0);
225            let vec = raw.vector.get(id).copied().unwrap_or(0.0);
226            let score = (1.0 - alpha) * bm25 + alpha * vec;
227            SearchResult {
228                symbol_id: id.clone(),
229                name: String::new(),
230                kind: String::new(),
231                file: String::new(),
232                score,
233                bm25_score: bm25,
234                vector_score: vec,
235                docstring: None,
236            }
237        })
238        .collect();
239
240    results.sort_by(|a, b| {
241        b.score
242            .partial_cmp(&a.score)
243            .unwrap_or(std::cmp::Ordering::Equal)
244    });
245    results.truncate(limit);
246    results
247}
248
249/// Hybrid search combining BM25 text relevance with vector similarity.
250#[allow(clippy::too_many_arguments)]
251pub fn hybrid_search(
252    query: &str,
253    bm25_index: &BM25Index,
254    embedder: &dyn EmbedProvider,
255    symbol_embeddings: &[(String, Vec<f32>)],
256    limit: usize,
257    alpha: f32, // 0.0 = pure BM25, 1.0 = pure vector
258    hnsw_index_path: Option<&Path>,
259    embeddings_path: Option<&Path>,
260) -> Result<Vec<SearchResult>> {
261    let raw = compute_raw_scores(
262        query,
263        bm25_index,
264        embedder,
265        symbol_embeddings,
266        limit * 2,
267        hnsw_index_path,
268        embeddings_path,
269    )?;
270    Ok(combine_scores(&raw, alpha, limit))
271}
272
273/// Simple whitespace + punctuation tokenizer with lowercasing.
274fn tokenize(text: &str) -> Vec<String> {
275    text.to_lowercase()
276        .split(|c: char| !c.is_alphanumeric() && c != '_')
277        .filter(|s| !s.is_empty() && s.len() > 1)
278        .map(String::from)
279        .collect()
280}
281
282// ---------------------------------------------------------------------------
283// grep-like text search
284// ---------------------------------------------------------------------------
285
286/// A single matching line from a grep search.
287#[derive(Debug, Clone)]
288pub struct GrepMatch {
289    /// Relative file path within the project.
290    pub file: String,
291    /// 1-based line number.
292    pub line_number: usize,
293    /// The full text of the matching line (trimmed of trailing newline).
294    pub line_text: String,
295}
296
297/// Walk `root`, optionally filtering by a glob `file_pattern`, and search every
298/// file for lines matching `pattern` (a regex).  Returns up to `limit` matches.
299pub fn grep_search(
300    root: &Path,
301    pattern: &str,
302    file_pattern: Option<&str>,
303    limit: usize,
304) -> Result<Vec<GrepMatch>> {
305    let re =
306        Regex::new(pattern).map_err(|e| anyhow::anyhow!("invalid regex '{}': {}", pattern, e))?;
307
308    let glob_pat = file_pattern
309        .map(glob::Pattern::new)
310        .transpose()
311        .map_err(|e| anyhow::anyhow!("invalid file pattern: {}", e))?;
312
313    let mut matches = Vec::new();
314    walk_and_search(root, root, &re, &glob_pat, limit, &mut matches)?;
315    Ok(matches)
316}
317
318/// Directories to skip during the grep walk (same set as Infigraph::walk_dir).
319const IGNORE_DIRS: &[&str] = &[
320    ".infigraph",
321    ".git",
322    "node_modules",
323    "__pycache__",
324    ".venv",
325    "venv",
326    "target",
327    "build",
328    "dist",
329    ".tox",
330];
331
332fn walk_and_search(
333    base: &Path,
334    dir: &Path,
335    re: &Regex,
336    glob_pat: &Option<glob::Pattern>,
337    limit: usize,
338    matches: &mut Vec<GrepMatch>,
339) -> Result<()> {
340    if matches.len() >= limit {
341        return Ok(());
342    }
343
344    let entries = match std::fs::read_dir(dir) {
345        Ok(e) => e,
346        Err(_) => return Ok(()), // skip unreadable dirs
347    };
348
349    for entry in entries {
350        if matches.len() >= limit {
351            return Ok(());
352        }
353        let entry = entry?;
354        let path = entry.path();
355        let name = entry.file_name();
356        let name_str = name.to_string_lossy();
357
358        if path.is_dir() {
359            if !IGNORE_DIRS.contains(&name_str.as_ref()) && !name_str.starts_with('.') {
360                walk_and_search(base, &path, re, glob_pat, limit, matches)?;
361            }
362        } else if path.is_file() {
363            let rel = path
364                .strip_prefix(base)
365                .unwrap_or(&path)
366                .to_string_lossy()
367                .replace('\\', "/");
368
369            // Apply optional file-name glob filter
370            if let Some(ref gp) = glob_pat {
371                if !gp.matches(&rel) {
372                    continue;
373                }
374            }
375
376            // Skip binary files — try to read as UTF-8
377            let content = match std::fs::read_to_string(&path) {
378                Ok(c) => c,
379                Err(_) => continue,
380            };
381
382            for (idx, line) in content.lines().enumerate() {
383                if matches.len() >= limit {
384                    return Ok(());
385                }
386                if re.is_match(line) {
387                    matches.push(GrepMatch {
388                        file: rel.clone(),
389                        line_number: idx + 1,
390                        line_text: line.to_string(),
391                    });
392                }
393            }
394        }
395    }
396    Ok(())
397}
398
399/// Read a range of lines [start_line..=end_line] (1-based) from a file.
400/// Returns the source text of those lines concatenated.
401pub fn read_lines_from_file(path: &Path, start_line: u32, end_line: u32) -> Result<String> {
402    let content = std::fs::read_to_string(path)
403        .map_err(|e| anyhow::anyhow!("cannot read {}: {}", path.display(), e))?;
404    let lines: Vec<&str> = content.lines().collect();
405    let start = (start_line as usize).saturating_sub(1);
406    let end = (end_line as usize).min(lines.len());
407    if start >= lines.len() {
408        return Ok(String::new());
409    }
410    Ok(lines[start..end].join("\n"))
411}