1use std::collections::HashMap;
9use std::fmt;
10use std::str::FromStr;
11
12use crate::bm25::Bm25Index;
13use crate::chunk::CodeChunk;
14use crate::index::SearchIndex;
15
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
18pub enum SearchMode {
19 #[default]
21 Hybrid,
22 Semantic,
24 Keyword,
26}
27
28impl fmt::Display for SearchMode {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::Hybrid => f.write_str("hybrid"),
32 Self::Semantic => f.write_str("semantic"),
33 Self::Keyword => f.write_str("keyword"),
34 }
35 }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct ParseSearchModeError(String);
41
42impl fmt::Display for ParseSearchModeError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(
45 f,
46 "unknown search mode {:?}; expected hybrid, semantic, or keyword",
47 self.0
48 )
49 }
50}
51
52impl std::error::Error for ParseSearchModeError {}
53
54impl FromStr for SearchMode {
55 type Err = ParseSearchModeError;
56
57 fn from_str(s: &str) -> Result<Self, Self::Err> {
58 match s {
59 "hybrid" => Ok(Self::Hybrid),
60 "semantic" => Ok(Self::Semantic),
61 "keyword" => Ok(Self::Keyword),
62 other => Err(ParseSearchModeError(other.to_string())),
63 }
64 }
65}
66
67pub struct HybridIndex {
72 pub semantic: SearchIndex,
74 bm25: Bm25Index,
76}
77
78impl HybridIndex {
79 pub fn new(
90 chunks: Vec<CodeChunk>,
91 embeddings: &[Vec<f32>],
92 cascade_dim: Option<usize>,
93 ) -> crate::Result<Self> {
94 let bm25 = Bm25Index::build(&chunks)?;
95 let semantic = SearchIndex::new(chunks, embeddings, cascade_dim);
96 Ok(Self { semantic, bm25 })
97 }
98
99 #[must_use]
104 pub fn from_parts(semantic: SearchIndex, bm25: Bm25Index) -> Self {
105 Self { semantic, bm25 }
106 }
107
108 #[must_use]
122 pub fn search(
123 &self,
124 query_embedding: &[f32],
125 query_text: &str,
126 top_k: usize,
127 threshold: f32,
128 mode: SearchMode,
129 ) -> Vec<(usize, f32)> {
130 let mut raw = match mode {
131 SearchMode::Semantic => {
132 self.semantic
134 .rank_turboquant(query_embedding, top_k.max(100), 0.0)
135 }
136 SearchMode::Keyword => self.bm25.search(query_text, top_k.max(100)),
137 SearchMode::Hybrid => {
138 let sem = self
139 .semantic
140 .rank_turboquant(query_embedding, top_k.max(100), 0.0);
141 let kw = self.bm25.search(query_text, top_k.max(100));
142 rrf_fuse(&sem, &kw, 60.0)
143 }
144 };
145
146 if let (Some(max), Some(min)) = (raw.first().map(|(_, s)| *s), raw.last().map(|(_, s)| *s))
148 {
149 let range = max - min;
150 if range > f32::EPSILON {
151 for (_, score) in &mut raw {
152 *score = (*score - min) / range;
153 }
154 } else {
155 for (_, score) in &mut raw {
157 *score = 1.0;
158 }
159 }
160 }
161
162 raw.retain(|(_, score)| *score >= threshold);
164 raw.truncate(top_k);
165 raw
166 }
167
168 #[must_use]
170 pub fn chunks(&self) -> &[CodeChunk] {
171 &self.semantic.chunks
172 }
173}
174
175#[must_use]
187pub fn rrf_fuse(semantic: &[(usize, f32)], bm25: &[(usize, f32)], k: f32) -> Vec<(usize, f32)> {
188 let mut scores: HashMap<usize, f32> = HashMap::new();
189
190 for (rank, &(idx, _)) in semantic.iter().enumerate() {
191 *scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
192 }
193 for (rank, &(idx, _)) in bm25.iter().enumerate() {
194 *scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
195 }
196
197 let mut results: Vec<(usize, f32)> = scores.into_iter().collect();
198 results.sort_unstable_by(|a, b| {
199 b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)) });
201 results
202}
203
204const PAGERANK_BETA: f32 = 10.0;
215
216pub fn boost_with_pagerank<S: std::hash::BuildHasher>(
234 results: &mut [(usize, f32)],
235 chunks: &[CodeChunk],
236 pagerank_by_file: &HashMap<String, f32, S>,
237 alpha: f32,
238) {
239 let log_denom = (1.0 + PAGERANK_BETA).ln();
240 if log_denom <= f32::EPSILON {
241 return;
242 }
243
244 for (idx, score) in results.iter_mut() {
245 if let Some(chunk) = chunks.get(*idx) {
246 let def_key = format!("{}::{}", chunk.file_path, chunk.name);
248 let rank = pagerank_by_file
249 .get(&def_key)
250 .or_else(|| pagerank_by_file.get(&chunk.file_path))
251 .copied()
252 .unwrap_or(0.0);
253 let saturated = (1.0 + PAGERANK_BETA * rank).ln() / log_denom;
254 *score *= 1.0 + alpha * saturated;
255 }
256 }
257 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
259}
260
261#[must_use]
267pub fn pagerank_lookup(graph: &crate::repo_map::RepoGraph) -> HashMap<String, f32> {
268 let max_rank = graph.def_ranks.iter().copied().fold(0.0_f32, f32::max);
269 if max_rank <= f32::EPSILON {
270 let file_max = graph.base_ranks.iter().copied().fold(0.0_f32, f32::max);
272 if file_max <= f32::EPSILON {
273 return HashMap::new();
274 }
275 return graph
276 .files
277 .iter()
278 .zip(graph.base_ranks.iter())
279 .map(|(file, &rank)| (file.path.clone(), rank / file_max))
280 .collect();
281 }
282
283 let mut map = HashMap::new();
284
285 for (file_idx, file) in graph.files.iter().enumerate() {
287 for (def_idx, def) in file.defs.iter().enumerate() {
288 let flat = graph.def_offsets[file_idx] + def_idx;
289 if let Some(&rank) = graph.def_ranks.get(flat) {
290 let key = format!("{}::{}", file.path, def.name);
291 map.insert(key, rank / max_rank);
292 }
293 }
294 if file_idx < graph.base_ranks.len() {
296 map.insert(file.path.clone(), graph.base_ranks[file_idx] / max_rank);
297 }
298 }
299
300 map
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn rrf_union_semantics() {
309 let sem = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
313 let bm25 = vec![(3, 10.0), (0, 8.0), (4, 6.0)];
314
315 let fused = rrf_fuse(&sem, &bm25, 60.0);
316
317 let indices: Vec<usize> = fused.iter().map(|&(i, _)| i).collect();
318
319 for expected in [0, 1, 2, 3, 4] {
321 assert!(
322 indices.contains(&expected),
323 "chunk {expected} missing from fused results"
324 );
325 }
326 assert_eq!(fused.len(), 5);
327
328 assert_eq!(indices[0], 0, "chunk 0 should rank first");
330 }
331
332 #[test]
333 fn rrf_single_list() {
334 let sem = vec![(0, 0.9), (1, 0.8)];
336 let bm25: Vec<(usize, f32)> = vec![];
337
338 let fused = rrf_fuse(&sem, &bm25, 60.0);
339
340 assert_eq!(fused.len(), 2);
341 assert_eq!(fused[0].0, 0);
343 assert_eq!(fused[1].0, 1);
344 assert!(fused[0].1 > fused[1].1);
345 }
346
347 #[test]
348 fn search_mode_roundtrip() {
349 assert_eq!("hybrid".parse::<SearchMode>().unwrap(), SearchMode::Hybrid);
350 assert_eq!(
351 "semantic".parse::<SearchMode>().unwrap(),
352 SearchMode::Semantic
353 );
354 assert_eq!(
355 "keyword".parse::<SearchMode>().unwrap(),
356 SearchMode::Keyword
357 );
358
359 let err = "invalid".parse::<SearchMode>();
360 assert!(err.is_err(), "expected parse error for 'invalid'");
361 let msg = err.unwrap_err().to_string();
362 assert!(
363 msg.contains("invalid"),
364 "error message should echo the bad input"
365 );
366 }
367
368 #[test]
369 fn search_mode_display() {
370 assert_eq!(SearchMode::Hybrid.to_string(), "hybrid");
371 assert_eq!(SearchMode::Semantic.to_string(), "semantic");
372 assert_eq!(SearchMode::Keyword.to_string(), "keyword");
373 }
374
375 #[test]
376 fn pagerank_boost_amplifies_relevant() {
377 let chunks = vec![
378 CodeChunk {
379 file_path: "important.rs".into(),
380 name: "a".into(),
381 kind: "function".into(),
382 start_line: 1,
383 end_line: 10,
384 content: String::new(),
385 enriched_content: String::new(),
386 },
387 CodeChunk {
388 file_path: "obscure.rs".into(),
389 name: "b".into(),
390 kind: "function".into(),
391 start_line: 1,
392 end_line: 10,
393 content: String::new(),
394 enriched_content: String::new(),
395 },
396 ];
397
398 let mut results = vec![(0, 0.8_f32), (1, 0.8)];
400 let mut pr = HashMap::new();
401 pr.insert("important.rs".to_string(), 1.0); pr.insert("obscure.rs".to_string(), 0.1); boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
405
406 assert_eq!(
408 results[0].0, 0,
409 "important.rs should rank first after boost"
410 );
411 assert!(results[0].1 > results[1].1);
412
413 assert!(
416 (results[0].1 - 1.04).abs() < 0.01,
417 "rank=1.0 boost: expected ~1.04, got {}",
418 results[0].1
419 );
420 assert!(
422 (results[1].1 - 0.869).abs() < 0.01,
423 "rank=0.1 boost: expected ~0.869, got {}",
424 results[1].1
425 );
426 }
427
428 #[test]
429 fn pagerank_boost_zero_relevance_stays_zero() {
430 let chunks = vec![CodeChunk {
431 file_path: "important.rs".into(),
432 name: "a".into(),
433 kind: "function".into(),
434 start_line: 1,
435 end_line: 10,
436 content: String::new(),
437 enriched_content: String::new(),
438 }];
439
440 let mut results = vec![(0, 0.0_f32)];
441 let mut pr = HashMap::new();
442 pr.insert("important.rs".to_string(), 1.0);
443
444 boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
445
446 assert_eq!(results[0].1, 0.0);
448 }
449
450 #[test]
451 fn pagerank_boost_unknown_file_no_effect() {
452 let chunks = vec![CodeChunk {
453 file_path: "unknown.rs".into(),
454 name: "a".into(),
455 kind: "function".into(),
456 start_line: 1,
457 end_line: 10,
458 content: String::new(),
459 enriched_content: String::new(),
460 }];
461
462 let mut results = vec![(0, 0.5_f32)];
463 let pr = HashMap::new(); boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
466
467 assert_eq!(results[0].1, 0.5);
469 }
470}