normalize_semantic/
search.rs1use crate::embedder::{cosine_similarity, decode_vector};
17
18const STALENESS_WEIGHT: f32 = 0.3;
20
21#[derive(Debug, Clone)]
23pub struct SearchHit {
24 pub id: i64,
26 pub source_type: String,
28 pub source_path: String,
30 pub source_id: Option<i64>,
32 pub similarity: f32,
34 pub staleness: f32,
36 pub score: f32,
38 pub chunk_text: String,
40 pub last_commit: Option<String>,
42}
43
44pub struct StoredEmbedding {
49 pub id: i64,
50 pub source_type: String,
51 pub source_path: String,
52 pub source_id: Option<i64>,
53 pub staleness: f32,
54 pub chunk_text: String,
55 pub last_commit: Option<String>,
56 pub vector: Vec<f32>,
57}
58
59pub fn rerank(query_vec: &[f32], stored: Vec<StoredEmbedding>, top_k: usize) -> Vec<SearchHit> {
63 let mut hits: Vec<SearchHit> = stored
64 .into_iter()
65 .map(|e| {
66 let similarity = cosine_similarity(query_vec, &e.vector);
67 let score = similarity * (1.0 - STALENESS_WEIGHT * e.staleness);
68 SearchHit {
69 id: e.id,
70 source_type: e.source_type,
71 source_path: e.source_path,
72 source_id: e.source_id,
73 similarity,
74 staleness: e.staleness,
75 score,
76 chunk_text: e.chunk_text,
77 last_commit: e.last_commit,
78 }
79 })
80 .collect();
81
82 hits.sort_by(|a, b| {
84 b.score
85 .partial_cmp(&a.score)
86 .unwrap_or(std::cmp::Ordering::Equal)
87 });
88 hits.truncate(top_k);
89 hits
90}
91
92pub fn parse_blob(blob: Vec<u8>) -> Vec<f32> {
94 decode_vector(&blob)
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 fn make_stored(id: i64, vec: Vec<f32>, staleness: f32) -> StoredEmbedding {
102 StoredEmbedding {
103 id,
104 source_type: "symbol".to_string(),
105 source_path: "src/lib.rs".to_string(),
106 source_id: Some(id),
107 staleness,
108 chunk_text: "test chunk".to_string(),
109 last_commit: None,
110 vector: vec,
111 }
112 }
113
114 #[test]
115 fn test_rerank_orders_by_score() {
116 let query = vec![1.0_f32, 0.0, 0.0];
117 let stored = vec![
118 make_stored(1, vec![1.0, 0.0, 0.0], 0.0), make_stored(2, vec![0.0, 1.0, 0.0], 0.0), make_stored(3, vec![0.9, 0.4, 0.0], 0.5), ];
122 let hits = rerank(&query, stored, 3);
123 assert_eq!(hits[0].id, 1, "most similar, no staleness should be first");
124 assert!(hits[0].score > hits[1].score);
125 }
126
127 #[test]
128 fn test_rerank_respects_top_k() {
129 let query = vec![1.0_f32, 0.0];
130 let stored = (0..10)
131 .map(|i| make_stored(i, vec![1.0, 0.0], 0.0))
132 .collect();
133 let hits = rerank(&query, stored, 3);
134 assert_eq!(hits.len(), 3);
135 }
136}