chroma_types/
metadata.rs

1use chroma_error::{ChromaError, ErrorCodes};
2use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
3use serde_json::{Number, Value};
4use sprs::CsVec;
5use std::{
6    cmp::Ordering,
7    collections::{HashMap, HashSet},
8    mem::size_of_val,
9    ops::{BitAnd, BitOr},
10};
11use thiserror::Error;
12
13use crate::chroma_proto;
14
15#[cfg(feature = "pyo3")]
16use pyo3::types::PyAnyMethods;
17
18#[cfg(feature = "testing")]
19use proptest::prelude::*;
20
21#[derive(Serialize, Deserialize)]
22struct SparseVectorSerdeHelper {
23    #[serde(rename = "#type")]
24    type_tag: Option<String>,
25    indices: Vec<u32>,
26    values: Vec<f32>,
27}
28
29/// Represents a sparse vector using parallel arrays for indices and values.
30///
31/// On deserialization: accepts both old format `{"indices": [...], "values": [...]}`
32/// and new format `{"#type": "sparse_vector", "indices": [...], "values": [...]}`.
33///
34/// On serialization: always includes `#type` field with value `"sparse_vector"`.
35#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
37pub struct SparseVector {
38    /// Dimension indices
39    pub indices: Vec<u32>,
40    /// Values corresponding to each index
41    pub values: Vec<f32>,
42}
43
44// Custom deserializer: accept both old and new formats
45impl<'de> Deserialize<'de> for SparseVector {
46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47    where
48        D: Deserializer<'de>,
49    {
50        let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
51
52        // If #type is present, validate it
53        if let Some(type_tag) = &helper.type_tag {
54            if type_tag != "sparse_vector" {
55                return Err(serde::de::Error::custom(format!(
56                    "Expected #type='sparse_vector', got '{}'",
57                    type_tag
58                )));
59            }
60        }
61
62        Ok(SparseVector {
63            indices: helper.indices,
64            values: helper.values,
65        })
66    }
67}
68
69// Custom serializer: always include #type field
70impl Serialize for SparseVector {
71    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72    where
73        S: Serializer,
74    {
75        let helper = SparseVectorSerdeHelper {
76            type_tag: Some("sparse_vector".to_string()),
77            indices: self.indices.clone(),
78            values: self.values.clone(),
79        };
80        helper.serialize(serializer)
81    }
82}
83
84impl SparseVector {
85    /// Create a new sparse vector from parallel arrays.
86    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Self {
87        Self { indices, values }
88    }
89
90    /// Create a sparse vector from an iterator of (index, value) pairs.
91    pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
92        let (indices, values) = pairs.into_iter().unzip();
93        Self { indices, values }
94    }
95
96    /// Iterate over (index, value) pairs.
97    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
98        self.indices
99            .iter()
100            .copied()
101            .zip(self.values.iter().copied())
102    }
103
104    /// Validate the sparse vector
105    pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
106        // Check that indices and values have the same length
107        if self.indices.len() != self.values.len() {
108            return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
109        }
110
111        // Check that indices are sorted in strictly ascending order (no duplicates)
112        for i in 1..self.indices.len() {
113            if self.indices[i] <= self.indices[i - 1] {
114                return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
115            }
116        }
117
118        Ok(())
119    }
120}
121
122impl Eq for SparseVector {}
123
124impl Ord for SparseVector {
125    fn cmp(&self, other: &Self) -> Ordering {
126        self.indices.cmp(&other.indices).then_with(|| {
127            for (a, b) in self.values.iter().zip(other.values.iter()) {
128                match a.total_cmp(b) {
129                    Ordering::Equal => continue,
130                    other => return other,
131                }
132            }
133            self.values.len().cmp(&other.values.len())
134        })
135    }
136}
137
138impl PartialOrd for SparseVector {
139    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144impl From<chroma_proto::SparseVector> for SparseVector {
145    fn from(proto: chroma_proto::SparseVector) -> Self {
146        SparseVector::new(proto.indices, proto.values)
147    }
148}
149
150impl From<SparseVector> for chroma_proto::SparseVector {
151    fn from(sparse: SparseVector) -> Self {
152        chroma_proto::SparseVector {
153            indices: sparse.indices,
154            values: sparse.values,
155        }
156    }
157}
158
159/// Convert SparseVector to sprs::CsVec for efficient sparse operations
160impl From<&SparseVector> for CsVec<f32> {
161    fn from(sparse: &SparseVector) -> Self {
162        let (indices, values) = sparse
163            .iter()
164            .map(|(index, value)| (index as usize, value))
165            .unzip();
166        CsVec::new(u32::MAX as usize, indices, values)
167    }
168}
169
170impl From<SparseVector> for CsVec<f32> {
171    fn from(sparse: SparseVector) -> Self {
172        (&sparse).into()
173    }
174}
175
176#[cfg(feature = "pyo3")]
177impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
178    type Target = pyo3::PyAny;
179    type Output = pyo3::Bound<'py, Self::Target>;
180    type Error = pyo3::PyErr;
181
182    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
183        use pyo3::types::PyDict;
184        let dict = PyDict::new(py);
185        dict.set_item("indices", self.indices)?;
186        dict.set_item("values", self.values)?;
187        Ok(dict.into_any())
188    }
189}
190
191#[cfg(feature = "pyo3")]
192impl<'py> pyo3::FromPyObject<'py> for SparseVector {
193    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
194        use pyo3::types::PyDict;
195
196        let dict = ob.downcast::<PyDict>()?;
197        let indices_obj = dict.get_item("indices")?;
198        let values_obj = dict.get_item("values")?;
199
200        let indices: Vec<u32> = indices_obj.extract()?;
201        let values: Vec<f32> = values_obj.extract()?;
202
203        Ok(SparseVector::new(indices, values))
204    }
205}
206
207#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
208#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
209#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
210#[serde(untagged)]
211pub enum UpdateMetadataValue {
212    Bool(bool),
213    Int(i64),
214    #[cfg_attr(
215        feature = "testing",
216        proptest(
217            strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
218        )
219    )]
220    Float(f64),
221    Str(String),
222    #[cfg_attr(feature = "testing", proptest(skip))]
223    SparseVector(SparseVector),
224    None,
225}
226
227#[cfg(feature = "pyo3")]
228impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
229    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
230        if ob.is_none() {
231            Ok(UpdateMetadataValue::None)
232        } else if let Ok(value) = ob.extract::<bool>() {
233            Ok(UpdateMetadataValue::Bool(value))
234        } else if let Ok(value) = ob.extract::<i64>() {
235            Ok(UpdateMetadataValue::Int(value))
236        } else if let Ok(value) = ob.extract::<f64>() {
237            Ok(UpdateMetadataValue::Float(value))
238        } else if let Ok(value) = ob.extract::<String>() {
239            Ok(UpdateMetadataValue::Str(value))
240        } else if let Ok(value) = ob.extract::<SparseVector>() {
241            Ok(UpdateMetadataValue::SparseVector(value))
242        } else {
243            Err(pyo3::exceptions::PyTypeError::new_err(
244                "Cannot convert Python object to UpdateMetadataValue",
245            ))
246        }
247    }
248}
249
250impl From<bool> for UpdateMetadataValue {
251    fn from(b: bool) -> Self {
252        Self::Bool(b)
253    }
254}
255
256impl From<i64> for UpdateMetadataValue {
257    fn from(v: i64) -> Self {
258        Self::Int(v)
259    }
260}
261
262impl From<i32> for UpdateMetadataValue {
263    fn from(v: i32) -> Self {
264        Self::Int(v as i64)
265    }
266}
267
268impl From<f64> for UpdateMetadataValue {
269    fn from(v: f64) -> Self {
270        Self::Float(v)
271    }
272}
273
274impl From<f32> for UpdateMetadataValue {
275    fn from(v: f32) -> Self {
276        Self::Float(v as f64)
277    }
278}
279
280impl From<String> for UpdateMetadataValue {
281    fn from(v: String) -> Self {
282        Self::Str(v)
283    }
284}
285
286impl From<&str> for UpdateMetadataValue {
287    fn from(v: &str) -> Self {
288        Self::Str(v.to_string())
289    }
290}
291
292impl From<SparseVector> for UpdateMetadataValue {
293    fn from(v: SparseVector) -> Self {
294        Self::SparseVector(v)
295    }
296}
297
298#[derive(Error, Debug)]
299pub enum UpdateMetadataValueConversionError {
300    #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
301    InvalidValue,
302}
303
304impl ChromaError for UpdateMetadataValueConversionError {
305    fn code(&self) -> ErrorCodes {
306        match self {
307            UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
308        }
309    }
310}
311
312impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
313    type Error = UpdateMetadataValueConversionError;
314
315    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
316        match &value.value {
317            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
318                Ok(UpdateMetadataValue::Bool(*value))
319            }
320            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
321                Ok(UpdateMetadataValue::Int(*value))
322            }
323            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
324                Ok(UpdateMetadataValue::Float(*value))
325            }
326            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
327                Ok(UpdateMetadataValue::Str(value.clone()))
328            }
329            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
330                Ok(UpdateMetadataValue::SparseVector(value.clone().into()))
331            }
332            // Used to communicate that the user wants to delete this key.
333            None => Ok(UpdateMetadataValue::None),
334        }
335    }
336}
337
338impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
339    fn from(value: UpdateMetadataValue) -> Self {
340        match value {
341            UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
342                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
343            },
344            UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
345                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
346            },
347            UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
348                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
349                    value,
350                )),
351            },
352            UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
353                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
354                    value,
355                )),
356            },
357            UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
358                value: Some(
359                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
360                        sparse_vec.into(),
361                    ),
362                ),
363            },
364            UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
365        }
366    }
367}
368
369impl TryFrom<&UpdateMetadataValue> for MetadataValue {
370    type Error = MetadataValueConversionError;
371
372    fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
373        match value {
374            UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
375            UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
376            UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
377            UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
378            UpdateMetadataValue::SparseVector(value) => {
379                Ok(MetadataValue::SparseVector(value.clone()))
380            }
381            UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
382        }
383    }
384}
385
386/*
387===========================================
388MetadataValue
389===========================================
390*/
391
392#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
393#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
394#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
395#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
396#[serde(untagged)]
397pub enum MetadataValue {
398    Bool(bool),
399    Int(i64),
400    #[cfg_attr(
401        feature = "testing",
402        proptest(
403            strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
404        )
405    )]
406    Float(f64),
407    Str(String),
408    #[cfg_attr(feature = "testing", proptest(skip))]
409    SparseVector(SparseVector),
410}
411
412impl Eq for MetadataValue {}
413
414#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
415pub enum MetadataValueType {
416    Bool,
417    Int,
418    Float,
419    Str,
420    SparseVector,
421}
422
423impl MetadataValue {
424    pub fn value_type(&self) -> MetadataValueType {
425        match self {
426            MetadataValue::Bool(_) => MetadataValueType::Bool,
427            MetadataValue::Int(_) => MetadataValueType::Int,
428            MetadataValue::Float(_) => MetadataValueType::Float,
429            MetadataValue::Str(_) => MetadataValueType::Str,
430            MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
431        }
432    }
433}
434
435impl From<&MetadataValue> for MetadataValueType {
436    fn from(value: &MetadataValue) -> Self {
437        value.value_type()
438    }
439}
440
441impl From<bool> for MetadataValue {
442    fn from(v: bool) -> Self {
443        MetadataValue::Bool(v)
444    }
445}
446
447impl From<i64> for MetadataValue {
448    fn from(v: i64) -> Self {
449        MetadataValue::Int(v)
450    }
451}
452
453impl From<i32> for MetadataValue {
454    fn from(v: i32) -> Self {
455        MetadataValue::Int(v as i64)
456    }
457}
458
459impl From<f64> for MetadataValue {
460    fn from(v: f64) -> Self {
461        MetadataValue::Float(v)
462    }
463}
464
465impl From<f32> for MetadataValue {
466    fn from(v: f32) -> Self {
467        MetadataValue::Float(v as f64)
468    }
469}
470
471impl From<String> for MetadataValue {
472    fn from(v: String) -> Self {
473        MetadataValue::Str(v)
474    }
475}
476
477impl From<&str> for MetadataValue {
478    fn from(v: &str) -> Self {
479        MetadataValue::Str(v.to_string())
480    }
481}
482
483impl From<SparseVector> for MetadataValue {
484    fn from(v: SparseVector) -> Self {
485        MetadataValue::SparseVector(v)
486    }
487}
488
489/// We need `Eq` and `Ord` since we want to use this as a key in `BTreeMap`
490///
491/// For cross-type comparisons, we define a consistent ordering based on variant position:
492/// Bool < Int < Float < Str < SparseVector
493#[allow(clippy::derive_ord_xor_partial_ord)]
494impl Ord for MetadataValue {
495    fn cmp(&self, other: &Self) -> Ordering {
496        // Define type ordering based on variant position
497        fn type_order(val: &MetadataValue) -> u8 {
498            match val {
499                MetadataValue::Bool(_) => 0,
500                MetadataValue::Int(_) => 1,
501                MetadataValue::Float(_) => 2,
502                MetadataValue::Str(_) => 3,
503                MetadataValue::SparseVector(_) => 4,
504            }
505        }
506
507        // Chain type ordering with value ordering
508        type_order(self).cmp(&type_order(other)).then_with(|| {
509            match (self, other) {
510                (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
511                (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
512                (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
513                (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
514                (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
515                    left.cmp(right)
516                }
517                _ => Ordering::Equal, // Different types, but type_order already handled this
518            }
519        })
520    }
521}
522
523impl PartialOrd for MetadataValue {
524    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
525        Some(self.cmp(other))
526    }
527}
528
529impl TryFrom<&MetadataValue> for bool {
530    type Error = MetadataValueConversionError;
531
532    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
533        match value {
534            MetadataValue::Bool(value) => Ok(*value),
535            _ => Err(MetadataValueConversionError::InvalidValue),
536        }
537    }
538}
539
540impl TryFrom<&MetadataValue> for i64 {
541    type Error = MetadataValueConversionError;
542
543    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
544        match value {
545            MetadataValue::Int(value) => Ok(*value),
546            _ => Err(MetadataValueConversionError::InvalidValue),
547        }
548    }
549}
550
551impl TryFrom<&MetadataValue> for f64 {
552    type Error = MetadataValueConversionError;
553
554    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
555        match value {
556            MetadataValue::Float(value) => Ok(*value),
557            _ => Err(MetadataValueConversionError::InvalidValue),
558        }
559    }
560}
561
562impl TryFrom<&MetadataValue> for String {
563    type Error = MetadataValueConversionError;
564
565    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
566        match value {
567            MetadataValue::Str(value) => Ok(value.clone()),
568            _ => Err(MetadataValueConversionError::InvalidValue),
569        }
570    }
571}
572
573impl From<MetadataValue> for UpdateMetadataValue {
574    fn from(value: MetadataValue) -> Self {
575        match value {
576            MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
577            MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
578            MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
579            MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
580            MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
581        }
582    }
583}
584
585impl From<MetadataValue> for Value {
586    fn from(value: MetadataValue) -> Self {
587        match value {
588            MetadataValue::Bool(val) => Self::Bool(val),
589            MetadataValue::Int(val) => Self::Number(
590                Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
591            ),
592            MetadataValue::Float(val) => Self::Number(
593                Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
594            ),
595            MetadataValue::Str(val) => Self::String(val),
596            MetadataValue::SparseVector(val) => {
597                let mut map = serde_json::Map::new();
598                map.insert(
599                    "indices".to_string(),
600                    Value::Array(
601                        val.indices
602                            .iter()
603                            .map(|&i| Value::Number(i.into()))
604                            .collect(),
605                    ),
606                );
607                map.insert(
608                    "values".to_string(),
609                    Value::Array(
610                        val.values
611                            .iter()
612                            .map(|&v| {
613                                Value::Number(
614                                    Number::from_f64(v as f64)
615                                        .expect("Float number should not be NaN or infinite"),
616                                )
617                            })
618                            .collect(),
619                    ),
620                );
621                Self::Object(map)
622            }
623        }
624    }
625}
626
627#[derive(Error, Debug)]
628pub enum MetadataValueConversionError {
629    #[error("Invalid metadata value, valid values are: Int, Float, Str")]
630    InvalidValue,
631    #[error("Metadata key cannot start with '#' or '$': {0}")]
632    InvalidKey(String),
633    #[error("Sparse vector indices and values must have the same length")]
634    SparseVectorLengthMismatch,
635    #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
636    SparseVectorIndicesNotSorted,
637}
638
639impl ChromaError for MetadataValueConversionError {
640    fn code(&self) -> ErrorCodes {
641        match self {
642            MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
643            MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
644            MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
645            MetadataValueConversionError::SparseVectorIndicesNotSorted => {
646                ErrorCodes::InvalidArgument
647            }
648        }
649    }
650}
651
652impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
653    type Error = MetadataValueConversionError;
654
655    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
656        match &value.value {
657            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
658                Ok(MetadataValue::Bool(*value))
659            }
660            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
661                Ok(MetadataValue::Int(*value))
662            }
663            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
664                Ok(MetadataValue::Float(*value))
665            }
666            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
667                Ok(MetadataValue::Str(value.clone()))
668            }
669            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
670                Ok(MetadataValue::SparseVector(value.clone().into()))
671            }
672            _ => Err(MetadataValueConversionError::InvalidValue),
673        }
674    }
675}
676
677impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
678    fn from(value: MetadataValue) -> Self {
679        match value {
680            MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
681                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
682            },
683            MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
684                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
685                    value,
686                )),
687            },
688            MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
689                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
690                    value,
691                )),
692            },
693            MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
694                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
695            },
696            MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
697                value: Some(
698                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
699                        sparse_vec.into(),
700                    ),
701                ),
702            },
703        }
704    }
705}
706
707/*
708===========================================
709UpdateMetadata
710===========================================
711*/
712pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
713
714/**
715 * Check if two metadata are close to equal. Ignores small differences in float values.
716 */
717pub fn are_update_metadatas_close_to_equal(
718    metadata1: &UpdateMetadata,
719    metadata2: &UpdateMetadata,
720) -> bool {
721    assert_eq!(metadata1.len(), metadata2.len());
722
723    for (key, value) in metadata1.iter() {
724        if !metadata2.contains_key(key) {
725            return false;
726        }
727        let other_value = metadata2.get(key).unwrap();
728
729        if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
730            (value, other_value)
731        {
732            if (value - other_value).abs() > 1e-6 {
733                return false;
734            }
735        } else if value != other_value {
736            return false;
737        }
738    }
739
740    true
741}
742
743pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
744    assert_eq!(metadata1.len(), metadata2.len());
745
746    for (key, value) in metadata1.iter() {
747        if !metadata2.contains_key(key) {
748            return false;
749        }
750        let other_value = metadata2.get(key).unwrap();
751
752        if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
753            (value, other_value)
754        {
755            if (value - other_value).abs() > 1e-6 {
756                return false;
757            }
758        } else if value != other_value {
759            return false;
760        }
761    }
762
763    true
764}
765
766impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
767    type Error = UpdateMetadataValueConversionError;
768
769    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
770        let mut metadata = UpdateMetadata::new();
771        for (key, value) in proto_metadata.metadata.iter() {
772            let value = match value.try_into() {
773                Ok(value) => value,
774                Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
775            };
776            metadata.insert(key.clone(), value);
777        }
778        Ok(metadata)
779    }
780}
781
782impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
783    fn from(metadata: UpdateMetadata) -> Self {
784        let mut metadata = metadata;
785        let mut proto_metadata = chroma_proto::UpdateMetadata {
786            metadata: HashMap::new(),
787        };
788        for (key, value) in metadata.drain() {
789            let proto_value = value.into();
790            proto_metadata.metadata.insert(key.clone(), proto_value);
791        }
792        proto_metadata
793    }
794}
795
796/*
797===========================================
798Metadata
799===========================================
800*/
801
802pub type Metadata = HashMap<String, MetadataValue>;
803pub type DeletedMetadata = HashSet<String>;
804
805pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
806    metadata
807        .iter()
808        .map(|(k, v)| {
809            k.len()
810                + match v {
811                    MetadataValue::Bool(b) => size_of_val(b),
812                    MetadataValue::Int(i) => size_of_val(i),
813                    MetadataValue::Float(f) => size_of_val(f),
814                    MetadataValue::Str(s) => s.len(),
815                    MetadataValue::SparseVector(v) => {
816                        size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
817                    }
818                }
819        })
820        .sum()
821}
822
823pub fn get_metadata_value_as<'a, T>(
824    metadata: &'a Metadata,
825    key: &str,
826) -> Result<T, Box<MetadataValueConversionError>>
827where
828    T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
829{
830    let res = match metadata.get(key) {
831        Some(value) => T::try_from(value),
832        None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
833    };
834    match res {
835        Ok(value) => Ok(value),
836        Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
837    }
838}
839
840impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
841    type Error = MetadataValueConversionError;
842
843    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
844        let mut metadata = Metadata::new();
845        for (key, value) in proto_metadata.metadata.iter() {
846            let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
847            if maybe_value.is_err() {
848                return Err(MetadataValueConversionError::InvalidValue);
849            }
850            let value = maybe_value.unwrap();
851            metadata.insert(key.clone(), value);
852        }
853        Ok(metadata)
854    }
855}
856
857impl From<Metadata> for chroma_proto::UpdateMetadata {
858    fn from(metadata: Metadata) -> Self {
859        let mut metadata = metadata;
860        let mut proto_metadata = chroma_proto::UpdateMetadata {
861            metadata: HashMap::new(),
862        };
863        for (key, value) in metadata.drain() {
864            let proto_value = value.into();
865            proto_metadata.metadata.insert(key.clone(), proto_value);
866        }
867        proto_metadata
868    }
869}
870
871#[derive(Debug, Default)]
872pub struct MetadataDelta<'referred_data> {
873    pub metadata_to_update: HashMap<
874        &'referred_data str,
875        (&'referred_data MetadataValue, &'referred_data MetadataValue),
876    >,
877    pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
878    pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
879}
880
881impl MetadataDelta<'_> {
882    pub fn new() -> Self {
883        Self::default()
884    }
885}
886
887/*
888===========================================
889Metadata queries
890===========================================
891*/
892
893#[derive(Clone, Debug, Error, PartialEq)]
894pub enum WhereConversionError {
895    #[error("Error: {0}")]
896    Cause(String),
897    #[error("{0} -> {1}")]
898    Trace(String, Box<Self>),
899}
900
901impl WhereConversionError {
902    pub fn cause(msg: impl ToString) -> Self {
903        Self::Cause(msg.to_string())
904    }
905
906    pub fn trace(self, context: impl ToString) -> Self {
907        Self::Trace(context.to_string(), Box::new(self))
908    }
909}
910
911/// This `Where` enum serves as an unified representation for the `where` and `where_document` clauses.
912/// Although this is not unified in the API level due to legacy design choices, in the future we will be
913/// unifying them together, and the structure of the unified AST should be identical to the one here.
914/// Currently both `where` and `where_document` clauses will be translated into `Where`, and if both are
915/// present we simply create a conjunction of both clauses as the actual filter. This is consistent with
916/// the semantics we used to have when the `where` and `where_document` clauses are treated seperately.
917// TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level.
918#[derive(Clone, Debug, PartialEq)]
919#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
920pub enum Where {
921    Composite(CompositeExpression),
922    Document(DocumentExpression),
923    Metadata(MetadataExpression),
924}
925
926impl serde::Serialize for Where {
927    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
928    where
929        S: Serializer,
930    {
931        match self {
932            Where::Composite(composite) => {
933                let mut map = serializer.serialize_map(Some(1))?;
934                let op_key = match composite.operator {
935                    BooleanOperator::And => "$and",
936                    BooleanOperator::Or => "$or",
937                };
938                map.serialize_entry(op_key, &composite.children)?;
939                map.end()
940            }
941            Where::Document(doc) => {
942                let mut outer_map = serializer.serialize_map(Some(1))?;
943                let mut inner_map = serde_json::Map::new();
944                let op_key = match doc.operator {
945                    DocumentOperator::Contains => "$contains",
946                    DocumentOperator::NotContains => "$not_contains",
947                    DocumentOperator::Regex => "$regex",
948                    DocumentOperator::NotRegex => "$not_regex",
949                };
950                inner_map.insert(
951                    op_key.to_string(),
952                    serde_json::Value::String(doc.pattern.clone()),
953                );
954                outer_map.serialize_entry("#document", &inner_map)?;
955                outer_map.end()
956            }
957            Where::Metadata(meta) => {
958                let mut outer_map = serializer.serialize_map(Some(1))?;
959                let mut inner_map = serde_json::Map::new();
960
961                match &meta.comparison {
962                    MetadataComparison::Primitive(op, value) => {
963                        let op_key = match op {
964                            PrimitiveOperator::Equal => "$eq",
965                            PrimitiveOperator::NotEqual => "$ne",
966                            PrimitiveOperator::GreaterThan => "$gt",
967                            PrimitiveOperator::GreaterThanOrEqual => "$gte",
968                            PrimitiveOperator::LessThan => "$lt",
969                            PrimitiveOperator::LessThanOrEqual => "$lte",
970                        };
971                        let value_json =
972                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
973                        inner_map.insert(op_key.to_string(), value_json);
974                    }
975                    MetadataComparison::Set(op, set_value) => {
976                        let op_key = match op {
977                            SetOperator::In => "$in",
978                            SetOperator::NotIn => "$nin",
979                        };
980                        let values_json = match set_value {
981                            MetadataSetValue::Bool(v) => serde_json::to_value(v),
982                            MetadataSetValue::Int(v) => serde_json::to_value(v),
983                            MetadataSetValue::Float(v) => serde_json::to_value(v),
984                            MetadataSetValue::Str(v) => serde_json::to_value(v),
985                        }
986                        .map_err(serde::ser::Error::custom)?;
987                        inner_map.insert(op_key.to_string(), values_json);
988                    }
989                }
990
991                outer_map.serialize_entry(&meta.key, &inner_map)?;
992                outer_map.end()
993            }
994        }
995    }
996}
997
998impl Where {
999    pub fn conjunction(children: Vec<Where>) -> Self {
1000        Self::Composite(CompositeExpression {
1001            operator: BooleanOperator::And,
1002            children,
1003        })
1004    }
1005    pub fn disjunction(children: Vec<Where>) -> Self {
1006        Self::Composite(CompositeExpression {
1007            operator: BooleanOperator::Or,
1008            children,
1009        })
1010    }
1011
1012    pub fn fts_query_length(&self) -> u64 {
1013        match self {
1014            Where::Composite(composite_expression) => composite_expression
1015                .children
1016                .iter()
1017                .map(Where::fts_query_length)
1018                .sum(),
1019            // The query length is defined to be the number of trigram tokens
1020            Where::Document(document_expression) => {
1021                document_expression.pattern.len().max(3) as u64 - 2
1022            }
1023            Where::Metadata(_) => 0,
1024        }
1025    }
1026
1027    pub fn metadata_predicate_count(&self) -> u64 {
1028        match self {
1029            Where::Composite(composite_expression) => composite_expression
1030                .children
1031                .iter()
1032                .map(Where::metadata_predicate_count)
1033                .sum(),
1034            Where::Document(_) => 0,
1035            Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1036                MetadataComparison::Primitive(_, _) => 1,
1037                MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1038                    MetadataSetValue::Bool(items) => items.len() as u64,
1039                    MetadataSetValue::Int(items) => items.len() as u64,
1040                    MetadataSetValue::Float(items) => items.len() as u64,
1041                    MetadataSetValue::Str(items) => items.len() as u64,
1042                },
1043            },
1044        }
1045    }
1046}
1047
1048impl BitAnd for Where {
1049    type Output = Where;
1050
1051    fn bitand(self, rhs: Self) -> Self::Output {
1052        match self {
1053            Where::Composite(CompositeExpression {
1054                operator: BooleanOperator::And,
1055                mut children,
1056            }) => match rhs {
1057                Where::Composite(CompositeExpression {
1058                    operator: BooleanOperator::And,
1059                    children: rhs_children,
1060                }) => {
1061                    children.extend(rhs_children);
1062                    Where::Composite(CompositeExpression {
1063                        operator: BooleanOperator::And,
1064                        children,
1065                    })
1066                }
1067                _ => {
1068                    children.push(rhs);
1069                    Where::Composite(CompositeExpression {
1070                        operator: BooleanOperator::And,
1071                        children,
1072                    })
1073                }
1074            },
1075            _ => match rhs {
1076                Where::Composite(CompositeExpression {
1077                    operator: BooleanOperator::And,
1078                    mut children,
1079                }) => {
1080                    children.insert(0, self);
1081                    Where::Composite(CompositeExpression {
1082                        operator: BooleanOperator::And,
1083                        children,
1084                    })
1085                }
1086                _ => Where::conjunction(vec![self, rhs]),
1087            },
1088        }
1089    }
1090}
1091
1092impl BitOr for Where {
1093    type Output = Where;
1094
1095    fn bitor(self, rhs: Self) -> Self::Output {
1096        match self {
1097            Where::Composite(CompositeExpression {
1098                operator: BooleanOperator::Or,
1099                mut children,
1100            }) => match rhs {
1101                Where::Composite(CompositeExpression {
1102                    operator: BooleanOperator::Or,
1103                    children: rhs_children,
1104                }) => {
1105                    children.extend(rhs_children);
1106                    Where::Composite(CompositeExpression {
1107                        operator: BooleanOperator::Or,
1108                        children,
1109                    })
1110                }
1111                _ => {
1112                    children.push(rhs);
1113                    Where::Composite(CompositeExpression {
1114                        operator: BooleanOperator::Or,
1115                        children,
1116                    })
1117                }
1118            },
1119            _ => match rhs {
1120                Where::Composite(CompositeExpression {
1121                    operator: BooleanOperator::Or,
1122                    mut children,
1123                }) => {
1124                    children.insert(0, self);
1125                    Where::Composite(CompositeExpression {
1126                        operator: BooleanOperator::Or,
1127                        children,
1128                    })
1129                }
1130                _ => Where::disjunction(vec![self, rhs]),
1131            },
1132        }
1133    }
1134}
1135
1136impl TryFrom<chroma_proto::Where> for Where {
1137    type Error = WhereConversionError;
1138
1139    fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1140        let where_inner = proto_where
1141            .r#where
1142            .ok_or(WhereConversionError::cause("Invalid Where"))?;
1143        Ok(match where_inner {
1144            chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1145                Self::Metadata(direct_comparison.try_into()?)
1146            }
1147            chroma_proto::r#where::Where::Children(where_children) => {
1148                Self::Composite(where_children.try_into()?)
1149            }
1150            chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1151                Self::Document(direct_where_document.into())
1152            }
1153        })
1154    }
1155}
1156
1157impl TryFrom<Where> for chroma_proto::Where {
1158    type Error = WhereConversionError;
1159
1160    fn try_from(value: Where) -> Result<Self, Self::Error> {
1161        let proto_where = match value {
1162            Where::Composite(composite_expression) => {
1163                chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1164            }
1165            Where::Document(document_expression) => {
1166                chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1167            }
1168            Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1169                chroma_proto::DirectComparison::try_from(metadata_expression)
1170                    .map_err(|err| err.trace("MetadataExpression"))?,
1171            ),
1172        };
1173        Ok(Self {
1174            r#where: Some(proto_where),
1175        })
1176    }
1177}
1178
1179#[derive(Clone, Debug, PartialEq)]
1180#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1181pub struct CompositeExpression {
1182    pub operator: BooleanOperator,
1183    pub children: Vec<Where>,
1184}
1185
1186impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1187    type Error = WhereConversionError;
1188
1189    fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1190        let operator = proto_children.operator().into();
1191        let children = proto_children
1192            .children
1193            .into_iter()
1194            .map(Where::try_from)
1195            .collect::<Result<Vec<_>, _>>()
1196            .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1197        Ok(Self { operator, children })
1198    }
1199}
1200
1201impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1202    type Error = WhereConversionError;
1203
1204    fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1205        Ok(Self {
1206            operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1207            children: value
1208                .children
1209                .into_iter()
1210                .map(chroma_proto::Where::try_from)
1211                .collect::<Result<_, _>>()?,
1212        })
1213    }
1214}
1215
1216#[derive(Clone, Debug, PartialEq)]
1217#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1218pub enum BooleanOperator {
1219    And,
1220    Or,
1221}
1222
1223impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1224    fn from(value: chroma_proto::BooleanOperator) -> Self {
1225        match value {
1226            chroma_proto::BooleanOperator::And => Self::And,
1227            chroma_proto::BooleanOperator::Or => Self::Or,
1228        }
1229    }
1230}
1231
1232impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1233    fn from(value: BooleanOperator) -> Self {
1234        match value {
1235            BooleanOperator::And => Self::And,
1236            BooleanOperator::Or => Self::Or,
1237        }
1238    }
1239}
1240
1241#[derive(Clone, Debug, PartialEq)]
1242#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1243pub struct DocumentExpression {
1244    pub operator: DocumentOperator,
1245    pub pattern: String,
1246}
1247
1248impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1249    fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1250        Self {
1251            operator: value.operator().into(),
1252            pattern: value.pattern,
1253        }
1254    }
1255}
1256
1257impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1258    fn from(value: DocumentExpression) -> Self {
1259        Self {
1260            pattern: value.pattern,
1261            operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1262        }
1263    }
1264}
1265
1266#[derive(Clone, Debug, PartialEq)]
1267#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1268pub enum DocumentOperator {
1269    Contains,
1270    NotContains,
1271    Regex,
1272    NotRegex,
1273}
1274impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1275    fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1276        match value {
1277            chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1278            chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1279            chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1280            chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1281        }
1282    }
1283}
1284
1285impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1286    fn from(value: DocumentOperator) -> Self {
1287        match value {
1288            DocumentOperator::Contains => Self::Contains,
1289            DocumentOperator::NotContains => Self::NotContains,
1290            DocumentOperator::Regex => Self::Regex,
1291            DocumentOperator::NotRegex => Self::NotRegex,
1292        }
1293    }
1294}
1295
1296#[derive(Clone, Debug, PartialEq)]
1297#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1298pub struct MetadataExpression {
1299    pub key: String,
1300    pub comparison: MetadataComparison,
1301}
1302
1303impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1304    type Error = WhereConversionError;
1305
1306    fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1307        let proto_comparison = value
1308            .comparison
1309            .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1310        let comparison = match proto_comparison {
1311            chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1312                single_string_comparison,
1313            ) => MetadataComparison::Primitive(
1314                single_string_comparison.comparator().into(),
1315                MetadataValue::Str(single_string_comparison.value),
1316            ),
1317            chroma_proto::direct_comparison::Comparison::StringListOperand(
1318                string_list_comparison,
1319            ) => MetadataComparison::Set(
1320                string_list_comparison.list_operator().into(),
1321                MetadataSetValue::Str(string_list_comparison.values),
1322            ),
1323            chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1324                single_int_comparison,
1325            ) => MetadataComparison::Primitive(
1326                match single_int_comparison
1327                    .comparator
1328                    .ok_or(WhereConversionError::cause(
1329                        "Invalid scalar integer operator",
1330                    ))? {
1331                    chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1332                        chroma_proto::GenericComparator::try_from(op)
1333                            .map_err(WhereConversionError::cause)?
1334                            .into()
1335                    }
1336                    chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1337                        chroma_proto::NumberComparator::try_from(op)
1338                            .map_err(WhereConversionError::cause)?
1339                            .into()
1340                    }
1341                },
1342                MetadataValue::Int(single_int_comparison.value),
1343            ),
1344            chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1345                MetadataComparison::Set(
1346                    int_list_comparison.list_operator().into(),
1347                    MetadataSetValue::Int(int_list_comparison.values),
1348                )
1349            }
1350            chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1351                single_double_comparison,
1352            ) => MetadataComparison::Primitive(
1353                match single_double_comparison
1354                    .comparator
1355                    .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1356                {
1357                    chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1358                        chroma_proto::GenericComparator::try_from(op)
1359                            .map_err(WhereConversionError::cause)?
1360                            .into()
1361                    }
1362                    chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1363                        chroma_proto::NumberComparator::try_from(op)
1364                            .map_err(WhereConversionError::cause)?
1365                            .into()
1366                    }
1367                },
1368                MetadataValue::Float(single_double_comparison.value),
1369            ),
1370            chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1371                double_list_comparison,
1372            ) => MetadataComparison::Set(
1373                double_list_comparison.list_operator().into(),
1374                MetadataSetValue::Float(double_list_comparison.values),
1375            ),
1376            chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1377                MetadataComparison::Set(
1378                    bool_list_comparison.list_operator().into(),
1379                    MetadataSetValue::Bool(bool_list_comparison.values),
1380                )
1381            }
1382            chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1383                single_bool_comparison,
1384            ) => MetadataComparison::Primitive(
1385                single_bool_comparison.comparator().into(),
1386                MetadataValue::Bool(single_bool_comparison.value),
1387            ),
1388        };
1389        Ok(Self {
1390            key: value.key,
1391            comparison,
1392        })
1393    }
1394}
1395
1396impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1397    type Error = WhereConversionError;
1398
1399    fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1400        let comparison = match value.comparison {
1401            MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1402                MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1403                MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1404                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1405                                numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1406                            }),
1407                MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1408                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1409                                numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1410                            }),
1411                MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1412                MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1413            },
1414            MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1415                MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1416                MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1417                MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1418                MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1419            },
1420        };
1421        Ok(Self {
1422            key: value.key,
1423            comparison: Some(comparison),
1424        })
1425    }
1426}
1427
1428#[derive(Clone, Debug, PartialEq)]
1429#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1430pub enum MetadataComparison {
1431    Primitive(PrimitiveOperator, MetadataValue),
1432    Set(SetOperator, MetadataSetValue),
1433}
1434
1435#[derive(Clone, Debug, PartialEq)]
1436#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1437pub enum PrimitiveOperator {
1438    Equal,
1439    NotEqual,
1440    GreaterThan,
1441    GreaterThanOrEqual,
1442    LessThan,
1443    LessThanOrEqual,
1444}
1445
1446impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1447    fn from(value: chroma_proto::GenericComparator) -> Self {
1448        match value {
1449            chroma_proto::GenericComparator::Eq => Self::Equal,
1450            chroma_proto::GenericComparator::Ne => Self::NotEqual,
1451        }
1452    }
1453}
1454
1455impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1456    type Error = WhereConversionError;
1457
1458    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1459        match value {
1460            PrimitiveOperator::Equal => Ok(Self::Eq),
1461            PrimitiveOperator::NotEqual => Ok(Self::Ne),
1462            op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1463        }
1464    }
1465}
1466
1467impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1468    fn from(value: chroma_proto::NumberComparator) -> Self {
1469        match value {
1470            chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1471            chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1472            chroma_proto::NumberComparator::Lt => Self::LessThan,
1473            chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1474        }
1475    }
1476}
1477
1478impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1479    type Error = WhereConversionError;
1480
1481    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1482        match value {
1483            PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1484            PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1485            PrimitiveOperator::LessThan => Ok(Self::Lt),
1486            PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1487            op => Err(WhereConversionError::cause(format!(
1488                "{op:?} ∉ [≤, <, >, ≥]"
1489            ))),
1490        }
1491    }
1492}
1493
1494#[derive(Clone, Debug, PartialEq)]
1495#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1496pub enum SetOperator {
1497    In,
1498    NotIn,
1499}
1500
1501impl From<chroma_proto::ListOperator> for SetOperator {
1502    fn from(value: chroma_proto::ListOperator) -> Self {
1503        match value {
1504            chroma_proto::ListOperator::In => Self::In,
1505            chroma_proto::ListOperator::Nin => Self::NotIn,
1506        }
1507    }
1508}
1509
1510impl From<SetOperator> for chroma_proto::ListOperator {
1511    fn from(value: SetOperator) -> Self {
1512        match value {
1513            SetOperator::In => Self::In,
1514            SetOperator::NotIn => Self::Nin,
1515        }
1516    }
1517}
1518
1519#[derive(Clone, Debug, PartialEq)]
1520#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1521pub enum MetadataSetValue {
1522    Bool(Vec<bool>),
1523    Int(Vec<i64>),
1524    Float(Vec<f64>),
1525    Str(Vec<String>),
1526}
1527
1528impl MetadataSetValue {
1529    pub fn value_type(&self) -> MetadataValueType {
1530        match self {
1531            MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1532            MetadataSetValue::Int(_) => MetadataValueType::Int,
1533            MetadataSetValue::Float(_) => MetadataValueType::Float,
1534            MetadataSetValue::Str(_) => MetadataValueType::Str,
1535        }
1536    }
1537}
1538
1539impl From<Vec<bool>> for MetadataSetValue {
1540    fn from(values: Vec<bool>) -> Self {
1541        MetadataSetValue::Bool(values)
1542    }
1543}
1544
1545impl From<Vec<i64>> for MetadataSetValue {
1546    fn from(values: Vec<i64>) -> Self {
1547        MetadataSetValue::Int(values)
1548    }
1549}
1550
1551impl From<Vec<i32>> for MetadataSetValue {
1552    fn from(values: Vec<i32>) -> Self {
1553        MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1554    }
1555}
1556
1557impl From<Vec<f64>> for MetadataSetValue {
1558    fn from(values: Vec<f64>) -> Self {
1559        MetadataSetValue::Float(values)
1560    }
1561}
1562
1563impl From<Vec<f32>> for MetadataSetValue {
1564    fn from(values: Vec<f32>) -> Self {
1565        MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1566    }
1567}
1568
1569impl From<Vec<String>> for MetadataSetValue {
1570    fn from(values: Vec<String>) -> Self {
1571        MetadataSetValue::Str(values)
1572    }
1573}
1574
1575impl From<Vec<&str>> for MetadataSetValue {
1576    fn from(values: Vec<&str>) -> Self {
1577        MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1578    }
1579}
1580
1581// TODO: Deprecate where_document
1582impl TryFrom<chroma_proto::WhereDocument> for Where {
1583    type Error = WhereConversionError;
1584
1585    fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1586        match proto_document.r#where_document {
1587            Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1588                let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1589                    proto_comparison.operator,
1590                ) {
1591                    Ok(operator) => operator,
1592                    Err(_) => {
1593                        return Err(WhereConversionError::cause(
1594                            "[Deprecated] Invalid where document operator",
1595                        ))
1596                    }
1597                };
1598                let comparison = DocumentExpression {
1599                    pattern: proto_comparison.pattern,
1600                    operator: operator.into(),
1601                };
1602                Ok(Where::Document(comparison))
1603            }
1604            Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1605                let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1606                    proto_children.operator,
1607                ) {
1608                    Ok(operator) => operator,
1609                    Err(_) => {
1610                        return Err(WhereConversionError::cause(
1611                            "[Deprecated] Invalid boolean operator",
1612                        ))
1613                    }
1614                };
1615                let children = CompositeExpression {
1616                    children: proto_children
1617                        .children
1618                        .into_iter()
1619                        .map(|child| child.try_into())
1620                        .collect::<Result<_, _>>()?,
1621                    operator: operator.into(),
1622                };
1623                Ok(Where::Composite(children))
1624            }
1625            None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1626        }
1627    }
1628}
1629
1630#[cfg(test)]
1631mod tests {
1632    use super::*;
1633
1634    #[test]
1635    fn test_update_metadata_try_from() {
1636        let mut proto_metadata = chroma_proto::UpdateMetadata {
1637            metadata: HashMap::new(),
1638        };
1639        proto_metadata.metadata.insert(
1640            "foo".to_string(),
1641            chroma_proto::UpdateMetadataValue {
1642                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1643            },
1644        );
1645        proto_metadata.metadata.insert(
1646            "bar".to_string(),
1647            chroma_proto::UpdateMetadataValue {
1648                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1649            },
1650        );
1651        proto_metadata.metadata.insert(
1652            "baz".to_string(),
1653            chroma_proto::UpdateMetadataValue {
1654                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1655                    "42".to_string(),
1656                )),
1657            },
1658        );
1659        // Add sparse vector test
1660        proto_metadata.metadata.insert(
1661            "sparse".to_string(),
1662            chroma_proto::UpdateMetadataValue {
1663                value: Some(
1664                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1665                        chroma_proto::SparseVector {
1666                            indices: vec![0, 5, 10],
1667                            values: vec![0.1, 0.5, 0.9],
1668                        },
1669                    ),
1670                ),
1671            },
1672        );
1673        let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1674        assert_eq!(converted_metadata.len(), 4);
1675        assert_eq!(
1676            converted_metadata.get("foo").unwrap(),
1677            &UpdateMetadataValue::Int(42)
1678        );
1679        assert_eq!(
1680            converted_metadata.get("bar").unwrap(),
1681            &UpdateMetadataValue::Float(42.0)
1682        );
1683        assert_eq!(
1684            converted_metadata.get("baz").unwrap(),
1685            &UpdateMetadataValue::Str("42".to_string())
1686        );
1687        assert_eq!(
1688            converted_metadata.get("sparse").unwrap(),
1689            &UpdateMetadataValue::SparseVector(SparseVector::new(
1690                vec![0, 5, 10],
1691                vec![0.1, 0.5, 0.9]
1692            ))
1693        );
1694    }
1695
1696    #[test]
1697    fn test_metadata_try_from() {
1698        let mut proto_metadata = chroma_proto::UpdateMetadata {
1699            metadata: HashMap::new(),
1700        };
1701        proto_metadata.metadata.insert(
1702            "foo".to_string(),
1703            chroma_proto::UpdateMetadataValue {
1704                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1705            },
1706        );
1707        proto_metadata.metadata.insert(
1708            "bar".to_string(),
1709            chroma_proto::UpdateMetadataValue {
1710                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1711            },
1712        );
1713        proto_metadata.metadata.insert(
1714            "baz".to_string(),
1715            chroma_proto::UpdateMetadataValue {
1716                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1717                    "42".to_string(),
1718                )),
1719            },
1720        );
1721        // Add sparse vector test
1722        proto_metadata.metadata.insert(
1723            "sparse".to_string(),
1724            chroma_proto::UpdateMetadataValue {
1725                value: Some(
1726                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1727                        chroma_proto::SparseVector {
1728                            indices: vec![1, 10, 100],
1729                            values: vec![0.2, 0.4, 0.6],
1730                        },
1731                    ),
1732                ),
1733            },
1734        );
1735        let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1736        assert_eq!(converted_metadata.len(), 4);
1737        assert_eq!(
1738            converted_metadata.get("foo").unwrap(),
1739            &MetadataValue::Int(42)
1740        );
1741        assert_eq!(
1742            converted_metadata.get("bar").unwrap(),
1743            &MetadataValue::Float(42.0)
1744        );
1745        assert_eq!(
1746            converted_metadata.get("baz").unwrap(),
1747            &MetadataValue::Str("42".to_string())
1748        );
1749        assert_eq!(
1750            converted_metadata.get("sparse").unwrap(),
1751            &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]))
1752        );
1753    }
1754
1755    #[test]
1756    fn test_where_clause_simple_from() {
1757        let proto_where = chroma_proto::Where {
1758            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1759                chroma_proto::DirectComparison {
1760                    key: "foo".to_string(),
1761                    comparison: Some(
1762                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1763                            chroma_proto::SingleIntComparison {
1764                                value: 42,
1765                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1766                            },
1767                        ),
1768                    ),
1769                },
1770            )),
1771        };
1772        let where_clause: Where = proto_where.try_into().unwrap();
1773        match where_clause {
1774            Where::Metadata(comparison) => {
1775                assert_eq!(comparison.key, "foo");
1776                match comparison.comparison {
1777                    MetadataComparison::Primitive(_, value) => {
1778                        assert_eq!(value, MetadataValue::Int(42));
1779                    }
1780                    _ => panic!("Invalid comparison type"),
1781                }
1782            }
1783            _ => panic!("Invalid where type"),
1784        }
1785    }
1786
1787    #[test]
1788    fn test_where_clause_with_children() {
1789        let proto_where = chroma_proto::Where {
1790            r#where: Some(chroma_proto::r#where::Where::Children(
1791                chroma_proto::WhereChildren {
1792                    children: vec![
1793                        chroma_proto::Where {
1794                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1795                                chroma_proto::DirectComparison {
1796                                    key: "foo".to_string(),
1797                                    comparison: Some(
1798                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1799                                            chroma_proto::SingleIntComparison {
1800                                                value: 42,
1801                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1802                                            },
1803                                        ),
1804                                    ),
1805                                },
1806                            )),
1807                        },
1808                        chroma_proto::Where {
1809                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1810                                chroma_proto::DirectComparison {
1811                                    key: "bar".to_string(),
1812                                    comparison: Some(
1813                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1814                                            chroma_proto::SingleIntComparison {
1815                                                value: 42,
1816                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1817                                            },
1818                                        ),
1819                                    ),
1820                                },
1821                            )),
1822                        },
1823                    ],
1824                    operator: chroma_proto::BooleanOperator::And.into(),
1825                },
1826            )),
1827        };
1828        let where_clause: Where = proto_where.try_into().unwrap();
1829        match where_clause {
1830            Where::Composite(children) => {
1831                assert_eq!(children.children.len(), 2);
1832                assert_eq!(children.operator, BooleanOperator::And);
1833            }
1834            _ => panic!("Invalid where type"),
1835        }
1836    }
1837
1838    #[test]
1839    fn test_where_document_simple() {
1840        let proto_where = chroma_proto::WhereDocument {
1841            r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
1842                chroma_proto::DirectWhereDocument {
1843                    pattern: "foo".to_string(),
1844                    operator: chroma_proto::WhereDocumentOperator::Contains.into(),
1845                },
1846            )),
1847        };
1848        let where_document: Where = proto_where.try_into().unwrap();
1849        match where_document {
1850            Where::Document(comparison) => {
1851                assert_eq!(comparison.pattern, "foo");
1852                assert_eq!(comparison.operator, DocumentOperator::Contains);
1853            }
1854            _ => panic!("Invalid where document type"),
1855        }
1856    }
1857
1858    #[test]
1859    fn test_where_document_with_children() {
1860        let proto_where = chroma_proto::WhereDocument {
1861            r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
1862                chroma_proto::WhereDocumentChildren {
1863                    children: vec![
1864                        chroma_proto::WhereDocument {
1865                            r#where_document: Some(
1866                                chroma_proto::where_document::WhereDocument::Direct(
1867                                    chroma_proto::DirectWhereDocument {
1868                                        pattern: "foo".to_string(),
1869                                        operator: chroma_proto::WhereDocumentOperator::Contains
1870                                            .into(),
1871                                    },
1872                                ),
1873                            ),
1874                        },
1875                        chroma_proto::WhereDocument {
1876                            r#where_document: Some(
1877                                chroma_proto::where_document::WhereDocument::Direct(
1878                                    chroma_proto::DirectWhereDocument {
1879                                        pattern: "bar".to_string(),
1880                                        operator: chroma_proto::WhereDocumentOperator::Contains
1881                                            .into(),
1882                                    },
1883                                ),
1884                            ),
1885                        },
1886                    ],
1887                    operator: chroma_proto::BooleanOperator::And.into(),
1888                },
1889            )),
1890        };
1891        let where_document: Where = proto_where.try_into().unwrap();
1892        match where_document {
1893            Where::Composite(children) => {
1894                assert_eq!(children.children.len(), 2);
1895                assert_eq!(children.operator, BooleanOperator::And);
1896            }
1897            _ => panic!("Invalid where document type"),
1898        }
1899    }
1900
1901    #[test]
1902    fn test_sparse_vector_new() {
1903        let indices = vec![0, 5, 10];
1904        let values = vec![0.1, 0.5, 0.9];
1905        let sparse = SparseVector::new(indices.clone(), values.clone());
1906        assert_eq!(sparse.indices, indices);
1907        assert_eq!(sparse.values, values);
1908    }
1909
1910    #[test]
1911    fn test_sparse_vector_from_pairs() {
1912        let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
1913        let sparse = SparseVector::from_pairs(pairs.clone());
1914        assert_eq!(sparse.indices, vec![0, 5, 10]);
1915        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
1916    }
1917
1918    #[test]
1919    fn test_sparse_vector_iter() {
1920        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
1921        let collected: Vec<(u32, f32)> = sparse.iter().collect();
1922        assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
1923    }
1924
1925    #[test]
1926    fn test_sparse_vector_ordering() {
1927        let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1928        let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1929        let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]);
1930        let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]);
1931
1932        assert_eq!(sparse1, sparse2);
1933        assert!(sparse1 < sparse3);
1934        assert!(sparse1 < sparse4);
1935    }
1936
1937    #[test]
1938    fn test_sparse_vector_proto_conversion() {
1939        let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]);
1940        let proto: chroma_proto::SparseVector = sparse.clone().into();
1941        assert_eq!(proto.indices, vec![1, 10, 100]);
1942        assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
1943
1944        let converted: SparseVector = proto.into();
1945        assert_eq!(converted, sparse);
1946    }
1947
1948    #[test]
1949    fn test_sparse_vector_logical_size() {
1950        let metadata = Metadata::from([(
1951            "sparse".to_string(),
1952            MetadataValue::SparseVector(SparseVector::new(
1953                vec![0, 1, 2, 3, 4],
1954                vec![0.1, 0.2, 0.3, 0.4, 0.5],
1955            )),
1956        )]);
1957
1958        let size = logical_size_of_metadata(&metadata);
1959        // Size should include the key string length and the sparse vector data
1960        // "sparse" = 6 bytes + 5 * 4 bytes (u32 indices) + 5 * 4 bytes (f32 values) = 46 bytes
1961        assert_eq!(size, 46);
1962    }
1963
1964    #[test]
1965    fn test_sparse_vector_validation() {
1966        // Valid sparse vector
1967        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]);
1968        assert!(sparse.validate().is_ok());
1969
1970        // Length mismatch
1971        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
1972        let result = sparse.validate();
1973        assert!(result.is_err());
1974        assert!(matches!(
1975            result.unwrap_err(),
1976            MetadataValueConversionError::SparseVectorLengthMismatch
1977        ));
1978
1979        // Unsorted indices (descending order)
1980        let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]);
1981        let result = sparse.validate();
1982        assert!(result.is_err());
1983        assert!(matches!(
1984            result.unwrap_err(),
1985            MetadataValueConversionError::SparseVectorIndicesNotSorted
1986        ));
1987
1988        // Duplicate indices (not strictly ascending)
1989        let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]);
1990        let result = sparse.validate();
1991        assert!(result.is_err());
1992        assert!(matches!(
1993            result.unwrap_err(),
1994            MetadataValueConversionError::SparseVectorIndicesNotSorted
1995        ));
1996
1997        // Descending at one point
1998        let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]);
1999        let result = sparse.validate();
2000        assert!(result.is_err());
2001        assert!(matches!(
2002            result.unwrap_err(),
2003            MetadataValueConversionError::SparseVectorIndicesNotSorted
2004        ));
2005    }
2006
2007    #[test]
2008    fn test_sparse_vector_deserialize_old_format() {
2009        // Old format without #type field (backward compatibility)
2010        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2011        let sv: SparseVector = serde_json::from_str(json).unwrap();
2012        assert_eq!(sv.indices, vec![0, 1, 2]);
2013        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2014    }
2015
2016    #[test]
2017    fn test_sparse_vector_deserialize_new_format() {
2018        // New format with #type field
2019        let json =
2020            "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2021        let sv: SparseVector = serde_json::from_str(json).unwrap();
2022        assert_eq!(sv.indices, vec![0, 1, 2]);
2023        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2024    }
2025
2026    #[test]
2027    fn test_sparse_vector_deserialize_new_format_field_order() {
2028        // New format with different field order (should still work)
2029        let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2030        let sv: SparseVector = serde_json::from_str(json).unwrap();
2031        assert_eq!(sv.indices, vec![5, 10]);
2032        assert_eq!(sv.values, vec![0.5, 1.0]);
2033    }
2034
2035    #[test]
2036    fn test_sparse_vector_deserialize_wrong_type_tag() {
2037        // Wrong #type field value should fail
2038        let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2039        let result: Result<SparseVector, _> = serde_json::from_str(json);
2040        assert!(result.is_err());
2041        let err_msg = result.unwrap_err().to_string();
2042        assert!(err_msg.contains("sparse_vector"));
2043    }
2044
2045    #[test]
2046    fn test_sparse_vector_serialize_always_has_type() {
2047        // Serialization should always include #type field
2048        let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
2049        let json = serde_json::to_value(&sv).unwrap();
2050
2051        assert_eq!(json["#type"], "sparse_vector");
2052        assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2053        assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2054    }
2055
2056    #[test]
2057    fn test_sparse_vector_roundtrip_with_type() {
2058        // Test that serialize -> deserialize preserves the data
2059        let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]);
2060        let json = serde_json::to_string(&original).unwrap();
2061
2062        // Verify the serialized JSON contains #type
2063        assert!(json.contains("\"#type\":\"sparse_vector\""));
2064
2065        let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2066        assert_eq!(original, deserialized);
2067    }
2068
2069    #[test]
2070    fn test_sparse_vector_in_metadata_old_format() {
2071        // Test that old format works when sparse vector is in metadata
2072        let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2073        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2074
2075        let sparse_value = &map["sparse"];
2076        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2077        assert_eq!(sv.indices, vec![0, 1]);
2078        assert_eq!(sv.values, vec![1.0, 2.0]);
2079    }
2080
2081    #[test]
2082    fn test_sparse_vector_in_metadata_new_format() {
2083        // Test that new format works when sparse vector is in metadata
2084        let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2085        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2086
2087        let sparse_value = &map["sparse"];
2088        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2089        assert_eq!(sv.indices, vec![0, 1]);
2090        assert_eq!(sv.values, vec![1.0, 2.0]);
2091    }
2092}