use crate::embedder::{cosine_similarity, decode_vector};
const STALENESS_WEIGHT: f32 = 0.3;
#[derive(Debug, Clone)]
pub struct SearchHit {
pub id: i64,
pub source_type: String,
pub source_path: String,
pub source_id: Option<i64>,
pub similarity: f32,
pub staleness: f32,
pub score: f32,
pub chunk_text: String,
pub last_commit: Option<String>,
}
pub struct StoredEmbedding {
pub id: i64,
pub source_type: String,
pub source_path: String,
pub source_id: Option<i64>,
pub staleness: f32,
pub chunk_text: String,
pub last_commit: Option<String>,
pub vector: Vec<f32>,
}
pub fn rerank(query_vec: &[f32], stored: Vec<StoredEmbedding>, top_k: usize) -> Vec<SearchHit> {
let mut hits: Vec<SearchHit> = stored
.into_iter()
.map(|e| {
let similarity = cosine_similarity(query_vec, &e.vector);
let score = similarity * (1.0 - STALENESS_WEIGHT * e.staleness);
SearchHit {
id: e.id,
source_type: e.source_type,
source_path: e.source_path,
source_id: e.source_id,
similarity,
staleness: e.staleness,
score,
chunk_text: e.chunk_text,
last_commit: e.last_commit,
}
})
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(top_k);
hits
}
pub fn parse_blob(blob: Vec<u8>) -> Vec<f32> {
decode_vector(&blob)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_stored(id: i64, vec: Vec<f32>, staleness: f32) -> StoredEmbedding {
StoredEmbedding {
id,
source_type: "symbol".to_string(),
source_path: "src/lib.rs".to_string(),
source_id: Some(id),
staleness,
chunk_text: "test chunk".to_string(),
last_commit: None,
vector: vec,
}
}
#[test]
fn test_rerank_orders_by_score() {
let query = vec![1.0_f32, 0.0, 0.0];
let stored = vec![
make_stored(1, vec![1.0, 0.0, 0.0], 0.0), make_stored(2, vec![0.0, 1.0, 0.0], 0.0), make_stored(3, vec![0.9, 0.4, 0.0], 0.5), ];
let hits = rerank(&query, stored, 3);
assert_eq!(hits[0].id, 1, "most similar, no staleness should be first");
assert!(hits[0].score > hits[1].score);
}
#[test]
fn test_rerank_respects_top_k() {
let query = vec![1.0_f32, 0.0];
let stored = (0..10)
.map(|i| make_stored(i, vec![1.0, 0.0], 0.0))
.collect();
let hits = rerank(&query, stored, 3);
assert_eq!(hits.len(), 3);
}
}