noether_engine/index/
search.rs1use super::embedding::Embedding;
2use noether_core::stage::StageId;
3
4#[derive(Debug, Clone)]
6pub struct IndexEntry {
7 pub stage_id: StageId,
8 pub embedding: Embedding,
9}
10
11#[derive(Debug, Clone, Default)]
13pub struct SubIndex {
14 entries: Vec<IndexEntry>,
15}
16
17#[derive(Debug, Clone)]
19pub struct SubSearchResult {
20 pub stage_id: StageId,
21 pub score: f32,
22}
23
24impl SubIndex {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn add(&mut self, stage_id: StageId, embedding: Embedding) {
30 self.entries.push(IndexEntry {
31 stage_id,
32 embedding,
33 });
34 }
35
36 pub fn remove(&mut self, stage_id: &StageId) {
37 self.entries.retain(|e| &e.stage_id != stage_id);
38 }
39
40 pub fn len(&self) -> usize {
41 self.entries.len()
42 }
43
44 pub fn is_empty(&self) -> bool {
45 self.entries.is_empty()
46 }
47
48 pub fn entries(&self) -> &[IndexEntry] {
50 &self.entries
51 }
52
53 pub fn search(&self, query: &Embedding, top_k: usize) -> Vec<SubSearchResult> {
56 let mut scored: Vec<SubSearchResult> = self
57 .entries
58 .iter()
59 .map(|entry| SubSearchResult {
60 stage_id: entry.stage_id.clone(),
61 score: cosine_similarity(query, &entry.embedding),
62 })
63 .collect();
64
65 scored.sort_by(|a, b| {
67 b.score
68 .partial_cmp(&a.score)
69 .unwrap_or(std::cmp::Ordering::Equal)
70 });
71 scored.truncate(top_k);
72 scored
73 }
74}
75
76pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79 debug_assert_eq!(a.len(), b.len());
80 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
81 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
82 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
83 if norm_a == 0.0 || norm_b == 0.0 {
84 return 0.0;
85 }
86 dot / (norm_a * norm_b)
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn cosine_identical_vectors() {
95 let v = vec![1.0, 2.0, 3.0];
96 let sim = cosine_similarity(&v, &v);
97 assert!((sim - 1.0).abs() < 1e-6);
98 }
99
100 #[test]
101 fn cosine_orthogonal_vectors() {
102 let a = vec![1.0, 0.0];
103 let b = vec![0.0, 1.0];
104 let sim = cosine_similarity(&a, &b);
105 assert!(sim.abs() < 1e-6);
106 }
107
108 #[test]
109 fn cosine_opposite_vectors() {
110 let a = vec![1.0, 0.0];
111 let b = vec![-1.0, 0.0];
112 let sim = cosine_similarity(&a, &b);
113 assert!((sim - (-1.0)).abs() < 1e-6);
114 }
115
116 #[test]
117 fn subindex_search_returns_top_k() {
118 let mut idx = SubIndex::new();
119 for i in 0..10 {
120 let mut emb = vec![0.0; 4];
121 emb[i % 4] = 1.0;
122 idx.add(StageId(format!("s{i}")), emb);
123 }
124 let query = vec![1.0, 0.0, 0.0, 0.0];
125 let results = idx.search(&query, 3);
126 assert_eq!(results.len(), 3);
127 }
128
129 #[test]
130 fn subindex_search_sorted_by_score() {
131 let mut idx = SubIndex::new();
132 idx.add(StageId("a".into()), vec![1.0, 0.0]);
133 idx.add(StageId("b".into()), vec![0.5, 0.5]);
134 idx.add(StageId("c".into()), vec![0.0, 1.0]);
135 let query = vec![1.0, 0.0];
136 let results = idx.search(&query, 3);
137 assert!(results[0].score >= results[1].score);
138 assert!(results[1].score >= results[2].score);
139 }
140
141 #[test]
142 fn subindex_empty_returns_empty() {
143 let idx = SubIndex::new();
144 let results = idx.search(&vec![1.0, 0.0], 5);
145 assert!(results.is_empty());
146 }
147
148 #[test]
149 fn subindex_remove() {
150 let mut idx = SubIndex::new();
151 idx.add(StageId("a".into()), vec![1.0, 0.0]);
152 idx.add(StageId("b".into()), vec![0.0, 1.0]);
153 assert_eq!(idx.len(), 2);
154 idx.remove(&StageId("a".into()));
155 assert_eq!(idx.len(), 1);
156 }
157}