noether_engine/index/
search.rs1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use super::embedding::Embedding;
5use noether_core::stage::StageId;
6
7#[derive(Debug, Clone)]
9pub struct IndexEntry {
10 pub stage_id: StageId,
11 pub embedding: Embedding,
12}
13
14#[derive(Debug, Clone, Default)]
16pub struct SubIndex {
17 entries: Vec<IndexEntry>,
18}
19
20#[derive(Debug, Clone)]
22pub struct SubSearchResult {
23 pub stage_id: StageId,
24 pub score: f32,
25}
26
27impl SubIndex {
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn add(&mut self, stage_id: StageId, embedding: Embedding) {
33 self.entries.push(IndexEntry {
34 stage_id,
35 embedding,
36 });
37 }
38
39 pub fn remove(&mut self, stage_id: &StageId) {
40 self.entries.retain(|e| &e.stage_id != stage_id);
41 }
42
43 pub fn len(&self) -> usize {
44 self.entries.len()
45 }
46
47 pub fn is_empty(&self) -> bool {
48 self.entries.is_empty()
49 }
50
51 pub fn entries(&self) -> &[IndexEntry] {
53 &self.entries
54 }
55
56 pub fn search(&self, query: &Embedding, top_k: usize) -> Vec<SubSearchResult> {
59 let mut scored: Vec<SubSearchResult> = self
60 .entries
61 .iter()
62 .map(|entry| SubSearchResult {
63 stage_id: entry.stage_id.clone(),
64 score: cosine_similarity(query, &entry.embedding),
65 })
66 .collect();
67
68 scored.sort_by(|a, b| {
70 b.score
71 .partial_cmp(&a.score)
72 .unwrap_or(std::cmp::Ordering::Equal)
73 });
74 scored.truncate(top_k);
75 scored
76 }
77}
78
79pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
82 debug_assert_eq!(a.len(), b.len());
83 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
84 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
85 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
86 if norm_a == 0.0 || norm_b == 0.0 {
87 return 0.0;
88 }
89 dot / (norm_a * norm_b)
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn cosine_identical_vectors() {
98 let v = vec![1.0, 2.0, 3.0];
99 let sim = cosine_similarity(&v, &v);
100 assert!((sim - 1.0).abs() < 1e-6);
101 }
102
103 #[test]
104 fn cosine_orthogonal_vectors() {
105 let a = vec![1.0, 0.0];
106 let b = vec![0.0, 1.0];
107 let sim = cosine_similarity(&a, &b);
108 assert!(sim.abs() < 1e-6);
109 }
110
111 #[test]
112 fn cosine_opposite_vectors() {
113 let a = vec![1.0, 0.0];
114 let b = vec![-1.0, 0.0];
115 let sim = cosine_similarity(&a, &b);
116 assert!((sim - (-1.0)).abs() < 1e-6);
117 }
118
119 #[test]
120 fn subindex_search_returns_top_k() {
121 let mut idx = SubIndex::new();
122 for i in 0..10 {
123 let mut emb = vec![0.0; 4];
124 emb[i % 4] = 1.0;
125 idx.add(StageId(format!("s{i}")), emb);
126 }
127 let query = vec![1.0, 0.0, 0.0, 0.0];
128 let results = idx.search(&query, 3);
129 assert_eq!(results.len(), 3);
130 }
131
132 #[test]
133 fn subindex_search_sorted_by_score() {
134 let mut idx = SubIndex::new();
135 idx.add(StageId("a".into()), vec![1.0, 0.0]);
136 idx.add(StageId("b".into()), vec![0.5, 0.5]);
137 idx.add(StageId("c".into()), vec![0.0, 1.0]);
138 let query = vec![1.0, 0.0];
139 let results = idx.search(&query, 3);
140 assert!(results[0].score >= results[1].score);
141 assert!(results[1].score >= results[2].score);
142 }
143
144 #[test]
145 fn subindex_empty_returns_empty() {
146 let idx = SubIndex::new();
147 let results = idx.search(&vec![1.0, 0.0], 5);
148 assert!(results.is_empty());
149 }
150
151 #[test]
152 fn subindex_remove() {
153 let mut idx = SubIndex::new();
154 idx.add(StageId("a".into()), vec![1.0, 0.0]);
155 idx.add(StageId("b".into()), vec![0.0, 1.0]);
156 assert_eq!(idx.len(), 2);
157 idx.remove(&StageId("a".into()));
158 assert_eq!(idx.len(), 1);
159 }
160}