Skip to main content

noether_engine/index/
search.rs

1use super::embedding::Embedding;
2use noether_core::stage::StageId;
3
4/// A single entry in a sub-index.
5#[derive(Debug, Clone)]
6pub struct IndexEntry {
7    pub stage_id: StageId,
8    pub embedding: Embedding,
9}
10
11/// One of the three sub-indexes (signature, semantic, or example).
12#[derive(Debug, Clone, Default)]
13pub struct SubIndex {
14    entries: Vec<IndexEntry>,
15}
16
17/// A search result from a single sub-index.
18#[derive(Debug, Clone)]
19pub struct SubSearchResult {
20    pub stage_id: StageId,
21    pub score: f32,
22}
23
24impl SubIndex {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    pub fn add(&mut self, stage_id: StageId, embedding: Embedding) {
30        self.entries.push(IndexEntry {
31            stage_id,
32            embedding,
33        });
34    }
35
36    pub fn remove(&mut self, stage_id: &StageId) {
37        self.entries.retain(|e| &e.stage_id != stage_id);
38    }
39
40    pub fn len(&self) -> usize {
41        self.entries.len()
42    }
43
44    pub fn is_empty(&self) -> bool {
45        self.entries.is_empty()
46    }
47
48    /// Read-only access to all entries (used by near-duplicate scanning).
49    pub fn entries(&self) -> &[IndexEntry] {
50        &self.entries
51    }
52
53    /// Brute-force search: compute cosine similarity against all entries,
54    /// return top-k results sorted by descending score.
55    pub fn search(&self, query: &Embedding, top_k: usize) -> Vec<SubSearchResult> {
56        let mut scored: Vec<SubSearchResult> = self
57            .entries
58            .iter()
59            .map(|entry| SubSearchResult {
60                stage_id: entry.stage_id.clone(),
61                score: cosine_similarity(query, &entry.embedding),
62            })
63            .collect();
64
65        // Sort descending by score
66        scored.sort_by(|a, b| {
67            b.score
68                .partial_cmp(&a.score)
69                .unwrap_or(std::cmp::Ordering::Equal)
70        });
71        scored.truncate(top_k);
72        scored
73    }
74}
75
76/// Cosine similarity between two vectors. Returns value in [-1, 1].
77/// For normalized vectors, this reduces to dot product.
78pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79    debug_assert_eq!(a.len(), b.len());
80    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
81    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
82    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
83    if norm_a == 0.0 || norm_b == 0.0 {
84        return 0.0;
85    }
86    dot / (norm_a * norm_b)
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn cosine_identical_vectors() {
95        let v = vec![1.0, 2.0, 3.0];
96        let sim = cosine_similarity(&v, &v);
97        assert!((sim - 1.0).abs() < 1e-6);
98    }
99
100    #[test]
101    fn cosine_orthogonal_vectors() {
102        let a = vec![1.0, 0.0];
103        let b = vec![0.0, 1.0];
104        let sim = cosine_similarity(&a, &b);
105        assert!(sim.abs() < 1e-6);
106    }
107
108    #[test]
109    fn cosine_opposite_vectors() {
110        let a = vec![1.0, 0.0];
111        let b = vec![-1.0, 0.0];
112        let sim = cosine_similarity(&a, &b);
113        assert!((sim - (-1.0)).abs() < 1e-6);
114    }
115
116    #[test]
117    fn subindex_search_returns_top_k() {
118        let mut idx = SubIndex::new();
119        for i in 0..10 {
120            let mut emb = vec![0.0; 4];
121            emb[i % 4] = 1.0;
122            idx.add(StageId(format!("s{i}")), emb);
123        }
124        let query = vec![1.0, 0.0, 0.0, 0.0];
125        let results = idx.search(&query, 3);
126        assert_eq!(results.len(), 3);
127    }
128
129    #[test]
130    fn subindex_search_sorted_by_score() {
131        let mut idx = SubIndex::new();
132        idx.add(StageId("a".into()), vec![1.0, 0.0]);
133        idx.add(StageId("b".into()), vec![0.5, 0.5]);
134        idx.add(StageId("c".into()), vec![0.0, 1.0]);
135        let query = vec![1.0, 0.0];
136        let results = idx.search(&query, 3);
137        assert!(results[0].score >= results[1].score);
138        assert!(results[1].score >= results[2].score);
139    }
140
141    #[test]
142    fn subindex_empty_returns_empty() {
143        let idx = SubIndex::new();
144        let results = idx.search(&vec![1.0, 0.0], 5);
145        assert!(results.is_empty());
146    }
147
148    #[test]
149    fn subindex_remove() {
150        let mut idx = SubIndex::new();
151        idx.add(StageId("a".into()), vec![1.0, 0.0]);
152        idx.add(StageId("b".into()), vec![0.0, 1.0]);
153        assert_eq!(idx.len(), 2);
154        idx.remove(&StageId("a".into()));
155        assert_eq!(idx.len(), 1);
156    }
157}