use crate::index::dense::DenseIndex;
use crate::index::sparse::Bm25Index;
use crate::ranking::{
apply_query_boost, boost_multi_chunk_files, is_symbol_query, rerank_topk, resolve_alpha,
};
use crate::tokenizer::tokenize;
use crate::types::{Chunk, SearchMode, SearchResult};
const RRF_K: f64 = 60.0;
const PURE_MODE_CANDIDATE_OVERSHOOT: usize = 5;
#[inline]
fn should_skip_search(query: &str, top_k: usize, chunks: &[Chunk]) -> bool {
chunks.is_empty() || top_k == 0 || query.trim().is_empty()
}
fn add_rrf_with_weight(out: &mut [f64], ranked: &[(usize, f64)], weight: f64) {
for (rank, (idx, _)) in ranked.iter().enumerate() {
if *idx < out.len() {
out[*idx] += weight / (RRF_K + (rank + 1) as f64);
}
}
}
pub fn search_semantic(
query: &str,
model: &model2vec_rs::model::StaticModel,
dense_index: &DenseIndex,
chunks: &[Chunk],
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<SearchResult> {
let _span = tracing::trace_span!("search.semantic", top_k, n_chunks = chunks.len()).entered();
if should_skip_search(query, top_k, chunks) {
return Vec::new();
}
let query_embedding = {
let _s = tracing::trace_span!("search.encode_query").entered();
model.encode(&[query.to_string()])
};
let query_vec = &query_embedding[0];
let candidate_count = top_k.saturating_mul(PURE_MODE_CANDIDATE_OVERSHOOT);
let (indices, similarities) = {
let _s = tracing::trace_span!("search.dense_query", candidate_count).entered();
dense_index.query(query_vec, candidate_count, selector)
};
let mut scores = vec![0.0f64; chunks.len()];
for (idx, sim) in indices.iter().zip(similarities.iter()) {
if *idx < scores.len() {
scores[*idx] = *sim as f64;
}
}
finalize_pure_mode(scores, query, chunks, top_k, SearchMode::Semantic)
}
pub fn search_bm25(
query: &str,
bm25_index: &Bm25Index,
chunks: &[Chunk],
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<SearchResult> {
let _span = tracing::trace_span!("search.bm25", top_k, n_chunks = chunks.len()).entered();
if should_skip_search(query, top_k, chunks) {
return Vec::new();
}
let tokens = bm25_query_tokens(query);
if tokens.is_empty() {
return Vec::new();
}
let candidate_count = top_k.saturating_mul(PURE_MODE_CANDIDATE_OVERSHOOT);
let raw = {
let _s = tracing::trace_span!("search.bm25_topk", n_tokens = tokens.len()).entered();
bm25_index.top_k(&tokens, candidate_count, selector)
};
let mut scores = vec![0.0f64; chunks.len()];
for (idx, score) in &raw {
if *idx < scores.len() {
scores[*idx] = *score;
}
}
finalize_pure_mode(scores, query, chunks, top_k, SearchMode::Bm25)
}
fn bm25_query_tokens(query: &str) -> Vec<String> {
if is_symbol_query(query) {
let trimmed = query.trim().to_lowercase();
if trimmed.is_empty() {
Vec::new()
} else {
vec![trimmed]
}
} else {
tokenize(query)
}
}
fn finalize_pure_mode(
mut scores: Vec<f64>,
query: &str,
chunks: &[Chunk],
top_k: usize,
source: SearchMode,
) -> Vec<SearchResult> {
{
let _s = tracing::trace_span!("search.boost_multi_chunk").entered();
boost_multi_chunk_files(&mut scores, chunks);
}
{
let _s = tracing::trace_span!("search.apply_query_boost").entered();
apply_query_boost(&mut scores, query, chunks);
}
let ranked = {
let _s = tracing::trace_span!("search.rerank_topk").entered();
rerank_topk(&scores, chunks, top_k, true)
};
ranked
.into_iter()
.map(|(idx, score)| SearchResult {
chunk: chunks[idx].clone(),
score,
source,
})
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn search_hybrid(
query: &str,
model: &model2vec_rs::model::StaticModel,
dense_index: &DenseIndex,
bm25_index: &Bm25Index,
chunks: &[Chunk],
top_k: usize,
alpha: Option<f64>,
selector: Option<&[usize]>,
) -> Vec<SearchResult> {
let _span = tracing::trace_span!("search.hybrid", top_k, n_chunks = chunks.len()).entered();
if should_skip_search(query, top_k, chunks) {
return Vec::new();
}
let alpha_weight = resolve_alpha(query, alpha);
let candidate_count = top_k * 5;
let n = chunks.len();
let query_emb = {
let _s = tracing::trace_span!("search.encode_query").entered();
model.encode(&[query.to_string()])
};
let (sem_idx, sem_sim) = {
let _s = tracing::trace_span!("search.dense_query", candidate_count).entered();
dense_index.query(&query_emb[0], candidate_count, selector)
};
let sem_topk: Vec<(usize, f64)> = sem_idx
.into_iter()
.zip(sem_sim)
.map(|(i, s)| (i, s as f64))
.collect();
let tokens = tokenize(query);
let bm25_topk = if tokens.is_empty() {
Vec::new()
} else {
let _s = tracing::trace_span!("search.bm25_topk", n_tokens = tokens.len()).entered();
bm25_index.top_k(&tokens, candidate_count, selector)
};
let mut combined: Vec<f64> = vec![0.0f64; n];
add_rrf_with_weight(&mut combined, &sem_topk, alpha_weight);
if !bm25_topk.is_empty() {
add_rrf_with_weight(&mut combined, &bm25_topk, 1.0 - alpha_weight);
}
{
let _s = tracing::trace_span!("search.boost_multi_chunk").entered();
boost_multi_chunk_files(&mut combined, chunks);
}
{
let _s = tracing::trace_span!("search.apply_query_boost").entered();
apply_query_boost(&mut combined, query, chunks);
}
let ranked = {
let _s = tracing::trace_span!("search.rerank_topk").entered();
rerank_topk(&combined, chunks, top_k, alpha_weight < 1.0)
};
ranked
.into_iter()
.map(|(idx, score)| SearchResult {
chunk: chunks[idx].clone(),
score,
source: SearchMode::Hybrid,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_rrf_with_weight_basic() {
let mut out = vec![0.0f64; 4];
let ranked = vec![(2, 10.0), (0, 5.0)];
add_rrf_with_weight(&mut out, &ranked, 1.0);
assert!(out[2] > out[0]); assert_eq!(out[1], 0.0);
assert_eq!(out[3], 0.0);
}
#[test]
fn add_rrf_with_weight_scales_and_accumulates() {
let mut out = vec![0.0f64; 3];
let sem = vec![(0, 1.0), (1, 0.5)];
let bm25 = vec![(1, 1.0), (2, 0.5)];
add_rrf_with_weight(&mut out, &sem, 0.5);
add_rrf_with_weight(&mut out, &bm25, 0.5);
assert!(out[0] > 0.0); assert!(out[2] > 0.0); assert!(out[1] > out[0]);
assert!(out[1] > out[2]);
}
fn dummy_chunk() -> Chunk {
Chunk {
content: "fn foo() {}".to_string(),
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 1,
language: Some("rust".to_string()),
}
}
#[test]
fn skip_when_chunks_empty() {
let none: Vec<Chunk> = Vec::new();
assert!(should_skip_search("anything", 5, &none));
}
#[test]
fn skip_when_top_k_zero() {
let chunks = vec![dummy_chunk()];
assert!(should_skip_search("anything", 0, &chunks));
}
#[test]
fn skip_when_query_blank() {
let chunks = vec![dummy_chunk()];
assert!(should_skip_search("", 5, &chunks));
assert!(should_skip_search(" ", 5, &chunks));
assert!(should_skip_search("\t\n", 5, &chunks));
}
#[test]
fn proceed_with_real_inputs() {
let chunks = vec![dummy_chunk()];
assert!(!should_skip_search("hello", 5, &chunks));
}
}