1use std::path::Path;
8
9use anyhow::{Context, Result};
10use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
11use usearch::Index;
12
13const DEFAULT_DIMENSIONS: usize = 1536;
15
16const DEFAULT_CONNECTIVITY: usize = 16;
18
19const DEFAULT_EXPANSION_SEARCH: usize = 128;
21
22const DEFAULT_EXPANSION_ADD: usize = 128;
24
25pub struct HnswIndex {
35 index: Index,
37 dimensions: usize,
39}
40
41impl std::fmt::Debug for HnswIndex {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("HnswIndex")
44 .field("dimensions", &self.dimensions)
45 .field("size", &self.len())
46 .finish()
47 }
48}
49
50impl HnswIndex {
51 pub fn new(dimensions: usize, capacity: usize) -> Result<Self> {
57 let options = IndexOptions {
58 dimensions,
59 metric: MetricKind::Cos,
60 quantization: ScalarKind::F32,
61 connectivity: DEFAULT_CONNECTIVITY,
62 expansion_add: DEFAULT_EXPANSION_ADD,
63 expansion_search: DEFAULT_EXPANSION_SEARCH,
64 multi: false,
65 };
66
67 let index = Index::new(&options).context("Failed to create HNSW index")?;
68 if capacity > 0 {
69 index
70 .reserve(capacity)
71 .map_err(|e| anyhow::anyhow!("Failed to reserve HNSW capacity: {}", e))?;
72 }
73
74 Ok(Self { index, dimensions })
75 }
76
77 pub fn with_default_dims(capacity: usize) -> Result<Self> {
79 Self::new(DEFAULT_DIMENSIONS, capacity)
80 }
81
82 pub fn add(&self, key: u64, vector: &[f32]) -> Result<()> {
87 anyhow::ensure!(
88 vector.len() == self.dimensions,
89 "Vector dimension mismatch: expected {}, got {}",
90 self.dimensions,
91 vector.len()
92 );
93 self.index
94 .add(key, vector)
95 .map_err(|e| anyhow::anyhow!("HNSW add failed for key {}: {}", key, e))?;
96 Ok(())
97 }
98
99 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
104 anyhow::ensure!(
105 query.len() == self.dimensions,
106 "Query dimension mismatch: expected {}, got {}",
107 self.dimensions,
108 query.len()
109 );
110 if k == 0 {
111 return Ok(Vec::new());
112 }
113
114 let results = self
115 .index
116 .search(query, k)
117 .map_err(|e| anyhow::anyhow!("HNSW search failed: {}", e))?;
118
119 Ok(results
120 .keys
121 .into_iter()
122 .zip(results.distances)
123 .filter(|(k, _)| *k != 0)
124 .collect())
125 }
126
127 pub fn remove(&self, key: u64) -> Result<()> {
129 self.index
130 .remove(key)
131 .map(|_| ())
132 .map_err(|e| anyhow::anyhow!("HNSW remove failed for key {}: {}", key, e))
133 }
134
135 pub fn contains(&self, key: u64) -> bool {
137 self.index.contains(key)
138 }
139
140 pub fn get(&self, key: u64) -> Option<Vec<f32>> {
142 let mut buffer = vec![0.0f32; self.dimensions];
143 match self.index.get(key, &mut buffer) {
144 Ok(count) if count > 0 => Some(buffer),
145 _ => None,
146 }
147 }
148
149 pub fn len(&self) -> usize {
151 self.index.size()
152 }
153
154 pub fn is_empty(&self) -> bool {
156 self.len() == 0
157 }
158
159 pub fn dimensions(&self) -> usize {
161 self.dimensions
162 }
163
164 pub fn save(&self, path: &Path) -> Result<()> {
166 let path_str = path.to_str().ok_or_else(|| {
167 anyhow::anyhow!("HNSW save path is not valid UTF-8: {}", path.display())
168 })?;
169 self.index
170 .save(path_str)
171 .map_err(|e| anyhow::anyhow!("HNSW save failed: {}", e))?;
172 Ok(())
173 }
174
175 pub fn load(path: &Path) -> Result<Self> {
179 let path_str = path.to_str().ok_or_else(|| {
180 anyhow::anyhow!("HNSW load path is not valid UTF-8: {}", path.display())
181 })?;
182 let index =
183 Index::restore(path_str).map_err(|e| anyhow::anyhow!("HNSW load failed: {}", e))?;
184 let dimensions = index.dimensions();
185 Ok(Self { index, dimensions })
186 }
187
188 pub fn reserve(&self, capacity: usize) -> Result<()> {
190 self.index
191 .reserve(capacity)
192 .map_err(|e| anyhow::anyhow!("HNSW reserve failed: {}", e))?;
193 Ok(())
194 }
195
196 pub fn rename(&self, from: u64, to: u64) -> Result<()> {
198 self.index
199 .rename(from, to)
200 .map(|_| ())
201 .map_err(|e| anyhow::anyhow!("HNSW rename failed: {} -> {}: {}", from, to, e))
202 }
203}
204
205#[cfg(test)]
210mod tests {
211 use super::*;
212 use tempfile::TempDir;
213
214 #[test]
215 fn test_hnsw_add_and_search() {
216 let index = HnswIndex::new(3, 100).unwrap();
217
218 let v1: Vec<f32> = vec![1.0, 0.0, 0.0];
219 let v2: Vec<f32> = vec![0.0, 1.0, 0.0];
220 let v3: Vec<f32> = vec![0.0, 0.0, 1.0];
221
222 index.add(1, &v1).unwrap();
223 index.add(2, &v2).unwrap();
224 index.add(3, &v3).unwrap();
225
226 assert_eq!(index.len(), 3);
227
228 let results = index.search(&v1, 1).unwrap();
230 assert_eq!(results.len(), 1);
231 assert_eq!(results[0].0, 1);
232 assert!(
234 results[0].1 < 0.01,
235 "Distance should be ~0, got {}",
236 results[0].1
237 );
238 }
239
240 #[test]
241 fn test_hnsw_search_multiple() {
242 let index = HnswIndex::new(4, 100).unwrap();
243
244 index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
246 index.add(2, &[0.9, 0.1, 0.0, 0.0]).unwrap();
247 index.add(3, &[0.0, 1.0, 0.0, 0.0]).unwrap();
248 index.add(4, &[0.0, 0.9, 0.1, 0.0]).unwrap();
249
250 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
251 assert_eq!(results.len(), 2);
252 assert_eq!(results[0].0, 1);
254 assert_eq!(results[1].0, 2);
256 }
257
258 #[test]
259 fn test_hnsw_dimension_mismatch() {
260 let index = HnswIndex::new(3, 10).unwrap();
261 let result = index.add(1, &[1.0, 0.0]); assert!(result.is_err());
263 }
264
265 #[test]
266 fn test_hnsw_save_and_load() {
267 let dir = TempDir::new().unwrap();
268 let path = dir.path().join("test.usearch");
269
270 {
271 let index = HnswIndex::new(3, 100).unwrap();
272 index.add(1, &[1.0, 0.0, 0.0]).unwrap();
273 index.add(2, &[0.0, 1.0, 0.0]).unwrap();
274 index.save(&path).unwrap();
275 }
276
277 let loaded = HnswIndex::load(&path).unwrap();
278 assert_eq!(loaded.len(), 2);
279 assert_eq!(loaded.dimensions(), 3);
280
281 let results = loaded.search(&[1.0, 0.0, 0.0], 1).unwrap();
282 assert_eq!(results[0].0, 1);
283 }
284
285 #[test]
286 fn test_hnsw_contains() {
287 let index = HnswIndex::new(3, 10).unwrap();
288 assert!(!index.contains(1));
289
290 index.add(1, &[1.0, 0.0, 0.0]).unwrap();
291 assert!(index.contains(1));
292 assert!(!index.contains(2));
293 }
294
295 #[test]
296 fn test_hnsw_remove() {
297 let index = HnswIndex::new(3, 100).unwrap();
298 index.add(1, &[1.0, 0.0, 0.0]).unwrap();
299 assert_eq!(index.len(), 1);
300
301 index.remove(1).unwrap();
302 assert_eq!(index.len(), 0);
303 }
304
305 #[test]
306 fn test_hnsw_empty_search() {
307 let index = HnswIndex::new(3, 10).unwrap();
308 let results = index.search(&[1.0, 0.0, 0.0], 5).unwrap();
309 assert!(results.is_empty());
310 }
311
312 #[test]
313 fn test_hnsw_with_default_dims() {
314 let index = HnswIndex::with_default_dims(100).unwrap();
315 assert_eq!(index.dimensions(), 1536);
316 }
317}