1use serde::{Deserialize, Serialize};
19
20use manifoldb_core::EntityId;
21use manifoldb_storage::{Cursor, Transaction};
22
23use crate::distance::DistanceMetric;
24use crate::error::VectorError;
25use crate::types::Embedding;
26
27use super::config::HnswConfig;
28
29pub const HNSW_REGISTRY_TABLE: &str = "hnsw_registry";
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct HnswIndexEntry {
35 pub name: String,
37 pub table: String,
39 pub column: String,
41 pub dimension: usize,
43 pub distance_metric: DistanceMetricSerde,
45 pub m: usize,
47 pub m_max0: usize,
49 pub ef_construction: usize,
51 pub ef_search: usize,
53 pub ml_bits: u64,
55 #[serde(default)]
57 pub pq_segments: usize,
58 #[serde(default = "default_pq_centroids")]
60 pub pq_centroids: usize,
61 #[serde(default)]
64 pub collection_name: Option<String>,
65 #[serde(default)]
68 pub vector_name: Option<String>,
69}
70
71fn default_pq_centroids() -> usize {
72 256
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
77pub enum DistanceMetricSerde {
78 Euclidean,
80 Cosine,
82 DotProduct,
84 Manhattan,
86 Chebyshev,
88}
89
90impl From<DistanceMetric> for DistanceMetricSerde {
91 fn from(metric: DistanceMetric) -> Self {
92 match metric {
93 DistanceMetric::Euclidean => Self::Euclidean,
94 DistanceMetric::Cosine => Self::Cosine,
95 DistanceMetric::DotProduct => Self::DotProduct,
96 DistanceMetric::Manhattan => Self::Manhattan,
97 DistanceMetric::Chebyshev => Self::Chebyshev,
98 }
99 }
100}
101
102impl From<DistanceMetricSerde> for DistanceMetric {
103 fn from(metric: DistanceMetricSerde) -> Self {
104 match metric {
105 DistanceMetricSerde::Euclidean => Self::Euclidean,
106 DistanceMetricSerde::Cosine => Self::Cosine,
107 DistanceMetricSerde::DotProduct => Self::DotProduct,
108 DistanceMetricSerde::Manhattan => Self::Manhattan,
109 DistanceMetricSerde::Chebyshev => Self::Chebyshev,
110 }
111 }
112}
113
114impl HnswIndexEntry {
115 #[must_use]
117 pub fn new(
118 name: impl Into<String>,
119 table: impl Into<String>,
120 column: impl Into<String>,
121 dimension: usize,
122 distance_metric: DistanceMetric,
123 config: &HnswConfig,
124 ) -> Self {
125 Self {
126 name: name.into(),
127 table: table.into(),
128 column: column.into(),
129 dimension,
130 distance_metric: distance_metric.into(),
131 m: config.m,
132 m_max0: config.m_max0,
133 ef_construction: config.ef_construction,
134 ef_search: config.ef_search,
135 ml_bits: config.ml.to_bits(),
136 pq_segments: config.pq_segments,
137 pq_centroids: config.pq_centroids,
138 collection_name: None,
139 vector_name: None,
140 }
141 }
142
143 #[must_use]
147 pub fn for_named_vector(
148 collection: impl Into<String>,
149 vector: impl Into<String>,
150 dimension: usize,
151 distance_metric: DistanceMetric,
152 config: &HnswConfig,
153 ) -> Self {
154 let collection_name = collection.into();
155 let vector_name = vector.into();
156 let index_name = format!("{}_{}_hnsw", collection_name, vector_name);
157
158 Self {
159 name: index_name,
160 table: collection_name.clone(),
161 column: vector_name.clone(),
162 dimension,
163 distance_metric: distance_metric.into(),
164 m: config.m,
165 m_max0: config.m_max0,
166 ef_construction: config.ef_construction,
167 ef_search: config.ef_search,
168 ml_bits: config.ml.to_bits(),
169 pq_segments: config.pq_segments,
170 pq_centroids: config.pq_centroids,
171 collection_name: Some(collection_name),
172 vector_name: Some(vector_name),
173 }
174 }
175
176 #[must_use]
178 pub fn is_named_vector_index(&self) -> bool {
179 self.collection_name.is_some() && self.vector_name.is_some()
180 }
181
182 #[must_use]
184 pub fn collection(&self) -> Option<&str> {
185 self.collection_name.as_deref()
186 }
187
188 #[must_use]
190 pub fn vector(&self) -> Option<&str> {
191 self.vector_name.as_deref()
192 }
193
194 #[must_use]
196 pub fn config(&self) -> HnswConfig {
197 HnswConfig {
198 m: self.m,
199 m_max0: self.m_max0,
200 ef_construction: self.ef_construction,
201 ef_search: self.ef_search,
202 ml: f64::from_bits(self.ml_bits),
203 pq_segments: self.pq_segments,
204 pq_centroids: self.pq_centroids,
205 pq_training_samples: 1000, }
207 }
208
209 #[must_use]
211 pub fn distance_metric(&self) -> DistanceMetric {
212 self.distance_metric.into()
213 }
214
215 pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
217 bincode::serde::encode_to_vec(self, bincode::config::standard())
218 .map_err(|e| VectorError::Encoding(format!("failed to serialize index entry: {e}")))
219 }
220
221 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
223 bincode::serde::decode_from_slice(bytes, bincode::config::standard())
224 .map(|(entry, _)| entry)
225 .map_err(|e| VectorError::Encoding(format!("failed to deserialize index entry: {e}")))
226 }
227}
228
229pub struct HnswRegistry;
237
238impl HnswRegistry {
239 pub fn register<T: Transaction>(tx: &mut T, entry: &HnswIndexEntry) -> Result<(), VectorError> {
244 let key = Self::entry_key(&entry.name);
245 let value = entry.to_bytes()?;
246 tx.put(HNSW_REGISTRY_TABLE, &key, &value)?;
247 Ok(())
248 }
249
250 pub fn get<T: Transaction>(tx: &T, name: &str) -> Result<Option<HnswIndexEntry>, VectorError> {
252 let key = Self::entry_key(name);
253 match tx.get(HNSW_REGISTRY_TABLE, &key)? {
254 Some(bytes) => Ok(Some(HnswIndexEntry::from_bytes(&bytes)?)),
255 None => Ok(None),
256 }
257 }
258
259 pub fn exists<T: Transaction>(tx: &T, name: &str) -> Result<bool, VectorError> {
261 let key = Self::entry_key(name);
262 Ok(tx.get(HNSW_REGISTRY_TABLE, &key)?.is_some())
263 }
264
265 pub fn drop<T: Transaction>(tx: &mut T, name: &str) -> Result<bool, VectorError> {
270 let key = Self::entry_key(name);
271 Ok(tx.delete(HNSW_REGISTRY_TABLE, &key)?)
272 }
273
274 pub fn list_for_table<T: Transaction>(
276 tx: &T,
277 table: &str,
278 ) -> Result<Vec<HnswIndexEntry>, VectorError> {
279 use std::ops::Bound;
280
281 let mut entries = Vec::new();
282
283 let mut cursor = tx.range(HNSW_REGISTRY_TABLE, Bound::Unbounded, Bound::Unbounded)?;
285
286 while let Some((_, value)) = cursor.next()? {
287 if let Ok(entry) = HnswIndexEntry::from_bytes(&value) {
288 if entry.table == table {
289 entries.push(entry);
290 }
291 }
292 }
293
294 Ok(entries)
295 }
296
297 pub fn list_for_column<T: Transaction>(
299 tx: &T,
300 table: &str,
301 column: &str,
302 ) -> Result<Vec<HnswIndexEntry>, VectorError> {
303 let table_entries = Self::list_for_table(tx, table)?;
304 Ok(table_entries.into_iter().filter(|e| e.column == column).collect())
305 }
306
307 pub fn list_all<T: Transaction>(tx: &T) -> Result<Vec<HnswIndexEntry>, VectorError> {
309 use std::ops::Bound;
310
311 let mut entries = Vec::new();
312
313 let mut cursor = tx.range(HNSW_REGISTRY_TABLE, Bound::Unbounded, Bound::Unbounded)?;
315
316 while let Some((_, value)) = cursor.next()? {
317 if let Ok(entry) = HnswIndexEntry::from_bytes(&value) {
318 entries.push(entry);
319 }
320 }
321
322 Ok(entries)
323 }
324
325 pub fn get_for_named_vector<T: Transaction>(
329 tx: &T,
330 collection: &str,
331 vector_name: &str,
332 ) -> Result<Option<HnswIndexEntry>, VectorError> {
333 let expected_name = format!("{}_{}_hnsw", collection, vector_name);
335 if let Some(entry) = Self::get(tx, &expected_name)? {
336 return Ok(Some(entry));
337 }
338
339 use std::ops::Bound;
341 let mut cursor = tx.range(HNSW_REGISTRY_TABLE, Bound::Unbounded, Bound::Unbounded)?;
342
343 while let Some((_, value)) = cursor.next()? {
344 if let Ok(entry) = HnswIndexEntry::from_bytes(&value) {
345 if entry.collection_name.as_deref() == Some(collection)
346 && entry.vector_name.as_deref() == Some(vector_name)
347 {
348 return Ok(Some(entry));
349 }
350 }
351 }
352
353 Ok(None)
354 }
355
356 pub fn list_for_collection<T: Transaction>(
358 tx: &T,
359 collection: &str,
360 ) -> Result<Vec<HnswIndexEntry>, VectorError> {
361 use std::ops::Bound;
362
363 let mut entries = Vec::new();
364
365 let mut cursor = tx.range(HNSW_REGISTRY_TABLE, Bound::Unbounded, Bound::Unbounded)?;
367
368 while let Some((_, value)) = cursor.next()? {
369 if let Ok(entry) = HnswIndexEntry::from_bytes(&value) {
370 if entry.collection_name.as_deref() == Some(collection) {
371 entries.push(entry);
372 }
373 }
374 }
375
376 Ok(entries)
377 }
378
379 pub fn exists_for_named_vector<T: Transaction>(
381 tx: &T,
382 collection: &str,
383 vector_name: &str,
384 ) -> Result<bool, VectorError> {
385 Ok(Self::get_for_named_vector(tx, collection, vector_name)?.is_some())
386 }
387
388 #[must_use]
390 pub fn index_name_for_vector(collection: &str, vector_name: &str) -> String {
391 format!("{}_{}_hnsw", collection, vector_name)
392 }
393
394 fn entry_key(name: &str) -> Vec<u8> {
396 name.as_bytes().to_vec()
397 }
398}
399
400pub trait EmbeddingLookup {
405 fn get_embedding(
407 &self,
408 entity_id: EntityId,
409 column: &str,
410 ) -> Result<Option<Embedding>, VectorError>;
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use manifoldb_storage::backends::RedbEngine;
417 use manifoldb_storage::StorageEngine;
418
419 #[test]
420 fn test_index_entry_roundtrip() {
421 let config = HnswConfig::default();
422 let entry = HnswIndexEntry::new(
423 "test_index",
424 "documents",
425 "embedding",
426 384,
427 DistanceMetric::Cosine,
428 &config,
429 );
430
431 let bytes = entry.to_bytes().unwrap();
432 let decoded = HnswIndexEntry::from_bytes(&bytes).unwrap();
433
434 assert_eq!(decoded.name, "test_index");
435 assert_eq!(decoded.table, "documents");
436 assert_eq!(decoded.column, "embedding");
437 assert_eq!(decoded.dimension, 384);
438 assert_eq!(decoded.distance_metric, DistanceMetricSerde::Cosine);
439 }
440
441 #[test]
442 fn test_registry_crud() {
443 let engine = RedbEngine::in_memory().unwrap();
444 let config = HnswConfig::default();
445
446 let entry = HnswIndexEntry::new(
447 "test_index",
448 "documents",
449 "embedding",
450 384,
451 DistanceMetric::Cosine,
452 &config,
453 );
454
455 {
457 let mut tx = engine.begin_write().unwrap();
458 HnswRegistry::register(&mut tx, &entry).unwrap();
459 tx.commit().unwrap();
460 }
461
462 {
464 let tx = engine.begin_read().unwrap();
465 let retrieved = HnswRegistry::get(&tx, "test_index").unwrap().unwrap();
466 assert_eq!(retrieved.name, "test_index");
467 assert!(HnswRegistry::exists(&tx, "test_index").unwrap());
468 assert!(!HnswRegistry::exists(&tx, "nonexistent").unwrap());
469 }
470
471 {
473 let tx = engine.begin_read().unwrap();
474 let entries = HnswRegistry::list_for_table(&tx, "documents").unwrap();
475 assert_eq!(entries.len(), 1);
476 let all = HnswRegistry::list_all(&tx).unwrap();
477 assert_eq!(all.len(), 1);
478 }
479
480 {
482 let mut tx = engine.begin_write().unwrap();
483 assert!(HnswRegistry::drop(&mut tx, "test_index").unwrap());
484 tx.commit().unwrap();
485 }
486
487 {
489 let tx = engine.begin_read().unwrap();
490 assert!(!HnswRegistry::exists(&tx, "test_index").unwrap());
491 }
492 }
493
494 #[test]
495 fn test_named_vector_entry() {
496 let config = HnswConfig::default();
497 let entry = HnswIndexEntry::for_named_vector(
498 "documents",
499 "embedding",
500 384,
501 DistanceMetric::Cosine,
502 &config,
503 );
504
505 assert_eq!(entry.name, "documents_embedding_hnsw");
506 assert_eq!(entry.table, "documents");
507 assert_eq!(entry.column, "embedding");
508 assert_eq!(entry.dimension, 384);
509 assert!(entry.is_named_vector_index());
510 assert_eq!(entry.collection(), Some("documents"));
511 assert_eq!(entry.vector(), Some("embedding"));
512 }
513
514 #[test]
515 fn test_named_vector_registry_lookup() {
516 let engine = RedbEngine::in_memory().unwrap();
517 let config = HnswConfig::default();
518
519 let entry = HnswIndexEntry::for_named_vector(
521 "documents",
522 "dense_embedding",
523 768,
524 DistanceMetric::Cosine,
525 &config,
526 );
527
528 {
530 let mut tx = engine.begin_write().unwrap();
531 HnswRegistry::register(&mut tx, &entry).unwrap();
532 tx.commit().unwrap();
533 }
534
535 {
537 let tx = engine.begin_read().unwrap();
538 let found =
539 HnswRegistry::get_for_named_vector(&tx, "documents", "dense_embedding").unwrap();
540 assert!(found.is_some());
541 let found = found.unwrap();
542 assert_eq!(found.name, "documents_dense_embedding_hnsw");
543 assert!(found.is_named_vector_index());
544 }
545
546 {
548 let tx = engine.begin_read().unwrap();
549 assert!(
550 HnswRegistry::exists_for_named_vector(&tx, "documents", "dense_embedding").unwrap()
551 );
552 assert!(
553 !HnswRegistry::exists_for_named_vector(&tx, "documents", "other_vector").unwrap()
554 );
555 }
556
557 {
559 let tx = engine.begin_read().unwrap();
560 let entries = HnswRegistry::list_for_collection(&tx, "documents").unwrap();
561 assert_eq!(entries.len(), 1);
562 assert_eq!(entries[0].vector(), Some("dense_embedding"));
563 }
564 }
565
566 #[test]
567 fn test_index_name_generation() {
568 assert_eq!(
569 HnswRegistry::index_name_for_vector("documents", "embedding"),
570 "documents_embedding_hnsw"
571 );
572 assert_eq!(
573 HnswRegistry::index_name_for_vector("my_collection", "dense"),
574 "my_collection_dense_hnsw"
575 );
576 }
577
578 #[test]
579 fn test_named_vector_entry_serialization() {
580 let config = HnswConfig::default();
581 let entry = HnswIndexEntry::for_named_vector(
582 "my_collection",
583 "text_vector",
584 512,
585 DistanceMetric::DotProduct,
586 &config,
587 );
588
589 let bytes = entry.to_bytes().unwrap();
590 let decoded = HnswIndexEntry::from_bytes(&bytes).unwrap();
591
592 assert_eq!(decoded.name, "my_collection_text_vector_hnsw");
593 assert_eq!(decoded.collection_name, Some("my_collection".to_string()));
594 assert_eq!(decoded.vector_name, Some("text_vector".to_string()));
595 assert!(decoded.is_named_vector_index());
596 }
597}