Skip to main content

infigraph_docs/
search.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use infigraph_core::embed::{cosine_similarity, doc_embedder, load_embeddings_cached, search_hnsw};
8
9use crate::store::DocStore;
10
11#[derive(Debug, Clone)]
12pub struct DocSearchResult {
13    pub chunk_id: String,
14    pub doc_file: String,
15    pub heading: Option<String>,
16    pub text: String,
17    pub score: f32,
18    pub bm25_score: f32,
19    pub vector_score: f32,
20    pub start_offset: usize,
21    pub end_offset: usize,
22    pub page: Option<usize>,
23}
24
25const K1: f32 = 1.2;
26const B: f32 = 0.75;
27
28pub struct DocBM25Index {
29    docs: Vec<(String, String)>,
30    inverted: HashMap<String, Vec<(usize, f32)>>,
31    avg_doc_len: f32,
32}
33
34impl DocBM25Index {
35    pub fn build(docs: Vec<(String, String)>) -> Self {
36        let n = docs.len();
37        let mut inverted: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
38        let mut total_len = 0usize;
39
40        for (i, (_id, text)) in docs.iter().enumerate() {
41            let tokens = tokenize(text);
42            total_len += tokens.len();
43
44            let mut tf_map: HashMap<&str, f32> = HashMap::new();
45            for t in &tokens {
46                *tf_map.entry(t.as_str()).or_default() += 1.0;
47            }
48
49            for (term, tf) in tf_map {
50                inverted.entry(term.to_string()).or_default().push((i, tf));
51            }
52        }
53
54        let avg_doc_len = if n > 0 {
55            total_len as f32 / n as f32
56        } else {
57            1.0
58        };
59
60        Self {
61            docs,
62            inverted,
63            avg_doc_len,
64        }
65    }
66
67    pub fn search(&self, query: &str, limit: usize) -> Vec<(usize, f32)> {
68        let query_tokens = tokenize(query);
69        let n = self.docs.len() as f32;
70        let mut scores = vec![0.0f32; self.docs.len()];
71
72        for token in &query_tokens {
73            if let Some(postings) = self.inverted.get(token.as_str()) {
74                let df = postings.len() as f32;
75                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
76
77                for &(doc_idx, tf) in postings {
78                    let doc_len = tokenize(&self.docs[doc_idx].1).len() as f32;
79                    let tf_norm =
80                        (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * doc_len / self.avg_doc_len));
81                    scores[doc_idx] += idf * tf_norm;
82                }
83            }
84        }
85
86        let mut results: Vec<(usize, f32)> = scores
87            .into_iter()
88            .enumerate()
89            .filter(|(_, s)| *s > 0.0)
90            .collect();
91        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
92        results.truncate(limit);
93        results
94    }
95}
96
97pub fn hybrid_doc_search(
98    query: &str,
99    store: &DocStore,
100    root: &Path,
101    limit: usize,
102    alpha: f32,
103) -> Result<Vec<DocSearchResult>> {
104    let chunks = store.get_all_chunks()?;
105
106    if chunks.is_empty() {
107        return Ok(Vec::new());
108    }
109
110    let bm25_index = DocBM25Index::build(chunks.clone());
111    let bm25_results = bm25_index.search(query, limit * 3);
112
113    // Normalize BM25
114    let max_bm25 = bm25_results
115        .first()
116        .map(|(_, s)| *s)
117        .unwrap_or(1.0)
118        .max(0.001);
119    let bm25_scores: HashMap<usize, f32> = bm25_results
120        .iter()
121        .map(|(idx, s)| (*idx, s / max_bm25))
122        .collect();
123
124    // Vector search
125    let tg_dir = root.join(".infigraph");
126    let emb_path = tg_dir.join("docs_embeddings.bin");
127    let hnsw_path = tg_dir.join("docs_hnsw_index.usearch");
128
129    let embedder = doc_embedder();
130    let query_vec = embedder.embed(query)?;
131
132    let vector_scores: HashMap<usize, f32> = if hnsw_path.exists() {
133        // HNSW path
134        if let Ok(Some(hnsw_results)) = search_hnsw(&hnsw_path, &emb_path, &query_vec, limit * 3) {
135            let id_to_idx: HashMap<&str, usize> = chunks
136                .iter()
137                .enumerate()
138                .map(|(i, (id, _))| (id.as_str(), i))
139                .collect();
140            hnsw_results
141                .into_iter()
142                .filter_map(|r| id_to_idx.get(r.id.as_str()).map(|&idx| (idx, r.score)))
143                .collect()
144        } else {
145            brute_force_vector(&chunks, &emb_path, &query_vec, limit * 3)?
146        }
147    } else {
148        brute_force_vector(&chunks, &emb_path, &query_vec, limit * 3)?
149    };
150
151    // Normalize vector
152    let max_vec = vector_scores.values().cloned().fold(0.001f32, f32::max);
153
154    // Combine
155    let mut all_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
156    all_indices.extend(bm25_scores.keys());
157    all_indices.extend(vector_scores.keys());
158
159    let mut combined: Vec<(usize, f32, f32, f32)> = all_indices
160        .into_iter()
161        .map(|idx| {
162            let b = bm25_scores.get(&idx).copied().unwrap_or(0.0);
163            let v = vector_scores.get(&idx).copied().unwrap_or(0.0) / max_vec;
164            let score = (1.0 - alpha) * b + alpha * v;
165            (idx, score, b, v)
166        })
167        .collect();
168    combined.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169    combined.truncate(limit);
170
171    // Fetch chunk details
172    let chunk_ids: Vec<&str> = combined
173        .iter()
174        .map(|(idx, _, _, _)| chunks[*idx].0.as_str())
175        .collect();
176    let details = store.get_chunk_details(&chunk_ids)?;
177    let detail_map: HashMap<&str, &crate::store::ChunkDetail> =
178        details.iter().map(|d| (d.id.as_str(), d)).collect();
179
180    let results = combined
181        .into_iter()
182        .filter_map(|(idx, score, bm25, vec_s)| {
183            let chunk_id = &chunks[idx].0;
184            let detail = detail_map.get(chunk_id.as_str())?;
185            Some(DocSearchResult {
186                chunk_id: chunk_id.clone(),
187                doc_file: detail.doc_file.clone(),
188                heading: detail.heading.clone(),
189                text: detail.text.clone(),
190                score,
191                bm25_score: bm25,
192                vector_score: vec_s,
193                start_offset: detail.start_offset,
194                end_offset: detail.end_offset,
195                page: detail.page,
196            })
197        })
198        .collect();
199
200    Ok(results)
201}
202
203fn brute_force_vector(
204    chunks: &[(String, String)],
205    emb_path: &Path,
206    query_vec: &[f32],
207    limit: usize,
208) -> Result<HashMap<usize, f32>> {
209    let embeddings = load_embeddings_cached(emb_path).unwrap_or_default();
210    let emb_map: HashMap<&str, &Vec<f32>> =
211        embeddings.iter().map(|(id, v)| (id.as_str(), v)).collect();
212
213    let id_to_idx: HashMap<&str, usize> = chunks
214        .iter()
215        .enumerate()
216        .map(|(i, (id, _))| (id.as_str(), i))
217        .collect();
218
219    let mut scores: Vec<(usize, f32)> = emb_map
220        .par_iter()
221        .filter_map(|(id, vec)| {
222            let idx = id_to_idx.get(id)?;
223            let sim = cosine_similarity(query_vec, vec);
224            Some((*idx, sim))
225        })
226        .collect();
227
228    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
229    scores.truncate(limit);
230    Ok(scores.into_iter().collect())
231}
232
233fn tokenize(text: &str) -> Vec<String> {
234    text.to_lowercase()
235        .split(|c: char| !c.is_alphanumeric() && c != '_')
236        .filter(|s| s.len() > 1)
237        .map(String::from)
238        .collect()
239}