chroma_types/
metadata.rs

1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7    cmp::Ordering,
8    collections::{HashMap, HashSet},
9    mem::size_of_val,
10    ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::PyAnyMethods;
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24    #[serde(rename = "#type")]
25    type_tag: Option<String>,
26    indices: Vec<u32>,
27    values: Vec<f32>,
28}
29
30/// Represents a sparse vector using parallel arrays for indices and values.
31///
32/// On deserialization: accepts both old format `{"indices": [...], "values": [...]}`
33/// and new format `{"#type": "sparse_vector", "indices": [...], "values": [...]}`.
34///
35/// On serialization: always includes `#type` field with value `"sparse_vector"`.
36#[derive(Clone, Debug, PartialEq)]
37#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
38pub struct SparseVector {
39    /// Dimension indices
40    pub indices: Vec<u32>,
41    /// Values corresponding to each index
42    pub values: Vec<f32>,
43}
44
45// Custom deserializer: accept both old and new formats
46impl<'de> Deserialize<'de> for SparseVector {
47    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
48    where
49        D: Deserializer<'de>,
50    {
51        let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
52
53        // If #type is present, validate it
54        if let Some(type_tag) = &helper.type_tag {
55            if type_tag != "sparse_vector" {
56                return Err(serde::de::Error::custom(format!(
57                    "Expected #type='sparse_vector', got '{}'",
58                    type_tag
59                )));
60            }
61        }
62
63        Ok(SparseVector {
64            indices: helper.indices,
65            values: helper.values,
66        })
67    }
68}
69
70// Custom serializer: always include #type field
71impl Serialize for SparseVector {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: Serializer,
75    {
76        let helper = SparseVectorSerdeHelper {
77            type_tag: Some("sparse_vector".to_string()),
78            indices: self.indices.clone(),
79            values: self.values.clone(),
80        };
81        helper.serialize(serializer)
82    }
83}
84
85impl SparseVector {
86    /// Create a new sparse vector from parallel arrays.
87    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Self {
88        Self { indices, values }
89    }
90
91    /// Create a sparse vector from an iterator of (index, value) pairs.
92    pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
93        let (indices, values) = pairs.into_iter().unzip();
94        Self { indices, values }
95    }
96
97    /// Iterate over (index, value) pairs.
98    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
99        self.indices
100            .iter()
101            .copied()
102            .zip(self.values.iter().copied())
103    }
104
105    /// Validate the sparse vector
106    pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
107        // Check that indices and values have the same length
108        if self.indices.len() != self.values.len() {
109            return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
110        }
111
112        // Check that indices are sorted in strictly ascending order (no duplicates)
113        for i in 1..self.indices.len() {
114            if self.indices[i] <= self.indices[i - 1] {
115                return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
116            }
117        }
118
119        Ok(())
120    }
121}
122
123impl Eq for SparseVector {}
124
125impl Ord for SparseVector {
126    fn cmp(&self, other: &Self) -> Ordering {
127        self.indices.cmp(&other.indices).then_with(|| {
128            for (a, b) in self.values.iter().zip(other.values.iter()) {
129                match a.total_cmp(b) {
130                    Ordering::Equal => continue,
131                    other => return other,
132                }
133            }
134            self.values.len().cmp(&other.values.len())
135        })
136    }
137}
138
139impl PartialOrd for SparseVector {
140    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141        Some(self.cmp(other))
142    }
143}
144
145impl From<chroma_proto::SparseVector> for SparseVector {
146    fn from(proto: chroma_proto::SparseVector) -> Self {
147        SparseVector::new(proto.indices, proto.values)
148    }
149}
150
151impl From<SparseVector> for chroma_proto::SparseVector {
152    fn from(sparse: SparseVector) -> Self {
153        chroma_proto::SparseVector {
154            indices: sparse.indices,
155            values: sparse.values,
156        }
157    }
158}
159
160/// Convert SparseVector to sprs::CsVec for efficient sparse operations
161impl From<&SparseVector> for CsVec<f32> {
162    fn from(sparse: &SparseVector) -> Self {
163        let (indices, values) = sparse
164            .iter()
165            .map(|(index, value)| (index as usize, value))
166            .unzip();
167        CsVec::new(u32::MAX as usize, indices, values)
168    }
169}
170
171impl From<SparseVector> for CsVec<f32> {
172    fn from(sparse: SparseVector) -> Self {
173        (&sparse).into()
174    }
175}
176
177#[cfg(feature = "pyo3")]
178impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
179    type Target = pyo3::PyAny;
180    type Output = pyo3::Bound<'py, Self::Target>;
181    type Error = pyo3::PyErr;
182
183    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
184        use pyo3::types::PyDict;
185        let dict = PyDict::new(py);
186        dict.set_item("indices", self.indices)?;
187        dict.set_item("values", self.values)?;
188        Ok(dict.into_any())
189    }
190}
191
192#[cfg(feature = "pyo3")]
193impl<'py> pyo3::FromPyObject<'py> for SparseVector {
194    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
195        use pyo3::types::PyDict;
196
197        let dict = ob.downcast::<PyDict>()?;
198        let indices_obj = dict.get_item("indices")?;
199        let values_obj = dict.get_item("values")?;
200
201        let indices: Vec<u32> = indices_obj.extract()?;
202        let values: Vec<f32> = values_obj.extract()?;
203
204        Ok(SparseVector::new(indices, values))
205    }
206}
207
208#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
209#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
210#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
211#[serde(untagged)]
212pub enum UpdateMetadataValue {
213    Bool(bool),
214    Int(i64),
215    #[cfg_attr(
216        feature = "testing",
217        proptest(
218            strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
219        )
220    )]
221    Float(f64),
222    Str(String),
223    #[cfg_attr(feature = "testing", proptest(skip))]
224    SparseVector(SparseVector),
225    None,
226}
227
228#[cfg(feature = "pyo3")]
229impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
230    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
231        if ob.is_none() {
232            Ok(UpdateMetadataValue::None)
233        } else if let Ok(value) = ob.extract::<bool>() {
234            Ok(UpdateMetadataValue::Bool(value))
235        } else if let Ok(value) = ob.extract::<i64>() {
236            Ok(UpdateMetadataValue::Int(value))
237        } else if let Ok(value) = ob.extract::<f64>() {
238            Ok(UpdateMetadataValue::Float(value))
239        } else if let Ok(value) = ob.extract::<String>() {
240            Ok(UpdateMetadataValue::Str(value))
241        } else if let Ok(value) = ob.extract::<SparseVector>() {
242            Ok(UpdateMetadataValue::SparseVector(value))
243        } else {
244            Err(pyo3::exceptions::PyTypeError::new_err(
245                "Cannot convert Python object to UpdateMetadataValue",
246            ))
247        }
248    }
249}
250
251impl From<bool> for UpdateMetadataValue {
252    fn from(b: bool) -> Self {
253        Self::Bool(b)
254    }
255}
256
257impl From<i64> for UpdateMetadataValue {
258    fn from(v: i64) -> Self {
259        Self::Int(v)
260    }
261}
262
263impl From<i32> for UpdateMetadataValue {
264    fn from(v: i32) -> Self {
265        Self::Int(v as i64)
266    }
267}
268
269impl From<f64> for UpdateMetadataValue {
270    fn from(v: f64) -> Self {
271        Self::Float(v)
272    }
273}
274
275impl From<f32> for UpdateMetadataValue {
276    fn from(v: f32) -> Self {
277        Self::Float(v as f64)
278    }
279}
280
281impl From<String> for UpdateMetadataValue {
282    fn from(v: String) -> Self {
283        Self::Str(v)
284    }
285}
286
287impl From<&str> for UpdateMetadataValue {
288    fn from(v: &str) -> Self {
289        Self::Str(v.to_string())
290    }
291}
292
293impl From<SparseVector> for UpdateMetadataValue {
294    fn from(v: SparseVector) -> Self {
295        Self::SparseVector(v)
296    }
297}
298
299#[derive(Error, Debug)]
300pub enum UpdateMetadataValueConversionError {
301    #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
302    InvalidValue,
303}
304
305impl ChromaError for UpdateMetadataValueConversionError {
306    fn code(&self) -> ErrorCodes {
307        match self {
308            UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
309        }
310    }
311}
312
313impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
314    type Error = UpdateMetadataValueConversionError;
315
316    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
317        match &value.value {
318            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
319                Ok(UpdateMetadataValue::Bool(*value))
320            }
321            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
322                Ok(UpdateMetadataValue::Int(*value))
323            }
324            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
325                Ok(UpdateMetadataValue::Float(*value))
326            }
327            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
328                Ok(UpdateMetadataValue::Str(value.clone()))
329            }
330            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
331                Ok(UpdateMetadataValue::SparseVector(value.clone().into()))
332            }
333            // Used to communicate that the user wants to delete this key.
334            None => Ok(UpdateMetadataValue::None),
335        }
336    }
337}
338
339impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
340    fn from(value: UpdateMetadataValue) -> Self {
341        match value {
342            UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
343                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
344            },
345            UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
346                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
347            },
348            UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
349                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
350                    value,
351                )),
352            },
353            UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
354                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
355                    value,
356                )),
357            },
358            UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
359                value: Some(
360                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
361                        sparse_vec.into(),
362                    ),
363                ),
364            },
365            UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
366        }
367    }
368}
369
370impl TryFrom<&UpdateMetadataValue> for MetadataValue {
371    type Error = MetadataValueConversionError;
372
373    fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
374        match value {
375            UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
376            UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
377            UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
378            UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
379            UpdateMetadataValue::SparseVector(value) => {
380                Ok(MetadataValue::SparseVector(value.clone()))
381            }
382            UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
383        }
384    }
385}
386
387/*
388===========================================
389MetadataValue
390===========================================
391*/
392
393#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
394#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
395#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
396#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
397#[serde(untagged)]
398pub enum MetadataValue {
399    Bool(bool),
400    Int(i64),
401    #[cfg_attr(
402        feature = "testing",
403        proptest(
404            strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
405        )
406    )]
407    Float(f64),
408    Str(String),
409    #[cfg_attr(feature = "testing", proptest(skip))]
410    SparseVector(SparseVector),
411}
412
413impl std::fmt::Display for MetadataValue {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        match self {
416            MetadataValue::Bool(v) => write!(f, "{}", v),
417            MetadataValue::Int(v) => write!(f, "{}", v),
418            MetadataValue::Float(v) => write!(f, "{}", v),
419            MetadataValue::Str(v) => write!(f, "\"{}\"", v),
420            MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
421        }
422    }
423}
424
425impl Eq for MetadataValue {}
426
427#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
428pub enum MetadataValueType {
429    Bool,
430    Int,
431    Float,
432    Str,
433    SparseVector,
434}
435
436impl MetadataValue {
437    pub fn value_type(&self) -> MetadataValueType {
438        match self {
439            MetadataValue::Bool(_) => MetadataValueType::Bool,
440            MetadataValue::Int(_) => MetadataValueType::Int,
441            MetadataValue::Float(_) => MetadataValueType::Float,
442            MetadataValue::Str(_) => MetadataValueType::Str,
443            MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
444        }
445    }
446}
447
448impl From<&MetadataValue> for MetadataValueType {
449    fn from(value: &MetadataValue) -> Self {
450        value.value_type()
451    }
452}
453
454impl From<bool> for MetadataValue {
455    fn from(v: bool) -> Self {
456        MetadataValue::Bool(v)
457    }
458}
459
460impl From<i64> for MetadataValue {
461    fn from(v: i64) -> Self {
462        MetadataValue::Int(v)
463    }
464}
465
466impl From<i32> for MetadataValue {
467    fn from(v: i32) -> Self {
468        MetadataValue::Int(v as i64)
469    }
470}
471
472impl From<f64> for MetadataValue {
473    fn from(v: f64) -> Self {
474        MetadataValue::Float(v)
475    }
476}
477
478impl From<f32> for MetadataValue {
479    fn from(v: f32) -> Self {
480        MetadataValue::Float(v as f64)
481    }
482}
483
484impl From<String> for MetadataValue {
485    fn from(v: String) -> Self {
486        MetadataValue::Str(v)
487    }
488}
489
490impl From<&str> for MetadataValue {
491    fn from(v: &str) -> Self {
492        MetadataValue::Str(v.to_string())
493    }
494}
495
496impl From<SparseVector> for MetadataValue {
497    fn from(v: SparseVector) -> Self {
498        MetadataValue::SparseVector(v)
499    }
500}
501
502/// We need `Eq` and `Ord` since we want to use this as a key in `BTreeMap`
503///
504/// For cross-type comparisons, we define a consistent ordering based on variant position:
505/// Bool < Int < Float < Str < SparseVector
506#[allow(clippy::derive_ord_xor_partial_ord)]
507impl Ord for MetadataValue {
508    fn cmp(&self, other: &Self) -> Ordering {
509        // Define type ordering based on variant position
510        fn type_order(val: &MetadataValue) -> u8 {
511            match val {
512                MetadataValue::Bool(_) => 0,
513                MetadataValue::Int(_) => 1,
514                MetadataValue::Float(_) => 2,
515                MetadataValue::Str(_) => 3,
516                MetadataValue::SparseVector(_) => 4,
517            }
518        }
519
520        // Chain type ordering with value ordering
521        type_order(self).cmp(&type_order(other)).then_with(|| {
522            match (self, other) {
523                (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
524                (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
525                (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
526                (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
527                (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
528                    left.cmp(right)
529                }
530                _ => Ordering::Equal, // Different types, but type_order already handled this
531            }
532        })
533    }
534}
535
536impl PartialOrd for MetadataValue {
537    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
538        Some(self.cmp(other))
539    }
540}
541
542impl TryFrom<&MetadataValue> for bool {
543    type Error = MetadataValueConversionError;
544
545    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
546        match value {
547            MetadataValue::Bool(value) => Ok(*value),
548            _ => Err(MetadataValueConversionError::InvalidValue),
549        }
550    }
551}
552
553impl TryFrom<&MetadataValue> for i64 {
554    type Error = MetadataValueConversionError;
555
556    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
557        match value {
558            MetadataValue::Int(value) => Ok(*value),
559            _ => Err(MetadataValueConversionError::InvalidValue),
560        }
561    }
562}
563
564impl TryFrom<&MetadataValue> for f64 {
565    type Error = MetadataValueConversionError;
566
567    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
568        match value {
569            MetadataValue::Float(value) => Ok(*value),
570            _ => Err(MetadataValueConversionError::InvalidValue),
571        }
572    }
573}
574
575impl TryFrom<&MetadataValue> for String {
576    type Error = MetadataValueConversionError;
577
578    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
579        match value {
580            MetadataValue::Str(value) => Ok(value.clone()),
581            _ => Err(MetadataValueConversionError::InvalidValue),
582        }
583    }
584}
585
586impl From<MetadataValue> for UpdateMetadataValue {
587    fn from(value: MetadataValue) -> Self {
588        match value {
589            MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
590            MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
591            MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
592            MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
593            MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
594        }
595    }
596}
597
598impl From<MetadataValue> for Value {
599    fn from(value: MetadataValue) -> Self {
600        match value {
601            MetadataValue::Bool(val) => Self::Bool(val),
602            MetadataValue::Int(val) => Self::Number(
603                Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
604            ),
605            MetadataValue::Float(val) => Self::Number(
606                Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
607            ),
608            MetadataValue::Str(val) => Self::String(val),
609            MetadataValue::SparseVector(val) => {
610                let mut map = serde_json::Map::new();
611                map.insert(
612                    "indices".to_string(),
613                    Value::Array(
614                        val.indices
615                            .iter()
616                            .map(|&i| Value::Number(i.into()))
617                            .collect(),
618                    ),
619                );
620                map.insert(
621                    "values".to_string(),
622                    Value::Array(
623                        val.values
624                            .iter()
625                            .map(|&v| {
626                                Value::Number(
627                                    Number::from_f64(v as f64)
628                                        .expect("Float number should not be NaN or infinite"),
629                                )
630                            })
631                            .collect(),
632                    ),
633                );
634                Self::Object(map)
635            }
636        }
637    }
638}
639
640#[derive(Error, Debug)]
641pub enum MetadataValueConversionError {
642    #[error("Invalid metadata value, valid values are: Int, Float, Str")]
643    InvalidValue,
644    #[error("Metadata key cannot start with '#' or '$': {0}")]
645    InvalidKey(String),
646    #[error("Sparse vector indices and values must have the same length")]
647    SparseVectorLengthMismatch,
648    #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
649    SparseVectorIndicesNotSorted,
650}
651
652impl ChromaError for MetadataValueConversionError {
653    fn code(&self) -> ErrorCodes {
654        match self {
655            MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
656            MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
657            MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
658            MetadataValueConversionError::SparseVectorIndicesNotSorted => {
659                ErrorCodes::InvalidArgument
660            }
661        }
662    }
663}
664
665impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
666    type Error = MetadataValueConversionError;
667
668    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
669        match &value.value {
670            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
671                Ok(MetadataValue::Bool(*value))
672            }
673            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
674                Ok(MetadataValue::Int(*value))
675            }
676            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
677                Ok(MetadataValue::Float(*value))
678            }
679            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
680                Ok(MetadataValue::Str(value.clone()))
681            }
682            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
683                Ok(MetadataValue::SparseVector(value.clone().into()))
684            }
685            _ => Err(MetadataValueConversionError::InvalidValue),
686        }
687    }
688}
689
690impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
691    fn from(value: MetadataValue) -> Self {
692        match value {
693            MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
694                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
695            },
696            MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
697                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
698                    value,
699                )),
700            },
701            MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
702                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
703                    value,
704                )),
705            },
706            MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
707                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
708            },
709            MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
710                value: Some(
711                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
712                        sparse_vec.into(),
713                    ),
714                ),
715            },
716        }
717    }
718}
719
720/*
721===========================================
722UpdateMetadata
723===========================================
724*/
725pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
726
727/**
728 * Check if two metadata are close to equal. Ignores small differences in float values.
729 */
730pub fn are_update_metadatas_close_to_equal(
731    metadata1: &UpdateMetadata,
732    metadata2: &UpdateMetadata,
733) -> bool {
734    assert_eq!(metadata1.len(), metadata2.len());
735
736    for (key, value) in metadata1.iter() {
737        if !metadata2.contains_key(key) {
738            return false;
739        }
740        let other_value = metadata2.get(key).unwrap();
741
742        if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
743            (value, other_value)
744        {
745            if (value - other_value).abs() > 1e-6 {
746                return false;
747            }
748        } else if value != other_value {
749            return false;
750        }
751    }
752
753    true
754}
755
756pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
757    assert_eq!(metadata1.len(), metadata2.len());
758
759    for (key, value) in metadata1.iter() {
760        if !metadata2.contains_key(key) {
761            return false;
762        }
763        let other_value = metadata2.get(key).unwrap();
764
765        if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
766            (value, other_value)
767        {
768            if (value - other_value).abs() > 1e-6 {
769                return false;
770            }
771        } else if value != other_value {
772            return false;
773        }
774    }
775
776    true
777}
778
779impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
780    type Error = UpdateMetadataValueConversionError;
781
782    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
783        let mut metadata = UpdateMetadata::new();
784        for (key, value) in proto_metadata.metadata.iter() {
785            let value = match value.try_into() {
786                Ok(value) => value,
787                Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
788            };
789            metadata.insert(key.clone(), value);
790        }
791        Ok(metadata)
792    }
793}
794
795impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
796    fn from(metadata: UpdateMetadata) -> Self {
797        let mut metadata = metadata;
798        let mut proto_metadata = chroma_proto::UpdateMetadata {
799            metadata: HashMap::new(),
800        };
801        for (key, value) in metadata.drain() {
802            let proto_value = value.into();
803            proto_metadata.metadata.insert(key.clone(), proto_value);
804        }
805        proto_metadata
806    }
807}
808
809/*
810===========================================
811Metadata
812===========================================
813*/
814
815pub type Metadata = HashMap<String, MetadataValue>;
816pub type DeletedMetadata = HashSet<String>;
817
818pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
819    metadata
820        .iter()
821        .map(|(k, v)| {
822            k.len()
823                + match v {
824                    MetadataValue::Bool(b) => size_of_val(b),
825                    MetadataValue::Int(i) => size_of_val(i),
826                    MetadataValue::Float(f) => size_of_val(f),
827                    MetadataValue::Str(s) => s.len(),
828                    MetadataValue::SparseVector(v) => {
829                        size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
830                    }
831                }
832        })
833        .sum()
834}
835
836pub fn get_metadata_value_as<'a, T>(
837    metadata: &'a Metadata,
838    key: &str,
839) -> Result<T, Box<MetadataValueConversionError>>
840where
841    T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
842{
843    let res = match metadata.get(key) {
844        Some(value) => T::try_from(value),
845        None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
846    };
847    match res {
848        Ok(value) => Ok(value),
849        Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
850    }
851}
852
853impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
854    type Error = MetadataValueConversionError;
855
856    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
857        let mut metadata = Metadata::new();
858        for (key, value) in proto_metadata.metadata.iter() {
859            let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
860            if maybe_value.is_err() {
861                return Err(MetadataValueConversionError::InvalidValue);
862            }
863            let value = maybe_value.unwrap();
864            metadata.insert(key.clone(), value);
865        }
866        Ok(metadata)
867    }
868}
869
870impl From<Metadata> for chroma_proto::UpdateMetadata {
871    fn from(metadata: Metadata) -> Self {
872        let mut metadata = metadata;
873        let mut proto_metadata = chroma_proto::UpdateMetadata {
874            metadata: HashMap::new(),
875        };
876        for (key, value) in metadata.drain() {
877            let proto_value = value.into();
878            proto_metadata.metadata.insert(key.clone(), proto_value);
879        }
880        proto_metadata
881    }
882}
883
884#[derive(Debug, Default)]
885pub struct MetadataDelta<'referred_data> {
886    pub metadata_to_update: HashMap<
887        &'referred_data str,
888        (&'referred_data MetadataValue, &'referred_data MetadataValue),
889    >,
890    pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
891    pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
892}
893
894impl MetadataDelta<'_> {
895    pub fn new() -> Self {
896        Self::default()
897    }
898}
899
900/*
901===========================================
902Metadata queries
903===========================================
904*/
905
906#[derive(Clone, Debug, Error, PartialEq)]
907pub enum WhereConversionError {
908    #[error("Error: {0}")]
909    Cause(String),
910    #[error("{0} -> {1}")]
911    Trace(String, Box<Self>),
912}
913
914impl WhereConversionError {
915    pub fn cause(msg: impl ToString) -> Self {
916        Self::Cause(msg.to_string())
917    }
918
919    pub fn trace(self, context: impl ToString) -> Self {
920        Self::Trace(context.to_string(), Box::new(self))
921    }
922}
923
924/// This `Where` enum serves as an unified representation for the `where` and `where_document` clauses.
925/// Although this is not unified in the API level due to legacy design choices, in the future we will be
926/// unifying them together, and the structure of the unified AST should be identical to the one here.
927/// Currently both `where` and `where_document` clauses will be translated into `Where`, and if both are
928/// present we simply create a conjunction of both clauses as the actual filter. This is consistent with
929/// the semantics we used to have when the `where` and `where_document` clauses are treated seperately.
930// TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level.
931#[derive(Clone, Debug, PartialEq)]
932#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
933pub enum Where {
934    Composite(CompositeExpression),
935    Document(DocumentExpression),
936    Metadata(MetadataExpression),
937}
938
939impl std::fmt::Display for Where {
940    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941        match self {
942            Where::Composite(composite) => {
943                let fragment = composite
944                    .children
945                    .iter()
946                    .map(|child| format!("{}", child))
947                    .collect::<Vec<_>>()
948                    .join(match composite.operator {
949                        BooleanOperator::And => " & ",
950                        BooleanOperator::Or => " | ",
951                    });
952                write!(f, "({})", fragment)
953            }
954            Where::Metadata(expr) => write!(f, "{}", expr),
955            Where::Document(expr) => write!(f, "{}", expr),
956        }
957    }
958}
959
960impl serde::Serialize for Where {
961    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
962    where
963        S: Serializer,
964    {
965        match self {
966            Where::Composite(composite) => {
967                let mut map = serializer.serialize_map(Some(1))?;
968                let op_key = match composite.operator {
969                    BooleanOperator::And => "$and",
970                    BooleanOperator::Or => "$or",
971                };
972                map.serialize_entry(op_key, &composite.children)?;
973                map.end()
974            }
975            Where::Document(doc) => {
976                let mut outer_map = serializer.serialize_map(Some(1))?;
977                let mut inner_map = serde_json::Map::new();
978                let op_key = match doc.operator {
979                    DocumentOperator::Contains => "$contains",
980                    DocumentOperator::NotContains => "$not_contains",
981                    DocumentOperator::Regex => "$regex",
982                    DocumentOperator::NotRegex => "$not_regex",
983                };
984                inner_map.insert(
985                    op_key.to_string(),
986                    serde_json::Value::String(doc.pattern.clone()),
987                );
988                outer_map.serialize_entry("#document", &inner_map)?;
989                outer_map.end()
990            }
991            Where::Metadata(meta) => {
992                let mut outer_map = serializer.serialize_map(Some(1))?;
993                let mut inner_map = serde_json::Map::new();
994
995                match &meta.comparison {
996                    MetadataComparison::Primitive(op, value) => {
997                        let op_key = match op {
998                            PrimitiveOperator::Equal => "$eq",
999                            PrimitiveOperator::NotEqual => "$ne",
1000                            PrimitiveOperator::GreaterThan => "$gt",
1001                            PrimitiveOperator::GreaterThanOrEqual => "$gte",
1002                            PrimitiveOperator::LessThan => "$lt",
1003                            PrimitiveOperator::LessThanOrEqual => "$lte",
1004                        };
1005                        let value_json =
1006                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1007                        inner_map.insert(op_key.to_string(), value_json);
1008                    }
1009                    MetadataComparison::Set(op, set_value) => {
1010                        let op_key = match op {
1011                            SetOperator::In => "$in",
1012                            SetOperator::NotIn => "$nin",
1013                        };
1014                        let values_json = match set_value {
1015                            MetadataSetValue::Bool(v) => serde_json::to_value(v),
1016                            MetadataSetValue::Int(v) => serde_json::to_value(v),
1017                            MetadataSetValue::Float(v) => serde_json::to_value(v),
1018                            MetadataSetValue::Str(v) => serde_json::to_value(v),
1019                        }
1020                        .map_err(serde::ser::Error::custom)?;
1021                        inner_map.insert(op_key.to_string(), values_json);
1022                    }
1023                }
1024
1025                outer_map.serialize_entry(&meta.key, &inner_map)?;
1026                outer_map.end()
1027            }
1028        }
1029    }
1030}
1031
1032impl From<bool> for Where {
1033    fn from(value: bool) -> Self {
1034        if value {
1035            Where::conjunction(vec![])
1036        } else {
1037            Where::disjunction(vec![])
1038        }
1039    }
1040}
1041
1042impl Where {
1043    pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1044        // If children.len() == 0, we will return a conjunction that is always true.
1045        // If children.len() == 1, we will return the single child.
1046        // Otherwise, we will return a conjunction of the children.
1047
1048        let mut children: Vec<_> = children
1049            .into_iter()
1050            .flat_map(|expr| {
1051                if let Where::Composite(CompositeExpression {
1052                    operator: BooleanOperator::And,
1053                    children,
1054                }) = expr
1055                {
1056                    return children;
1057                }
1058                vec![expr]
1059            })
1060            .dedup()
1061            .collect();
1062
1063        if children.len() == 1 {
1064            return children.pop().expect("just checked len is 1");
1065        }
1066
1067        Self::Composite(CompositeExpression {
1068            operator: BooleanOperator::And,
1069            children,
1070        })
1071    }
1072    pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1073        // If children.len() == 0, we will return a disjunction that is always false.
1074        // If children.len() == 1, we will return the single child.
1075        // Otherwise, we will return a disjunction of the children.
1076
1077        let mut children: Vec<_> = children
1078            .into_iter()
1079            .flat_map(|expr| {
1080                if let Where::Composite(CompositeExpression {
1081                    operator: BooleanOperator::Or,
1082                    children,
1083                }) = expr
1084                {
1085                    return children;
1086                }
1087                vec![expr]
1088            })
1089            .dedup()
1090            .collect();
1091
1092        if children.len() == 1 {
1093            return children.pop().expect("just checked len is 1");
1094        }
1095
1096        Self::Composite(CompositeExpression {
1097            operator: BooleanOperator::Or,
1098            children,
1099        })
1100    }
1101
1102    pub fn fts_query_length(&self) -> u64 {
1103        match self {
1104            Where::Composite(composite_expression) => composite_expression
1105                .children
1106                .iter()
1107                .map(Where::fts_query_length)
1108                .sum(),
1109            // The query length is defined to be the number of trigram tokens
1110            Where::Document(document_expression) => {
1111                document_expression.pattern.len().max(3) as u64 - 2
1112            }
1113            Where::Metadata(_) => 0,
1114        }
1115    }
1116
1117    pub fn metadata_predicate_count(&self) -> u64 {
1118        match self {
1119            Where::Composite(composite_expression) => composite_expression
1120                .children
1121                .iter()
1122                .map(Where::metadata_predicate_count)
1123                .sum(),
1124            Where::Document(_) => 0,
1125            Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1126                MetadataComparison::Primitive(_, _) => 1,
1127                MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1128                    MetadataSetValue::Bool(items) => items.len() as u64,
1129                    MetadataSetValue::Int(items) => items.len() as u64,
1130                    MetadataSetValue::Float(items) => items.len() as u64,
1131                    MetadataSetValue::Str(items) => items.len() as u64,
1132                },
1133            },
1134        }
1135    }
1136}
1137
1138impl BitAnd for Where {
1139    type Output = Where;
1140
1141    fn bitand(self, rhs: Self) -> Self::Output {
1142        Self::conjunction([self, rhs])
1143    }
1144}
1145
1146impl BitOr for Where {
1147    type Output = Where;
1148
1149    fn bitor(self, rhs: Self) -> Self::Output {
1150        Self::disjunction([self, rhs])
1151    }
1152}
1153
1154impl TryFrom<chroma_proto::Where> for Where {
1155    type Error = WhereConversionError;
1156
1157    fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1158        let where_inner = proto_where
1159            .r#where
1160            .ok_or(WhereConversionError::cause("Invalid Where"))?;
1161        Ok(match where_inner {
1162            chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1163                Self::Metadata(direct_comparison.try_into()?)
1164            }
1165            chroma_proto::r#where::Where::Children(where_children) => {
1166                Self::Composite(where_children.try_into()?)
1167            }
1168            chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1169                Self::Document(direct_where_document.into())
1170            }
1171        })
1172    }
1173}
1174
1175impl TryFrom<Where> for chroma_proto::Where {
1176    type Error = WhereConversionError;
1177
1178    fn try_from(value: Where) -> Result<Self, Self::Error> {
1179        let proto_where = match value {
1180            Where::Composite(composite_expression) => {
1181                chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1182            }
1183            Where::Document(document_expression) => {
1184                chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1185            }
1186            Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1187                chroma_proto::DirectComparison::try_from(metadata_expression)
1188                    .map_err(|err| err.trace("MetadataExpression"))?,
1189            ),
1190        };
1191        Ok(Self {
1192            r#where: Some(proto_where),
1193        })
1194    }
1195}
1196
1197#[derive(Clone, Debug, PartialEq)]
1198#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1199pub struct CompositeExpression {
1200    pub operator: BooleanOperator,
1201    pub children: Vec<Where>,
1202}
1203
1204impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1205    type Error = WhereConversionError;
1206
1207    fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1208        let operator = proto_children.operator().into();
1209        let children = proto_children
1210            .children
1211            .into_iter()
1212            .map(Where::try_from)
1213            .collect::<Result<Vec<_>, _>>()
1214            .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1215        Ok(Self { operator, children })
1216    }
1217}
1218
1219impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1220    type Error = WhereConversionError;
1221
1222    fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1223        Ok(Self {
1224            operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1225            children: value
1226                .children
1227                .into_iter()
1228                .map(chroma_proto::Where::try_from)
1229                .collect::<Result<_, _>>()?,
1230        })
1231    }
1232}
1233
1234#[derive(Clone, Debug, PartialEq)]
1235#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1236pub enum BooleanOperator {
1237    And,
1238    Or,
1239}
1240
1241impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1242    fn from(value: chroma_proto::BooleanOperator) -> Self {
1243        match value {
1244            chroma_proto::BooleanOperator::And => Self::And,
1245            chroma_proto::BooleanOperator::Or => Self::Or,
1246        }
1247    }
1248}
1249
1250impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1251    fn from(value: BooleanOperator) -> Self {
1252        match value {
1253            BooleanOperator::And => Self::And,
1254            BooleanOperator::Or => Self::Or,
1255        }
1256    }
1257}
1258
1259#[derive(Clone, Debug, PartialEq)]
1260#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1261pub struct DocumentExpression {
1262    pub operator: DocumentOperator,
1263    pub pattern: String,
1264}
1265
1266impl std::fmt::Display for DocumentExpression {
1267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1268        let op_str = match self.operator {
1269            DocumentOperator::Contains => "CONTAINS",
1270            DocumentOperator::NotContains => "NOT CONTAINS",
1271            DocumentOperator::Regex => "REGEX",
1272            DocumentOperator::NotRegex => "NOT REGEX",
1273        };
1274        write!(f, "#document {} \"{}\"", op_str, self.pattern)
1275    }
1276}
1277
1278impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1279    fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1280        Self {
1281            operator: value.operator().into(),
1282            pattern: value.pattern,
1283        }
1284    }
1285}
1286
1287impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1288    fn from(value: DocumentExpression) -> Self {
1289        Self {
1290            pattern: value.pattern,
1291            operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1292        }
1293    }
1294}
1295
1296#[derive(Clone, Debug, PartialEq)]
1297#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1298pub enum DocumentOperator {
1299    Contains,
1300    NotContains,
1301    Regex,
1302    NotRegex,
1303}
1304impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1305    fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1306        match value {
1307            chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1308            chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1309            chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1310            chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1311        }
1312    }
1313}
1314
1315impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1316    fn from(value: DocumentOperator) -> Self {
1317        match value {
1318            DocumentOperator::Contains => Self::Contains,
1319            DocumentOperator::NotContains => Self::NotContains,
1320            DocumentOperator::Regex => Self::Regex,
1321            DocumentOperator::NotRegex => Self::NotRegex,
1322        }
1323    }
1324}
1325
1326#[derive(Clone, Debug, PartialEq)]
1327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1328pub struct MetadataExpression {
1329    pub key: String,
1330    pub comparison: MetadataComparison,
1331}
1332
1333impl std::fmt::Display for MetadataExpression {
1334    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1335        match &self.comparison {
1336            MetadataComparison::Primitive(op, value) => {
1337                write!(f, "{} {} {}", self.key, op, value)
1338            }
1339            MetadataComparison::Set(op, set_value) => {
1340                write!(f, "{} {} {}", self.key, op, set_value)
1341            }
1342        }
1343    }
1344}
1345
1346impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1347    type Error = WhereConversionError;
1348
1349    fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1350        let proto_comparison = value
1351            .comparison
1352            .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1353        let comparison = match proto_comparison {
1354            chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1355                single_string_comparison,
1356            ) => MetadataComparison::Primitive(
1357                single_string_comparison.comparator().into(),
1358                MetadataValue::Str(single_string_comparison.value),
1359            ),
1360            chroma_proto::direct_comparison::Comparison::StringListOperand(
1361                string_list_comparison,
1362            ) => MetadataComparison::Set(
1363                string_list_comparison.list_operator().into(),
1364                MetadataSetValue::Str(string_list_comparison.values),
1365            ),
1366            chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1367                single_int_comparison,
1368            ) => MetadataComparison::Primitive(
1369                match single_int_comparison
1370                    .comparator
1371                    .ok_or(WhereConversionError::cause(
1372                        "Invalid scalar integer operator",
1373                    ))? {
1374                    chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1375                        chroma_proto::GenericComparator::try_from(op)
1376                            .map_err(WhereConversionError::cause)?
1377                            .into()
1378                    }
1379                    chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1380                        chroma_proto::NumberComparator::try_from(op)
1381                            .map_err(WhereConversionError::cause)?
1382                            .into()
1383                    }
1384                },
1385                MetadataValue::Int(single_int_comparison.value),
1386            ),
1387            chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1388                MetadataComparison::Set(
1389                    int_list_comparison.list_operator().into(),
1390                    MetadataSetValue::Int(int_list_comparison.values),
1391                )
1392            }
1393            chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1394                single_double_comparison,
1395            ) => MetadataComparison::Primitive(
1396                match single_double_comparison
1397                    .comparator
1398                    .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1399                {
1400                    chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1401                        chroma_proto::GenericComparator::try_from(op)
1402                            .map_err(WhereConversionError::cause)?
1403                            .into()
1404                    }
1405                    chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1406                        chroma_proto::NumberComparator::try_from(op)
1407                            .map_err(WhereConversionError::cause)?
1408                            .into()
1409                    }
1410                },
1411                MetadataValue::Float(single_double_comparison.value),
1412            ),
1413            chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1414                double_list_comparison,
1415            ) => MetadataComparison::Set(
1416                double_list_comparison.list_operator().into(),
1417                MetadataSetValue::Float(double_list_comparison.values),
1418            ),
1419            chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1420                MetadataComparison::Set(
1421                    bool_list_comparison.list_operator().into(),
1422                    MetadataSetValue::Bool(bool_list_comparison.values),
1423                )
1424            }
1425            chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1426                single_bool_comparison,
1427            ) => MetadataComparison::Primitive(
1428                single_bool_comparison.comparator().into(),
1429                MetadataValue::Bool(single_bool_comparison.value),
1430            ),
1431        };
1432        Ok(Self {
1433            key: value.key,
1434            comparison,
1435        })
1436    }
1437}
1438
1439impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1440    type Error = WhereConversionError;
1441
1442    fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1443        let comparison = match value.comparison {
1444            MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1445                MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1446                MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1447                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1448                                numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1449                            }),
1450                MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1451                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1452                                numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1453                            }),
1454                MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1455                MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1456            },
1457            MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1458                MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1459                MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1460                MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1461                MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1462            },
1463        };
1464        Ok(Self {
1465            key: value.key,
1466            comparison: Some(comparison),
1467        })
1468    }
1469}
1470
1471#[derive(Clone, Debug, PartialEq)]
1472#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1473pub enum MetadataComparison {
1474    Primitive(PrimitiveOperator, MetadataValue),
1475    Set(SetOperator, MetadataSetValue),
1476}
1477
1478#[derive(Clone, Debug, PartialEq)]
1479#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1480pub enum PrimitiveOperator {
1481    Equal,
1482    NotEqual,
1483    GreaterThan,
1484    GreaterThanOrEqual,
1485    LessThan,
1486    LessThanOrEqual,
1487}
1488
1489impl std::fmt::Display for PrimitiveOperator {
1490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1491        let op_str = match self {
1492            PrimitiveOperator::Equal => "=",
1493            PrimitiveOperator::NotEqual => "≠",
1494            PrimitiveOperator::GreaterThan => ">",
1495            PrimitiveOperator::GreaterThanOrEqual => "≥",
1496            PrimitiveOperator::LessThan => "<",
1497            PrimitiveOperator::LessThanOrEqual => "≤",
1498        };
1499        write!(f, "{}", op_str)
1500    }
1501}
1502
1503impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1504    fn from(value: chroma_proto::GenericComparator) -> Self {
1505        match value {
1506            chroma_proto::GenericComparator::Eq => Self::Equal,
1507            chroma_proto::GenericComparator::Ne => Self::NotEqual,
1508        }
1509    }
1510}
1511
1512impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1513    type Error = WhereConversionError;
1514
1515    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1516        match value {
1517            PrimitiveOperator::Equal => Ok(Self::Eq),
1518            PrimitiveOperator::NotEqual => Ok(Self::Ne),
1519            op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1520        }
1521    }
1522}
1523
1524impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1525    fn from(value: chroma_proto::NumberComparator) -> Self {
1526        match value {
1527            chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1528            chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1529            chroma_proto::NumberComparator::Lt => Self::LessThan,
1530            chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1531        }
1532    }
1533}
1534
1535impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1536    type Error = WhereConversionError;
1537
1538    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1539        match value {
1540            PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1541            PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1542            PrimitiveOperator::LessThan => Ok(Self::Lt),
1543            PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1544            op => Err(WhereConversionError::cause(format!(
1545                "{op:?} ∉ [≤, <, >, ≥]"
1546            ))),
1547        }
1548    }
1549}
1550
1551#[derive(Clone, Debug, PartialEq, Eq)]
1552#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1553pub enum SetOperator {
1554    In,
1555    NotIn,
1556}
1557
1558impl std::fmt::Display for SetOperator {
1559    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1560        let op_str = match self {
1561            SetOperator::In => "∈",
1562            SetOperator::NotIn => "∉",
1563        };
1564        write!(f, "{}", op_str)
1565    }
1566}
1567
1568impl From<chroma_proto::ListOperator> for SetOperator {
1569    fn from(value: chroma_proto::ListOperator) -> Self {
1570        match value {
1571            chroma_proto::ListOperator::In => Self::In,
1572            chroma_proto::ListOperator::Nin => Self::NotIn,
1573        }
1574    }
1575}
1576
1577impl From<SetOperator> for chroma_proto::ListOperator {
1578    fn from(value: SetOperator) -> Self {
1579        match value {
1580            SetOperator::In => Self::In,
1581            SetOperator::NotIn => Self::Nin,
1582        }
1583    }
1584}
1585
1586#[derive(Clone, Debug, PartialEq)]
1587#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1588pub enum MetadataSetValue {
1589    Bool(Vec<bool>),
1590    Int(Vec<i64>),
1591    Float(Vec<f64>),
1592    Str(Vec<String>),
1593}
1594
1595impl std::fmt::Display for MetadataSetValue {
1596    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1597        match self {
1598            MetadataSetValue::Bool(values) => {
1599                let values_str = values
1600                    .iter()
1601                    .map(|v| format!("\"{}\"", v))
1602                    .collect::<Vec<_>>()
1603                    .join(", ");
1604                write!(f, "[{}]", values_str)
1605            }
1606            MetadataSetValue::Int(values) => {
1607                let values_str = values
1608                    .iter()
1609                    .map(|v| v.to_string())
1610                    .collect::<Vec<_>>()
1611                    .join(", ");
1612                write!(f, "[{}]", values_str)
1613            }
1614            MetadataSetValue::Float(values) => {
1615                let values_str = values
1616                    .iter()
1617                    .map(|v| v.to_string())
1618                    .collect::<Vec<_>>()
1619                    .join(", ");
1620                write!(f, "[{}]", values_str)
1621            }
1622            MetadataSetValue::Str(values) => {
1623                let values_str = values
1624                    .iter()
1625                    .map(|v| format!("\"{}\"", v))
1626                    .collect::<Vec<_>>()
1627                    .join(", ");
1628                write!(f, "[{}]", values_str)
1629            }
1630        }
1631    }
1632}
1633
1634impl MetadataSetValue {
1635    pub fn value_type(&self) -> MetadataValueType {
1636        match self {
1637            MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1638            MetadataSetValue::Int(_) => MetadataValueType::Int,
1639            MetadataSetValue::Float(_) => MetadataValueType::Float,
1640            MetadataSetValue::Str(_) => MetadataValueType::Str,
1641        }
1642    }
1643}
1644
1645impl From<Vec<bool>> for MetadataSetValue {
1646    fn from(values: Vec<bool>) -> Self {
1647        MetadataSetValue::Bool(values)
1648    }
1649}
1650
1651impl From<Vec<i64>> for MetadataSetValue {
1652    fn from(values: Vec<i64>) -> Self {
1653        MetadataSetValue::Int(values)
1654    }
1655}
1656
1657impl From<Vec<i32>> for MetadataSetValue {
1658    fn from(values: Vec<i32>) -> Self {
1659        MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1660    }
1661}
1662
1663impl From<Vec<f64>> for MetadataSetValue {
1664    fn from(values: Vec<f64>) -> Self {
1665        MetadataSetValue::Float(values)
1666    }
1667}
1668
1669impl From<Vec<f32>> for MetadataSetValue {
1670    fn from(values: Vec<f32>) -> Self {
1671        MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1672    }
1673}
1674
1675impl From<Vec<String>> for MetadataSetValue {
1676    fn from(values: Vec<String>) -> Self {
1677        MetadataSetValue::Str(values)
1678    }
1679}
1680
1681impl From<Vec<&str>> for MetadataSetValue {
1682    fn from(values: Vec<&str>) -> Self {
1683        MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1684    }
1685}
1686
1687// TODO: Deprecate where_document
1688impl TryFrom<chroma_proto::WhereDocument> for Where {
1689    type Error = WhereConversionError;
1690
1691    fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1692        match proto_document.r#where_document {
1693            Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1694                let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1695                    proto_comparison.operator,
1696                ) {
1697                    Ok(operator) => operator,
1698                    Err(_) => {
1699                        return Err(WhereConversionError::cause(
1700                            "[Deprecated] Invalid where document operator",
1701                        ))
1702                    }
1703                };
1704                let comparison = DocumentExpression {
1705                    pattern: proto_comparison.pattern,
1706                    operator: operator.into(),
1707                };
1708                Ok(Where::Document(comparison))
1709            }
1710            Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1711                let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1712                    proto_children.operator,
1713                ) {
1714                    Ok(operator) => operator,
1715                    Err(_) => {
1716                        return Err(WhereConversionError::cause(
1717                            "[Deprecated] Invalid boolean operator",
1718                        ))
1719                    }
1720                };
1721                let children = CompositeExpression {
1722                    children: proto_children
1723                        .children
1724                        .into_iter()
1725                        .map(|child| child.try_into())
1726                        .collect::<Result<_, _>>()?,
1727                    operator: operator.into(),
1728                };
1729                Ok(Where::Composite(children))
1730            }
1731            None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1732        }
1733    }
1734}
1735
1736#[cfg(test)]
1737mod tests {
1738    use crate::operator::Key;
1739
1740    use super::*;
1741
1742    #[test]
1743    fn test_update_metadata_try_from() {
1744        let mut proto_metadata = chroma_proto::UpdateMetadata {
1745            metadata: HashMap::new(),
1746        };
1747        proto_metadata.metadata.insert(
1748            "foo".to_string(),
1749            chroma_proto::UpdateMetadataValue {
1750                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1751            },
1752        );
1753        proto_metadata.metadata.insert(
1754            "bar".to_string(),
1755            chroma_proto::UpdateMetadataValue {
1756                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1757            },
1758        );
1759        proto_metadata.metadata.insert(
1760            "baz".to_string(),
1761            chroma_proto::UpdateMetadataValue {
1762                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1763                    "42".to_string(),
1764                )),
1765            },
1766        );
1767        // Add sparse vector test
1768        proto_metadata.metadata.insert(
1769            "sparse".to_string(),
1770            chroma_proto::UpdateMetadataValue {
1771                value: Some(
1772                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1773                        chroma_proto::SparseVector {
1774                            indices: vec![0, 5, 10],
1775                            values: vec![0.1, 0.5, 0.9],
1776                        },
1777                    ),
1778                ),
1779            },
1780        );
1781        let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1782        assert_eq!(converted_metadata.len(), 4);
1783        assert_eq!(
1784            converted_metadata.get("foo").unwrap(),
1785            &UpdateMetadataValue::Int(42)
1786        );
1787        assert_eq!(
1788            converted_metadata.get("bar").unwrap(),
1789            &UpdateMetadataValue::Float(42.0)
1790        );
1791        assert_eq!(
1792            converted_metadata.get("baz").unwrap(),
1793            &UpdateMetadataValue::Str("42".to_string())
1794        );
1795        assert_eq!(
1796            converted_metadata.get("sparse").unwrap(),
1797            &UpdateMetadataValue::SparseVector(SparseVector::new(
1798                vec![0, 5, 10],
1799                vec![0.1, 0.5, 0.9]
1800            ))
1801        );
1802    }
1803
1804    #[test]
1805    fn test_metadata_try_from() {
1806        let mut proto_metadata = chroma_proto::UpdateMetadata {
1807            metadata: HashMap::new(),
1808        };
1809        proto_metadata.metadata.insert(
1810            "foo".to_string(),
1811            chroma_proto::UpdateMetadataValue {
1812                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1813            },
1814        );
1815        proto_metadata.metadata.insert(
1816            "bar".to_string(),
1817            chroma_proto::UpdateMetadataValue {
1818                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1819            },
1820        );
1821        proto_metadata.metadata.insert(
1822            "baz".to_string(),
1823            chroma_proto::UpdateMetadataValue {
1824                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1825                    "42".to_string(),
1826                )),
1827            },
1828        );
1829        // Add sparse vector test
1830        proto_metadata.metadata.insert(
1831            "sparse".to_string(),
1832            chroma_proto::UpdateMetadataValue {
1833                value: Some(
1834                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1835                        chroma_proto::SparseVector {
1836                            indices: vec![1, 10, 100],
1837                            values: vec![0.2, 0.4, 0.6],
1838                        },
1839                    ),
1840                ),
1841            },
1842        );
1843        let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1844        assert_eq!(converted_metadata.len(), 4);
1845        assert_eq!(
1846            converted_metadata.get("foo").unwrap(),
1847            &MetadataValue::Int(42)
1848        );
1849        assert_eq!(
1850            converted_metadata.get("bar").unwrap(),
1851            &MetadataValue::Float(42.0)
1852        );
1853        assert_eq!(
1854            converted_metadata.get("baz").unwrap(),
1855            &MetadataValue::Str("42".to_string())
1856        );
1857        assert_eq!(
1858            converted_metadata.get("sparse").unwrap(),
1859            &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]))
1860        );
1861    }
1862
1863    #[test]
1864    fn test_where_clause_simple_from() {
1865        let proto_where = chroma_proto::Where {
1866            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1867                chroma_proto::DirectComparison {
1868                    key: "foo".to_string(),
1869                    comparison: Some(
1870                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1871                            chroma_proto::SingleIntComparison {
1872                                value: 42,
1873                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1874                            },
1875                        ),
1876                    ),
1877                },
1878            )),
1879        };
1880        let where_clause: Where = proto_where.try_into().unwrap();
1881        match where_clause {
1882            Where::Metadata(comparison) => {
1883                assert_eq!(comparison.key, "foo");
1884                match comparison.comparison {
1885                    MetadataComparison::Primitive(_, value) => {
1886                        assert_eq!(value, MetadataValue::Int(42));
1887                    }
1888                    _ => panic!("Invalid comparison type"),
1889                }
1890            }
1891            _ => panic!("Invalid where type"),
1892        }
1893    }
1894
1895    #[test]
1896    fn test_where_clause_with_children() {
1897        let proto_where = chroma_proto::Where {
1898            r#where: Some(chroma_proto::r#where::Where::Children(
1899                chroma_proto::WhereChildren {
1900                    children: vec![
1901                        chroma_proto::Where {
1902                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1903                                chroma_proto::DirectComparison {
1904                                    key: "foo".to_string(),
1905                                    comparison: Some(
1906                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1907                                            chroma_proto::SingleIntComparison {
1908                                                value: 42,
1909                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1910                                            },
1911                                        ),
1912                                    ),
1913                                },
1914                            )),
1915                        },
1916                        chroma_proto::Where {
1917                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1918                                chroma_proto::DirectComparison {
1919                                    key: "bar".to_string(),
1920                                    comparison: Some(
1921                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1922                                            chroma_proto::SingleIntComparison {
1923                                                value: 42,
1924                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1925                                            },
1926                                        ),
1927                                    ),
1928                                },
1929                            )),
1930                        },
1931                    ],
1932                    operator: chroma_proto::BooleanOperator::And.into(),
1933                },
1934            )),
1935        };
1936        let where_clause: Where = proto_where.try_into().unwrap();
1937        match where_clause {
1938            Where::Composite(children) => {
1939                assert_eq!(children.children.len(), 2);
1940                assert_eq!(children.operator, BooleanOperator::And);
1941            }
1942            _ => panic!("Invalid where type"),
1943        }
1944    }
1945
1946    #[test]
1947    fn test_where_document_simple() {
1948        let proto_where = chroma_proto::WhereDocument {
1949            r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
1950                chroma_proto::DirectWhereDocument {
1951                    pattern: "foo".to_string(),
1952                    operator: chroma_proto::WhereDocumentOperator::Contains.into(),
1953                },
1954            )),
1955        };
1956        let where_document: Where = proto_where.try_into().unwrap();
1957        match where_document {
1958            Where::Document(comparison) => {
1959                assert_eq!(comparison.pattern, "foo");
1960                assert_eq!(comparison.operator, DocumentOperator::Contains);
1961            }
1962            _ => panic!("Invalid where document type"),
1963        }
1964    }
1965
1966    #[test]
1967    fn test_where_document_with_children() {
1968        let proto_where = chroma_proto::WhereDocument {
1969            r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
1970                chroma_proto::WhereDocumentChildren {
1971                    children: vec![
1972                        chroma_proto::WhereDocument {
1973                            r#where_document: Some(
1974                                chroma_proto::where_document::WhereDocument::Direct(
1975                                    chroma_proto::DirectWhereDocument {
1976                                        pattern: "foo".to_string(),
1977                                        operator: chroma_proto::WhereDocumentOperator::Contains
1978                                            .into(),
1979                                    },
1980                                ),
1981                            ),
1982                        },
1983                        chroma_proto::WhereDocument {
1984                            r#where_document: Some(
1985                                chroma_proto::where_document::WhereDocument::Direct(
1986                                    chroma_proto::DirectWhereDocument {
1987                                        pattern: "bar".to_string(),
1988                                        operator: chroma_proto::WhereDocumentOperator::Contains
1989                                            .into(),
1990                                    },
1991                                ),
1992                            ),
1993                        },
1994                    ],
1995                    operator: chroma_proto::BooleanOperator::And.into(),
1996                },
1997            )),
1998        };
1999        let where_document: Where = proto_where.try_into().unwrap();
2000        match where_document {
2001            Where::Composite(children) => {
2002                assert_eq!(children.children.len(), 2);
2003                assert_eq!(children.operator, BooleanOperator::And);
2004            }
2005            _ => panic!("Invalid where document type"),
2006        }
2007    }
2008
2009    #[test]
2010    fn test_sparse_vector_new() {
2011        let indices = vec![0, 5, 10];
2012        let values = vec![0.1, 0.5, 0.9];
2013        let sparse = SparseVector::new(indices.clone(), values.clone());
2014        assert_eq!(sparse.indices, indices);
2015        assert_eq!(sparse.values, values);
2016    }
2017
2018    #[test]
2019    fn test_sparse_vector_from_pairs() {
2020        let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2021        let sparse = SparseVector::from_pairs(pairs.clone());
2022        assert_eq!(sparse.indices, vec![0, 5, 10]);
2023        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2024    }
2025
2026    #[test]
2027    fn test_sparse_vector_iter() {
2028        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
2029        let collected: Vec<(u32, f32)> = sparse.iter().collect();
2030        assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2031    }
2032
2033    #[test]
2034    fn test_sparse_vector_ordering() {
2035        let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
2036        let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
2037        let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]);
2038        let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]);
2039
2040        assert_eq!(sparse1, sparse2);
2041        assert!(sparse1 < sparse3);
2042        assert!(sparse1 < sparse4);
2043    }
2044
2045    #[test]
2046    fn test_sparse_vector_proto_conversion() {
2047        let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]);
2048        let proto: chroma_proto::SparseVector = sparse.clone().into();
2049        assert_eq!(proto.indices, vec![1, 10, 100]);
2050        assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2051
2052        let converted: SparseVector = proto.into();
2053        assert_eq!(converted, sparse);
2054    }
2055
2056    #[test]
2057    fn test_sparse_vector_logical_size() {
2058        let metadata = Metadata::from([(
2059            "sparse".to_string(),
2060            MetadataValue::SparseVector(SparseVector::new(
2061                vec![0, 1, 2, 3, 4],
2062                vec![0.1, 0.2, 0.3, 0.4, 0.5],
2063            )),
2064        )]);
2065
2066        let size = logical_size_of_metadata(&metadata);
2067        // Size should include the key string length and the sparse vector data
2068        // "sparse" = 6 bytes + 5 * 4 bytes (u32 indices) + 5 * 4 bytes (f32 values) = 46 bytes
2069        assert_eq!(size, 46);
2070    }
2071
2072    #[test]
2073    fn test_sparse_vector_validation() {
2074        // Valid sparse vector
2075        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]);
2076        assert!(sparse.validate().is_ok());
2077
2078        // Length mismatch
2079        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2080        let result = sparse.validate();
2081        assert!(result.is_err());
2082        assert!(matches!(
2083            result.unwrap_err(),
2084            MetadataValueConversionError::SparseVectorLengthMismatch
2085        ));
2086
2087        // Unsorted indices (descending order)
2088        let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]);
2089        let result = sparse.validate();
2090        assert!(result.is_err());
2091        assert!(matches!(
2092            result.unwrap_err(),
2093            MetadataValueConversionError::SparseVectorIndicesNotSorted
2094        ));
2095
2096        // Duplicate indices (not strictly ascending)
2097        let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]);
2098        let result = sparse.validate();
2099        assert!(result.is_err());
2100        assert!(matches!(
2101            result.unwrap_err(),
2102            MetadataValueConversionError::SparseVectorIndicesNotSorted
2103        ));
2104
2105        // Descending at one point
2106        let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]);
2107        let result = sparse.validate();
2108        assert!(result.is_err());
2109        assert!(matches!(
2110            result.unwrap_err(),
2111            MetadataValueConversionError::SparseVectorIndicesNotSorted
2112        ));
2113    }
2114
2115    #[test]
2116    fn test_sparse_vector_deserialize_old_format() {
2117        // Old format without #type field (backward compatibility)
2118        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2119        let sv: SparseVector = serde_json::from_str(json).unwrap();
2120        assert_eq!(sv.indices, vec![0, 1, 2]);
2121        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2122    }
2123
2124    #[test]
2125    fn test_sparse_vector_deserialize_new_format() {
2126        // New format with #type field
2127        let json =
2128            "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2129        let sv: SparseVector = serde_json::from_str(json).unwrap();
2130        assert_eq!(sv.indices, vec![0, 1, 2]);
2131        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2132    }
2133
2134    #[test]
2135    fn test_sparse_vector_deserialize_new_format_field_order() {
2136        // New format with different field order (should still work)
2137        let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2138        let sv: SparseVector = serde_json::from_str(json).unwrap();
2139        assert_eq!(sv.indices, vec![5, 10]);
2140        assert_eq!(sv.values, vec![0.5, 1.0]);
2141    }
2142
2143    #[test]
2144    fn test_sparse_vector_deserialize_wrong_type_tag() {
2145        // Wrong #type field value should fail
2146        let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2147        let result: Result<SparseVector, _> = serde_json::from_str(json);
2148        assert!(result.is_err());
2149        let err_msg = result.unwrap_err().to_string();
2150        assert!(err_msg.contains("sparse_vector"));
2151    }
2152
2153    #[test]
2154    fn test_sparse_vector_serialize_always_has_type() {
2155        // Serialization should always include #type field
2156        let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
2157        let json = serde_json::to_value(&sv).unwrap();
2158
2159        assert_eq!(json["#type"], "sparse_vector");
2160        assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2161        assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2162    }
2163
2164    #[test]
2165    fn test_sparse_vector_roundtrip_with_type() {
2166        // Test that serialize -> deserialize preserves the data
2167        let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]);
2168        let json = serde_json::to_string(&original).unwrap();
2169
2170        // Verify the serialized JSON contains #type
2171        assert!(json.contains("\"#type\":\"sparse_vector\""));
2172
2173        let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2174        assert_eq!(original, deserialized);
2175    }
2176
2177    #[test]
2178    fn test_sparse_vector_in_metadata_old_format() {
2179        // Test that old format works when sparse vector is in metadata
2180        let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2181        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2182
2183        let sparse_value = &map["sparse"];
2184        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2185        assert_eq!(sv.indices, vec![0, 1]);
2186        assert_eq!(sv.values, vec![1.0, 2.0]);
2187    }
2188
2189    #[test]
2190    fn test_sparse_vector_in_metadata_new_format() {
2191        // Test that new format works when sparse vector is in metadata
2192        let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2193        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2194
2195        let sparse_value = &map["sparse"];
2196        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2197        assert_eq!(sv.indices, vec![0, 1]);
2198        assert_eq!(sv.values, vec![1.0, 2.0]);
2199    }
2200
2201    #[test]
2202    fn test_simplifies_identities() {
2203        let all: Where = true.into();
2204        assert_eq!(all.clone() & all.clone(), true.into());
2205        assert_eq!(all.clone() | all.clone(), true.into());
2206
2207        let foo = Key::field("foo").eq("bar");
2208        assert_eq!(foo.clone() & all.clone(), foo.clone());
2209        assert_eq!(all.clone() & foo.clone(), foo.clone());
2210
2211        let none: Where = false.into();
2212        assert_eq!(foo.clone() | none.clone(), foo.clone());
2213        assert_eq!(none | foo.clone(), foo);
2214    }
2215
2216    #[test]
2217    fn test_flattens() {
2218        let foo = Key::field("foo").eq("bar");
2219        let baz = Key::field("baz").eq("quux");
2220
2221        let and_nested = foo.clone() & (baz.clone() & foo.clone());
2222        assert_eq!(
2223            and_nested,
2224            Where::Composite(CompositeExpression {
2225                operator: BooleanOperator::And,
2226                children: vec![foo.clone(), baz.clone(), foo.clone()]
2227            })
2228        );
2229
2230        let or_nested = foo.clone() | (baz.clone() | foo.clone());
2231        assert_eq!(
2232            or_nested,
2233            Where::Composite(CompositeExpression {
2234                operator: BooleanOperator::Or,
2235                children: vec![foo.clone(), baz.clone(), foo.clone()]
2236            })
2237        );
2238    }
2239}