Skip to main content

keradb_sdk/
vector.rs

1//! Vector search types mirroring the Python SDK's `vector.py`.
2
3// @group VectorConfig     : Configuration types for vector collections
4// @group VectorDocuments  : Document and search result types
5// @group VectorFilter     : Metadata filter for filtered search
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10// ---------------------------------------------------------------------------
11// @group VectorConfig : Configuration types for vector collections
12// ---------------------------------------------------------------------------
13
14/// Distance metric used for vector similarity comparisons.
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum Distance {
18    /// Cosine similarity – range \[0, 2\], 0 = identical.
19    Cosine,
20    /// Euclidean (L2) distance.
21    Euclidean,
22    /// Negative dot product for similarity ranking.
23    DotProduct,
24    /// Manhattan (L1) distance.
25    Manhattan,
26}
27
28impl Distance {
29    /// Return the wire string used in JSON configuration.
30    pub fn as_str(&self) -> &'static str {
31        match self {
32            Distance::Cosine => "cosine",
33            Distance::Euclidean => "euclidean",
34            Distance::DotProduct => "dot_product",
35            Distance::Manhattan => "manhattan",
36        }
37    }
38}
39
40impl std::fmt::Display for Distance {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.write_str(self.as_str())
43    }
44}
45
46/// Vector storage compression mode.
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum CompressionMode {
50    /// No compression – full precision vectors.
51    None,
52    /// Delta compression – sparse differences from neighbours.
53    Delta,
54    /// Aggressive quantised deltas.
55    QuantizedDelta,
56}
57
58impl CompressionMode {
59    /// Return the wire string.
60    pub fn as_str(&self) -> &'static str {
61        match self {
62            CompressionMode::None => "none",
63            CompressionMode::Delta => "delta",
64            CompressionMode::QuantizedDelta => "quantized_delta",
65        }
66    }
67}
68
69impl std::fmt::Display for CompressionMode {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.write_str(self.as_str())
72    }
73}
74
75/// Fine-grained configuration for vector compression.
76#[derive(Debug, Clone, Default, Serialize)]
77pub struct CompressionConfig {
78    /// Compression mode.
79    pub mode: Option<CompressionMode>,
80    /// Threshold for considering a vector sparse.
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub sparsity_threshold: Option<f64>,
83    /// Maximum allowed density for delta compression.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub max_density: Option<f64>,
86    /// Frequency (in insertions) at which anchor vectors are written.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub anchor_frequency: Option<u32>,
89    /// Bit-width used during quantisation.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub quantization_bits: Option<u32>,
92}
93
94impl CompressionConfig {
95    /// Create a config with a specific mode and no extra parameters.
96    pub fn new(mode: CompressionMode) -> Self {
97        Self {
98            mode: Some(mode),
99            ..Default::default()
100        }
101    }
102
103    /// Delta compression with default parameters.
104    pub fn delta() -> Self {
105        Self::new(CompressionMode::Delta)
106    }
107
108    /// Quantised-delta compression with default parameters.
109    pub fn quantized_delta() -> Self {
110        Self::new(CompressionMode::QuantizedDelta)
111    }
112
113    /// Serialise to a `serde_json::Value` map.
114    pub fn to_value(&self) -> Value {
115        let mut map = serde_json::Map::new();
116        if let Some(mode) = &self.mode {
117            map.insert("mode".into(), Value::String(mode.as_str().to_owned()));
118        }
119        if let Some(v) = self.sparsity_threshold {
120            map.insert("sparsity_threshold".into(), Value::from(v));
121        }
122        if let Some(v) = self.max_density {
123            map.insert("max_density".into(), Value::from(v));
124        }
125        if let Some(v) = self.anchor_frequency {
126            map.insert("anchor_frequency".into(), Value::from(v));
127        }
128        if let Some(v) = self.quantization_bits {
129            map.insert("quantization_bits".into(), Value::from(v));
130        }
131        Value::Object(map)
132    }
133}
134
135/// Configuration for a vector collection, using a builder-style API that
136/// mirrors `VectorConfig` in the Python SDK.
137#[derive(Debug, Clone)]
138pub struct VectorConfig {
139    /// Number of dimensions in each embedding.
140    pub dimensions: usize,
141    /// Distance metric (default: [`Distance::Cosine`]).
142    pub distance: Distance,
143    /// HNSW `M` parameter – connections per node (default: library default).
144    pub m: Option<u32>,
145    /// HNSW `ef_construction` – build quality (default: library default).
146    pub ef_construction: Option<u32>,
147    /// HNSW `ef_search` – query quality (default: library default).
148    pub ef_search: Option<u32>,
149    /// Enable lazy embedding mode (store text, embed on demand).
150    pub lazy_embedding: bool,
151    /// Model name to use for lazy embedding.
152    pub embedding_model: Option<String>,
153    /// Optional vector compression settings.
154    pub compression: Option<CompressionConfig>,
155}
156
157impl VectorConfig {
158    /// Create a minimal config with the given number of dimensions.
159    pub fn new(dimensions: usize) -> Self {
160        Self {
161            dimensions,
162            distance: Distance::Cosine,
163            m: None,
164            ef_construction: None,
165            ef_search: None,
166            lazy_embedding: false,
167            embedding_model: None,
168            compression: None,
169        }
170    }
171
172    // Builder methods
173
174    /// Set the distance metric.
175    pub fn with_distance(mut self, distance: Distance) -> Self {
176        self.distance = distance;
177        self
178    }
179
180    /// Set the HNSW `M` parameter.
181    pub fn with_m(mut self, m: u32) -> Self {
182        self.m = Some(m);
183        self
184    }
185
186    /// Set the `ef_construction` parameter.
187    pub fn with_ef_construction(mut self, ef: u32) -> Self {
188        self.ef_construction = Some(ef);
189        self
190    }
191
192    /// Set the `ef_search` parameter.
193    pub fn with_ef_search(mut self, ef: u32) -> Self {
194        self.ef_search = Some(ef);
195        self
196    }
197
198    /// Enable lazy embedding using an embedding model name.
199    pub fn with_lazy_embedding(mut self, model: impl Into<String>) -> Self {
200        self.lazy_embedding = true;
201        self.embedding_model = Some(model.into());
202        self
203    }
204
205    /// Set a custom compression configuration.
206    pub fn with_compression(mut self, config: CompressionConfig) -> Self {
207        self.compression = Some(config);
208        self
209    }
210
211    /// Enable delta compression with default settings.
212    pub fn with_delta_compression(self) -> Self {
213        self.with_compression(CompressionConfig::delta())
214    }
215
216    /// Enable quantised-delta compression with default settings.
217    pub fn with_quantized_compression(self) -> Self {
218        self.with_compression(CompressionConfig::quantized_delta())
219    }
220
221    /// Serialise to JSON for passing to the native C API.
222    pub fn to_json(&self) -> String {
223        let mut map = serde_json::Map::new();
224        map.insert("dimensions".into(), Value::from(self.dimensions));
225        map.insert(
226            "distance".into(),
227            Value::String(self.distance.as_str().to_owned()),
228        );
229        if let Some(m) = self.m {
230            map.insert("m".into(), Value::from(m));
231        }
232        if let Some(ef) = self.ef_construction {
233            map.insert("ef_construction".into(), Value::from(ef));
234        }
235        if let Some(ef) = self.ef_search {
236            map.insert("ef_search".into(), Value::from(ef));
237        }
238        if self.lazy_embedding {
239            map.insert("lazy_embedding".into(), Value::Bool(true));
240            if let Some(model) = &self.embedding_model {
241                map.insert("embedding_model".into(), Value::String(model.clone()));
242            }
243        }
244        if let Some(comp) = &self.compression {
245            map.insert("compression".into(), comp.to_value());
246        }
247        Value::Object(map).to_string()
248    }
249}
250
251// ---------------------------------------------------------------------------
252// @group VectorDocuments : Document and search result types
253// ---------------------------------------------------------------------------
254
255/// A document stored in a vector collection.
256#[derive(Debug, Clone)]
257pub struct VectorDocument {
258    /// Numeric ID assigned by KeraDB.
259    pub id: u64,
260    /// The stored embedding (may be absent if not requested).
261    pub embedding: Option<Vec<f32>>,
262    /// Original text if stored via lazy embedding.
263    pub text: Option<String>,
264    /// User-supplied metadata as arbitrary JSON.
265    pub metadata: Value,
266}
267
268impl VectorDocument {
269    /// Deserialise from the JSON object returned by the native library.
270    pub fn from_value(v: &Value) -> Option<Self> {
271        let id = v.get("id")?.as_u64()?;
272        let embedding = v.get("embedding").and_then(|e| {
273            e.as_array().map(|arr| {
274                arr.iter()
275                    .filter_map(|x| x.as_f64().map(|f| f as f32))
276                    .collect()
277            })
278        });
279        let text = v.get("text").and_then(|t| t.as_str()).map(String::from);
280        let metadata = v
281            .get("metadata")
282            .cloned()
283            .unwrap_or(Value::Object(Default::default()));
284
285        Some(Self {
286            id,
287            embedding,
288            text,
289            metadata,
290        })
291    }
292}
293
294impl std::fmt::Display for VectorDocument {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        write!(f, "VectorDocument(id={}, metadata={})", self.id, self.metadata)
297    }
298}
299
300/// A single result from a vector similarity search.
301#[derive(Debug, Clone)]
302pub struct VectorSearchResult {
303    /// The matched document.
304    pub document: VectorDocument,
305    /// Similarity or distance score (lower is closer for distance metrics).
306    pub score: f64,
307    /// 1-based rank in the result set.
308    pub rank: usize,
309}
310
311impl VectorSearchResult {
312    /// Deserialise from the JSON object in the search results array.
313    pub fn from_value(v: &Value) -> Option<Self> {
314        let doc = VectorDocument::from_value(v.get("document")?)?;
315        let score = v.get("score")?.as_f64()?;
316        let rank = v.get("rank")?.as_u64()? as usize;
317        Some(Self {
318            document: doc,
319            score,
320            rank,
321        })
322    }
323}
324
325impl std::fmt::Display for VectorSearchResult {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        write!(
328            f,
329            "VectorSearchResult(rank={}, score={:.4}, id={})",
330            self.rank, self.score, self.document.id
331        )
332    }
333}
334
335/// Runtime statistics for a vector collection.
336#[derive(Debug, Clone)]
337pub struct VectorCollectionStats {
338    /// Total number of vectors stored.
339    pub vector_count: usize,
340    /// Per-vector dimensionality.
341    pub dimensions: usize,
342    /// Distance metric in use.
343    pub distance: Distance,
344    /// Approximate memory usage in bytes.
345    pub memory_usage: usize,
346    /// Number of HNSW layers in the index.
347    pub layer_count: usize,
348    /// Whether lazy embedding is enabled.
349    pub lazy_embedding: bool,
350    /// Active compression mode (if any).
351    pub compression: Option<CompressionMode>,
352    /// Number of anchor vectors (delta/quantised modes only).
353    pub anchor_count: Option<usize>,
354    /// Number of delta vectors (delta/quantised modes only).
355    pub delta_count: Option<usize>,
356}
357
358impl VectorCollectionStats {
359    /// Deserialise from the JSON object returned by `keradb_vector_stats`.
360    pub fn from_value(v: &Value) -> Option<Self> {
361        let distance_str = v.get("distance")?.as_str()?;
362        let distance = match distance_str {
363            "cosine" => Distance::Cosine,
364            "euclidean" => Distance::Euclidean,
365            "dot_product" => Distance::DotProduct,
366            "manhattan" => Distance::Manhattan,
367            _ => Distance::Cosine,
368        };
369        let compression = v.get("compression").and_then(|c| c.as_str()).map(|s| {
370            match s {
371                "delta" => CompressionMode::Delta,
372                "quantized_delta" => CompressionMode::QuantizedDelta,
373                _ => CompressionMode::None,
374            }
375        });
376
377        Some(Self {
378            vector_count: v.get("vector_count")?.as_u64()? as usize,
379            dimensions: v.get("dimensions")?.as_u64()? as usize,
380            distance,
381            memory_usage: v.get("memory_usage").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
382            layer_count: v.get("layer_count").and_then(|x| x.as_u64()).unwrap_or(0) as usize,
383            lazy_embedding: v
384                .get("lazy_embedding")
385                .and_then(|x| x.as_bool())
386                .unwrap_or(false),
387            compression,
388            anchor_count: v
389                .get("anchor_count")
390                .and_then(|x| x.as_u64())
391                .map(|x| x as usize),
392            delta_count: v
393                .get("delta_count")
394                .and_then(|x| x.as_u64())
395                .map(|x| x as usize),
396        })
397    }
398}
399
400impl std::fmt::Display for VectorCollectionStats {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        write!(
403            f,
404            "VectorCollectionStats(vectors={}, dimensions={}, distance={}, memory={} bytes)",
405            self.vector_count, self.dimensions, self.distance, self.memory_usage
406        )
407    }
408}
409
410// ---------------------------------------------------------------------------
411// @group VectorFilter : Metadata filter for filtered vector search
412// ---------------------------------------------------------------------------
413
414/// A filter on a metadata field, used in [`Client::vector_search_filtered`].
415///
416/// Supported conditions match the Python SDK:
417/// `eq`, `ne`, `gt`, `gte`, `lt`, `lte`, `in`, `not_in`, `contains`,
418/// `starts_with`, `ends_with`.
419#[derive(Debug, Clone, Serialize)]
420pub struct MetadataFilter {
421    /// The metadata field name to filter on.
422    pub field: String,
423    /// The condition type (e.g. `"eq"`, `"gt"`, `"in"`).
424    pub condition: String,
425    /// The comparison value as a JSON value.
426    pub value: Value,
427}
428
429impl MetadataFilter {
430    /// Create a new metadata filter.
431    pub fn new(field: impl Into<String>, condition: impl Into<String>, value: Value) -> Self {
432        Self {
433            field: field.into(),
434            condition: condition.into(),
435            value,
436        }
437    }
438
439    /// Equality filter shorthand.
440    pub fn eq(field: impl Into<String>, value: Value) -> Self {
441        Self::new(field, "eq", value)
442    }
443
444    /// Greater-than filter shorthand.
445    pub fn gt(field: impl Into<String>, value: Value) -> Self {
446        Self::new(field, "gt", value)
447    }
448
449    /// Less-than filter shorthand.
450    pub fn lt(field: impl Into<String>, value: Value) -> Self {
451        Self::new(field, "lt", value)
452    }
453
454    /// Serialise to a JSON string for passing to the C API.
455    pub fn to_json(&self) -> String {
456        serde_json::to_string(self).unwrap_or_else(|_| "{}".to_owned())
457    }
458}
459
460/// Slim info struct returned by `Client::list_vector_collections`.
461#[derive(Debug, Clone)]
462pub struct VectorCollectionInfo {
463    /// Collection name.
464    pub name: String,
465    /// Number of vectors in the collection.
466    pub count: usize,
467}
468
469// ---------------------------------------------------------------------------
470// @group UnitTests : Vector type construction, serialisation and parsing
471// ---------------------------------------------------------------------------
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use serde_json::json;
477
478    // --- Distance ---
479
480    #[test]
481    fn distance_as_str() {
482        assert_eq!(Distance::Cosine.as_str(), "cosine");
483        assert_eq!(Distance::Euclidean.as_str(), "euclidean");
484        assert_eq!(Distance::DotProduct.as_str(), "dot_product");
485        assert_eq!(Distance::Manhattan.as_str(), "manhattan");
486    }
487
488    #[test]
489    fn distance_display() {
490        assert_eq!(Distance::Cosine.to_string(), "cosine");
491    }
492
493    // --- CompressionMode ---
494
495    #[test]
496    fn compression_mode_as_str() {
497        assert_eq!(CompressionMode::None.as_str(), "none");
498        assert_eq!(CompressionMode::Delta.as_str(), "delta");
499        assert_eq!(CompressionMode::QuantizedDelta.as_str(), "quantized_delta");
500    }
501
502    // --- VectorConfig builder ---
503
504    #[test]
505    fn vector_config_defaults() {
506        let cfg = VectorConfig::new(128);
507        assert_eq!(cfg.dimensions, 128);
508        assert_eq!(cfg.distance, Distance::Cosine);
509        assert_eq!(cfg.m, None);
510        assert_eq!(cfg.ef_construction, None);
511        assert_eq!(cfg.ef_search, None);
512        assert!(!cfg.lazy_embedding);
513    }
514
515    #[test]
516    fn vector_config_builder_chain() {
517        let cfg = VectorConfig::new(256)
518            .with_distance(Distance::Euclidean)
519            .with_m(32)
520            .with_ef_construction(400)
521            .with_ef_search(80);
522        assert_eq!(cfg.dimensions, 256);
523        assert_eq!(cfg.distance, Distance::Euclidean);
524        assert_eq!(cfg.m, Some(32));
525        assert_eq!(cfg.ef_construction, Some(400));
526        assert_eq!(cfg.ef_search, Some(80));
527    }
528
529    #[test]
530    fn vector_config_lazy_embedding() {
531        let cfg = VectorConfig::new(768).with_lazy_embedding("my-model");
532        assert!(cfg.lazy_embedding);
533        assert_eq!(cfg.embedding_model.as_deref(), Some("my-model"));
534    }
535
536    #[test]
537    fn vector_config_delta_compression() {
538        let cfg = VectorConfig::new(64).with_delta_compression();
539        assert!(cfg.compression.is_some());
540        let c = cfg.compression.as_ref().unwrap();
541        assert_eq!(c.mode, Some(CompressionMode::Delta));
542    }
543
544    #[test]
545    fn vector_config_to_json_roundtrip() {
546        let cfg = VectorConfig::new(128)
547            .with_distance(Distance::DotProduct)
548            .with_m(24);
549        let s = cfg.to_json();
550        let v: Value = serde_json::from_str(&s).unwrap();
551        assert_eq!(v["dimensions"], json!(128));
552        assert_eq!(v["distance"], json!("dot_product"));
553        assert_eq!(v["m"], json!(24));
554    }
555
556    // --- MetadataFilter ---
557
558    #[test]
559    fn metadata_filter_eq_shorthand() {
560        let f = MetadataFilter::eq("category", json!("news"));
561        assert_eq!(f.field, "category");
562        assert_eq!(f.condition, "eq");
563        assert_eq!(f.value, json!("news"));
564    }
565
566    #[test]
567    fn metadata_filter_gt_shorthand() {
568        let f = MetadataFilter::gt("score", json!(0.8));
569        assert_eq!(f.condition, "gt");
570    }
571
572    #[test]
573    fn metadata_filter_lt_shorthand() {
574        let f = MetadataFilter::lt("score", json!(0.5));
575        assert_eq!(f.condition, "lt");
576    }
577
578    #[test]
579    fn metadata_filter_to_json() {
580        let f = MetadataFilter::eq("lang", json!("en"));
581        let s = f.to_json();
582        let v: Value = serde_json::from_str(&s).unwrap();
583        assert_eq!(v["field"], json!("lang"));
584        assert_eq!(v["condition"], json!("eq"));
585        assert_eq!(v["value"], json!("en"));
586    }
587
588    // --- VectorSearchResult ---
589
590    #[test]
591    fn vector_search_result_from_value() {
592        let v = json!({
593            "document": {"id": 7, "metadata": {"label": "A"}},
594            "score": 0.95,
595            "rank": 1
596        });
597        let r = VectorSearchResult::from_value(&v).unwrap();
598        assert_eq!(r.document.id, 7);
599        assert!((r.score - 0.95).abs() < 1e-6);
600        assert_eq!(r.rank, 1);
601        assert_eq!(r.document.metadata["label"], json!("A"));
602    }
603
604    #[test]
605    fn vector_search_result_display() {
606        let v = json!({
607            "document": {"id": 1, "metadata": {}},
608            "score": 0.5,
609            "rank": 1
610        });
611        let r = VectorSearchResult::from_value(&v).unwrap();
612        let s = r.to_string();
613        assert!(s.contains("id=1"));
614        assert!(s.contains("score="));
615    }
616
617    #[test]
618    fn vector_search_result_missing_document_returns_none() {
619        let v = json!({"score": 0.9, "rank": 1});
620        assert!(VectorSearchResult::from_value(&v).is_none());
621    }
622
623    // --- VectorDocument ---
624
625    #[test]
626    fn vector_document_from_value() {
627        let v = json!({"id": 3, "embedding": [0.1, 0.2, 0.3], "metadata": {"tag": "x"}});
628        let d = VectorDocument::from_value(&v).unwrap();
629        assert_eq!(d.id, 3);
630        let emb = d.embedding.unwrap();
631        assert!((emb[0] - 0.1f32).abs() < 1e-6);
632        assert_eq!(d.metadata["tag"], json!("x"));
633    }
634
635    #[test]
636    fn vector_document_missing_id_returns_none() {
637        let v = json!({"embedding": [0.1], "metadata": {}});
638        assert!(VectorDocument::from_value(&v).is_none());
639    }
640
641    #[test]
642    fn vector_document_optional_embedding() {
643        let v = json!({"id": 5, "metadata": {}});
644        let d = VectorDocument::from_value(&v).unwrap();
645        assert!(d.embedding.is_none());
646    }
647
648    // --- CompressionConfig ---
649
650    #[test]
651    fn compression_config_delta() {
652        let c = CompressionConfig::delta();
653        assert_eq!(c.mode, Some(CompressionMode::Delta));
654        assert!(c.quantization_bits.is_none());
655    }
656
657    #[test]
658    fn compression_config_quantized_delta() {
659        let c = CompressionConfig::quantized_delta();
660        assert_eq!(c.mode, Some(CompressionMode::QuantizedDelta));
661    }
662
663    #[test]
664    fn compression_config_with_quantization_bits() {
665        let c = CompressionConfig {
666            mode: Some(CompressionMode::QuantizedDelta),
667            quantization_bits: Some(8),
668            ..Default::default()
669        };
670        assert_eq!(c.mode, Some(CompressionMode::QuantizedDelta));
671        assert_eq!(c.quantization_bits, Some(8));
672    }
673}