use anyhow::Result;
use rusqlite::ToSql;
use serde::Serialize;
use std::fmt::Write as _;
use crate::embed::{
DEFAULT_EMBED_MODEL, EmbedRole, EmbeddingBackend, OllamaClient, VEC_MIRROR_DIM, blob_to_f32s,
cosine_similarity, embedding_stats, f32s_to_blob, prepare_embedding_text,
};
use crate::inspect::now_unix;
use crate::store::{Store, VEC_MIRROR_TABLE};
#[derive(Debug, Clone, Serialize)]
pub struct SearchHit {
pub chunk_id: String,
pub source_id: String,
pub uri: String,
pub path: Option<String>,
pub kind: String,
pub ordinal: i64,
pub byte_start: i64,
pub byte_end: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub turn_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp_unix: Option<i64>,
pub access_count: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_accessed_at: Option<i64>,
pub feedback_score: i64,
pub score: f64,
pub confidence: f64,
pub snippet: String,
pub text: String,
}
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub limit: usize,
pub kind: Option<String>,
pub path_contains: Option<String>,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
limit: 10,
kind: None,
path_contains: None,
}
}
}
pub(crate) fn compute_confidence(
now_unix: i64,
last_accessed_at: Option<i64>,
timestamp_unix: Option<i64>,
access_count: i64,
feedback_score: i64,
) -> f64 {
const DECAY_SECS: f64 = 30.0 * 24.0 * 3600.0;
const ACCESS_SCALE: f64 = 5.0;
const FEEDBACK_SCALE: f64 = 5.0;
let reference = last_accessed_at.or(timestamp_unix);
let freshness = match reference {
Some(ts) => {
let age_secs = (now_unix - ts).max(0) as f64;
(0.25 + 0.75 * (-age_secs / DECAY_SECS).exp()).clamp(0.0, 1.0)
}
None => 0.25,
};
let n = access_count.max(0) as f64;
let access_boost = 1.0 - (-n / ACCESS_SCALE).exp();
let base = (freshness + (1.0 - freshness) * access_boost).clamp(0.0, 1.0);
let factor = (feedback_score as f64 / FEEDBACK_SCALE).tanh();
let adjusted = if factor >= 0.0 {
base + (1.0 - base) * factor
} else {
base * (1.0 + factor)
};
adjusted.clamp(0.0, 1.0)
}
fn hit_confidence(
last_accessed_at: Option<i64>,
timestamp_unix: Option<i64>,
access_count: i64,
feedback_score: i64,
) -> f64 {
compute_confidence(
now_unix(),
last_accessed_at,
timestamp_unix,
access_count,
feedback_score,
)
}
fn bump_access_metadata(store: &Store, hits: &[SearchHit]) -> Result<()> {
if hits.is_empty() {
return Ok(());
}
let now = now_unix();
let conn = store.conn();
for hit in hits {
conn.execute(
"UPDATE chunks
SET access_count = access_count + 1,
last_accessed_at = ?1,
access_decay_at = ?1
WHERE id = ?2",
rusqlite::params![now, &hit.chunk_id],
)?;
}
Ok(())
}
pub fn search(store: &Store, query: &str, opts: SearchOptions) -> Result<Vec<SearchHit>> {
let fts_query = build_fts_query(query);
if fts_query.is_empty() {
return Ok(Vec::new());
}
let mut sql = String::from(
"SELECT
c.id, c.source_id, c.ordinal, c.byte_start, c.byte_end, c.text,
c.role, c.session_id, c.turn_id, c.tool_name, c.timestamp_unix,
c.access_count, c.last_accessed_at, c.feedback_score,
s.uri, s.path, s.kind,
bm25(chunks_fts) AS score,
snippet(chunks_fts, 0, '<<', '>>', '…', 16) AS snippet
FROM chunks_fts
JOIN chunks c ON c.rowid = chunks_fts.rowid
JOIN sources s ON s.id = c.source_id
WHERE chunks_fts MATCH ?",
);
let mut args: Vec<Box<dyn ToSql>> = vec![Box::new(fts_query)];
if let Some(kind) = &opts.kind {
sql.push_str(" AND s.kind = ?");
args.push(Box::new(kind.clone()));
}
if let Some(path) = &opts.path_contains {
sql.push_str(" AND (s.path LIKE ? OR s.uri LIKE ?)");
let like = format!("%{path}%");
args.push(Box::new(like.clone()));
args.push(Box::new(like));
}
sql.push_str(" ORDER BY score LIMIT ?");
args.push(Box::new(opts.limit as i64));
let conn = store.conn();
let mut stmt = conn.prepare(&sql)?;
let rows = stmt.query_map(rusqlite::params_from_iter(args.iter()), |row| {
let timestamp_unix: Option<i64> = row.get(10)?;
let access_count: i64 = row.get(11)?;
let last_accessed_at: Option<i64> = row.get(12)?;
let feedback_score: i64 = row.get(13)?;
Ok(SearchHit {
chunk_id: row.get(0)?,
source_id: row.get(1)?,
ordinal: row.get(2)?,
byte_start: row.get(3)?,
byte_end: row.get(4)?,
text: row.get(5)?,
role: row.get(6)?,
session_id: row.get(7)?,
turn_id: row.get(8)?,
tool_name: row.get(9)?,
timestamp_unix,
access_count,
last_accessed_at,
feedback_score,
uri: row.get(14)?,
path: row.get(15)?,
kind: row.get(16)?,
score: row.get(17)?,
confidence: hit_confidence(
last_accessed_at,
timestamp_unix,
access_count,
feedback_score,
),
snippet: row.get(18)?,
})
})?;
let hits = rows.collect::<Result<Vec<_>, _>>()?;
bump_access_metadata(store, &hits)?;
Ok(hits)
}
#[derive(Debug, Clone)]
pub struct SemanticOptions {
pub limit: usize,
pub kind: Option<String>,
pub path_contains: Option<String>,
pub model: String,
pub ollama_url: String,
pub instruction: Option<String>,
}
impl Default for SemanticOptions {
fn default() -> Self {
Self {
limit: 10,
kind: None,
path_contains: None,
model: crate::embed::DEFAULT_EMBED_MODEL.to_string(),
ollama_url: crate::embed::DEFAULT_OLLAMA_URL.to_string(),
instruction: None,
}
}
}
pub fn semantic_search(
store: &Store,
query: &str,
opts: &SemanticOptions,
) -> Result<Vec<SearchHit>> {
let client = OllamaClient::new(&opts.ollama_url, &opts.model)?;
semantic_search_with(store, query, opts, &client)
}
pub fn semantic_search_with(
store: &Store,
query: &str,
opts: &SemanticOptions,
backend: &dyn EmbeddingBackend,
) -> Result<Vec<SearchHit>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
if vec_eligible(opts) {
return vec_semantic_search_with(store, query, opts, backend);
}
preflight_embeddings(store, &opts.model)?;
let query = prepare_embedding_text(
&opts.model,
EmbedRole::Query,
query,
opts.instruction.as_deref(),
);
let query_vec = backend.embed(&query)?;
let candidates = load_embedded_chunks(store, &opts.model, &opts.kind, &opts.path_contains)?;
let hits = rank_by_cosine(&query_vec, candidates, opts.limit)?;
bump_access_metadata(store, &hits)?;
Ok(hits)
}
pub(crate) fn vec_eligible(opts: &SemanticOptions) -> bool {
opts.model == DEFAULT_EMBED_MODEL && opts.kind.is_none() && opts.path_contains.is_none()
}
pub fn vec_semantic_search_with(
store: &Store,
query: &str,
opts: &SemanticOptions,
backend: &dyn EmbeddingBackend,
) -> Result<Vec<SearchHit>> {
if opts.model != DEFAULT_EMBED_MODEL {
anyhow::bail!(
"vec-backed semantic search only supports the default model '{}' (got '{}'); \
use semantic_search_with for other models",
DEFAULT_EMBED_MODEL,
opts.model,
);
}
if query.trim().is_empty() {
return Ok(Vec::new());
}
preflight_embeddings(store, &opts.model)?;
let query = prepare_embedding_text(
&opts.model,
EmbedRole::Query,
query,
opts.instruction.as_deref(),
);
let query_vec = backend.embed(&query)?;
if query_vec.len() != VEC_MIRROR_DIM {
anyhow::bail!(
"query embedding has {} dims but vec mirror is {} dims",
query_vec.len(),
VEC_MIRROR_DIM,
);
}
let blob = f32s_to_blob(&query_vec);
let sql = format!(
"SELECT c.id, c.source_id, c.ordinal, c.byte_start, c.byte_end, c.text,
c.role, c.session_id, c.turn_id, c.tool_name, c.timestamp_unix,
c.access_count, c.last_accessed_at, c.feedback_score,
s.uri, s.path, s.kind, v.distance
FROM (
SELECT rowid, distance
FROM {VEC_MIRROR_TABLE}
WHERE embedding MATCH ?1
ORDER BY distance
LIMIT ?2
) v
JOIN chunks c ON c.rowid = v.rowid
JOIN sources s ON s.id = c.source_id
ORDER BY v.distance"
);
let conn = store.conn();
let mut stmt = conn.prepare(&sql)?;
let rows = stmt.query_map(rusqlite::params![blob, opts.limit as i64], |row| {
let text: String = row.get(5)?;
let timestamp_unix: Option<i64> = row.get(10)?;
let access_count: i64 = row.get(11)?;
let last_accessed_at: Option<i64> = row.get(12)?;
let feedback_score: i64 = row.get(13)?;
let distance: f64 = row.get(17)?;
Ok(SearchHit {
chunk_id: row.get(0)?,
source_id: row.get(1)?,
ordinal: row.get(2)?,
byte_start: row.get(3)?,
byte_end: row.get(4)?,
snippet: truncate_snippet(&text, 160),
text,
role: row.get(6)?,
session_id: row.get(7)?,
turn_id: row.get(8)?,
tool_name: row.get(9)?,
timestamp_unix,
access_count,
last_accessed_at,
feedback_score,
uri: row.get(14)?,
path: row.get(15)?,
kind: row.get(16)?,
score: 1.0 - distance,
confidence: hit_confidence(
last_accessed_at,
timestamp_unix,
access_count,
feedback_score,
),
})
})?;
let hits = rows.collect::<Result<Vec<_>, _>>()?;
bump_access_metadata(store, &hits)?;
Ok(hits)
}
pub fn hybrid_search(store: &Store, query: &str, opts: &SemanticOptions) -> Result<Vec<SearchHit>> {
let client = OllamaClient::new(&opts.ollama_url, &opts.model)?;
hybrid_search_with(store, query, opts, &client)
}
pub fn hybrid_search_with(
store: &Store,
query: &str,
opts: &SemanticOptions,
backend: &dyn EmbeddingBackend,
) -> Result<Vec<SearchHit>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
preflight_embeddings(store, &opts.model)?;
let kw_hits = search(
store,
query,
SearchOptions {
limit: opts.limit.max(10) * 4,
kind: opts.kind.clone(),
path_contains: opts.path_contains.clone(),
},
)?;
let sem_opts = SemanticOptions {
limit: opts.limit.max(10) * 4,
..opts.clone()
};
let sem_hits = semantic_search_with(store, query, &sem_opts, backend)?;
let hits = blend_hits(kw_hits, sem_hits, opts.limit);
bump_access_metadata(store, &hits)?;
Ok(hits)
}
struct CandidateRow {
chunk_id: String,
source_id: String,
ordinal: i64,
byte_start: i64,
byte_end: i64,
text: String,
role: Option<String>,
session_id: Option<String>,
turn_id: Option<String>,
tool_name: Option<String>,
timestamp_unix: Option<i64>,
access_count: i64,
last_accessed_at: Option<i64>,
feedback_score: i64,
uri: String,
path: Option<String>,
kind: String,
embedding: Vec<f32>,
}
fn load_embedded_chunks(
store: &Store,
model: &str,
kind: &Option<String>,
path_contains: &Option<String>,
) -> Result<Vec<CandidateRow>> {
let mut sql = String::from(
"SELECT c.id, c.source_id, c.ordinal, c.byte_start, c.byte_end, c.text,
c.role, c.session_id, c.turn_id, c.tool_name, c.timestamp_unix,
c.access_count, c.last_accessed_at, c.feedback_score,
s.uri, s.path, s.kind, e.embedding
FROM embeddings e
JOIN chunks c ON c.id = e.chunk_id
JOIN sources s ON s.id = c.source_id
WHERE e.model = ?",
);
let mut args: Vec<Box<dyn ToSql>> = vec![Box::new(model.to_string())];
if let Some(k) = kind {
sql.push_str(" AND s.kind = ?");
args.push(Box::new(k.clone()));
}
if let Some(p) = path_contains {
sql.push_str(" AND (s.path LIKE ? OR s.uri LIKE ?)");
let like = format!("%{p}%");
args.push(Box::new(like.clone()));
args.push(Box::new(like));
}
let conn = store.conn();
let mut stmt = conn.prepare(&sql)?;
let rows = stmt.query_map(rusqlite::params_from_iter(args.iter()), |row| {
let blob: Vec<u8> = row.get(17)?;
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, i64>(2)?,
row.get::<_, i64>(3)?,
row.get::<_, i64>(4)?,
row.get::<_, String>(5)?,
row.get::<_, Option<String>>(6)?,
row.get::<_, Option<String>>(7)?,
row.get::<_, Option<String>>(8)?,
row.get::<_, Option<String>>(9)?,
row.get::<_, Option<i64>>(10)?,
row.get::<_, i64>(11)?,
row.get::<_, Option<i64>>(12)?,
row.get::<_, i64>(13)?,
row.get::<_, String>(14)?,
row.get::<_, Option<String>>(15)?,
row.get::<_, String>(16)?,
blob,
))
})?;
let mut out = Vec::new();
for r in rows {
let (
chunk_id,
source_id,
ordinal,
byte_start,
byte_end,
text,
role,
session_id,
turn_id,
tool_name,
timestamp_unix,
access_count,
last_accessed_at,
feedback_score,
uri,
path,
kind,
blob,
) = r?;
let embedding = blob_to_f32s(&blob)?;
out.push(CandidateRow {
chunk_id,
source_id,
ordinal,
byte_start,
byte_end,
text,
role,
session_id,
turn_id,
tool_name,
timestamp_unix,
access_count,
last_accessed_at,
feedback_score,
uri,
path,
kind,
embedding,
});
}
Ok(out)
}
fn preflight_embeddings(store: &Store, model: &str) -> Result<()> {
let stats = embedding_stats(store)?;
if stats.iter().any(|stat| stat.model == model) {
return Ok(());
}
let available = if stats.is_empty() {
String::from("none exist")
} else {
stats
.into_iter()
.map(|stat| format!("{} (dim {}, count {})", stat.model, stat.dim, stat.count))
.collect::<Vec<_>>()
.join(", ")
};
anyhow::bail!("no stored embeddings for model '{model}'; available models: {available}");
}
fn rank_by_cosine(
query_vec: &[f32],
candidates: Vec<CandidateRow>,
limit: usize,
) -> Result<Vec<SearchHit>> {
let mut scored: Vec<(f32, CandidateRow)> = candidates
.into_iter()
.map(|c| (cosine_similarity(query_vec, &c.embedding), c))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored
.into_iter()
.map(|(score, c)| SearchHit {
chunk_id: c.chunk_id,
source_id: c.source_id,
uri: c.uri,
path: c.path,
kind: c.kind,
ordinal: c.ordinal,
byte_start: c.byte_start,
byte_end: c.byte_end,
role: c.role,
session_id: c.session_id,
turn_id: c.turn_id,
tool_name: c.tool_name,
timestamp_unix: c.timestamp_unix,
access_count: c.access_count,
last_accessed_at: c.last_accessed_at,
feedback_score: c.feedback_score,
score: score as f64,
confidence: hit_confidence(
c.last_accessed_at,
c.timestamp_unix,
c.access_count,
c.feedback_score,
),
snippet: truncate_snippet(&c.text, 160),
text: c.text,
})
.collect())
}
fn truncate_snippet(text: &str, max_chars: usize) -> String {
let mut out: String = text.chars().take(max_chars).collect();
if text.chars().count() > max_chars {
out.push('…');
}
out
}
fn blend_hits(kw: Vec<SearchHit>, sem: Vec<SearchHit>, limit: usize) -> Vec<SearchHit> {
use std::collections::HashMap;
const K: f64 = 60.0;
let mut combined: HashMap<String, (f64, SearchHit)> = HashMap::new();
for (rank, hit) in kw.into_iter().enumerate() {
let contribution = 1.0 / (K + (rank + 1) as f64);
combined.insert(hit.chunk_id.clone(), (contribution, hit));
}
for (rank, hit) in sem.into_iter().enumerate() {
let contribution = 1.0 / (K + (rank + 1) as f64);
combined
.entry(hit.chunk_id.clone())
.and_modify(|slot| slot.0 += contribution)
.or_insert((contribution, hit));
}
let mut scored: Vec<(f64, SearchHit)> = combined
.into_values()
.map(|(score, mut hit)| {
hit.score = score;
(score, hit)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
scored.into_iter().map(|(_, h)| h).collect()
}
pub(crate) fn build_fts_query(q: &str) -> String {
q.split_whitespace()
.filter_map(|t| {
let cleaned: String = t
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect();
if cleaned.is_empty() {
None
} else {
Some(format!("{cleaned}*"))
}
})
.collect::<Vec<_>>()
.join(" ")
}
pub fn format_summary(query: &str, hits: &[SearchHit]) -> String {
let mut out = String::new();
let _ = writeln!(out, "query: {query:?} hits: {}", hits.len());
for (i, hit) in hits.iter().enumerate() {
let source = hit.path.as_deref().unwrap_or(hit.uri.as_str());
let _ = writeln!(
out,
" {rank}. [{score:.3} conf={confidence:.2}] {source} #{ord}",
rank = i + 1,
score = hit.score,
confidence = hit.confidence,
source = source,
ord = hit.ordinal,
);
let preview = snippet_preview(&hit.snippet, 96);
if !preview.is_empty() {
let _ = writeln!(out, " {preview}");
}
if let Some(meta) = hit_metadata_line(hit) {
let _ = writeln!(out, " {meta}");
}
}
out
}
pub fn print_summary(query: &str, hits: &[SearchHit]) {
print!("{}", format_summary(query, hits));
}
fn snippet_preview(snippet: &str, max_chars: usize) -> String {
let first = snippet.lines().next().unwrap_or("").trim();
let char_count = first.chars().count();
let mut out: String = first.chars().take(max_chars).collect();
if char_count > max_chars {
out.push('…');
}
out
}
fn hit_metadata_line(hit: &SearchHit) -> Option<String> {
let mut parts = Vec::new();
if let Some(role) = &hit.role {
parts.push(format!("role={role}"));
}
if let Some(session_id) = &hit.session_id {
parts.push(format!("session={session_id}"));
}
if let Some(turn_id) = &hit.turn_id {
parts.push(format!("turn={turn_id}"));
}
if let Some(tool_name) = &hit.tool_name {
parts.push(format!("tool={tool_name}"));
}
if let Some(ts) = hit.timestamp_unix {
parts.push(format!("ts={ts}"));
}
if parts.is_empty() {
None
} else {
Some(parts.join(" "))
}
}
pub fn format_text(query: &str, hits: &[SearchHit]) -> String {
let mut out = String::new();
if hits.is_empty() {
let _ = writeln!(out, "no results for {query:?}");
return out;
}
for (i, hit) in hits.iter().enumerate() {
let _ = writeln!(
out,
"[{rank}] score={score:.4} conf={confidence:.2} chunk={ordinal} bytes={start}-{end} uri={uri}",
rank = i + 1,
score = hit.score,
confidence = hit.confidence,
ordinal = hit.ordinal,
start = hit.byte_start,
end = hit.byte_end,
uri = hit.uri,
);
let _ = writeln!(out, " {}", hit.snippet);
if let Some(meta) = hit_metadata_line(hit) {
let _ = writeln!(out, " {meta}");
}
}
let _ = writeln!(out, "summary query={query:?} results={}", hits.len());
out
}
pub fn print_text(query: &str, hits: &[SearchHit]) {
print!("{}", format_text(query, hits));
}
pub fn format_json(query: &str, model: Option<&str>, hits: &[SearchHit]) -> Result<String> {
#[derive(Serialize)]
struct Envelope<'a> {
query: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<&'a str>,
results: &'a [SearchHit],
}
let env = Envelope {
query,
model,
results: hits,
};
Ok(serde_json::to_string_pretty(&env)?)
}
pub fn print_json(query: &str, model: Option<&str>, hits: &[SearchHit]) -> Result<()> {
print!("{}", format_json(query, model, hits)?);
Ok(())
}
#[cfg(test)]
mod tests {
use super::{
SearchHit, build_fts_query, compute_confidence, format_json, format_summary, format_text,
snippet_preview,
};
fn sample_hit() -> SearchHit {
SearchHit {
chunk_id: "chunk-1".into(),
source_id: "source-1".into(),
uri: "file:///tmp/needle.txt".into(),
path: Some("/tmp/needle.txt".into()),
kind: "text/plain".into(),
ordinal: 0,
byte_start: 0,
byte_end: 12,
role: None,
session_id: None,
turn_id: None,
tool_name: None,
timestamp_unix: None,
access_count: 0,
last_accessed_at: None,
feedback_score: 0,
score: 0.99,
confidence: 0.5,
snippet: "needle snippet".into(),
text: "needle text".into(),
}
}
fn sample_hit_with_metadata() -> SearchHit {
SearchHit {
role: Some("assistant".into()),
session_id: Some("sess-7".into()),
turn_id: Some("turn-9".into()),
tool_name: Some("search".into()),
timestamp_unix: Some(1_700_000_003),
..sample_hit()
}
}
#[test]
fn empty_query_produces_empty_match() {
assert_eq!(build_fts_query(""), "");
assert_eq!(build_fts_query(" "), "");
}
#[test]
fn single_token_is_prefix_matched() {
assert_eq!(build_fts_query("hello"), "hello*");
}
#[test]
fn multiple_tokens_are_and_joined() {
assert_eq!(build_fts_query("foo bar"), "foo* bar*");
}
#[test]
fn preview_takes_first_line_and_truncates_with_ellipsis() {
assert_eq!(snippet_preview("", 10), "");
assert_eq!(snippet_preview("short line", 40), "short line");
assert_eq!(snippet_preview("line one\nline two", 40), "line one");
assert_eq!(snippet_preview("abcdefghij", 5), "abcde…");
let preview = snippet_preview("世界世界世界世界", 4);
assert_eq!(preview, "世界世界…");
}
#[test]
fn punctuation_and_special_chars_are_stripped() {
assert_eq!(build_fts_query("hello, world!"), "hello* world*");
assert_eq!(build_fts_query(r#"he"llo"#), "hello*");
assert_eq!(build_fts_query("!!!"), "");
}
#[test]
fn summary_formatter_includes_query_hit_and_metadata() {
let output = format_summary("needle", &[sample_hit_with_metadata()]);
assert!(output.contains("query: \"needle\" hits: 1"), "{output}");
assert!(
output.contains("[0.990 conf=0.50] /tmp/needle.txt #0"),
"{output}"
);
assert!(output.contains("needle snippet"), "{output}");
assert!(output.contains("role=assistant"), "{output}");
assert!(output.contains("session=sess-7"), "{output}");
assert!(output.contains("turn=turn-9"), "{output}");
assert!(output.contains("tool=search"), "{output}");
assert!(output.contains("ts=1700000003"), "{output}");
}
#[test]
fn text_formatter_includes_detailed_hit_block() {
let output = format_text("needle", &[sample_hit_with_metadata()]);
assert!(
output.contains(
"[1] score=0.9900 conf=0.50 chunk=0 bytes=0-12 uri=file:///tmp/needle.txt"
),
"{output}"
);
assert!(output.contains("needle snippet"), "{output}");
assert!(output.contains("role=assistant"), "{output}");
assert!(
output.contains("summary query=\"needle\" results=1"),
"{output}"
);
}
#[test]
fn json_formatter_includes_model_when_present() {
let json = format_json("needle", Some("nomic-embed-text"), &[sample_hit()]).unwrap();
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(value["query"], "needle");
assert_eq!(value["model"], "nomic-embed-text");
assert_eq!(value["results"].as_array().unwrap().len(), 1);
}
#[test]
fn json_formatter_omits_model_when_absent() {
let json = format_json("needle", None, &[sample_hit()]).unwrap();
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(value["query"], "needle");
assert!(value.get("model").is_none());
}
const NOW: i64 = 1_800_000_000;
const DAY: i64 = 24 * 3600;
#[test]
fn confidence_floor_when_no_timestamps_or_access() {
let c = compute_confidence(NOW, None, None, 0, 0);
assert!((c - 0.25).abs() < 1e-9, "got {c}");
}
#[test]
fn confidence_fresh_chunk_near_one() {
let c = compute_confidence(NOW, None, Some(NOW), 0, 0);
assert!((c - 1.0).abs() < 1e-9, "got {c}");
}
#[test]
fn confidence_decays_with_age_toward_floor() {
let recent = compute_confidence(NOW, None, Some(NOW - DAY), 0, 0);
let week_old = compute_confidence(NOW, None, Some(NOW - 7 * DAY), 0, 0);
let year_old = compute_confidence(NOW, None, Some(NOW - 365 * DAY), 0, 0);
assert!(recent > week_old, "recent {recent} vs week {week_old}");
assert!(week_old > year_old, "week {week_old} vs year {year_old}");
assert!(
year_old >= 0.25 && year_old < 0.26,
"year-old should hit the floor, got {year_old}"
);
}
#[test]
fn confidence_future_timestamps_clamp_to_fresh() {
let c = compute_confidence(NOW, None, Some(NOW + 10 * DAY), 0, 0);
assert!((c - 1.0).abs() < 1e-9, "got {c}");
}
#[test]
fn confidence_last_accessed_preferred_over_timestamp() {
let c = compute_confidence(NOW, Some(NOW), Some(NOW - 365 * DAY), 0, 0);
assert!((c - 1.0).abs() < 1e-9, "got {c}");
}
#[test]
fn confidence_access_lifts_above_freshness() {
let no_access = compute_confidence(NOW, None, Some(NOW - 365 * DAY), 0, 0);
let some_access = compute_confidence(NOW, None, Some(NOW - 365 * DAY), 3, 0);
let many_access = compute_confidence(NOW, None, Some(NOW - 365 * DAY), 50, 0);
assert!(some_access > no_access, "{some_access} vs {no_access}");
assert!(many_access > some_access, "{many_access} vs {some_access}");
assert!(many_access <= 1.0, "must stay in [0,1]: {many_access}");
}
#[test]
fn confidence_access_count_zero_preserves_legacy_behavior() {
let age = 10 * DAY;
let c = compute_confidence(NOW, None, Some(NOW - age), 0, 0);
let expected = 0.25 + 0.75 * (-(age as f64) / (30.0 * DAY as f64)).exp();
assert!((c - expected).abs() < 1e-9, "got {c}, expected {expected}");
}
#[test]
fn confidence_negative_access_count_treated_as_zero() {
let c = compute_confidence(NOW, None, Some(NOW - 365 * DAY), -7, 0);
let baseline = compute_confidence(NOW, None, Some(NOW - 365 * DAY), 0, 0);
assert!((c - baseline).abs() < 1e-9, "got {c}");
}
#[test]
fn confidence_feedback_neutral_matches_legacy() {
for &(last, ts, access) in &[
(None, None, 0),
(None, Some(NOW - 10 * DAY), 0),
(Some(NOW - DAY), Some(NOW - 365 * DAY), 3),
(None, Some(NOW - 30 * DAY), 12),
] {
let with_feedback = compute_confidence(NOW, last, ts, access, 0);
let legacy_base = {
const DECAY_SECS: f64 = 30.0 * 24.0 * 3600.0;
const ACCESS_SCALE: f64 = 5.0;
let reference = last.or(ts);
let freshness = match reference {
Some(t) => {
let age_secs = (NOW - t).max(0) as f64;
(0.25 + 0.75 * (-age_secs / DECAY_SECS).exp()).clamp(0.0, 1.0)
}
None => 0.25,
};
let n = access.max(0) as f64;
let access_boost = 1.0 - (-n / ACCESS_SCALE).exp();
(freshness + (1.0 - freshness) * access_boost).clamp(0.0, 1.0)
};
assert!(
(with_feedback - legacy_base).abs() < 1e-12,
"feedback=0 must be a no-op (last={last:?}, ts={ts:?}, access={access}): \
got {with_feedback}, expected {legacy_base}"
);
}
}
#[test]
fn confidence_positive_feedback_lifts_toward_one() {
let ts = Some(NOW - 365 * DAY);
let base = compute_confidence(NOW, None, ts, 0, 0);
let one_up = compute_confidence(NOW, None, ts, 0, 1);
let many_up = compute_confidence(NOW, None, ts, 0, 100);
assert!(one_up > base, "one up {one_up} vs base {base}");
assert!(many_up > one_up, "many up {many_up} vs one {one_up}");
assert!(many_up <= 1.0, "must stay <= 1: {many_up}");
assert!((many_up - 1.0).abs() < 1e-6, "saturation: {many_up}");
}
#[test]
fn confidence_negative_feedback_pulls_toward_zero() {
let ts = Some(NOW);
let base = compute_confidence(NOW, None, ts, 0, 0);
let one_down = compute_confidence(NOW, None, ts, 0, -1);
let many_down = compute_confidence(NOW, None, ts, 0, -100);
assert!(one_down < base, "one down {one_down} vs base {base}");
assert!(
many_down < one_down,
"many down {many_down} vs one {one_down}"
);
assert!(many_down >= 0.0, "must stay >= 0: {many_down}");
assert!(many_down < 1e-6, "saturation: {many_down}");
}
#[test]
fn confidence_feedback_cancels_out() {
let ts = Some(NOW - 30 * DAY);
let base = compute_confidence(NOW, None, ts, 0, 0);
let up = compute_confidence(NOW, None, ts, 0, 3);
let down = compute_confidence(NOW, None, ts, 0, -3);
let factor = (3.0_f64 / 5.0).tanh();
assert!((up - (base + (1.0 - base) * factor)).abs() < 1e-9);
assert!((down - (base * (1.0 - factor))).abs() < 1e-9);
}
}