Skip to main content

normalize_semantic/
search.rs

1//! ANN search with staleness-based re-ranking.
2//!
3//! ## Search flow (ANN path — preferred)
4//! 1. Register the sqlite-vec extension via [`crate::vec_ext::register_vec_extension`].
5//! 2. Embed the query with the same model used at index time.
6//! 3. Ask the `vec_embeddings` virtual table for the top-[`crate::store::ANN_CANDIDATE_COUNT`]
7//!    nearest vectors.
8//! 4. Re-rank candidates: `score = cosine_sim * (1 - staleness_weight * staleness)`.
9//! 5. Return top-K results sorted by final score descending.
10//!
11//! ## Search flow (brute-force fallback)
12//! Steps 2–5 above, but step 3 is replaced by loading all vectors from SQLite
13//! into memory.  Used when sqlite-vec is not available or `vec_embeddings`
14//! does not yet exist (e.g. on the first rebuild before the schema migration).
15
16use crate::embedder::{cosine_similarity, decode_vector};
17
18/// Weight applied to staleness during re-ranking. Tunable.
19const STALENESS_WEIGHT: f32 = 0.3;
20
21/// One result from a semantic search.
22#[derive(Debug, Clone)]
23pub struct SearchHit {
24    /// Row id in the embeddings table.
25    pub id: i64,
26    /// Source type tag ("symbol", "doc", …).
27    pub source_type: String,
28    /// Relative file path.
29    pub source_path: String,
30    /// FK into symbols table (if a symbol chunk).
31    pub source_id: Option<i64>,
32    /// Cosine similarity before re-ranking.
33    pub similarity: f32,
34    /// Staleness score stored at index time.
35    pub staleness: f32,
36    /// Final score after staleness penalty.
37    pub score: f32,
38    /// The chunk text that was embedded.
39    pub chunk_text: String,
40    /// Git commit SHA when this chunk was last embedded.
41    pub last_commit: Option<String>,
42}
43
44/// In-memory representation of a stored embedding row.
45///
46/// Used for both the ANN candidate set (post-`vec_search`) and the brute-force
47/// fallback path.
48pub struct StoredEmbedding {
49    pub id: i64,
50    pub source_type: String,
51    pub source_path: String,
52    pub source_id: Option<i64>,
53    pub staleness: f32,
54    pub chunk_text: String,
55    pub last_commit: Option<String>,
56    pub vector: Vec<f32>,
57}
58
59/// Re-rank a list of stored embeddings against a query vector.
60///
61/// Returns hits sorted by final score descending, limited to `top_k`.
62pub fn rerank(query_vec: &[f32], stored: Vec<StoredEmbedding>, top_k: usize) -> Vec<SearchHit> {
63    let mut hits: Vec<SearchHit> = stored
64        .into_iter()
65        .map(|e| {
66            let similarity = cosine_similarity(query_vec, &e.vector);
67            let score = similarity * (1.0 - STALENESS_WEIGHT * e.staleness);
68            SearchHit {
69                id: e.id,
70                source_type: e.source_type,
71                source_path: e.source_path,
72                source_id: e.source_id,
73                similarity,
74                staleness: e.staleness,
75                score,
76                chunk_text: e.chunk_text,
77                last_commit: e.last_commit,
78            }
79        })
80        .collect();
81
82    // Sort descending by final score
83    hits.sort_by(|a, b| {
84        b.score
85            .partial_cmp(&a.score)
86            .unwrap_or(std::cmp::Ordering::Equal)
87    });
88    hits.truncate(top_k);
89    hits
90}
91
92/// Parse a raw BLOB from the database into a f32 vector via `decode_vector`.
93pub fn parse_blob(blob: Vec<u8>) -> Vec<f32> {
94    decode_vector(&blob)
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    fn make_stored(id: i64, vec: Vec<f32>, staleness: f32) -> StoredEmbedding {
102        StoredEmbedding {
103            id,
104            source_type: "symbol".to_string(),
105            source_path: "src/lib.rs".to_string(),
106            source_id: Some(id),
107            staleness,
108            chunk_text: "test chunk".to_string(),
109            last_commit: None,
110            vector: vec,
111        }
112    }
113
114    #[test]
115    fn test_rerank_orders_by_score() {
116        let query = vec![1.0_f32, 0.0, 0.0];
117        let stored = vec![
118            make_stored(1, vec![1.0, 0.0, 0.0], 0.0), // sim=1.0, staleness=0 → score=1.0
119            make_stored(2, vec![0.0, 1.0, 0.0], 0.0), // sim=0.0 → score=0.0
120            make_stored(3, vec![0.9, 0.4, 0.0], 0.5), // lower final score due to staleness
121        ];
122        let hits = rerank(&query, stored, 3);
123        assert_eq!(hits[0].id, 1, "most similar, no staleness should be first");
124        assert!(hits[0].score > hits[1].score);
125    }
126
127    #[test]
128    fn test_rerank_respects_top_k() {
129        let query = vec![1.0_f32, 0.0];
130        let stored = (0..10)
131            .map(|i| make_stored(i, vec![1.0, 0.0], 0.0))
132            .collect();
133        let hits = rerank(&query, stored, 3);
134        assert_eq!(hits.len(), 3);
135    }
136}