1use std::collections::HashMap;
2use std::path::Path;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use infigraph_core::embed::{cosine_similarity, doc_embedder, load_embeddings_cached, search_hnsw};
8
9use crate::store::DocStore;
10
11#[derive(Debug, Clone)]
12pub struct DocSearchResult {
13 pub chunk_id: String,
14 pub doc_file: String,
15 pub heading: Option<String>,
16 pub text: String,
17 pub score: f32,
18 pub bm25_score: f32,
19 pub vector_score: f32,
20 pub start_offset: usize,
21 pub end_offset: usize,
22 pub page: Option<usize>,
23}
24
25const K1: f32 = 1.2;
26const B: f32 = 0.75;
27
28pub struct DocBM25Index {
29 docs: Vec<(String, String)>,
30 inverted: HashMap<String, Vec<(usize, f32)>>,
31 avg_doc_len: f32,
32}
33
34impl DocBM25Index {
35 pub fn build(docs: Vec<(String, String)>) -> Self {
36 let n = docs.len();
37 let mut inverted: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
38 let mut total_len = 0usize;
39
40 for (i, (_id, text)) in docs.iter().enumerate() {
41 let tokens = tokenize(text);
42 total_len += tokens.len();
43
44 let mut tf_map: HashMap<&str, f32> = HashMap::new();
45 for t in &tokens {
46 *tf_map.entry(t.as_str()).or_default() += 1.0;
47 }
48
49 for (term, tf) in tf_map {
50 inverted.entry(term.to_string()).or_default().push((i, tf));
51 }
52 }
53
54 let avg_doc_len = if n > 0 {
55 total_len as f32 / n as f32
56 } else {
57 1.0
58 };
59
60 Self {
61 docs,
62 inverted,
63 avg_doc_len,
64 }
65 }
66
67 pub fn search(&self, query: &str, limit: usize) -> Vec<(usize, f32)> {
68 let query_tokens = tokenize(query);
69 let n = self.docs.len() as f32;
70 let mut scores = vec![0.0f32; self.docs.len()];
71
72 for token in &query_tokens {
73 if let Some(postings) = self.inverted.get(token.as_str()) {
74 let df = postings.len() as f32;
75 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
76
77 for &(doc_idx, tf) in postings {
78 let doc_len = tokenize(&self.docs[doc_idx].1).len() as f32;
79 let tf_norm =
80 (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * doc_len / self.avg_doc_len));
81 scores[doc_idx] += idf * tf_norm;
82 }
83 }
84 }
85
86 let mut results: Vec<(usize, f32)> = scores
87 .into_iter()
88 .enumerate()
89 .filter(|(_, s)| *s > 0.0)
90 .collect();
91 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
92 results.truncate(limit);
93 results
94 }
95}
96
97pub fn hybrid_doc_search(
98 query: &str,
99 store: &DocStore,
100 root: &Path,
101 limit: usize,
102 alpha: f32,
103) -> Result<Vec<DocSearchResult>> {
104 let chunks = store.get_all_chunks()?;
105
106 if chunks.is_empty() {
107 return Ok(Vec::new());
108 }
109
110 let bm25_index = DocBM25Index::build(chunks.clone());
111 let bm25_results = bm25_index.search(query, limit * 3);
112
113 let max_bm25 = bm25_results
115 .first()
116 .map(|(_, s)| *s)
117 .unwrap_or(1.0)
118 .max(0.001);
119 let bm25_scores: HashMap<usize, f32> = bm25_results
120 .iter()
121 .map(|(idx, s)| (*idx, s / max_bm25))
122 .collect();
123
124 let tg_dir = root.join(".infigraph");
126 let emb_path = tg_dir.join("docs_embeddings.bin");
127 let hnsw_path = tg_dir.join("docs_hnsw_index.usearch");
128
129 let embedder = doc_embedder();
130 let query_vec = embedder.embed(query)?;
131
132 let vector_scores: HashMap<usize, f32> = if hnsw_path.exists() {
133 if let Ok(Some(hnsw_results)) = search_hnsw(&hnsw_path, &emb_path, &query_vec, limit * 3) {
135 let id_to_idx: HashMap<&str, usize> = chunks
136 .iter()
137 .enumerate()
138 .map(|(i, (id, _))| (id.as_str(), i))
139 .collect();
140 hnsw_results
141 .into_iter()
142 .filter_map(|r| id_to_idx.get(r.id.as_str()).map(|&idx| (idx, r.score)))
143 .collect()
144 } else {
145 brute_force_vector(&chunks, &emb_path, &query_vec, limit * 3)?
146 }
147 } else {
148 brute_force_vector(&chunks, &emb_path, &query_vec, limit * 3)?
149 };
150
151 let max_vec = vector_scores.values().cloned().fold(0.001f32, f32::max);
153
154 let mut all_indices: std::collections::HashSet<usize> = std::collections::HashSet::new();
156 all_indices.extend(bm25_scores.keys());
157 all_indices.extend(vector_scores.keys());
158
159 let mut combined: Vec<(usize, f32, f32, f32)> = all_indices
160 .into_iter()
161 .map(|idx| {
162 let b = bm25_scores.get(&idx).copied().unwrap_or(0.0);
163 let v = vector_scores.get(&idx).copied().unwrap_or(0.0) / max_vec;
164 let score = (1.0 - alpha) * b + alpha * v;
165 (idx, score, b, v)
166 })
167 .collect();
168 combined.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169 combined.truncate(limit);
170
171 let chunk_ids: Vec<&str> = combined
173 .iter()
174 .map(|(idx, _, _, _)| chunks[*idx].0.as_str())
175 .collect();
176 let details = store.get_chunk_details(&chunk_ids)?;
177 let detail_map: HashMap<&str, &crate::store::ChunkDetail> =
178 details.iter().map(|d| (d.id.as_str(), d)).collect();
179
180 let results = combined
181 .into_iter()
182 .filter_map(|(idx, score, bm25, vec_s)| {
183 let chunk_id = &chunks[idx].0;
184 let detail = detail_map.get(chunk_id.as_str())?;
185 Some(DocSearchResult {
186 chunk_id: chunk_id.clone(),
187 doc_file: detail.doc_file.clone(),
188 heading: detail.heading.clone(),
189 text: detail.text.clone(),
190 score,
191 bm25_score: bm25,
192 vector_score: vec_s,
193 start_offset: detail.start_offset,
194 end_offset: detail.end_offset,
195 page: detail.page,
196 })
197 })
198 .collect();
199
200 Ok(results)
201}
202
203fn brute_force_vector(
204 chunks: &[(String, String)],
205 emb_path: &Path,
206 query_vec: &[f32],
207 limit: usize,
208) -> Result<HashMap<usize, f32>> {
209 let embeddings = load_embeddings_cached(emb_path).unwrap_or_default();
210 let emb_map: HashMap<&str, &Vec<f32>> =
211 embeddings.iter().map(|(id, v)| (id.as_str(), v)).collect();
212
213 let id_to_idx: HashMap<&str, usize> = chunks
214 .iter()
215 .enumerate()
216 .map(|(i, (id, _))| (id.as_str(), i))
217 .collect();
218
219 let mut scores: Vec<(usize, f32)> = emb_map
220 .par_iter()
221 .filter_map(|(id, vec)| {
222 let idx = id_to_idx.get(id)?;
223 let sim = cosine_similarity(query_vec, vec);
224 Some((*idx, sim))
225 })
226 .collect();
227
228 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
229 scores.truncate(limit);
230 Ok(scores.into_iter().collect())
231}
232
233fn tokenize(text: &str) -> Vec<String> {
234 text.to_lowercase()
235 .split(|c: char| !c.is_alphanumeric() && c != '_')
236 .filter(|s| s.len() > 1)
237 .map(String::from)
238 .collect()
239}