1use codemem_core::{CodememError, VectorBackend, VectorConfig, VectorStats};
7use std::collections::HashMap;
8use std::path::Path;
9use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
10
11pub struct HnswIndex {
13 index: Index,
14 config: VectorConfig,
15 id_to_key: HashMap<String, u64>,
17 key_to_id: HashMap<u64, String>,
19 next_key: u64,
21}
22
23impl HnswIndex {
24 pub fn new(config: VectorConfig) -> Result<Self, CodememError> {
26 let metric = match config.metric {
27 codemem_core::DistanceMetric::Cosine => MetricKind::Cos,
28 codemem_core::DistanceMetric::L2 => MetricKind::L2sq,
29 codemem_core::DistanceMetric::InnerProduct => MetricKind::IP,
30 };
31
32 let options = IndexOptions {
33 dimensions: config.dimensions,
34 metric,
35 quantization: ScalarKind::F32,
36 connectivity: config.m,
37 expansion_add: config.ef_construction,
38 expansion_search: config.ef_search,
39 multi: false,
40 };
41
42 let index = Index::new(&options).map_err(|e| CodememError::Vector(e.to_string()))?;
43
44 index
46 .reserve(10_000)
47 .map_err(|e| CodememError::Vector(e.to_string()))?;
48
49 Ok(Self {
50 index,
51 config,
52 id_to_key: HashMap::new(),
53 key_to_id: HashMap::new(),
54 next_key: 0,
55 })
56 }
57
58 pub fn with_defaults() -> Result<Self, CodememError> {
60 Self::new(VectorConfig::default())
61 }
62
63 pub fn len(&self) -> usize {
65 self.index.size()
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.len() == 0
71 }
72
73 fn allocate_key(&mut self) -> u64 {
74 let key = self.next_key;
75 self.next_key += 1;
76 key
77 }
78}
79
80impl VectorBackend for HnswIndex {
81 fn insert(&mut self, id: &str, embedding: &[f32]) -> Result<(), CodememError> {
82 if embedding.len() != self.config.dimensions {
83 return Err(CodememError::Vector(format!(
84 "Expected {} dimensions, got {}",
85 self.config.dimensions,
86 embedding.len()
87 )));
88 }
89
90 if let Some(&old_key) = self.id_to_key.get(id) {
92 self.index
93 .remove(old_key)
94 .map_err(|e| CodememError::Vector(e.to_string()))?;
95 self.key_to_id.remove(&old_key);
96 }
97
98 let key = self.allocate_key();
99
100 if self.index.size() >= self.index.capacity() {
102 let new_cap = self.index.capacity() * 2;
103 self.index
104 .reserve(new_cap)
105 .map_err(|e| CodememError::Vector(e.to_string()))?;
106 }
107
108 self.index
109 .add(key, embedding)
110 .map_err(|e| CodememError::Vector(e.to_string()))?;
111
112 self.id_to_key.insert(id.to_string(), key);
113 self.key_to_id.insert(key, id.to_string());
114
115 Ok(())
116 }
117
118 fn insert_batch(&mut self, items: &[(String, Vec<f32>)]) -> Result<(), CodememError> {
119 for (id, embedding) in items {
120 self.insert(id, embedding)?;
121 }
122 Ok(())
123 }
124
125 fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>, CodememError> {
126 if self.is_empty() {
127 return Ok(vec![]);
128 }
129
130 let results = self
131 .index
132 .search(query, k)
133 .map_err(|e| CodememError::Vector(e.to_string()))?;
134
135 let mut output = Vec::with_capacity(results.keys.len());
136 for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
137 if let Some(id) = self.key_to_id.get(key) {
138 let similarity = 1.0 - distance;
140 output.push((id.clone(), similarity));
141 }
142 }
143
144 Ok(output)
145 }
146
147 fn remove(&mut self, id: &str) -> Result<bool, CodememError> {
148 if let Some(key) = self.id_to_key.remove(id) {
149 self.index
150 .remove(key)
151 .map_err(|e| CodememError::Vector(e.to_string()))?;
152 self.key_to_id.remove(&key);
153 Ok(true)
154 } else {
155 Ok(false)
156 }
157 }
158
159 fn save(&self, path: &Path) -> Result<(), CodememError> {
160 self.index
161 .save(path.to_str().unwrap_or("hnsw.index"))
162 .map_err(|e| CodememError::Vector(e.to_string()))?;
163
164 let map_path = path.with_extension("idmap");
166 let map_data = serde_json::to_string(&IdMapping {
167 id_to_key: &self.id_to_key,
168 next_key: self.next_key,
169 })
170 .map_err(|e| CodememError::Vector(e.to_string()))?;
171
172 std::fs::write(map_path, map_data)?;
173 Ok(())
174 }
175
176 fn load(&mut self, path: &Path) -> Result<(), CodememError> {
177 self.index
178 .load(path.to_str().unwrap_or("hnsw.index"))
179 .map_err(|e| CodememError::Vector(e.to_string()))?;
180
181 let map_path = path.with_extension("idmap");
183 if map_path.exists() {
184 let map_data = std::fs::read_to_string(map_path)?;
185 let mapping: IdMappingOwned =
186 serde_json::from_str(&map_data).map_err(|e| CodememError::Vector(e.to_string()))?;
187
188 self.id_to_key = mapping.id_to_key;
189 self.key_to_id = self
190 .id_to_key
191 .iter()
192 .map(|(id, key)| (*key, id.clone()))
193 .collect();
194 self.next_key = mapping.next_key;
195 }
196
197 Ok(())
198 }
199
200 fn stats(&self) -> VectorStats {
201 VectorStats {
202 count: self.len(),
203 dimensions: self.config.dimensions,
204 metric: format!("{:?}", self.config.metric),
205 memory_bytes: self.index.memory_usage(),
206 }
207 }
208}
209
210use serde::{Deserialize, Serialize};
211
212#[derive(Serialize)]
213struct IdMapping<'a> {
214 id_to_key: &'a HashMap<String, u64>,
215 next_key: u64,
216}
217
218#[derive(Deserialize)]
219struct IdMappingOwned {
220 id_to_key: HashMap<String, u64>,
221 next_key: u64,
222}
223
224#[cfg(test)]
225#[path = "tests/lib_tests.rs"]
226mod tests;