use std::cmp::Ordering;
use std::collections::HashMap;
use rusqlite::{params, Connection};
use crate::db::LexaDb;
use crate::embed::{matryoshka_truncate, vector_blob, PREVIEW_DIMS};
use crate::query::{fts_query, tokenize};
use crate::types::{LexaError, SearchHit, SearchTier, TierBreakdown};
use crate::Result;
const RRF_K: f32 = 60.0;
const SPARSE_TOP_K: usize = 50;
const DENSE_TOP_K: usize = 50;
const PREVIEW_TOP_K: usize = DENSE_TOP_K * 8;
const RERANK_CANDIDATES: usize = 15;
const RERANK_BLEND: f32 = 0.7;
const EXCERPT_MAX_CHARS: usize = 500;
const HIGHLIGHT_TARGET_CHARS: usize = 220;
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub query: String,
pub tier: SearchTier,
pub limit: usize,
pub additional_queries: Vec<String>,
}
impl SearchOptions {
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
tier: SearchTier::Auto,
limit: 10,
additional_queries: Vec::new(),
}
}
}
pub fn search_impl(db: &LexaDb, options: &SearchOptions) -> Result<Vec<SearchHit>> {
let conn = db.conn();
let limit = options.limit.max(1);
let (effective_tier, routed_to) = if options.tier == SearchTier::Auto {
let routed = classify_query(&options.query);
(routed, Some(routed))
} else {
(options.tier, None)
};
let mut hits = match effective_tier {
SearchTier::Auto => unreachable!("Auto resolves to a concrete tier above"),
SearchTier::Instant => {
let bm25 = bm25_search(conn, &options.query, SPARSE_TOP_K)?;
hydrate(conn, &options.query, &rank_to_rrf(&bm25), &bm25, &[], limit)?.0
}
SearchTier::Dense => {
let vector = vector_search(db, &options.query, DENSE_TOP_K)?;
hydrate(
conn,
&options.query,
&rank_to_rrf(&vector),
&[],
&vector,
limit,
)?
.0
}
SearchTier::Fast | SearchTier::Deep => {
let embedder_lock = db.embedder()?;
let query_str = options.query.as_str();
let (bm25, embedding) = std::thread::scope(|scope| -> Result<_> {
let embed_handle = scope.spawn(|| -> Result<Vec<f32>> {
let mut guard = embedder_lock
.lock()
.map_err(|err| LexaError::Embedding(err.to_string()))?;
guard.embed_query(query_str)
});
let bm25 = bm25_search(conn, query_str, SPARSE_TOP_K)?;
let embedding = embed_handle
.join()
.map_err(|_| LexaError::Embedding("embed worker panicked".into()))??;
Ok((bm25, embedding))
})?;
let vector = vector_knn(conn, &embedding, DENSE_TOP_K)?;
let fused =
if effective_tier == SearchTier::Deep && !options.additional_queries.is_empty() {
let mut all_lists: Vec<Vec<(i64, f32)>> =
Vec::with_capacity(2 + options.additional_queries.len() * 2);
all_lists.push(bm25.clone());
all_lists.push(vector.clone());
for extra in &options.additional_queries {
let extra_str = extra.as_str();
let (extra_bm25, extra_emb) = std::thread::scope(|scope| -> Result<_> {
let h = scope.spawn(|| -> Result<Vec<f32>> {
let mut guard = embedder_lock
.lock()
.map_err(|err| LexaError::Embedding(err.to_string()))?;
guard.embed_query(extra_str)
});
let b = bm25_search(conn, extra_str, SPARSE_TOP_K)?;
let e = h.join().map_err(|_| {
LexaError::Embedding("embed worker panicked".into())
})??;
Ok((b, e))
})?;
let extra_vec = vector_knn(conn, &extra_emb, DENSE_TOP_K)?;
all_lists.push(extra_bm25);
all_lists.push(extra_vec);
}
let refs: Vec<&[(i64, f32)]> = all_lists.iter().map(Vec::as_slice).collect();
fuse_many(&refs)
} else {
fuse(&bm25, &vector)
};
let candidate_count = if effective_tier == SearchTier::Deep {
RERANK_CANDIDATES
} else {
limit
};
let (mut hits, full_texts) = hydrate(
conn,
&options.query,
&fused,
&bm25,
&vector,
candidate_count,
)?;
if effective_tier == SearchTier::Deep && !hits.is_empty() {
rerank(db, &options.query, &mut hits, &full_texts)?;
}
hits.truncate(limit);
hits
}
};
if let Some(tier) = routed_to {
for hit in &mut hits {
hit.breakdown.routed_to = Some(tier);
}
}
Ok(hits)
}
fn classify_query(query: &str) -> SearchTier {
let trimmed = query.trim();
if let Some(rest) = trimmed.strip_prefix("[deep]") {
let _ = rest;
return SearchTier::Deep;
}
let tokens: Vec<&str> = trimmed
.split_whitespace()
.filter(|tok| tok.chars().any(char::is_alphanumeric))
.collect();
if tokens.is_empty() {
return SearchTier::Fast;
}
if tokens.len() == 1 {
let tok = tokens[0];
let snake_case = tok.contains('_') && tok.chars().any(|c| c.is_ascii_alphanumeric());
let mixed_case = tok.chars().any(|c| c.is_ascii_uppercase())
&& tok.chars().any(|c| c.is_ascii_lowercase());
let path_like = tok.contains("::") || (tok.contains('.') && !tok.ends_with('.'));
if snake_case || mixed_case || path_like {
return SearchTier::Instant;
}
}
if tokens.len() >= 6 && trimmed.ends_with('?') {
return SearchTier::Deep;
}
SearchTier::Fast
}
fn vector_knn(conn: &Connection, embedding: &[f32], limit: usize) -> Result<Vec<(i64, f32)>> {
let preview_blob = vector_blob(&matryoshka_truncate(embedding, PREVIEW_DIMS));
let mut preview_stmt = conn.prepare_cached(
"SELECT rowid
FROM vectors_bin_preview
WHERE embedding MATCH vec_quantize_binary(?1) AND k = ?2
ORDER BY distance",
)?;
let preview_ids: Vec<i64> = preview_stmt
.query_map(params![preview_blob, PREVIEW_TOP_K as i64], |row| {
row.get::<_, i64>(0)
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
if preview_ids.is_empty() {
return Ok(Vec::new());
}
let full_blob = vector_blob(embedding);
let preview_ids_json = serde_json::to_string(&preview_ids)?;
let mut rescore_stmt = conn.prepare_cached(
"SELECT v.rowid,
vec_distance_hamming(v.embedding, vec_quantize_binary(?1)) AS distance
FROM vectors_bin AS v
WHERE v.rowid IN (SELECT value FROM json_each(?2))
ORDER BY distance
LIMIT ?3",
)?;
let rows =
rescore_stmt.query_map(params![full_blob, preview_ids_json, limit as i64], |row| {
let id: i64 = row.get(0)?;
let distance: f64 = row.get(1)?;
Ok((id, (1.0 / (1.0 + distance)) as f32))
})?;
rows.collect::<std::result::Result<Vec<_>, _>>()
.map_err(Into::into)
}
fn bm25_search(conn: &Connection, query: &str, limit: usize) -> Result<Vec<(i64, f32)>> {
let fts_query = fts_query(query);
if fts_query.is_empty() {
return Ok(Vec::new());
}
let mut stmt = conn.prepare_cached(
"SELECT rowid, bm25(chunks_fts) AS rank
FROM chunks_fts
WHERE chunks_fts MATCH ?1
ORDER BY rank
LIMIT ?2",
)?;
let rows = stmt.query_map(params![fts_query, limit as i64], |row| {
let id: i64 = row.get(0)?;
let rank: f64 = row.get(1)?;
Ok((id, (1.0 / (1.0 + rank.abs())) as f32))
})?;
rows.collect::<std::result::Result<Vec<_>, _>>()
.map_err(Into::into)
}
fn vector_search(db: &LexaDb, query: &str, limit: usize) -> Result<Vec<(i64, f32)>> {
let embedding = {
let lock = db.embedder()?;
let mut guard = lock
.lock()
.map_err(|err| LexaError::Embedding(err.to_string()))?;
guard.embed_query(query)?
};
vector_knn(db.conn(), &embedding, limit)
}
fn fuse_many(lists: &[&[(i64, f32)]]) -> Vec<(i64, f32)> {
let mut scores = HashMap::<i64, f32>::new();
for list in lists {
for (rank, (id, _)) in list.iter().enumerate() {
*scores.entry(*id).or_default() += 1.0 / (RRF_K + rank as f32 + 1.0);
}
}
let mut fused: Vec<_> = scores.into_iter().collect();
fused.sort_by(score_desc);
fused
}
fn fuse(bm25: &[(i64, f32)], vector: &[(i64, f32)]) -> Vec<(i64, f32)> {
fuse_many(&[bm25, vector])
}
fn rank_to_rrf(items: &[(i64, f32)]) -> Vec<(i64, f32)> {
items
.iter()
.enumerate()
.map(|(rank, (id, _))| (*id, 1.0 / (RRF_K + rank as f32 + 1.0)))
.collect()
}
fn hydrate(
conn: &Connection,
query: &str,
ranked: &[(i64, f32)],
bm25: &[(i64, f32)],
vector: &[(i64, f32)],
limit: usize,
) -> Result<(Vec<SearchHit>, Vec<String>)> {
let bm25_rank = ranks(bm25);
let vector_rank = ranks(vector);
let bm25_scores = score_map(bm25);
let vector_scores = score_map(vector);
let mut hits = Vec::new();
let mut full_texts = Vec::new();
let mut stmt = conn.prepare_cached(
"SELECT d.path, c.line_start, c.line_end, c.text, c.context
FROM chunks c JOIN documents d ON d.id = c.doc_id
WHERE c.id = ?1",
)?;
for (id, score) in ranked.iter().take(limit) {
let (hit, text) = stmt.query_row(params![id], |row| {
let text: String = row.get(3)?;
let heading: Option<String> = row.get(4)?;
let hit = SearchHit {
path: row.get(0)?,
line_start: row.get(1)?,
line_end: row.get(2)?,
score: *score,
excerpt: highlight(query, &text),
heading,
breakdown: TierBreakdown {
bm25_rank: bm25_rank.get(id).copied(),
vector_rank: vector_rank.get(id).copied(),
bm25_score: bm25_scores.get(id).copied().unwrap_or_default(),
vector_score: vector_scores.get(id).copied().unwrap_or_default(),
rerank_score: None,
routed_to: None,
},
};
Ok((hit, text))
})?;
hits.push(hit);
full_texts.push(text);
}
Ok((hits, full_texts))
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn rerank(db: &LexaDb, query: &str, hits: &mut [SearchHit], full_texts: &[String]) -> Result<()> {
let docs: Vec<String> = full_texts.to_vec();
let scores = {
let lock = db.reranker()?;
let mut guard = lock
.lock()
.map_err(|err| LexaError::Embedding(err.to_string()))?;
guard.rerank(query, &docs)?
};
for (idx, raw_score) in scores {
if let Some(hit) = hits.get_mut(idx) {
let rrf = hit.score;
let squashed = sigmoid(raw_score);
hit.score = RERANK_BLEND * squashed + (1.0 - RERANK_BLEND) * rrf;
hit.breakdown.rerank_score = Some(raw_score);
}
}
hits.sort_by(|left, right| {
right
.score
.partial_cmp(&left.score)
.unwrap_or(Ordering::Equal)
});
Ok(())
}
fn ranks(items: &[(i64, f32)]) -> HashMap<i64, usize> {
items
.iter()
.enumerate()
.map(|(idx, (id, _))| (*id, idx + 1))
.collect()
}
fn score_map(items: &[(i64, f32)]) -> HashMap<i64, f32> {
items.iter().copied().collect()
}
fn score_desc(left: &(i64, f32), right: &(i64, f32)) -> Ordering {
right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal)
}
fn highlight(query: &str, text: &str) -> String {
let query_tokens: std::collections::HashSet<String> = tokenize(query).collect();
if query_tokens.is_empty() {
return excerpt(text);
}
let compact = text.split_whitespace().collect::<Vec<_>>().join(" ");
if compact.is_empty() {
return String::new();
}
let sentences = split_sentences(&compact);
if sentences.is_empty() {
return excerpt(&compact);
}
let scores: Vec<(usize, usize)> = sentences
.iter()
.enumerate()
.map(|(idx, sentence)| {
let tokens: std::collections::HashSet<String> = tokenize(sentence).collect();
let overlap = query_tokens.intersection(&tokens).count();
(idx, overlap)
})
.collect();
let best = scores.iter().max_by_key(|(_, score)| *score).copied();
let Some((best_idx, best_score)) = best else {
return excerpt(&compact);
};
if best_score == 0 {
return excerpt(&compact);
}
let mut start = best_idx;
let mut end = best_idx;
let mut span_len = sentences[best_idx].len();
while span_len < HIGHLIGHT_TARGET_CHARS {
let grew = if start > 0
&& (end + 1 == sentences.len() || start.abs_diff(0) <= end + 1 - best_idx)
{
start -= 1;
span_len += sentences[start].len() + 1;
true
} else if end + 1 < sentences.len() {
end += 1;
span_len += sentences[end].len() + 1;
true
} else {
false
};
if !grew {
break;
}
}
let span: String = sentences[start..=end].join(" ");
let cap = HIGHLIGHT_TARGET_CHARS * 3 / 2;
if span.len() <= cap {
span
} else {
let mut cut = cap;
while cut > 0 && !span.is_char_boundary(cut) {
cut -= 1;
}
format!("{}...", &span[..cut])
}
}
fn split_sentences(text: &str) -> Vec<&str> {
let bytes = text.as_bytes();
let mut starts = vec![0];
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if matches!(b, b'.' | b'!' | b'?' | b';' | b'\n')
&& i + 1 < bytes.len()
&& (bytes[i + 1] == b' ' || bytes[i + 1] == b'\n' || bytes[i + 1] == b'\t')
{
let mut j = i + 1;
while j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'\n' || bytes[j] == b'\t') {
j += 1;
}
if j < bytes.len() && text.is_char_boundary(j) {
starts.push(j);
}
i = j;
continue;
}
i += 1;
}
starts.push(text.len());
starts
.windows(2)
.map(|w| text[w[0]..w[1]].trim())
.filter(|s| !s.is_empty())
.collect()
}
fn excerpt(text: &str) -> String {
let compact = text.split_whitespace().collect::<Vec<_>>().join(" ");
if compact.len() <= EXCERPT_MAX_CHARS {
return compact;
}
let mut end = EXCERPT_MAX_CHARS;
while end > 0 && !compact.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &compact[..end])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rrf_boosts_overlap() {
let bm25 = vec![(1, 1.0), (2, 0.8)];
let vector = vec![(3, 1.0), (1, 0.7)];
let fused = fuse(&bm25, &vector);
assert_eq!(fused[0].0, 1);
}
#[test]
fn classify_routes_single_identifier_to_instant() {
assert_eq!(classify_query("vec_quantize_binary"), SearchTier::Instant);
assert_eq!(classify_query("LexaDb::open"), SearchTier::Instant);
assert_eq!(classify_query("Embedder::embed_query"), SearchTier::Instant);
}
#[test]
fn classify_keeps_natural_language_with_identifiers_on_fast() {
assert_eq!(
classify_query("matryoshka_truncate helper that re-normalizes"),
SearchTier::Fast
);
assert_eq!(
classify_query("the BGE cross encoder reranker"),
SearchTier::Fast
);
}
#[test]
fn classify_routes_explicit_deep_prefix() {
assert_eq!(
classify_query("[deep] explain the rerank pipeline"),
SearchTier::Deep
);
}
#[test]
fn classify_routes_long_questions_to_deep() {
assert_eq!(
classify_query("how does the reranker score truncated excerpts in deep tier?"),
SearchTier::Deep
);
}
#[test]
fn classify_defaults_to_fast() {
assert_eq!(
classify_query("hybrid lexical dense retrieval"),
SearchTier::Fast
);
assert_eq!(
classify_query("binary quantized vector search"),
SearchTier::Fast
);
}
#[test]
fn highlight_picks_query_relevant_sentence() {
let filler: String = "alpha beta gamma delta. ".repeat(20);
let text = format!(
"{filler}\
The reranker scores candidates by cross encoder logits. \
{filler}"
);
let span = highlight("reranker cross encoder logits", &text);
assert!(span.contains("reranker"));
assert!(span.contains("cross encoder"));
assert!(span.len() <= HIGHLIGHT_TARGET_CHARS * 3 / 2 + 4);
}
#[test]
fn highlight_falls_back_when_no_overlap() {
let text = "Some prose without any of the query's words.";
let span = highlight("matryoshka quantization", text);
assert_eq!(span, excerpt(text));
}
#[test]
fn highlight_caps_at_soft_target() {
let prefix: String = "ipsum dolor sit amet. ".repeat(50);
let suffix: String = "vivamus sed lacus. ".repeat(50);
let text = format!("{}TARGET token here. {}", prefix, suffix);
let span = highlight("target token", &text);
assert!(span.len() <= HIGHLIGHT_TARGET_CHARS * 3 / 2 + 4 );
}
#[test]
fn hash_fallback_can_score() {
let query = crate::embed::hash_embedding("config validation");
let doc = crate::embed::hash_embedding("configuration validation function");
assert!(crate::embed::cosine(&query, &doc) > -1.0);
}
}