chroma_types/
metadata.rs

1use chroma_error::{ChromaError, ErrorCodes};
2use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
3use serde_json::{Number, Value};
4use sprs::CsVec;
5use std::{
6    cmp::Ordering,
7    collections::{HashMap, HashSet},
8    mem::size_of_val,
9    ops::{BitAnd, BitOr},
10};
11use thiserror::Error;
12
13use crate::chroma_proto;
14
15#[cfg(feature = "pyo3")]
16use pyo3::types::PyAnyMethods;
17
18#[cfg(feature = "testing")]
19use proptest::prelude::*;
20
21#[derive(Serialize, Deserialize)]
22struct SparseVectorSerdeHelper {
23    #[serde(rename = "#type")]
24    type_tag: Option<String>,
25    indices: Vec<u32>,
26    values: Vec<f32>,
27}
28
29/// Represents a sparse vector using parallel arrays for indices and values.
30///
31/// On deserialization: accepts both old format `{"indices": [...], "values": [...]}`
32/// and new format `{"#type": "sparse_vector", "indices": [...], "values": [...]}`.
33///
34/// On serialization: always includes `#type` field with value `"sparse_vector"`.
35#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
37pub struct SparseVector {
38    /// Dimension indices
39    pub indices: Vec<u32>,
40    /// Values corresponding to each index
41    pub values: Vec<f32>,
42}
43
44// Custom deserializer: accept both old and new formats
45impl<'de> Deserialize<'de> for SparseVector {
46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47    where
48        D: Deserializer<'de>,
49    {
50        let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
51
52        // If #type is present, validate it
53        if let Some(type_tag) = &helper.type_tag {
54            if type_tag != "sparse_vector" {
55                return Err(serde::de::Error::custom(format!(
56                    "Expected #type='sparse_vector', got '{}'",
57                    type_tag
58                )));
59            }
60        }
61
62        Ok(SparseVector {
63            indices: helper.indices,
64            values: helper.values,
65        })
66    }
67}
68
69// Custom serializer: always include #type field
70impl Serialize for SparseVector {
71    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72    where
73        S: Serializer,
74    {
75        let helper = SparseVectorSerdeHelper {
76            type_tag: Some("sparse_vector".to_string()),
77            indices: self.indices.clone(),
78            values: self.values.clone(),
79        };
80        helper.serialize(serializer)
81    }
82}
83
84impl SparseVector {
85    /// Create a new sparse vector from parallel arrays.
86    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Self {
87        Self { indices, values }
88    }
89
90    /// Create a sparse vector from an iterator of (index, value) pairs.
91    pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
92        let (indices, values) = pairs.into_iter().unzip();
93        Self { indices, values }
94    }
95
96    /// Iterate over (index, value) pairs.
97    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
98        self.indices
99            .iter()
100            .copied()
101            .zip(self.values.iter().copied())
102    }
103
104    /// Validate the sparse vector
105    pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
106        // Check that indices and values have the same length
107        if self.indices.len() != self.values.len() {
108            return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
109        }
110
111        // Check that indices are sorted in strictly ascending order (no duplicates)
112        for i in 1..self.indices.len() {
113            if self.indices[i] <= self.indices[i - 1] {
114                return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
115            }
116        }
117
118        Ok(())
119    }
120}
121
122impl Eq for SparseVector {}
123
124impl Ord for SparseVector {
125    fn cmp(&self, other: &Self) -> Ordering {
126        self.indices.cmp(&other.indices).then_with(|| {
127            for (a, b) in self.values.iter().zip(other.values.iter()) {
128                match a.total_cmp(b) {
129                    Ordering::Equal => continue,
130                    other => return other,
131                }
132            }
133            self.values.len().cmp(&other.values.len())
134        })
135    }
136}
137
138impl PartialOrd for SparseVector {
139    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144impl From<chroma_proto::SparseVector> for SparseVector {
145    fn from(proto: chroma_proto::SparseVector) -> Self {
146        SparseVector::new(proto.indices, proto.values)
147    }
148}
149
150impl From<SparseVector> for chroma_proto::SparseVector {
151    fn from(sparse: SparseVector) -> Self {
152        chroma_proto::SparseVector {
153            indices: sparse.indices,
154            values: sparse.values,
155        }
156    }
157}
158
159/// Convert SparseVector to sprs::CsVec for efficient sparse operations
160impl From<&SparseVector> for CsVec<f32> {
161    fn from(sparse: &SparseVector) -> Self {
162        let (indices, values) = sparse
163            .iter()
164            .map(|(index, value)| (index as usize, value))
165            .unzip();
166        CsVec::new(u32::MAX as usize, indices, values)
167    }
168}
169
170impl From<SparseVector> for CsVec<f32> {
171    fn from(sparse: SparseVector) -> Self {
172        (&sparse).into()
173    }
174}
175
176#[cfg(feature = "pyo3")]
177impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
178    type Target = pyo3::PyAny;
179    type Output = pyo3::Bound<'py, Self::Target>;
180    type Error = pyo3::PyErr;
181
182    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
183        use pyo3::types::PyDict;
184        let dict = PyDict::new(py);
185        dict.set_item("indices", self.indices)?;
186        dict.set_item("values", self.values)?;
187        Ok(dict.into_any())
188    }
189}
190
191#[cfg(feature = "pyo3")]
192impl<'py> pyo3::FromPyObject<'py> for SparseVector {
193    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
194        use pyo3::types::PyDict;
195
196        let dict = ob.downcast::<PyDict>()?;
197        let indices_obj = dict.get_item("indices")?;
198        let values_obj = dict.get_item("values")?;
199
200        let indices: Vec<u32> = indices_obj.extract()?;
201        let values: Vec<f32> = values_obj.extract()?;
202
203        Ok(SparseVector::new(indices, values))
204    }
205}
206
207#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
208#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
209#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
210#[serde(untagged)]
211pub enum UpdateMetadataValue {
212    Bool(bool),
213    Int(i64),
214    #[cfg_attr(
215        feature = "testing",
216        proptest(
217            strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
218        )
219    )]
220    Float(f64),
221    Str(String),
222    #[cfg_attr(feature = "testing", proptest(skip))]
223    SparseVector(SparseVector),
224    None,
225}
226
227#[cfg(feature = "pyo3")]
228impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
229    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
230        if ob.is_none() {
231            Ok(UpdateMetadataValue::None)
232        } else if let Ok(value) = ob.extract::<bool>() {
233            Ok(UpdateMetadataValue::Bool(value))
234        } else if let Ok(value) = ob.extract::<i64>() {
235            Ok(UpdateMetadataValue::Int(value))
236        } else if let Ok(value) = ob.extract::<f64>() {
237            Ok(UpdateMetadataValue::Float(value))
238        } else if let Ok(value) = ob.extract::<String>() {
239            Ok(UpdateMetadataValue::Str(value))
240        } else if let Ok(value) = ob.extract::<SparseVector>() {
241            Ok(UpdateMetadataValue::SparseVector(value))
242        } else {
243            Err(pyo3::exceptions::PyTypeError::new_err(
244                "Cannot convert Python object to UpdateMetadataValue",
245            ))
246        }
247    }
248}
249
250#[derive(Error, Debug)]
251pub enum UpdateMetadataValueConversionError {
252    #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
253    InvalidValue,
254}
255
256impl ChromaError for UpdateMetadataValueConversionError {
257    fn code(&self) -> ErrorCodes {
258        match self {
259            UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
260        }
261    }
262}
263
264impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
265    type Error = UpdateMetadataValueConversionError;
266
267    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
268        match &value.value {
269            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
270                Ok(UpdateMetadataValue::Bool(*value))
271            }
272            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
273                Ok(UpdateMetadataValue::Int(*value))
274            }
275            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
276                Ok(UpdateMetadataValue::Float(*value))
277            }
278            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
279                Ok(UpdateMetadataValue::Str(value.clone()))
280            }
281            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
282                Ok(UpdateMetadataValue::SparseVector(value.clone().into()))
283            }
284            // Used to communicate that the user wants to delete this key.
285            None => Ok(UpdateMetadataValue::None),
286        }
287    }
288}
289
290impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
291    fn from(value: UpdateMetadataValue) -> Self {
292        match value {
293            UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
294                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
295            },
296            UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
297                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
298            },
299            UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
300                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
301                    value,
302                )),
303            },
304            UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
305                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
306                    value,
307                )),
308            },
309            UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
310                value: Some(
311                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
312                        sparse_vec.into(),
313                    ),
314                ),
315            },
316            UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
317        }
318    }
319}
320
321impl TryFrom<&UpdateMetadataValue> for MetadataValue {
322    type Error = MetadataValueConversionError;
323
324    fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
325        match value {
326            UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
327            UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
328            UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
329            UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
330            UpdateMetadataValue::SparseVector(value) => {
331                Ok(MetadataValue::SparseVector(value.clone()))
332            }
333            UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
334        }
335    }
336}
337
338/*
339===========================================
340MetadataValue
341===========================================
342*/
343
344#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
345#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
346#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
347#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
348#[serde(untagged)]
349pub enum MetadataValue {
350    Bool(bool),
351    Int(i64),
352    #[cfg_attr(
353        feature = "testing",
354        proptest(
355            strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
356        )
357    )]
358    Float(f64),
359    Str(String),
360    #[cfg_attr(feature = "testing", proptest(skip))]
361    SparseVector(SparseVector),
362}
363
364impl Eq for MetadataValue {}
365
366#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
367pub enum MetadataValueType {
368    Bool,
369    Int,
370    Float,
371    Str,
372    SparseVector,
373}
374
375impl MetadataValue {
376    pub fn value_type(&self) -> MetadataValueType {
377        match self {
378            MetadataValue::Bool(_) => MetadataValueType::Bool,
379            MetadataValue::Int(_) => MetadataValueType::Int,
380            MetadataValue::Float(_) => MetadataValueType::Float,
381            MetadataValue::Str(_) => MetadataValueType::Str,
382            MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
383        }
384    }
385}
386
387impl From<&MetadataValue> for MetadataValueType {
388    fn from(value: &MetadataValue) -> Self {
389        value.value_type()
390    }
391}
392
393impl From<bool> for MetadataValue {
394    fn from(v: bool) -> Self {
395        MetadataValue::Bool(v)
396    }
397}
398
399impl From<i64> for MetadataValue {
400    fn from(v: i64) -> Self {
401        MetadataValue::Int(v)
402    }
403}
404
405impl From<i32> for MetadataValue {
406    fn from(v: i32) -> Self {
407        MetadataValue::Int(v as i64)
408    }
409}
410
411impl From<f64> for MetadataValue {
412    fn from(v: f64) -> Self {
413        MetadataValue::Float(v)
414    }
415}
416
417impl From<f32> for MetadataValue {
418    fn from(v: f32) -> Self {
419        MetadataValue::Float(v as f64)
420    }
421}
422
423impl From<String> for MetadataValue {
424    fn from(v: String) -> Self {
425        MetadataValue::Str(v)
426    }
427}
428
429impl From<&str> for MetadataValue {
430    fn from(v: &str) -> Self {
431        MetadataValue::Str(v.to_string())
432    }
433}
434
435impl From<SparseVector> for MetadataValue {
436    fn from(v: SparseVector) -> Self {
437        MetadataValue::SparseVector(v)
438    }
439}
440
441/// We need `Eq` and `Ord` since we want to use this as a key in `BTreeMap`
442///
443/// For cross-type comparisons, we define a consistent ordering based on variant position:
444/// Bool < Int < Float < Str < SparseVector
445#[allow(clippy::derive_ord_xor_partial_ord)]
446impl Ord for MetadataValue {
447    fn cmp(&self, other: &Self) -> Ordering {
448        // Define type ordering based on variant position
449        fn type_order(val: &MetadataValue) -> u8 {
450            match val {
451                MetadataValue::Bool(_) => 0,
452                MetadataValue::Int(_) => 1,
453                MetadataValue::Float(_) => 2,
454                MetadataValue::Str(_) => 3,
455                MetadataValue::SparseVector(_) => 4,
456            }
457        }
458
459        // Chain type ordering with value ordering
460        type_order(self).cmp(&type_order(other)).then_with(|| {
461            match (self, other) {
462                (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
463                (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
464                (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
465                (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
466                (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
467                    left.cmp(right)
468                }
469                _ => Ordering::Equal, // Different types, but type_order already handled this
470            }
471        })
472    }
473}
474
475impl PartialOrd for MetadataValue {
476    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
477        Some(self.cmp(other))
478    }
479}
480
481impl TryFrom<&MetadataValue> for bool {
482    type Error = MetadataValueConversionError;
483
484    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
485        match value {
486            MetadataValue::Bool(value) => Ok(*value),
487            _ => Err(MetadataValueConversionError::InvalidValue),
488        }
489    }
490}
491
492impl TryFrom<&MetadataValue> for i64 {
493    type Error = MetadataValueConversionError;
494
495    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
496        match value {
497            MetadataValue::Int(value) => Ok(*value),
498            _ => Err(MetadataValueConversionError::InvalidValue),
499        }
500    }
501}
502
503impl TryFrom<&MetadataValue> for f64 {
504    type Error = MetadataValueConversionError;
505
506    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
507        match value {
508            MetadataValue::Float(value) => Ok(*value),
509            _ => Err(MetadataValueConversionError::InvalidValue),
510        }
511    }
512}
513
514impl TryFrom<&MetadataValue> for String {
515    type Error = MetadataValueConversionError;
516
517    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
518        match value {
519            MetadataValue::Str(value) => Ok(value.clone()),
520            _ => Err(MetadataValueConversionError::InvalidValue),
521        }
522    }
523}
524
525impl From<MetadataValue> for UpdateMetadataValue {
526    fn from(value: MetadataValue) -> Self {
527        match value {
528            MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
529            MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
530            MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
531            MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
532            MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
533        }
534    }
535}
536
537impl From<MetadataValue> for Value {
538    fn from(value: MetadataValue) -> Self {
539        match value {
540            MetadataValue::Bool(val) => Self::Bool(val),
541            MetadataValue::Int(val) => Self::Number(
542                Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
543            ),
544            MetadataValue::Float(val) => Self::Number(
545                Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
546            ),
547            MetadataValue::Str(val) => Self::String(val),
548            MetadataValue::SparseVector(val) => {
549                let mut map = serde_json::Map::new();
550                map.insert(
551                    "indices".to_string(),
552                    Value::Array(
553                        val.indices
554                            .iter()
555                            .map(|&i| Value::Number(i.into()))
556                            .collect(),
557                    ),
558                );
559                map.insert(
560                    "values".to_string(),
561                    Value::Array(
562                        val.values
563                            .iter()
564                            .map(|&v| {
565                                Value::Number(
566                                    Number::from_f64(v as f64)
567                                        .expect("Float number should not be NaN or infinite"),
568                                )
569                            })
570                            .collect(),
571                    ),
572                );
573                Self::Object(map)
574            }
575        }
576    }
577}
578
579#[derive(Error, Debug)]
580pub enum MetadataValueConversionError {
581    #[error("Invalid metadata value, valid values are: Int, Float, Str")]
582    InvalidValue,
583    #[error("Metadata key cannot start with '#' or '$': {0}")]
584    InvalidKey(String),
585    #[error("Sparse vector indices and values must have the same length")]
586    SparseVectorLengthMismatch,
587    #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
588    SparseVectorIndicesNotSorted,
589}
590
591impl ChromaError for MetadataValueConversionError {
592    fn code(&self) -> ErrorCodes {
593        match self {
594            MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
595            MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
596            MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
597            MetadataValueConversionError::SparseVectorIndicesNotSorted => {
598                ErrorCodes::InvalidArgument
599            }
600        }
601    }
602}
603
604impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
605    type Error = MetadataValueConversionError;
606
607    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
608        match &value.value {
609            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
610                Ok(MetadataValue::Bool(*value))
611            }
612            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
613                Ok(MetadataValue::Int(*value))
614            }
615            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
616                Ok(MetadataValue::Float(*value))
617            }
618            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
619                Ok(MetadataValue::Str(value.clone()))
620            }
621            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
622                Ok(MetadataValue::SparseVector(value.clone().into()))
623            }
624            _ => Err(MetadataValueConversionError::InvalidValue),
625        }
626    }
627}
628
629impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
630    fn from(value: MetadataValue) -> Self {
631        match value {
632            MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
633                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
634            },
635            MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
636                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
637                    value,
638                )),
639            },
640            MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
641                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
642                    value,
643                )),
644            },
645            MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
646                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
647            },
648            MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
649                value: Some(
650                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
651                        sparse_vec.into(),
652                    ),
653                ),
654            },
655        }
656    }
657}
658
659/*
660===========================================
661UpdateMetadata
662===========================================
663*/
664pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
665
666/**
667 * Check if two metadata are close to equal. Ignores small differences in float values.
668 */
669pub fn are_update_metadatas_close_to_equal(
670    metadata1: &UpdateMetadata,
671    metadata2: &UpdateMetadata,
672) -> bool {
673    assert_eq!(metadata1.len(), metadata2.len());
674
675    for (key, value) in metadata1.iter() {
676        if !metadata2.contains_key(key) {
677            return false;
678        }
679        let other_value = metadata2.get(key).unwrap();
680
681        if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
682            (value, other_value)
683        {
684            if (value - other_value).abs() > 1e-6 {
685                return false;
686            }
687        } else if value != other_value {
688            return false;
689        }
690    }
691
692    true
693}
694
695pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
696    assert_eq!(metadata1.len(), metadata2.len());
697
698    for (key, value) in metadata1.iter() {
699        if !metadata2.contains_key(key) {
700            return false;
701        }
702        let other_value = metadata2.get(key).unwrap();
703
704        if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
705            (value, other_value)
706        {
707            if (value - other_value).abs() > 1e-6 {
708                return false;
709            }
710        } else if value != other_value {
711            return false;
712        }
713    }
714
715    true
716}
717
718impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
719    type Error = UpdateMetadataValueConversionError;
720
721    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
722        let mut metadata = UpdateMetadata::new();
723        for (key, value) in proto_metadata.metadata.iter() {
724            let value = match value.try_into() {
725                Ok(value) => value,
726                Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
727            };
728            metadata.insert(key.clone(), value);
729        }
730        Ok(metadata)
731    }
732}
733
734impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
735    fn from(metadata: UpdateMetadata) -> Self {
736        let mut metadata = metadata;
737        let mut proto_metadata = chroma_proto::UpdateMetadata {
738            metadata: HashMap::new(),
739        };
740        for (key, value) in metadata.drain() {
741            let proto_value = value.into();
742            proto_metadata.metadata.insert(key.clone(), proto_value);
743        }
744        proto_metadata
745    }
746}
747
748/*
749===========================================
750Metadata
751===========================================
752*/
753
754pub type Metadata = HashMap<String, MetadataValue>;
755pub type DeletedMetadata = HashSet<String>;
756
757pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
758    metadata
759        .iter()
760        .map(|(k, v)| {
761            k.len()
762                + match v {
763                    MetadataValue::Bool(b) => size_of_val(b),
764                    MetadataValue::Int(i) => size_of_val(i),
765                    MetadataValue::Float(f) => size_of_val(f),
766                    MetadataValue::Str(s) => s.len(),
767                    MetadataValue::SparseVector(v) => {
768                        size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
769                    }
770                }
771        })
772        .sum()
773}
774
775pub fn get_metadata_value_as<'a, T>(
776    metadata: &'a Metadata,
777    key: &str,
778) -> Result<T, Box<MetadataValueConversionError>>
779where
780    T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
781{
782    let res = match metadata.get(key) {
783        Some(value) => T::try_from(value),
784        None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
785    };
786    match res {
787        Ok(value) => Ok(value),
788        Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
789    }
790}
791
792impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
793    type Error = MetadataValueConversionError;
794
795    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
796        let mut metadata = Metadata::new();
797        for (key, value) in proto_metadata.metadata.iter() {
798            let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
799            if maybe_value.is_err() {
800                return Err(MetadataValueConversionError::InvalidValue);
801            }
802            let value = maybe_value.unwrap();
803            metadata.insert(key.clone(), value);
804        }
805        Ok(metadata)
806    }
807}
808
809impl From<Metadata> for chroma_proto::UpdateMetadata {
810    fn from(metadata: Metadata) -> Self {
811        let mut metadata = metadata;
812        let mut proto_metadata = chroma_proto::UpdateMetadata {
813            metadata: HashMap::new(),
814        };
815        for (key, value) in metadata.drain() {
816            let proto_value = value.into();
817            proto_metadata.metadata.insert(key.clone(), proto_value);
818        }
819        proto_metadata
820    }
821}
822
823#[derive(Debug, Default)]
824pub struct MetadataDelta<'referred_data> {
825    pub metadata_to_update: HashMap<
826        &'referred_data str,
827        (&'referred_data MetadataValue, &'referred_data MetadataValue),
828    >,
829    pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
830    pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
831}
832
833impl MetadataDelta<'_> {
834    pub fn new() -> Self {
835        Self::default()
836    }
837}
838
839/*
840===========================================
841Metadata queries
842===========================================
843*/
844
845#[derive(Clone, Debug, Error, PartialEq)]
846pub enum WhereConversionError {
847    #[error("Error: {0}")]
848    Cause(String),
849    #[error("{0} -> {1}")]
850    Trace(String, Box<Self>),
851}
852
853impl WhereConversionError {
854    pub fn cause(msg: impl ToString) -> Self {
855        Self::Cause(msg.to_string())
856    }
857
858    pub fn trace(self, context: impl ToString) -> Self {
859        Self::Trace(context.to_string(), Box::new(self))
860    }
861}
862
863/// This `Where` enum serves as an unified representation for the `where` and `where_document` clauses.
864/// Although this is not unified in the API level due to legacy design choices, in the future we will be
865/// unifying them together, and the structure of the unified AST should be identical to the one here.
866/// Currently both `where` and `where_document` clauses will be translated into `Where`, and if both are
867/// present we simply create a conjunction of both clauses as the actual filter. This is consistent with
868/// the semantics we used to have when the `where` and `where_document` clauses are treated seperately.
869// TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level.
870#[derive(Clone, Debug, PartialEq)]
871#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
872pub enum Where {
873    Composite(CompositeExpression),
874    Document(DocumentExpression),
875    Metadata(MetadataExpression),
876}
877
878impl serde::Serialize for Where {
879    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
880    where
881        S: Serializer,
882    {
883        match self {
884            Where::Composite(composite) => {
885                let mut map = serializer.serialize_map(Some(1))?;
886                let op_key = match composite.operator {
887                    BooleanOperator::And => "$and",
888                    BooleanOperator::Or => "$or",
889                };
890                map.serialize_entry(op_key, &composite.children)?;
891                map.end()
892            }
893            Where::Document(doc) => {
894                let mut outer_map = serializer.serialize_map(Some(1))?;
895                let mut inner_map = serde_json::Map::new();
896                let op_key = match doc.operator {
897                    DocumentOperator::Contains => "$contains",
898                    DocumentOperator::NotContains => "$not_contains",
899                    DocumentOperator::Regex => "$regex",
900                    DocumentOperator::NotRegex => "$not_regex",
901                };
902                inner_map.insert(
903                    op_key.to_string(),
904                    serde_json::Value::String(doc.pattern.clone()),
905                );
906                outer_map.serialize_entry("#document", &inner_map)?;
907                outer_map.end()
908            }
909            Where::Metadata(meta) => {
910                let mut outer_map = serializer.serialize_map(Some(1))?;
911                let mut inner_map = serde_json::Map::new();
912
913                match &meta.comparison {
914                    MetadataComparison::Primitive(op, value) => {
915                        let op_key = match op {
916                            PrimitiveOperator::Equal => "$eq",
917                            PrimitiveOperator::NotEqual => "$ne",
918                            PrimitiveOperator::GreaterThan => "$gt",
919                            PrimitiveOperator::GreaterThanOrEqual => "$gte",
920                            PrimitiveOperator::LessThan => "$lt",
921                            PrimitiveOperator::LessThanOrEqual => "$lte",
922                        };
923                        let value_json =
924                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
925                        inner_map.insert(op_key.to_string(), value_json);
926                    }
927                    MetadataComparison::Set(op, set_value) => {
928                        let op_key = match op {
929                            SetOperator::In => "$in",
930                            SetOperator::NotIn => "$nin",
931                        };
932                        let values_json = match set_value {
933                            MetadataSetValue::Bool(v) => serde_json::to_value(v),
934                            MetadataSetValue::Int(v) => serde_json::to_value(v),
935                            MetadataSetValue::Float(v) => serde_json::to_value(v),
936                            MetadataSetValue::Str(v) => serde_json::to_value(v),
937                        }
938                        .map_err(serde::ser::Error::custom)?;
939                        inner_map.insert(op_key.to_string(), values_json);
940                    }
941                }
942
943                outer_map.serialize_entry(&meta.key, &inner_map)?;
944                outer_map.end()
945            }
946        }
947    }
948}
949
950impl Where {
951    pub fn conjunction(children: Vec<Where>) -> Self {
952        Self::Composite(CompositeExpression {
953            operator: BooleanOperator::And,
954            children,
955        })
956    }
957    pub fn disjunction(children: Vec<Where>) -> Self {
958        Self::Composite(CompositeExpression {
959            operator: BooleanOperator::Or,
960            children,
961        })
962    }
963
964    pub fn fts_query_length(&self) -> u64 {
965        match self {
966            Where::Composite(composite_expression) => composite_expression
967                .children
968                .iter()
969                .map(Where::fts_query_length)
970                .sum(),
971            // The query length is defined to be the number of trigram tokens
972            Where::Document(document_expression) => {
973                document_expression.pattern.len().max(3) as u64 - 2
974            }
975            Where::Metadata(_) => 0,
976        }
977    }
978
979    pub fn metadata_predicate_count(&self) -> u64 {
980        match self {
981            Where::Composite(composite_expression) => composite_expression
982                .children
983                .iter()
984                .map(Where::metadata_predicate_count)
985                .sum(),
986            Where::Document(_) => 0,
987            Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
988                MetadataComparison::Primitive(_, _) => 1,
989                MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
990                    MetadataSetValue::Bool(items) => items.len() as u64,
991                    MetadataSetValue::Int(items) => items.len() as u64,
992                    MetadataSetValue::Float(items) => items.len() as u64,
993                    MetadataSetValue::Str(items) => items.len() as u64,
994                },
995            },
996        }
997    }
998}
999
1000impl BitAnd for Where {
1001    type Output = Where;
1002
1003    fn bitand(self, rhs: Self) -> Self::Output {
1004        match self {
1005            Where::Composite(CompositeExpression {
1006                operator: BooleanOperator::And,
1007                mut children,
1008            }) => match rhs {
1009                Where::Composite(CompositeExpression {
1010                    operator: BooleanOperator::And,
1011                    children: rhs_children,
1012                }) => {
1013                    children.extend(rhs_children);
1014                    Where::Composite(CompositeExpression {
1015                        operator: BooleanOperator::And,
1016                        children,
1017                    })
1018                }
1019                _ => {
1020                    children.push(rhs);
1021                    Where::Composite(CompositeExpression {
1022                        operator: BooleanOperator::And,
1023                        children,
1024                    })
1025                }
1026            },
1027            _ => match rhs {
1028                Where::Composite(CompositeExpression {
1029                    operator: BooleanOperator::And,
1030                    mut children,
1031                }) => {
1032                    children.insert(0, self);
1033                    Where::Composite(CompositeExpression {
1034                        operator: BooleanOperator::And,
1035                        children,
1036                    })
1037                }
1038                _ => Where::conjunction(vec![self, rhs]),
1039            },
1040        }
1041    }
1042}
1043
1044impl BitOr for Where {
1045    type Output = Where;
1046
1047    fn bitor(self, rhs: Self) -> Self::Output {
1048        match self {
1049            Where::Composite(CompositeExpression {
1050                operator: BooleanOperator::Or,
1051                mut children,
1052            }) => match rhs {
1053                Where::Composite(CompositeExpression {
1054                    operator: BooleanOperator::Or,
1055                    children: rhs_children,
1056                }) => {
1057                    children.extend(rhs_children);
1058                    Where::Composite(CompositeExpression {
1059                        operator: BooleanOperator::Or,
1060                        children,
1061                    })
1062                }
1063                _ => {
1064                    children.push(rhs);
1065                    Where::Composite(CompositeExpression {
1066                        operator: BooleanOperator::Or,
1067                        children,
1068                    })
1069                }
1070            },
1071            _ => match rhs {
1072                Where::Composite(CompositeExpression {
1073                    operator: BooleanOperator::Or,
1074                    mut children,
1075                }) => {
1076                    children.insert(0, self);
1077                    Where::Composite(CompositeExpression {
1078                        operator: BooleanOperator::Or,
1079                        children,
1080                    })
1081                }
1082                _ => Where::disjunction(vec![self, rhs]),
1083            },
1084        }
1085    }
1086}
1087
1088impl TryFrom<chroma_proto::Where> for Where {
1089    type Error = WhereConversionError;
1090
1091    fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1092        let where_inner = proto_where
1093            .r#where
1094            .ok_or(WhereConversionError::cause("Invalid Where"))?;
1095        Ok(match where_inner {
1096            chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1097                Self::Metadata(direct_comparison.try_into()?)
1098            }
1099            chroma_proto::r#where::Where::Children(where_children) => {
1100                Self::Composite(where_children.try_into()?)
1101            }
1102            chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1103                Self::Document(direct_where_document.into())
1104            }
1105        })
1106    }
1107}
1108
1109impl TryFrom<Where> for chroma_proto::Where {
1110    type Error = WhereConversionError;
1111
1112    fn try_from(value: Where) -> Result<Self, Self::Error> {
1113        let proto_where = match value {
1114            Where::Composite(composite_expression) => {
1115                chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1116            }
1117            Where::Document(document_expression) => {
1118                chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1119            }
1120            Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1121                chroma_proto::DirectComparison::try_from(metadata_expression)
1122                    .map_err(|err| err.trace("MetadataExpression"))?,
1123            ),
1124        };
1125        Ok(Self {
1126            r#where: Some(proto_where),
1127        })
1128    }
1129}
1130
1131#[derive(Clone, Debug, PartialEq)]
1132#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1133pub struct CompositeExpression {
1134    pub operator: BooleanOperator,
1135    pub children: Vec<Where>,
1136}
1137
1138impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1139    type Error = WhereConversionError;
1140
1141    fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1142        let operator = proto_children.operator().into();
1143        let children = proto_children
1144            .children
1145            .into_iter()
1146            .map(Where::try_from)
1147            .collect::<Result<Vec<_>, _>>()
1148            .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1149        Ok(Self { operator, children })
1150    }
1151}
1152
1153impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1154    type Error = WhereConversionError;
1155
1156    fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1157        Ok(Self {
1158            operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1159            children: value
1160                .children
1161                .into_iter()
1162                .map(chroma_proto::Where::try_from)
1163                .collect::<Result<_, _>>()?,
1164        })
1165    }
1166}
1167
1168#[derive(Clone, Debug, PartialEq)]
1169#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1170pub enum BooleanOperator {
1171    And,
1172    Or,
1173}
1174
1175impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1176    fn from(value: chroma_proto::BooleanOperator) -> Self {
1177        match value {
1178            chroma_proto::BooleanOperator::And => Self::And,
1179            chroma_proto::BooleanOperator::Or => Self::Or,
1180        }
1181    }
1182}
1183
1184impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1185    fn from(value: BooleanOperator) -> Self {
1186        match value {
1187            BooleanOperator::And => Self::And,
1188            BooleanOperator::Or => Self::Or,
1189        }
1190    }
1191}
1192
1193#[derive(Clone, Debug, PartialEq)]
1194#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1195pub struct DocumentExpression {
1196    pub operator: DocumentOperator,
1197    pub pattern: String,
1198}
1199
1200impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1201    fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1202        Self {
1203            operator: value.operator().into(),
1204            pattern: value.pattern,
1205        }
1206    }
1207}
1208
1209impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1210    fn from(value: DocumentExpression) -> Self {
1211        Self {
1212            pattern: value.pattern,
1213            operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1214        }
1215    }
1216}
1217
1218#[derive(Clone, Debug, PartialEq)]
1219#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1220pub enum DocumentOperator {
1221    Contains,
1222    NotContains,
1223    Regex,
1224    NotRegex,
1225}
1226impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1227    fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1228        match value {
1229            chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1230            chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1231            chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1232            chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1233        }
1234    }
1235}
1236
1237impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1238    fn from(value: DocumentOperator) -> Self {
1239        match value {
1240            DocumentOperator::Contains => Self::Contains,
1241            DocumentOperator::NotContains => Self::NotContains,
1242            DocumentOperator::Regex => Self::Regex,
1243            DocumentOperator::NotRegex => Self::NotRegex,
1244        }
1245    }
1246}
1247
1248#[derive(Clone, Debug, PartialEq)]
1249#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1250pub struct MetadataExpression {
1251    pub key: String,
1252    pub comparison: MetadataComparison,
1253}
1254
1255impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1256    type Error = WhereConversionError;
1257
1258    fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1259        let proto_comparison = value
1260            .comparison
1261            .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1262        let comparison = match proto_comparison {
1263            chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1264                single_string_comparison,
1265            ) => MetadataComparison::Primitive(
1266                single_string_comparison.comparator().into(),
1267                MetadataValue::Str(single_string_comparison.value),
1268            ),
1269            chroma_proto::direct_comparison::Comparison::StringListOperand(
1270                string_list_comparison,
1271            ) => MetadataComparison::Set(
1272                string_list_comparison.list_operator().into(),
1273                MetadataSetValue::Str(string_list_comparison.values),
1274            ),
1275            chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1276                single_int_comparison,
1277            ) => MetadataComparison::Primitive(
1278                match single_int_comparison
1279                    .comparator
1280                    .ok_or(WhereConversionError::cause(
1281                        "Invalid scalar integer operator",
1282                    ))? {
1283                    chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1284                        chroma_proto::GenericComparator::try_from(op)
1285                            .map_err(WhereConversionError::cause)?
1286                            .into()
1287                    }
1288                    chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1289                        chroma_proto::NumberComparator::try_from(op)
1290                            .map_err(WhereConversionError::cause)?
1291                            .into()
1292                    }
1293                },
1294                MetadataValue::Int(single_int_comparison.value),
1295            ),
1296            chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1297                MetadataComparison::Set(
1298                    int_list_comparison.list_operator().into(),
1299                    MetadataSetValue::Int(int_list_comparison.values),
1300                )
1301            }
1302            chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1303                single_double_comparison,
1304            ) => MetadataComparison::Primitive(
1305                match single_double_comparison
1306                    .comparator
1307                    .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1308                {
1309                    chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1310                        chroma_proto::GenericComparator::try_from(op)
1311                            .map_err(WhereConversionError::cause)?
1312                            .into()
1313                    }
1314                    chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1315                        chroma_proto::NumberComparator::try_from(op)
1316                            .map_err(WhereConversionError::cause)?
1317                            .into()
1318                    }
1319                },
1320                MetadataValue::Float(single_double_comparison.value),
1321            ),
1322            chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1323                double_list_comparison,
1324            ) => MetadataComparison::Set(
1325                double_list_comparison.list_operator().into(),
1326                MetadataSetValue::Float(double_list_comparison.values),
1327            ),
1328            chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1329                MetadataComparison::Set(
1330                    bool_list_comparison.list_operator().into(),
1331                    MetadataSetValue::Bool(bool_list_comparison.values),
1332                )
1333            }
1334            chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1335                single_bool_comparison,
1336            ) => MetadataComparison::Primitive(
1337                single_bool_comparison.comparator().into(),
1338                MetadataValue::Bool(single_bool_comparison.value),
1339            ),
1340        };
1341        Ok(Self {
1342            key: value.key,
1343            comparison,
1344        })
1345    }
1346}
1347
1348impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1349    type Error = WhereConversionError;
1350
1351    fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1352        let comparison = match value.comparison {
1353            MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1354                MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1355                MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1356                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1357                                numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1358                            }),
1359                MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1360                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1361                                numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1362                            }),
1363                MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1364                MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1365            },
1366            MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1367                MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1368                MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1369                MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1370                MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1371            },
1372        };
1373        Ok(Self {
1374            key: value.key,
1375            comparison: Some(comparison),
1376        })
1377    }
1378}
1379
1380#[derive(Clone, Debug, PartialEq)]
1381#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1382pub enum MetadataComparison {
1383    Primitive(PrimitiveOperator, MetadataValue),
1384    Set(SetOperator, MetadataSetValue),
1385}
1386
1387#[derive(Clone, Debug, PartialEq)]
1388#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1389pub enum PrimitiveOperator {
1390    Equal,
1391    NotEqual,
1392    GreaterThan,
1393    GreaterThanOrEqual,
1394    LessThan,
1395    LessThanOrEqual,
1396}
1397
1398impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1399    fn from(value: chroma_proto::GenericComparator) -> Self {
1400        match value {
1401            chroma_proto::GenericComparator::Eq => Self::Equal,
1402            chroma_proto::GenericComparator::Ne => Self::NotEqual,
1403        }
1404    }
1405}
1406
1407impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1408    type Error = WhereConversionError;
1409
1410    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1411        match value {
1412            PrimitiveOperator::Equal => Ok(Self::Eq),
1413            PrimitiveOperator::NotEqual => Ok(Self::Ne),
1414            op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1415        }
1416    }
1417}
1418
1419impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1420    fn from(value: chroma_proto::NumberComparator) -> Self {
1421        match value {
1422            chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1423            chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1424            chroma_proto::NumberComparator::Lt => Self::LessThan,
1425            chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1426        }
1427    }
1428}
1429
1430impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1431    type Error = WhereConversionError;
1432
1433    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1434        match value {
1435            PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1436            PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1437            PrimitiveOperator::LessThan => Ok(Self::Lt),
1438            PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1439            op => Err(WhereConversionError::cause(format!(
1440                "{op:?} ∉ [≤, <, >, ≥]"
1441            ))),
1442        }
1443    }
1444}
1445
1446#[derive(Clone, Debug, PartialEq)]
1447#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1448pub enum SetOperator {
1449    In,
1450    NotIn,
1451}
1452
1453impl From<chroma_proto::ListOperator> for SetOperator {
1454    fn from(value: chroma_proto::ListOperator) -> Self {
1455        match value {
1456            chroma_proto::ListOperator::In => Self::In,
1457            chroma_proto::ListOperator::Nin => Self::NotIn,
1458        }
1459    }
1460}
1461
1462impl From<SetOperator> for chroma_proto::ListOperator {
1463    fn from(value: SetOperator) -> Self {
1464        match value {
1465            SetOperator::In => Self::In,
1466            SetOperator::NotIn => Self::Nin,
1467        }
1468    }
1469}
1470
1471#[derive(Clone, Debug, PartialEq)]
1472#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1473pub enum MetadataSetValue {
1474    Bool(Vec<bool>),
1475    Int(Vec<i64>),
1476    Float(Vec<f64>),
1477    Str(Vec<String>),
1478}
1479
1480impl MetadataSetValue {
1481    pub fn value_type(&self) -> MetadataValueType {
1482        match self {
1483            MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1484            MetadataSetValue::Int(_) => MetadataValueType::Int,
1485            MetadataSetValue::Float(_) => MetadataValueType::Float,
1486            MetadataSetValue::Str(_) => MetadataValueType::Str,
1487        }
1488    }
1489}
1490
1491impl From<Vec<bool>> for MetadataSetValue {
1492    fn from(values: Vec<bool>) -> Self {
1493        MetadataSetValue::Bool(values)
1494    }
1495}
1496
1497impl From<Vec<i64>> for MetadataSetValue {
1498    fn from(values: Vec<i64>) -> Self {
1499        MetadataSetValue::Int(values)
1500    }
1501}
1502
1503impl From<Vec<i32>> for MetadataSetValue {
1504    fn from(values: Vec<i32>) -> Self {
1505        MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1506    }
1507}
1508
1509impl From<Vec<f64>> for MetadataSetValue {
1510    fn from(values: Vec<f64>) -> Self {
1511        MetadataSetValue::Float(values)
1512    }
1513}
1514
1515impl From<Vec<f32>> for MetadataSetValue {
1516    fn from(values: Vec<f32>) -> Self {
1517        MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1518    }
1519}
1520
1521impl From<Vec<String>> for MetadataSetValue {
1522    fn from(values: Vec<String>) -> Self {
1523        MetadataSetValue::Str(values)
1524    }
1525}
1526
1527impl From<Vec<&str>> for MetadataSetValue {
1528    fn from(values: Vec<&str>) -> Self {
1529        MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1530    }
1531}
1532
1533// TODO: Deprecate where_document
1534impl TryFrom<chroma_proto::WhereDocument> for Where {
1535    type Error = WhereConversionError;
1536
1537    fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1538        match proto_document.r#where_document {
1539            Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1540                let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1541                    proto_comparison.operator,
1542                ) {
1543                    Ok(operator) => operator,
1544                    Err(_) => {
1545                        return Err(WhereConversionError::cause(
1546                            "[Deprecated] Invalid where document operator",
1547                        ))
1548                    }
1549                };
1550                let comparison = DocumentExpression {
1551                    pattern: proto_comparison.pattern,
1552                    operator: operator.into(),
1553                };
1554                Ok(Where::Document(comparison))
1555            }
1556            Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1557                let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1558                    proto_children.operator,
1559                ) {
1560                    Ok(operator) => operator,
1561                    Err(_) => {
1562                        return Err(WhereConversionError::cause(
1563                            "[Deprecated] Invalid boolean operator",
1564                        ))
1565                    }
1566                };
1567                let children = CompositeExpression {
1568                    children: proto_children
1569                        .children
1570                        .into_iter()
1571                        .map(|child| child.try_into())
1572                        .collect::<Result<_, _>>()?,
1573                    operator: operator.into(),
1574                };
1575                Ok(Where::Composite(children))
1576            }
1577            None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1578        }
1579    }
1580}
1581
1582#[cfg(test)]
1583mod tests {
1584    use super::*;
1585
1586    #[test]
1587    fn test_update_metadata_try_from() {
1588        let mut proto_metadata = chroma_proto::UpdateMetadata {
1589            metadata: HashMap::new(),
1590        };
1591        proto_metadata.metadata.insert(
1592            "foo".to_string(),
1593            chroma_proto::UpdateMetadataValue {
1594                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1595            },
1596        );
1597        proto_metadata.metadata.insert(
1598            "bar".to_string(),
1599            chroma_proto::UpdateMetadataValue {
1600                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1601            },
1602        );
1603        proto_metadata.metadata.insert(
1604            "baz".to_string(),
1605            chroma_proto::UpdateMetadataValue {
1606                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1607                    "42".to_string(),
1608                )),
1609            },
1610        );
1611        // Add sparse vector test
1612        proto_metadata.metadata.insert(
1613            "sparse".to_string(),
1614            chroma_proto::UpdateMetadataValue {
1615                value: Some(
1616                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1617                        chroma_proto::SparseVector {
1618                            indices: vec![0, 5, 10],
1619                            values: vec![0.1, 0.5, 0.9],
1620                        },
1621                    ),
1622                ),
1623            },
1624        );
1625        let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1626        assert_eq!(converted_metadata.len(), 4);
1627        assert_eq!(
1628            converted_metadata.get("foo").unwrap(),
1629            &UpdateMetadataValue::Int(42)
1630        );
1631        assert_eq!(
1632            converted_metadata.get("bar").unwrap(),
1633            &UpdateMetadataValue::Float(42.0)
1634        );
1635        assert_eq!(
1636            converted_metadata.get("baz").unwrap(),
1637            &UpdateMetadataValue::Str("42".to_string())
1638        );
1639        assert_eq!(
1640            converted_metadata.get("sparse").unwrap(),
1641            &UpdateMetadataValue::SparseVector(SparseVector::new(
1642                vec![0, 5, 10],
1643                vec![0.1, 0.5, 0.9]
1644            ))
1645        );
1646    }
1647
1648    #[test]
1649    fn test_metadata_try_from() {
1650        let mut proto_metadata = chroma_proto::UpdateMetadata {
1651            metadata: HashMap::new(),
1652        };
1653        proto_metadata.metadata.insert(
1654            "foo".to_string(),
1655            chroma_proto::UpdateMetadataValue {
1656                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1657            },
1658        );
1659        proto_metadata.metadata.insert(
1660            "bar".to_string(),
1661            chroma_proto::UpdateMetadataValue {
1662                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1663            },
1664        );
1665        proto_metadata.metadata.insert(
1666            "baz".to_string(),
1667            chroma_proto::UpdateMetadataValue {
1668                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1669                    "42".to_string(),
1670                )),
1671            },
1672        );
1673        // Add sparse vector test
1674        proto_metadata.metadata.insert(
1675            "sparse".to_string(),
1676            chroma_proto::UpdateMetadataValue {
1677                value: Some(
1678                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1679                        chroma_proto::SparseVector {
1680                            indices: vec![1, 10, 100],
1681                            values: vec![0.2, 0.4, 0.6],
1682                        },
1683                    ),
1684                ),
1685            },
1686        );
1687        let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1688        assert_eq!(converted_metadata.len(), 4);
1689        assert_eq!(
1690            converted_metadata.get("foo").unwrap(),
1691            &MetadataValue::Int(42)
1692        );
1693        assert_eq!(
1694            converted_metadata.get("bar").unwrap(),
1695            &MetadataValue::Float(42.0)
1696        );
1697        assert_eq!(
1698            converted_metadata.get("baz").unwrap(),
1699            &MetadataValue::Str("42".to_string())
1700        );
1701        assert_eq!(
1702            converted_metadata.get("sparse").unwrap(),
1703            &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]))
1704        );
1705    }
1706
1707    #[test]
1708    fn test_where_clause_simple_from() {
1709        let proto_where = chroma_proto::Where {
1710            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1711                chroma_proto::DirectComparison {
1712                    key: "foo".to_string(),
1713                    comparison: Some(
1714                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1715                            chroma_proto::SingleIntComparison {
1716                                value: 42,
1717                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1718                            },
1719                        ),
1720                    ),
1721                },
1722            )),
1723        };
1724        let where_clause: Where = proto_where.try_into().unwrap();
1725        match where_clause {
1726            Where::Metadata(comparison) => {
1727                assert_eq!(comparison.key, "foo");
1728                match comparison.comparison {
1729                    MetadataComparison::Primitive(_, value) => {
1730                        assert_eq!(value, MetadataValue::Int(42));
1731                    }
1732                    _ => panic!("Invalid comparison type"),
1733                }
1734            }
1735            _ => panic!("Invalid where type"),
1736        }
1737    }
1738
1739    #[test]
1740    fn test_where_clause_with_children() {
1741        let proto_where = chroma_proto::Where {
1742            r#where: Some(chroma_proto::r#where::Where::Children(
1743                chroma_proto::WhereChildren {
1744                    children: vec![
1745                        chroma_proto::Where {
1746                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1747                                chroma_proto::DirectComparison {
1748                                    key: "foo".to_string(),
1749                                    comparison: Some(
1750                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1751                                            chroma_proto::SingleIntComparison {
1752                                                value: 42,
1753                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1754                                            },
1755                                        ),
1756                                    ),
1757                                },
1758                            )),
1759                        },
1760                        chroma_proto::Where {
1761                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1762                                chroma_proto::DirectComparison {
1763                                    key: "bar".to_string(),
1764                                    comparison: Some(
1765                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1766                                            chroma_proto::SingleIntComparison {
1767                                                value: 42,
1768                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1769                                            },
1770                                        ),
1771                                    ),
1772                                },
1773                            )),
1774                        },
1775                    ],
1776                    operator: chroma_proto::BooleanOperator::And.into(),
1777                },
1778            )),
1779        };
1780        let where_clause: Where = proto_where.try_into().unwrap();
1781        match where_clause {
1782            Where::Composite(children) => {
1783                assert_eq!(children.children.len(), 2);
1784                assert_eq!(children.operator, BooleanOperator::And);
1785            }
1786            _ => panic!("Invalid where type"),
1787        }
1788    }
1789
1790    #[test]
1791    fn test_where_document_simple() {
1792        let proto_where = chroma_proto::WhereDocument {
1793            r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
1794                chroma_proto::DirectWhereDocument {
1795                    pattern: "foo".to_string(),
1796                    operator: chroma_proto::WhereDocumentOperator::Contains.into(),
1797                },
1798            )),
1799        };
1800        let where_document: Where = proto_where.try_into().unwrap();
1801        match where_document {
1802            Where::Document(comparison) => {
1803                assert_eq!(comparison.pattern, "foo");
1804                assert_eq!(comparison.operator, DocumentOperator::Contains);
1805            }
1806            _ => panic!("Invalid where document type"),
1807        }
1808    }
1809
1810    #[test]
1811    fn test_where_document_with_children() {
1812        let proto_where = chroma_proto::WhereDocument {
1813            r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
1814                chroma_proto::WhereDocumentChildren {
1815                    children: vec![
1816                        chroma_proto::WhereDocument {
1817                            r#where_document: Some(
1818                                chroma_proto::where_document::WhereDocument::Direct(
1819                                    chroma_proto::DirectWhereDocument {
1820                                        pattern: "foo".to_string(),
1821                                        operator: chroma_proto::WhereDocumentOperator::Contains
1822                                            .into(),
1823                                    },
1824                                ),
1825                            ),
1826                        },
1827                        chroma_proto::WhereDocument {
1828                            r#where_document: Some(
1829                                chroma_proto::where_document::WhereDocument::Direct(
1830                                    chroma_proto::DirectWhereDocument {
1831                                        pattern: "bar".to_string(),
1832                                        operator: chroma_proto::WhereDocumentOperator::Contains
1833                                            .into(),
1834                                    },
1835                                ),
1836                            ),
1837                        },
1838                    ],
1839                    operator: chroma_proto::BooleanOperator::And.into(),
1840                },
1841            )),
1842        };
1843        let where_document: Where = proto_where.try_into().unwrap();
1844        match where_document {
1845            Where::Composite(children) => {
1846                assert_eq!(children.children.len(), 2);
1847                assert_eq!(children.operator, BooleanOperator::And);
1848            }
1849            _ => panic!("Invalid where document type"),
1850        }
1851    }
1852
1853    #[test]
1854    fn test_sparse_vector_new() {
1855        let indices = vec![0, 5, 10];
1856        let values = vec![0.1, 0.5, 0.9];
1857        let sparse = SparseVector::new(indices.clone(), values.clone());
1858        assert_eq!(sparse.indices, indices);
1859        assert_eq!(sparse.values, values);
1860    }
1861
1862    #[test]
1863    fn test_sparse_vector_from_pairs() {
1864        let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
1865        let sparse = SparseVector::from_pairs(pairs.clone());
1866        assert_eq!(sparse.indices, vec![0, 5, 10]);
1867        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
1868    }
1869
1870    #[test]
1871    fn test_sparse_vector_iter() {
1872        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
1873        let collected: Vec<(u32, f32)> = sparse.iter().collect();
1874        assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
1875    }
1876
1877    #[test]
1878    fn test_sparse_vector_ordering() {
1879        let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1880        let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1881        let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]);
1882        let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]);
1883
1884        assert_eq!(sparse1, sparse2);
1885        assert!(sparse1 < sparse3);
1886        assert!(sparse1 < sparse4);
1887    }
1888
1889    #[test]
1890    fn test_sparse_vector_proto_conversion() {
1891        let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]);
1892        let proto: chroma_proto::SparseVector = sparse.clone().into();
1893        assert_eq!(proto.indices, vec![1, 10, 100]);
1894        assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
1895
1896        let converted: SparseVector = proto.into();
1897        assert_eq!(converted, sparse);
1898    }
1899
1900    #[test]
1901    fn test_sparse_vector_logical_size() {
1902        let metadata = Metadata::from([(
1903            "sparse".to_string(),
1904            MetadataValue::SparseVector(SparseVector::new(
1905                vec![0, 1, 2, 3, 4],
1906                vec![0.1, 0.2, 0.3, 0.4, 0.5],
1907            )),
1908        )]);
1909
1910        let size = logical_size_of_metadata(&metadata);
1911        // Size should include the key string length and the sparse vector data
1912        // "sparse" = 6 bytes + 5 * 4 bytes (u32 indices) + 5 * 4 bytes (f32 values) = 46 bytes
1913        assert_eq!(size, 46);
1914    }
1915
1916    #[test]
1917    fn test_sparse_vector_validation() {
1918        // Valid sparse vector
1919        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]);
1920        assert!(sparse.validate().is_ok());
1921
1922        // Length mismatch
1923        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
1924        let result = sparse.validate();
1925        assert!(result.is_err());
1926        assert!(matches!(
1927            result.unwrap_err(),
1928            MetadataValueConversionError::SparseVectorLengthMismatch
1929        ));
1930
1931        // Unsorted indices (descending order)
1932        let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]);
1933        let result = sparse.validate();
1934        assert!(result.is_err());
1935        assert!(matches!(
1936            result.unwrap_err(),
1937            MetadataValueConversionError::SparseVectorIndicesNotSorted
1938        ));
1939
1940        // Duplicate indices (not strictly ascending)
1941        let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]);
1942        let result = sparse.validate();
1943        assert!(result.is_err());
1944        assert!(matches!(
1945            result.unwrap_err(),
1946            MetadataValueConversionError::SparseVectorIndicesNotSorted
1947        ));
1948
1949        // Descending at one point
1950        let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]);
1951        let result = sparse.validate();
1952        assert!(result.is_err());
1953        assert!(matches!(
1954            result.unwrap_err(),
1955            MetadataValueConversionError::SparseVectorIndicesNotSorted
1956        ));
1957    }
1958
1959    #[test]
1960    fn test_sparse_vector_deserialize_old_format() {
1961        // Old format without #type field (backward compatibility)
1962        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
1963        let sv: SparseVector = serde_json::from_str(json).unwrap();
1964        assert_eq!(sv.indices, vec![0, 1, 2]);
1965        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
1966    }
1967
1968    #[test]
1969    fn test_sparse_vector_deserialize_new_format() {
1970        // New format with #type field
1971        let json =
1972            "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
1973        let sv: SparseVector = serde_json::from_str(json).unwrap();
1974        assert_eq!(sv.indices, vec![0, 1, 2]);
1975        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
1976    }
1977
1978    #[test]
1979    fn test_sparse_vector_deserialize_new_format_field_order() {
1980        // New format with different field order (should still work)
1981        let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
1982        let sv: SparseVector = serde_json::from_str(json).unwrap();
1983        assert_eq!(sv.indices, vec![5, 10]);
1984        assert_eq!(sv.values, vec![0.5, 1.0]);
1985    }
1986
1987    #[test]
1988    fn test_sparse_vector_deserialize_wrong_type_tag() {
1989        // Wrong #type field value should fail
1990        let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
1991        let result: Result<SparseVector, _> = serde_json::from_str(json);
1992        assert!(result.is_err());
1993        let err_msg = result.unwrap_err().to_string();
1994        assert!(err_msg.contains("sparse_vector"));
1995    }
1996
1997    #[test]
1998    fn test_sparse_vector_serialize_always_has_type() {
1999        // Serialization should always include #type field
2000        let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
2001        let json = serde_json::to_value(&sv).unwrap();
2002
2003        assert_eq!(json["#type"], "sparse_vector");
2004        assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2005        assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2006    }
2007
2008    #[test]
2009    fn test_sparse_vector_roundtrip_with_type() {
2010        // Test that serialize -> deserialize preserves the data
2011        let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]);
2012        let json = serde_json::to_string(&original).unwrap();
2013
2014        // Verify the serialized JSON contains #type
2015        assert!(json.contains("\"#type\":\"sparse_vector\""));
2016
2017        let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2018        assert_eq!(original, deserialized);
2019    }
2020
2021    #[test]
2022    fn test_sparse_vector_in_metadata_old_format() {
2023        // Test that old format works when sparse vector is in metadata
2024        let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2025        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2026
2027        let sparse_value = &map["sparse"];
2028        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2029        assert_eq!(sv.indices, vec![0, 1]);
2030        assert_eq!(sv.values, vec![1.0, 2.0]);
2031    }
2032
2033    #[test]
2034    fn test_sparse_vector_in_metadata_new_format() {
2035        // Test that new format works when sparse vector is in metadata
2036        let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2037        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2038
2039        let sparse_value = &map["sparse"];
2040        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2041        assert_eq!(sv.indices, vec![0, 1]);
2042        assert_eq!(sv.values, vec![1.0, 2.0]);
2043    }
2044}