use std::time::Instant;
use crate::cache::rerank as rerank_cache;
use crate::inference::RerankClient;
use rusqlite::Connection;
use super::chunk_excerpt::{
CHUNK_INTENT_TERM_WEIGHT, CHUNK_QUERY_TERM_WEIGHT, focused_chunk_excerpt,
};
use super::fuse::blend_rerank_probabilities;
use super::hooks::SearchHooks;
use super::intent;
use super::types::RawSearchResult;
pub(crate) struct IntentRerankOptions<'a> {
pub(crate) conn: &'a Connection,
pub(crate) rerank: &'a RerankClient,
pub(crate) query: &'a str,
pub(crate) intent: Option<&'a str>,
pub(crate) candidates: Vec<RawSearchResult>,
pub(crate) top_k: u32,
pub(crate) hooks: &'a SearchHooks,
pub(crate) db_version: u64,
}
#[cfg(test)]
const fn position_weights(rank_index: usize) -> (f64, f64) {
if rank_index < 10 {
(0.75, 0.25)
} else if rank_index < 20 {
(0.60, 0.40)
} else {
(0.40, 0.60)
}
}
fn rerank_text(candidate: &RawSearchResult) -> String {
format!("{}\n\n{}", candidate.title, candidate.snippet)
}
#[must_use]
pub fn rerank_candidates(
rerank: &RerankClient,
query: &str,
candidates: Vec<RawSearchResult>,
top_k: u32,
hooks: &SearchHooks,
) -> Vec<RawSearchResult> {
let options = RerankOptions {
conn: None,
rerank,
query,
intent: None,
candidates,
top_k,
hooks,
db_version: None,
};
rerank_candidates_inner(options)
}
#[must_use]
pub fn rerank_candidates_with_db_version(
rerank: &RerankClient,
query: &str,
candidates: Vec<RawSearchResult>,
top_k: u32,
hooks: &SearchHooks,
db_version: u64,
) -> Vec<RawSearchResult> {
let options = RerankOptions {
conn: None,
rerank,
query,
intent: None,
candidates,
top_k,
hooks,
db_version: Some(db_version),
};
rerank_candidates_inner(options)
}
#[must_use]
pub(crate) fn rerank_candidates_with_intent(
options: IntentRerankOptions<'_>,
) -> Vec<RawSearchResult> {
rerank_candidates_inner(RerankOptions {
conn: Some(options.conn),
rerank: options.rerank,
query: options.query,
intent: options.intent,
candidates: options.candidates,
top_k: options.top_k,
hooks: options.hooks,
db_version: Some(options.db_version),
})
}
struct RerankOptions<'a> {
conn: Option<&'a Connection>,
rerank: &'a RerankClient,
query: &'a str,
intent: Option<&'a str>,
candidates: Vec<RawSearchResult>,
top_k: u32,
hooks: &'a SearchHooks,
db_version: Option<u64>,
}
fn rerank_candidates_inner(options: RerankOptions<'_>) -> Vec<RawSearchResult> {
let RerankOptions {
conn,
rerank,
query,
intent,
candidates,
top_k,
hooks,
db_version,
} = options;
if candidates.is_empty() {
return candidates;
}
let limit = (top_k as usize).min(candidates.len());
let mut active: Vec<RawSearchResult> = candidates.into_iter().take(limit).collect();
if let Some(conn) = conn {
select_best_chunks_for_rerank(conn, query, intent, &mut active);
}
let rerank_query = intent::prefix_query(intent, query);
hooks.emit_rerank_start(active.len());
let started = Instant::now();
let texts: Vec<String> = active.iter().map(rerank_text).collect();
let mut scores: Vec<Option<f64>> = vec![None; limit];
let mut missing_indices = Vec::new();
let mut missing_texts = Vec::new();
for (index, text) in texts.iter().enumerate() {
if let Some(score) =
db_version.and_then(|db_version| rerank_cache::lookup(text, &rerank_query, db_version))
{
scores[index] = Some(score);
} else {
missing_indices.push(index);
missing_texts.push(text.clone());
}
}
if !missing_texts.is_empty() {
let Ok(rerank_results) = rerank.rerank(&rerank_query, &missing_texts, false) else {
let elapsed_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
hooks.emit_rerank_end(elapsed_ms);
return active;
};
for result in rerank_results {
let Some(original_index) = missing_indices.get(result.index as usize).copied() else {
continue;
};
let score = f64::from(result.score).clamp(0.0, 1.0);
if let Some(slot) = scores.get_mut(original_index) {
*slot = Some(score);
}
if let (Some(text), Some(db_version)) = (texts.get(original_index), db_version) {
rerank_cache::store(text, &rerank_query, score, db_version);
}
}
}
let blended = blend_rerank_probabilities(&active, &scores);
let elapsed_ms = u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX);
hooks.emit_rerank_end(elapsed_ms);
blended
}
fn select_best_chunks_for_rerank(
conn: &Connection,
query: &str,
intent: Option<&str>,
candidates: &mut [RawSearchResult],
) {
let query_terms = intent::extract_terms(query);
let intent_terms = intent.map(intent::extract_terms).unwrap_or_default();
if query_terms.is_empty() && intent_terms.is_empty() {
return;
}
for candidate in candidates {
let Some(chunk) =
best_chunk_for_candidate(conn, &candidate.path, &query_terms, &intent_terms)
else {
continue;
};
candidate.snippet = chunk.text;
candidate.semantic_heading = chunk.heading_path;
candidate.semantic_char_start = chunk.char_start;
candidate.semantic_char_end = chunk.char_end;
}
}
struct RerankChunk {
text: String,
heading_path: Option<String>,
char_start: Option<u32>,
char_end: Option<u32>,
}
fn best_chunk_for_candidate(
conn: &Connection,
path: &str,
query_terms: &[String],
intent_terms: &[String],
) -> Option<RerankChunk> {
let mut stmt = conn
.prepare_cached(
"SELECT c.text, c.heading_path, c.char_start, c.char_end, c.chunk_index
FROM chunks c
JOIN notes n ON n.id = c.note_id
WHERE n.vault_path = ?1
ORDER BY c.chunk_index",
)
.ok()?;
let rows = stmt
.query_map([path], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, Option<String>>(1)?,
row.get::<_, Option<i64>>(2)?,
row.get::<_, Option<i64>>(3)?,
row.get::<_, i64>(4)?,
))
})
.ok()?;
let mut best: Option<(u32, i64, RerankChunk)> = None;
for row in rows {
let Ok((text, heading_path, char_start, char_end, chunk_index)) = row else {
continue;
};
let score = chunk_term_score_units(&text, query_terms, intent_terms);
let text = focused_chunk_excerpt(&text, query_terms, intent_terms);
let chunk = RerankChunk {
text,
heading_path,
char_start: char_start.and_then(|value| u32::try_from(value).ok()),
char_end: char_end.and_then(|value| u32::try_from(value).ok()),
};
match &best {
Some((best_score, best_index, _))
if score < *best_score || (score == *best_score && chunk_index >= *best_index) => {}
_ => best = Some((score, chunk_index, chunk)),
}
}
best.map(|(_, _, chunk)| chunk)
}
fn chunk_term_score_units(text: &str, query_terms: &[String], intent_terms: &[String]) -> u32 {
let chunk = crate::text::nfd::normalize(text).to_lowercase();
let query_hits = query_terms
.iter()
.filter(|term| chunk.contains(term.as_str()))
.fold(0_u32, |count, _| count.saturating_add(1));
let intent_hits = intent_terms
.iter()
.filter(|term| chunk.contains(term.as_str()))
.fold(0_u32, |count, _| count.saturating_add(1));
query_hits
.saturating_mul(CHUNK_QUERY_TERM_WEIGHT)
.saturating_add(intent_hits.saturating_mul(CHUNK_INTENT_TERM_WEIGHT))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[path = "rerank_pipeline_intent_tests.rs"]
mod intent_tests;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[path = "rerank_pipeline_tests.rs"]
mod tests;