1use codemem_core::{CodememError, VectorBackend, VectorConfig, VectorStats};
7use std::collections::HashMap;
8use std::path::Path;
9use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
10
11pub struct HnswIndex {
13 index: Index,
14 config: VectorConfig,
15 id_to_key: HashMap<String, u64>,
17 key_to_id: HashMap<u64, String>,
19 next_key: u64,
21 ghost_count: usize,
24}
25
26impl HnswIndex {
27 pub fn new(config: VectorConfig) -> Result<Self, CodememError> {
29 let metric = match config.metric {
30 codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
31 codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
32 codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
33 };
34
35 let options = IndexOptions {
36 dimensions: config.dimensions,
37 metric,
38 quantization: ScalarKind::F32,
39 connectivity: config.m,
40 expansion_add: config.ef_construction,
41 expansion_search: config.ef_search,
42 multi: false,
43 };
44
45 let index = Index::new(&options).map_err(|e| CodememError::Vector(e.to_string()))?;
46
47 index
49 .reserve(10_000)
50 .map_err(|e| CodememError::Vector(e.to_string()))?;
51
52 Ok(Self {
53 index,
54 config,
55 id_to_key: HashMap::new(),
56 key_to_id: HashMap::new(),
57 next_key: 0,
58 ghost_count: 0,
59 })
60 }
61
62 pub fn with_defaults() -> Result<Self, CodememError> {
64 Self::new(VectorConfig::default())
65 }
66
67 pub fn len(&self) -> usize {
69 self.index.size()
70 }
71
72 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76
77 fn allocate_key(&mut self) -> u64 {
78 let key = self.next_key;
79 self.next_key += 1;
80 key
81 }
82
83 pub fn rebuild_from_entries(
90 &mut self,
91 entries: &[(String, Vec<f32>)],
92 ) -> Result<(), CodememError> {
93 let new_index = Index::new(&IndexOptions {
94 dimensions: self.config.dimensions,
95 metric: match self.config.metric {
96 codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
97 codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
98 codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
99 },
100 quantization: ScalarKind::F32,
101 connectivity: self.config.m,
102 expansion_add: self.config.ef_construction,
103 expansion_search: self.config.ef_search,
104 multi: false,
105 })
106 .map_err(|e| CodememError::Vector(e.to_string()))?;
107
108 let capacity = entries.len().max(1024);
109 new_index
110 .reserve(capacity)
111 .map_err(|e| CodememError::Vector(e.to_string()))?;
112
113 self.index = new_index;
114 self.id_to_key.clear();
115 self.key_to_id.clear();
116 self.next_key = 0;
117 self.ghost_count = 0;
118
119 for (id, embedding) in entries {
120 self.insert(id, embedding)?;
121 }
122
123 Ok(())
124 }
125
126 pub fn needs_compaction(&self) -> bool {
128 let live = self.id_to_key.len();
129 live > 0 && self.ghost_count > live / 5
130 }
131
132 pub fn ghost_count(&self) -> usize {
134 self.ghost_count
135 }
136}
137
138impl VectorBackend for HnswIndex {
139 fn insert(&mut self, id: &str, embedding: &[f32]) -> Result<(), CodememError> {
140 if embedding.len() != self.config.dimensions {
141 return Err(CodememError::Vector(format!(
142 "Expected {} dimensions, got {}",
143 self.config.dimensions,
144 embedding.len()
145 )));
146 }
147
148 if let Some(&old_key) = self.id_to_key.get(id) {
150 self.index
151 .remove(old_key)
152 .map_err(|e| CodememError::Vector(e.to_string()))?;
153 self.key_to_id.remove(&old_key);
154 self.ghost_count += 1;
155 }
156
157 let key = self.allocate_key();
158
159 if self.index.size() >= self.index.capacity() {
161 let cap = self.index.capacity();
162 let new_cap = cap + 1024.max(cap / 4);
163 self.index
164 .reserve(new_cap)
165 .map_err(|e| CodememError::Vector(e.to_string()))?;
166 }
167
168 self.index
169 .add(key, embedding)
170 .map_err(|e| CodememError::Vector(e.to_string()))?;
171
172 self.id_to_key.insert(id.to_string(), key);
173 self.key_to_id.insert(key, id.to_string());
174
175 Ok(())
176 }
177
178 fn insert_batch(&mut self, items: &[(String, Vec<f32>)]) -> Result<(), CodememError> {
179 let needed = self.index.size() + items.len();
181 if needed > self.index.capacity() {
182 self.index
183 .reserve(needed)
184 .map_err(|e| CodememError::Vector(e.to_string()))?;
185 }
186 for (id, embedding) in items {
187 self.insert(id, embedding)?;
188 }
189 Ok(())
190 }
191
192 fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>, CodememError> {
193 if self.is_empty() {
194 return Ok(vec![]);
195 }
196
197 let results = self
198 .index
199 .search(query, k)
200 .map_err(|e| CodememError::Vector(e.to_string()))?;
201
202 let mut output = Vec::with_capacity(results.keys.len());
203 for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
204 if let Some(id) = self.key_to_id.get(key) {
205 let similarity = 1.0 - distance;
207 output.push((id.clone(), similarity));
208 }
209 }
210
211 Ok(output)
212 }
213
214 fn remove(&mut self, id: &str) -> Result<bool, CodememError> {
215 if let Some(key) = self.id_to_key.remove(id) {
216 self.index
217 .remove(key)
218 .map_err(|e| CodememError::Vector(e.to_string()))?;
219 self.key_to_id.remove(&key);
220 self.ghost_count += 1;
221 Ok(true)
222 } else {
223 Ok(false)
224 }
225 }
226
227 fn save(&self, path: &Path) -> Result<(), CodememError> {
228 let path_str = path
229 .to_str()
230 .ok_or_else(|| CodememError::Vector("Path contains non-UTF-8 characters".into()))?;
231
232 let idmap_path = path.with_extension("idmap");
233
234 let map_data = serde_json::to_string(&IdMapping {
236 id_to_key: &self.id_to_key,
237 next_key: self.next_key,
238 })
239 .map_err(|e| CodememError::Vector(e.to_string()))?;
240
241 let tmp_idmap = path.with_extension("idmap.tmp");
243 std::fs::write(&tmp_idmap, map_data)?;
244
245 let tmp_idx = path.with_extension("idx.tmp");
246 let tmp_idx_str = tmp_idx.to_str().ok_or_else(|| {
247 CodememError::Vector("Temp path contains non-UTF-8 characters".into())
248 })?;
249 self.index
250 .save(tmp_idx_str)
251 .map_err(|e| CodememError::Vector(e.to_string()))?;
252
253 std::fs::rename(&tmp_idmap, &idmap_path)?;
255 std::fs::rename(&tmp_idx, path_str)?;
256
257 Ok(())
258 }
259
260 fn load(&mut self, path: &Path) -> Result<(), CodememError> {
261 let path_str = path
262 .to_str()
263 .ok_or_else(|| CodememError::Vector("Path contains non-UTF-8 characters".into()))?;
264 self.index
265 .load(path_str)
266 .map_err(|e| CodememError::Vector(e.to_string()))?;
267
268 let map_path = path.with_extension("idmap");
270 if map_path.exists() {
271 let map_data = std::fs::read_to_string(map_path)?;
272 let mapping: IdMappingOwned =
273 serde_json::from_str(&map_data).map_err(|e| CodememError::Vector(e.to_string()))?;
274
275 self.id_to_key = mapping.id_to_key;
276 self.key_to_id = self
277 .id_to_key
278 .iter()
279 .map(|(id, key)| (*key, id.clone()))
280 .collect();
281 self.next_key = mapping.next_key;
282 self.ghost_count = 0; }
284
285 Ok(())
286 }
287
288 fn stats(&self) -> VectorStats {
289 VectorStats {
290 count: self.len(),
291 dimensions: self.config.dimensions,
292 metric: format!("{:?}", self.config.metric),
293 memory_bytes: self.index.memory_usage(),
294 }
295 }
296}
297
298use serde::{Deserialize, Serialize};
299
300#[derive(Serialize)]
301struct IdMapping<'a> {
302 id_to_key: &'a HashMap<String, u64>,
303 next_key: u64,
304}
305
306#[derive(Deserialize)]
307struct IdMappingOwned {
308 id_to_key: HashMap<String, u64>,
309 next_key: u64,
310}
311
312pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
314 if a.len() != b.len() || a.is_empty() {
315 return 0.0;
316 }
317 let mut dot = 0.0f64;
318 let mut norm_a = 0.0f64;
319 let mut norm_b = 0.0f64;
320 for (x, y) in a.iter().zip(b.iter()) {
321 let x = *x as f64;
322 let y = *y as f64;
323 dot += x * y;
324 norm_a += x * x;
325 norm_b += y * y;
326 }
327 let denom = norm_a.sqrt() * norm_b.sqrt();
328 if denom < 1e-12 {
329 0.0
330 } else {
331 dot / denom
332 }
333}
334
335#[cfg(test)]
336#[path = "tests/vector_tests.rs"]
337mod tests;