Skip to main content

noether_engine/index/
search.rs

1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use super::embedding::Embedding;
5use noether_core::stage::StageId;
6
7/// A single entry in a sub-index.
8#[derive(Debug, Clone)]
9pub struct IndexEntry {
10    pub stage_id: StageId,
11    pub embedding: Embedding,
12}
13
14/// One of the three sub-indexes (signature, semantic, or example).
15#[derive(Debug, Clone, Default)]
16pub struct SubIndex {
17    entries: Vec<IndexEntry>,
18}
19
20/// A search result from a single sub-index.
21#[derive(Debug, Clone)]
22pub struct SubSearchResult {
23    pub stage_id: StageId,
24    pub score: f32,
25}
26
27impl SubIndex {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub fn add(&mut self, stage_id: StageId, embedding: Embedding) {
33        self.entries.push(IndexEntry {
34            stage_id,
35            embedding,
36        });
37    }
38
39    pub fn remove(&mut self, stage_id: &StageId) {
40        self.entries.retain(|e| &e.stage_id != stage_id);
41    }
42
43    pub fn len(&self) -> usize {
44        self.entries.len()
45    }
46
47    pub fn is_empty(&self) -> bool {
48        self.entries.is_empty()
49    }
50
51    /// Read-only access to all entries (used by near-duplicate scanning).
52    pub fn entries(&self) -> &[IndexEntry] {
53        &self.entries
54    }
55
56    /// Brute-force search: compute cosine similarity against all entries,
57    /// return top-k results sorted by descending score.
58    pub fn search(&self, query: &Embedding, top_k: usize) -> Vec<SubSearchResult> {
59        let mut scored: Vec<SubSearchResult> = self
60            .entries
61            .iter()
62            .map(|entry| SubSearchResult {
63                stage_id: entry.stage_id.clone(),
64                score: cosine_similarity(query, &entry.embedding),
65            })
66            .collect();
67
68        // Sort descending by score
69        scored.sort_by(|a, b| {
70            b.score
71                .partial_cmp(&a.score)
72                .unwrap_or(std::cmp::Ordering::Equal)
73        });
74        scored.truncate(top_k);
75        scored
76    }
77}
78
79/// Cosine similarity between two vectors. Returns value in [-1, 1].
80/// For normalized vectors, this reduces to dot product.
81pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
82    debug_assert_eq!(a.len(), b.len());
83    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
84    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
85    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
86    if norm_a == 0.0 || norm_b == 0.0 {
87        return 0.0;
88    }
89    dot / (norm_a * norm_b)
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn cosine_identical_vectors() {
98        let v = vec![1.0, 2.0, 3.0];
99        let sim = cosine_similarity(&v, &v);
100        assert!((sim - 1.0).abs() < 1e-6);
101    }
102
103    #[test]
104    fn cosine_orthogonal_vectors() {
105        let a = vec![1.0, 0.0];
106        let b = vec![0.0, 1.0];
107        let sim = cosine_similarity(&a, &b);
108        assert!(sim.abs() < 1e-6);
109    }
110
111    #[test]
112    fn cosine_opposite_vectors() {
113        let a = vec![1.0, 0.0];
114        let b = vec![-1.0, 0.0];
115        let sim = cosine_similarity(&a, &b);
116        assert!((sim - (-1.0)).abs() < 1e-6);
117    }
118
119    #[test]
120    fn subindex_search_returns_top_k() {
121        let mut idx = SubIndex::new();
122        for i in 0..10 {
123            let mut emb = vec![0.0; 4];
124            emb[i % 4] = 1.0;
125            idx.add(StageId(format!("s{i}")), emb);
126        }
127        let query = vec![1.0, 0.0, 0.0, 0.0];
128        let results = idx.search(&query, 3);
129        assert_eq!(results.len(), 3);
130    }
131
132    #[test]
133    fn subindex_search_sorted_by_score() {
134        let mut idx = SubIndex::new();
135        idx.add(StageId("a".into()), vec![1.0, 0.0]);
136        idx.add(StageId("b".into()), vec![0.5, 0.5]);
137        idx.add(StageId("c".into()), vec![0.0, 1.0]);
138        let query = vec![1.0, 0.0];
139        let results = idx.search(&query, 3);
140        assert!(results[0].score >= results[1].score);
141        assert!(results[1].score >= results[2].score);
142    }
143
144    #[test]
145    fn subindex_empty_returns_empty() {
146        let idx = SubIndex::new();
147        let results = idx.search(&vec![1.0, 0.0], 5);
148        assert!(results.is_empty());
149    }
150
151    #[test]
152    fn subindex_remove() {
153        let mut idx = SubIndex::new();
154        idx.add(StageId("a".into()), vec![1.0, 0.0]);
155        idx.add(StageId("b".into()), vec![0.0, 1.0]);
156        assert_eq!(idx.len(), 2);
157        idx.remove(&StageId("a".into()));
158        assert_eq!(idx.len(), 1);
159    }
160}