use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use crate::bm25::Bm25Index;
use crate::chunk::CodeChunk;
use crate::index::SearchIndex;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum SearchMode {
#[default]
Hybrid,
Semantic,
Keyword,
}
impl fmt::Display for SearchMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Hybrid => f.write_str("hybrid"),
Self::Semantic => f.write_str("semantic"),
Self::Keyword => f.write_str("keyword"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseSearchModeError(String);
impl fmt::Display for ParseSearchModeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"unknown search mode {:?}; expected hybrid, semantic, or keyword",
self.0
)
}
}
impl std::error::Error for ParseSearchModeError {}
impl FromStr for SearchMode {
type Err = ParseSearchModeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"hybrid" => Ok(Self::Hybrid),
"semantic" => Ok(Self::Semantic),
"keyword" => Ok(Self::Keyword),
other => Err(ParseSearchModeError(other.to_string())),
}
}
}
pub struct HybridIndex {
pub semantic: SearchIndex,
bm25: Bm25Index,
}
impl HybridIndex {
pub fn new(
chunks: Vec<CodeChunk>,
embeddings: &[Vec<f32>],
cascade_dim: Option<usize>,
) -> crate::Result<Self> {
let bm25 = Bm25Index::build(&chunks)?;
let semantic = SearchIndex::new(chunks, embeddings, cascade_dim);
Ok(Self { semantic, bm25 })
}
#[must_use]
pub fn from_parts(semantic: SearchIndex, bm25: Bm25Index) -> Self {
Self { semantic, bm25 }
}
#[must_use]
pub fn search(
&self,
query_embedding: &[f32],
query_text: &str,
top_k: usize,
threshold: f32,
mode: SearchMode,
) -> Vec<(usize, f32)> {
let mut raw = match mode {
SearchMode::Semantic => {
self.semantic
.rank_turboquant(query_embedding, top_k.max(100), 0.0)
}
SearchMode::Keyword => self.bm25.search(query_text, top_k.max(100)),
SearchMode::Hybrid => {
let sem = self
.semantic
.rank_turboquant(query_embedding, top_k.max(100), 0.0);
let kw = self.bm25.search(query_text, top_k.max(100));
rrf_fuse(&sem, &kw, 60.0)
}
};
if let (Some(max), Some(min)) = (raw.first().map(|(_, s)| *s), raw.last().map(|(_, s)| *s))
{
let range = max - min;
if range > f32::EPSILON {
for (_, score) in &mut raw {
*score = (*score - min) / range;
}
} else {
for (_, score) in &mut raw {
*score = 1.0;
}
}
}
raw.retain(|(_, score)| *score >= threshold);
raw.truncate(top_k);
raw
}
#[must_use]
pub fn chunks(&self) -> &[CodeChunk] {
&self.semantic.chunks
}
}
#[must_use]
pub fn rrf_fuse(semantic: &[(usize, f32)], bm25: &[(usize, f32)], k: f32) -> Vec<(usize, f32)> {
let mut scores: HashMap<usize, f32> = HashMap::new();
for (rank, &(idx, _)) in semantic.iter().enumerate() {
*scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
}
for (rank, &(idx, _)) in bm25.iter().enumerate() {
*scores.entry(idx).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
}
let mut results: Vec<(usize, f32)> = scores.into_iter().collect();
results.sort_unstable_by(|a, b| {
b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)) });
results
}
const PAGERANK_SIGMOID_STEEPNESS: f32 = 0.15;
#[must_use]
pub fn pagerank_boost_factor(percentile: f32, alpha: f32) -> f32 {
if percentile <= 0.0 || alpha <= 0.0 {
return 1.0;
}
let z = (percentile.clamp(0.0, 1.0) - 0.5) / PAGERANK_SIGMOID_STEEPNESS;
let sigmoid = 1.0 / (1.0 + (-z).exp());
1.0 + alpha * sigmoid
}
pub fn boost_with_pagerank<S: std::hash::BuildHasher>(
results: &mut [(usize, f32)],
chunks: &[CodeChunk],
pagerank_by_file: &HashMap<String, f32, S>,
alpha: f32,
) {
for (idx, score) in results.iter_mut() {
if let Some(chunk) = chunks.get(*idx) {
let rank = lookup_rank(pagerank_by_file, &chunk.file_path, &chunk.name);
*score *= pagerank_boost_factor(rank, alpha);
}
}
results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
}
pub fn boost_with_pagerank_results<S: std::hash::BuildHasher>(
results: &mut [crate::embed::SearchResult],
pagerank_by_file: &HashMap<String, f32, S>,
alpha: f32,
) {
for r in results.iter_mut() {
let rank = lookup_rank(pagerank_by_file, &r.chunk.file_path, &r.chunk.name);
r.similarity *= pagerank_boost_factor(rank, alpha);
}
results.sort_unstable_by(|a, b| b.similarity.total_cmp(&a.similarity));
}
pub(crate) fn lookup_rank_for_chunk<S: std::hash::BuildHasher>(
pr: &HashMap<String, f32, S>,
file_path: &str,
name: &str,
) -> f32 {
lookup_rank(pr, file_path, name)
}
fn lookup_rank<S: std::hash::BuildHasher>(
pr: &HashMap<String, f32, S>,
file_path: &str,
name: &str,
) -> f32 {
let def_key = format!("{file_path}::{name}");
if let Some(&r) = pr.get(&def_key) {
return r;
}
if let Some(&r) = pr.get(file_path) {
return r;
}
let mut rest = file_path;
while let Some(idx) = rest.find('/') {
rest = &rest[idx + 1..];
if rest.is_empty() {
break;
}
let def_key = format!("{rest}::{name}");
if let Some(&r) = pr.get(&def_key) {
return r;
}
if let Some(&r) = pr.get(rest) {
return r;
}
}
0.0
}
#[must_use]
pub fn pagerank_lookup(graph: &crate::repo_map::RepoGraph) -> HashMap<String, f32> {
let def_pct = make_percentile_fn(&graph.def_ranks);
let base_pct = make_percentile_fn(&graph.base_ranks);
let mut map = HashMap::new();
for (file_idx, file) in graph.files.iter().enumerate() {
for (def_idx, def) in file.defs.iter().enumerate() {
let flat = graph.def_offsets[file_idx] + def_idx;
if let Some(&rank) = graph.def_ranks.get(flat) {
let key = format!("{}::{}", file.path, def.name);
map.insert(key, def_pct(rank));
}
}
if file_idx < graph.base_ranks.len() {
map.insert(file.path.clone(), base_pct(graph.base_ranks[file_idx]));
}
}
map
}
fn make_percentile_fn(values: &[f32]) -> impl Fn(f32) -> f32 + '_ {
let mut sorted: Vec<f32> = values.iter().copied().filter(|v| v.is_finite()).collect();
sorted.sort_unstable_by(f32::total_cmp);
move |value: f32| {
if sorted.is_empty() {
return 0.0;
}
let count_below = sorted.partition_point(|&v| v < value);
#[expect(
clippy::cast_precision_loss,
reason = "rank counts well below f32 precision threshold"
)]
let pct = count_below as f32 / sorted.len() as f32;
pct
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rrf_union_semantics() {
let sem = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
let bm25 = vec![(3, 10.0), (0, 8.0), (4, 6.0)];
let fused = rrf_fuse(&sem, &bm25, 60.0);
let indices: Vec<usize> = fused.iter().map(|&(i, _)| i).collect();
for expected in [0, 1, 2, 3, 4] {
assert!(
indices.contains(&expected),
"chunk {expected} missing from fused results"
);
}
assert_eq!(fused.len(), 5);
assert_eq!(indices[0], 0, "chunk 0 should rank first");
}
#[test]
fn rrf_single_list() {
let sem = vec![(0, 0.9), (1, 0.8)];
let bm25: Vec<(usize, f32)> = vec![];
let fused = rrf_fuse(&sem, &bm25, 60.0);
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].0, 0);
assert_eq!(fused[1].0, 1);
assert!(fused[0].1 > fused[1].1);
}
#[test]
fn search_mode_roundtrip() {
assert_eq!("hybrid".parse::<SearchMode>().unwrap(), SearchMode::Hybrid);
assert_eq!(
"semantic".parse::<SearchMode>().unwrap(),
SearchMode::Semantic
);
assert_eq!(
"keyword".parse::<SearchMode>().unwrap(),
SearchMode::Keyword
);
let err = "invalid".parse::<SearchMode>();
assert!(err.is_err(), "expected parse error for 'invalid'");
let msg = err.unwrap_err().to_string();
assert!(
msg.contains("invalid"),
"error message should echo the bad input"
);
}
#[test]
fn search_mode_display() {
assert_eq!(SearchMode::Hybrid.to_string(), "hybrid");
assert_eq!(SearchMode::Semantic.to_string(), "semantic");
assert_eq!(SearchMode::Keyword.to_string(), "keyword");
}
#[test]
fn pagerank_boost_amplifies_relevant() {
let chunks = vec![
CodeChunk {
file_path: "important.rs".into(),
name: "a".into(),
kind: "function".into(),
start_line: 1,
end_line: 10,
content: String::new(),
enriched_content: String::new(),
},
CodeChunk {
file_path: "obscure.rs".into(),
name: "b".into(),
kind: "function".into(),
start_line: 1,
end_line: 10,
content: String::new(),
enriched_content: String::new(),
},
];
let mut results = vec![(0, 0.8_f32), (1, 0.8)];
let mut pr = HashMap::new();
pr.insert("important.rs".to_string(), 1.0); pr.insert("obscure.rs".to_string(), 0.1);
boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
assert_eq!(
results[0].0, 0,
"important.rs should rank first after boost"
);
assert!(results[0].1 > results[1].1);
assert!(
(results[0].1 - 1.032).abs() < 0.01,
"rank=1.0 boost: expected ~1.032, got {}",
results[0].1
);
assert!(
(results[1].1 - 0.816).abs() < 0.01,
"rank=0.1 boost: expected ~0.816, got {}",
results[1].1
);
}
#[test]
fn pagerank_boost_zero_relevance_stays_zero() {
let chunks = vec![CodeChunk {
file_path: "important.rs".into(),
name: "a".into(),
kind: "function".into(),
start_line: 1,
end_line: 10,
content: String::new(),
enriched_content: String::new(),
}];
let mut results = vec![(0, 0.0_f32)];
let mut pr = HashMap::new();
pr.insert("important.rs".to_string(), 1.0);
boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
assert!(results[0].1.abs() < f32::EPSILON);
}
#[test]
fn pagerank_boost_unknown_file_no_effect() {
let chunks = vec![CodeChunk {
file_path: "unknown.rs".into(),
name: "a".into(),
kind: "function".into(),
start_line: 1,
end_line: 10,
content: String::new(),
enriched_content: String::new(),
}];
let mut results = vec![(0, 0.5_f32)];
let pr = HashMap::new();
boost_with_pagerank(&mut results, &chunks, &pr, 0.3);
assert!((results[0].1 - 0.5).abs() < f32::EPSILON);
}
}