Skip to main content

chroma_types/
metadata.rs

1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7    cmp::Ordering,
8    collections::{HashMap, HashSet},
9    mem::size_of_val,
10    ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::{PyAnyMethods, PyDictMethods};
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24    #[serde(rename = "#type")]
25    type_tag: Option<String>,
26    indices: Vec<u32>,
27    values: Vec<f32>,
28    tokens: Option<Vec<String>>,
29}
30
31/// Represents a sparse vector using parallel arrays for indices and values.
32///
33/// On deserialization: accepts both old format `{"indices": [...], "values": [...]}`
34/// and new format `{"#type": "sparse_vector", "indices": [...], "values": [...]}`.
35///
36/// On serialization: always includes `#type` field with value `"sparse_vector"`.
37#[derive(Clone, Debug, PartialEq)]
38#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
39pub struct SparseVector {
40    /// Dimension indices
41    pub indices: Vec<u32>,
42    /// Values corresponding to each index
43    pub values: Vec<f32>,
44    /// Tokens corresponding to each index
45    pub tokens: Option<Vec<String>>,
46}
47
48// Custom deserializer: accept both old and new formats
49impl<'de> Deserialize<'de> for SparseVector {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: Deserializer<'de>,
53    {
54        let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
55
56        // If #type is present, validate it
57        if let Some(type_tag) = &helper.type_tag {
58            if type_tag != "sparse_vector" {
59                return Err(serde::de::Error::custom(format!(
60                    "Expected #type='sparse_vector', got '{}'",
61                    type_tag
62                )));
63            }
64        }
65
66        Ok(SparseVector {
67            indices: helper.indices,
68            values: helper.values,
69            tokens: helper.tokens,
70        })
71    }
72}
73
74// Custom serializer: always include #type field
75impl Serialize for SparseVector {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        let helper = SparseVectorSerdeHelper {
81            type_tag: Some("sparse_vector".to_string()),
82            indices: self.indices.clone(),
83            values: self.values.clone(),
84            tokens: self.tokens.clone(),
85        };
86        helper.serialize(serializer)
87    }
88}
89
90/// Length mismatch between indices, values, and tokens in a sparse vector.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct SparseVectorLengthMismatch;
93
94impl std::fmt::Display for SparseVectorLengthMismatch {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(
97            f,
98            "Sparse vector indices, values, and tokens (when present) must have the same length"
99        )
100    }
101}
102
103impl std::error::Error for SparseVectorLengthMismatch {}
104
105impl ChromaError for SparseVectorLengthMismatch {
106    fn code(&self) -> ErrorCodes {
107        ErrorCodes::InvalidArgument
108    }
109}
110
111impl SparseVector {
112    /// Create a new sparse vector from parallel arrays.
113    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseVectorLengthMismatch> {
114        if indices.len() != values.len() {
115            return Err(SparseVectorLengthMismatch);
116        }
117        Ok(Self {
118            indices,
119            values,
120            tokens: None,
121        })
122    }
123
124    /// Create a new sparse vector from parallel arrays.
125    pub fn new_with_tokens(
126        indices: Vec<u32>,
127        values: Vec<f32>,
128        tokens: Vec<String>,
129    ) -> Result<Self, SparseVectorLengthMismatch> {
130        if indices.len() != values.len() {
131            return Err(SparseVectorLengthMismatch);
132        }
133        if tokens.len() != indices.len() {
134            return Err(SparseVectorLengthMismatch);
135        }
136        Ok(Self {
137            indices,
138            values,
139            tokens: Some(tokens),
140        })
141    }
142
143    /// Create a sparse vector from an iterator of (index, value) pairs.
144    pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
145        let mut indices = vec![];
146        let mut values = vec![];
147        for (index, value) in pairs {
148            indices.push(index);
149            values.push(value);
150        }
151        let tokens = None;
152        Self {
153            indices,
154            values,
155            tokens,
156        }
157    }
158
159    /// Create a sparse vector from an iterator of (string, index, value) pairs.
160    pub fn from_triples(triples: impl IntoIterator<Item = (String, u32, f32)>) -> Self {
161        let mut tokens = vec![];
162        let mut indices = vec![];
163        let mut values = vec![];
164        for (token, index, value) in triples {
165            tokens.push(token);
166            indices.push(index);
167            values.push(value);
168        }
169        let tokens = Some(tokens);
170        Self {
171            indices,
172            values,
173            tokens,
174        }
175    }
176
177    /// Iterate over (index, value) pairs.
178    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
179        self.indices
180            .iter()
181            .copied()
182            .zip(self.values.iter().copied())
183    }
184
185    /// Validate the sparse vector
186    pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
187        // Check that indices and values have the same length
188        if self.indices.len() != self.values.len() {
189            return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
190        }
191
192        // Check that tokens (if present) align with indices
193        if let Some(tokens) = self.tokens.as_ref() {
194            if tokens.len() != self.indices.len() {
195                return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
196            }
197        }
198
199        // Check that indices are sorted in strictly ascending order (no duplicates)
200        for i in 1..self.indices.len() {
201            if self.indices[i] <= self.indices[i - 1] {
202                return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
203            }
204        }
205
206        Ok(())
207    }
208}
209
210impl Eq for SparseVector {}
211
212impl Ord for SparseVector {
213    fn cmp(&self, other: &Self) -> Ordering {
214        self.indices.cmp(&other.indices).then_with(|| {
215            for (a, b) in self.values.iter().zip(other.values.iter()) {
216                match a.total_cmp(b) {
217                    Ordering::Equal => continue,
218                    other => return other,
219                }
220            }
221            self.values.len().cmp(&other.values.len())
222        })
223    }
224}
225
226impl PartialOrd for SparseVector {
227    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
228        Some(self.cmp(other))
229    }
230}
231
232impl TryFrom<chroma_proto::SparseVector> for SparseVector {
233    type Error = SparseVectorLengthMismatch;
234
235    fn try_from(proto: chroma_proto::SparseVector) -> Result<Self, Self::Error> {
236        if proto.tokens.is_empty() {
237            SparseVector::new(proto.indices, proto.values)
238        } else {
239            SparseVector::new_with_tokens(proto.indices, proto.values, proto.tokens)
240        }
241    }
242}
243
244impl From<SparseVector> for chroma_proto::SparseVector {
245    fn from(sparse: SparseVector) -> Self {
246        chroma_proto::SparseVector {
247            indices: sparse.indices,
248            values: sparse.values,
249            tokens: sparse.tokens.unwrap_or_default(),
250        }
251    }
252}
253
254/// Convert SparseVector to sprs::CsVec for efficient sparse operations
255impl From<&SparseVector> for CsVec<f32> {
256    fn from(sparse: &SparseVector) -> Self {
257        let (indices, values) = sparse
258            .iter()
259            .map(|(index, value)| (index as usize, value))
260            .unzip();
261        CsVec::new(u32::MAX as usize, indices, values)
262    }
263}
264
265impl From<SparseVector> for CsVec<f32> {
266    fn from(sparse: SparseVector) -> Self {
267        (&sparse).into()
268    }
269}
270
271#[cfg(feature = "pyo3")]
272impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
273    type Target = pyo3::PyAny;
274    type Output = pyo3::Bound<'py, Self::Target>;
275    type Error = pyo3::PyErr;
276
277    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
278        use pyo3::types::PyDict;
279
280        let dict = PyDict::new(py);
281        dict.set_item("indices", self.indices)?;
282        dict.set_item("values", self.values)?;
283        dict.set_item("tokens", self.tokens)?;
284        Ok(dict.into_any())
285    }
286}
287
288#[cfg(feature = "pyo3")]
289impl<'py> pyo3::FromPyObject<'py> for SparseVector {
290    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
291        use pyo3::types::PyDict;
292
293        let dict = ob.downcast::<PyDict>()?;
294        let indices_obj = dict.get_item("indices")?;
295        if indices_obj.is_none() {
296            return Err(pyo3::exceptions::PyKeyError::new_err(
297                "missing 'indices' key",
298            ));
299        }
300        let indices: Vec<u32> = indices_obj.unwrap().extract()?;
301
302        let values_obj = dict.get_item("values")?;
303        if values_obj.is_none() {
304            return Err(pyo3::exceptions::PyKeyError::new_err(
305                "missing 'values' key",
306            ));
307        }
308        let values: Vec<f32> = values_obj.unwrap().extract()?;
309
310        let tokens_obj = dict.get_item("tokens")?;
311        let tokens = match tokens_obj {
312            Some(obj) if obj.is_none() => None,
313            Some(obj) => Some(obj.extract::<Vec<String>>()?),
314            None => None,
315        };
316
317        let result = match tokens {
318            Some(tokens) => SparseVector::new_with_tokens(indices, values, tokens),
319            None => SparseVector::new(indices, values),
320        };
321
322        result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
323    }
324}
325
326#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
328#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
329#[serde(untagged)]
330pub enum UpdateMetadataValue {
331    Bool(bool),
332    Int(i64),
333    #[cfg_attr(
334        feature = "testing",
335        proptest(
336            strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
337        )
338    )]
339    Float(f64),
340    Str(String),
341    #[cfg_attr(feature = "testing", proptest(skip))]
342    SparseVector(SparseVector),
343    // Array types for multi-valued metadata fields
344    // TODO: Add support for these in proptests
345    #[cfg_attr(feature = "testing", proptest(skip))]
346    BoolArray(Vec<bool>),
347    #[cfg_attr(feature = "testing", proptest(skip))]
348    IntArray(Vec<i64>),
349    #[cfg_attr(feature = "testing", proptest(skip))]
350    FloatArray(Vec<f64>),
351    #[cfg_attr(feature = "testing", proptest(skip))]
352    StringArray(Vec<String>),
353    None,
354}
355
356#[cfg(feature = "pyo3")]
357impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
358    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
359        use pyo3::types::PyList;
360
361        if ob.is_none() {
362            Ok(UpdateMetadataValue::None)
363        } else if let Ok(value) = ob.extract::<bool>() {
364            Ok(UpdateMetadataValue::Bool(value))
365        } else if let Ok(value) = ob.extract::<i64>() {
366            Ok(UpdateMetadataValue::Int(value))
367        } else if let Ok(value) = ob.extract::<f64>() {
368            Ok(UpdateMetadataValue::Float(value))
369        } else if let Ok(value) = ob.extract::<String>() {
370            Ok(UpdateMetadataValue::Str(value))
371        } else if let Ok(value) = ob.extract::<SparseVector>() {
372            Ok(UpdateMetadataValue::SparseVector(value))
373        } else if let Ok(list) = ob.downcast::<PyList>() {
374            // Empty lists are not allowed
375            if list.is_empty()? {
376                return Err(pyo3::exceptions::PyValueError::new_err(
377                    "Empty lists are not allowed as metadata values",
378                ));
379            }
380            // Try to extract entire list as each type.
381            // We check all elements (not just the first) to handle mixed-numeric
382            // lists like [1, 2.5, 3] which should be inferred as FloatArray.
383            if let Ok(arr) = list.extract::<Vec<bool>>() {
384                Ok(UpdateMetadataValue::BoolArray(arr))
385            } else if let Ok(arr) = list.extract::<Vec<i64>>() {
386                Ok(UpdateMetadataValue::IntArray(arr))
387            } else if let Ok(arr) = list.extract::<Vec<f64>>() {
388                Ok(UpdateMetadataValue::FloatArray(arr))
389            } else if let Ok(arr) = list.extract::<Vec<String>>() {
390                Ok(UpdateMetadataValue::StringArray(arr))
391            } else {
392                Err(pyo3::exceptions::PyTypeError::new_err(
393                    "Cannot convert Python list to UpdateMetadataValue: mixed or unsupported element types",
394                ))
395            }
396        } else {
397            Err(pyo3::exceptions::PyTypeError::new_err(
398                "Cannot convert Python object to UpdateMetadataValue",
399            ))
400        }
401    }
402}
403
404impl From<bool> for UpdateMetadataValue {
405    fn from(b: bool) -> Self {
406        Self::Bool(b)
407    }
408}
409
410impl From<i64> for UpdateMetadataValue {
411    fn from(v: i64) -> Self {
412        Self::Int(v)
413    }
414}
415
416impl From<i32> for UpdateMetadataValue {
417    fn from(v: i32) -> Self {
418        Self::Int(v as i64)
419    }
420}
421
422impl From<f64> for UpdateMetadataValue {
423    fn from(v: f64) -> Self {
424        Self::Float(v)
425    }
426}
427
428impl From<f32> for UpdateMetadataValue {
429    fn from(v: f32) -> Self {
430        Self::Float(v as f64)
431    }
432}
433
434impl From<String> for UpdateMetadataValue {
435    fn from(v: String) -> Self {
436        Self::Str(v)
437    }
438}
439
440impl From<&str> for UpdateMetadataValue {
441    fn from(v: &str) -> Self {
442        Self::Str(v.to_string())
443    }
444}
445
446impl From<SparseVector> for UpdateMetadataValue {
447    fn from(v: SparseVector) -> Self {
448        Self::SparseVector(v)
449    }
450}
451
452impl From<Vec<bool>> for UpdateMetadataValue {
453    fn from(v: Vec<bool>) -> Self {
454        Self::BoolArray(v)
455    }
456}
457
458impl From<Vec<i64>> for UpdateMetadataValue {
459    fn from(v: Vec<i64>) -> Self {
460        Self::IntArray(v)
461    }
462}
463
464impl From<Vec<f64>> for UpdateMetadataValue {
465    fn from(v: Vec<f64>) -> Self {
466        Self::FloatArray(v)
467    }
468}
469
470impl From<Vec<String>> for UpdateMetadataValue {
471    fn from(v: Vec<String>) -> Self {
472        Self::StringArray(v)
473    }
474}
475
476#[derive(Error, Debug)]
477pub enum UpdateMetadataValueConversionError {
478    #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
479    InvalidValue,
480}
481
482impl ChromaError for UpdateMetadataValueConversionError {
483    fn code(&self) -> ErrorCodes {
484        match self {
485            UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
486        }
487    }
488}
489
490impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
491    type Error = UpdateMetadataValueConversionError;
492
493    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
494        match &value.value {
495            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
496                Ok(UpdateMetadataValue::Bool(*value))
497            }
498            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
499                Ok(UpdateMetadataValue::Int(*value))
500            }
501            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
502                Ok(UpdateMetadataValue::Float(*value))
503            }
504            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
505                Ok(UpdateMetadataValue::Str(value.clone()))
506            }
507            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
508                let sparse = value
509                    .clone()
510                    .try_into()
511                    .map_err(|_| UpdateMetadataValueConversionError::InvalidValue)?;
512                Ok(UpdateMetadataValue::SparseVector(sparse))
513            }
514            Some(chroma_proto::update_metadata_value::Value::BoolListValue(value)) => {
515                Ok(UpdateMetadataValue::BoolArray(value.values.clone()))
516            }
517            Some(chroma_proto::update_metadata_value::Value::IntListValue(value)) => {
518                Ok(UpdateMetadataValue::IntArray(value.values.clone()))
519            }
520            Some(chroma_proto::update_metadata_value::Value::DoubleListValue(value)) => {
521                Ok(UpdateMetadataValue::FloatArray(value.values.clone()))
522            }
523            Some(chroma_proto::update_metadata_value::Value::StringListValue(value)) => {
524                Ok(UpdateMetadataValue::StringArray(value.values.clone()))
525            }
526            // Used to communicate that the user wants to delete this key.
527            None => Ok(UpdateMetadataValue::None),
528        }
529    }
530}
531
532impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
533    fn from(value: UpdateMetadataValue) -> Self {
534        match value {
535            UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
536                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
537            },
538            UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
539                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
540            },
541            UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
542                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
543                    value,
544                )),
545            },
546            UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
547                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
548                    value,
549                )),
550            },
551            UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
552                value: Some(
553                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
554                        sparse_vec.into(),
555                    ),
556                ),
557            },
558            UpdateMetadataValue::BoolArray(values) => chroma_proto::UpdateMetadataValue {
559                value: Some(chroma_proto::update_metadata_value::Value::BoolListValue(
560                    chroma_proto::BoolListValue { values },
561                )),
562            },
563            UpdateMetadataValue::IntArray(values) => chroma_proto::UpdateMetadataValue {
564                value: Some(chroma_proto::update_metadata_value::Value::IntListValue(
565                    chroma_proto::IntListValue { values },
566                )),
567            },
568            UpdateMetadataValue::FloatArray(values) => chroma_proto::UpdateMetadataValue {
569                value: Some(chroma_proto::update_metadata_value::Value::DoubleListValue(
570                    chroma_proto::DoubleListValue { values },
571                )),
572            },
573            UpdateMetadataValue::StringArray(values) => chroma_proto::UpdateMetadataValue {
574                value: Some(chroma_proto::update_metadata_value::Value::StringListValue(
575                    chroma_proto::StringListValue { values },
576                )),
577            },
578            UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
579        }
580    }
581}
582
583impl TryFrom<&UpdateMetadataValue> for MetadataValue {
584    type Error = MetadataValueConversionError;
585
586    fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
587        match value {
588            UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
589            UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
590            UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
591            UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
592            UpdateMetadataValue::SparseVector(value) => {
593                Ok(MetadataValue::SparseVector(value.clone()))
594            }
595            UpdateMetadataValue::BoolArray(value) => Ok(MetadataValue::BoolArray(value.clone())),
596            UpdateMetadataValue::IntArray(value) => Ok(MetadataValue::IntArray(value.clone())),
597            UpdateMetadataValue::FloatArray(value) => Ok(MetadataValue::FloatArray(value.clone())),
598            UpdateMetadataValue::StringArray(value) => {
599                Ok(MetadataValue::StringArray(value.clone()))
600            }
601            UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
602        }
603    }
604}
605
606/*
607===========================================
608MetadataValue
609===========================================
610*/
611
612#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
613#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
614#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
615#[cfg_attr(feature = "pyo3", derive(pyo3::IntoPyObject))]
616#[serde(untagged)]
617pub enum MetadataValue {
618    Bool(bool),
619    Int(i64),
620    #[cfg_attr(
621        feature = "testing",
622        proptest(
623            strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
624        )
625    )]
626    Float(f64),
627    Str(String),
628    #[cfg_attr(feature = "testing", proptest(skip))]
629    SparseVector(SparseVector),
630    // Array types for multi-valued metadata fields
631    // TODO: Add support for these in proptests
632    #[cfg_attr(feature = "testing", proptest(skip))]
633    BoolArray(Vec<bool>),
634    #[cfg_attr(feature = "testing", proptest(skip))]
635    IntArray(Vec<i64>),
636    #[cfg_attr(feature = "testing", proptest(skip))]
637    FloatArray(Vec<f64>),
638    #[cfg_attr(feature = "testing", proptest(skip))]
639    StringArray(Vec<String>),
640}
641
642#[cfg(feature = "pyo3")]
643impl<'py> pyo3::FromPyObject<'py> for MetadataValue {
644    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
645        use pyo3::types::PyList;
646
647        if let Ok(value) = ob.extract::<bool>() {
648            Ok(MetadataValue::Bool(value))
649        } else if let Ok(value) = ob.extract::<i64>() {
650            Ok(MetadataValue::Int(value))
651        } else if let Ok(value) = ob.extract::<f64>() {
652            Ok(MetadataValue::Float(value))
653        } else if let Ok(value) = ob.extract::<String>() {
654            Ok(MetadataValue::Str(value))
655        } else if let Ok(value) = ob.extract::<SparseVector>() {
656            Ok(MetadataValue::SparseVector(value))
657        } else if let Ok(list) = ob.downcast::<PyList>() {
658            // Empty lists are not allowed
659            if list.is_empty()? {
660                return Err(pyo3::exceptions::PyValueError::new_err(
661                    "Empty lists are not allowed as metadata values",
662                ));
663            }
664            // Try to extract entire list as each type.
665            // We check all elements (not just the first) to handle mixed-numeric
666            // lists like [1, 2.5, 3] which should be inferred as FloatArray.
667            if let Ok(arr) = list.extract::<Vec<bool>>() {
668                Ok(MetadataValue::BoolArray(arr))
669            } else if let Ok(arr) = list.extract::<Vec<i64>>() {
670                Ok(MetadataValue::IntArray(arr))
671            } else if let Ok(arr) = list.extract::<Vec<f64>>() {
672                Ok(MetadataValue::FloatArray(arr))
673            } else if let Ok(arr) = list.extract::<Vec<String>>() {
674                Ok(MetadataValue::StringArray(arr))
675            } else {
676                Err(pyo3::exceptions::PyTypeError::new_err(
677                    "Cannot convert Python list to MetadataValue: mixed or unsupported element types",
678                ))
679            }
680        } else {
681            Err(pyo3::exceptions::PyTypeError::new_err(
682                "Cannot convert Python object to MetadataValue",
683            ))
684        }
685    }
686}
687
688impl std::fmt::Display for MetadataValue {
689    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
690        match self {
691            MetadataValue::Bool(v) => write!(f, "{}", v),
692            MetadataValue::Int(v) => write!(f, "{}", v),
693            MetadataValue::Float(v) => write!(f, "{}", v),
694            MetadataValue::Str(v) => write!(f, "\"{}\"", v),
695            MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
696            MetadataValue::BoolArray(v) => write!(f, "BoolArray(len={})", v.len()),
697            MetadataValue::IntArray(v) => write!(f, "IntArray(len={})", v.len()),
698            MetadataValue::FloatArray(v) => write!(f, "FloatArray(len={})", v.len()),
699            MetadataValue::StringArray(v) => write!(f, "StringArray(len={})", v.len()),
700        }
701    }
702}
703
704impl Eq for MetadataValue {}
705
706#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
707pub enum MetadataValueType {
708    Bool,
709    Int,
710    Float,
711    Str,
712    SparseVector,
713    BoolArray,
714    IntArray,
715    FloatArray,
716    StringArray,
717}
718
719impl MetadataValue {
720    pub fn value_type(&self) -> MetadataValueType {
721        match self {
722            MetadataValue::Bool(_) => MetadataValueType::Bool,
723            MetadataValue::Int(_) => MetadataValueType::Int,
724            MetadataValue::Float(_) => MetadataValueType::Float,
725            MetadataValue::Str(_) => MetadataValueType::Str,
726            MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
727            MetadataValue::BoolArray(_) => MetadataValueType::BoolArray,
728            MetadataValue::IntArray(_) => MetadataValueType::IntArray,
729            MetadataValue::FloatArray(_) => MetadataValueType::FloatArray,
730            MetadataValue::StringArray(_) => MetadataValueType::StringArray,
731        }
732    }
733}
734
735impl From<&MetadataValue> for MetadataValueType {
736    fn from(value: &MetadataValue) -> Self {
737        value.value_type()
738    }
739}
740
741impl From<bool> for MetadataValue {
742    fn from(v: bool) -> Self {
743        MetadataValue::Bool(v)
744    }
745}
746
747impl From<i64> for MetadataValue {
748    fn from(v: i64) -> Self {
749        MetadataValue::Int(v)
750    }
751}
752
753impl From<i32> for MetadataValue {
754    fn from(v: i32) -> Self {
755        MetadataValue::Int(v as i64)
756    }
757}
758
759impl From<f64> for MetadataValue {
760    fn from(v: f64) -> Self {
761        MetadataValue::Float(v)
762    }
763}
764
765impl From<f32> for MetadataValue {
766    fn from(v: f32) -> Self {
767        MetadataValue::Float(v as f64)
768    }
769}
770
771impl From<String> for MetadataValue {
772    fn from(v: String) -> Self {
773        MetadataValue::Str(v)
774    }
775}
776
777impl From<&str> for MetadataValue {
778    fn from(v: &str) -> Self {
779        MetadataValue::Str(v.to_string())
780    }
781}
782
783impl From<SparseVector> for MetadataValue {
784    fn from(v: SparseVector) -> Self {
785        MetadataValue::SparseVector(v)
786    }
787}
788
789impl From<Vec<bool>> for MetadataValue {
790    fn from(v: Vec<bool>) -> Self {
791        MetadataValue::BoolArray(v)
792    }
793}
794
795impl From<Vec<i64>> for MetadataValue {
796    fn from(v: Vec<i64>) -> Self {
797        MetadataValue::IntArray(v)
798    }
799}
800
801impl From<Vec<i32>> for MetadataValue {
802    fn from(v: Vec<i32>) -> Self {
803        MetadataValue::IntArray(v.into_iter().map(|x| x as i64).collect())
804    }
805}
806
807impl From<Vec<f64>> for MetadataValue {
808    fn from(v: Vec<f64>) -> Self {
809        MetadataValue::FloatArray(v)
810    }
811}
812
813impl From<Vec<f32>> for MetadataValue {
814    fn from(v: Vec<f32>) -> Self {
815        MetadataValue::FloatArray(v.into_iter().map(|x| x as f64).collect())
816    }
817}
818
819impl From<Vec<String>> for MetadataValue {
820    fn from(v: Vec<String>) -> Self {
821        MetadataValue::StringArray(v)
822    }
823}
824
825impl From<Vec<&str>> for MetadataValue {
826    fn from(v: Vec<&str>) -> Self {
827        MetadataValue::StringArray(v.into_iter().map(|s| s.to_string()).collect())
828    }
829}
830
831/// We need `Eq` and `Ord` since we want to use this as a key in `BTreeMap`
832///
833/// For cross-type comparisons, we define a consistent ordering based on variant position:
834/// Bool < Int < Float < Str < SparseVector < BoolArray < IntArray < FloatArray < StringArray
835#[allow(clippy::derive_ord_xor_partial_ord)]
836impl Ord for MetadataValue {
837    fn cmp(&self, other: &Self) -> Ordering {
838        // Define type ordering based on variant position
839        fn type_order(val: &MetadataValue) -> u8 {
840            match val {
841                MetadataValue::Bool(_) => 0,
842                MetadataValue::Int(_) => 1,
843                MetadataValue::Float(_) => 2,
844                MetadataValue::Str(_) => 3,
845                MetadataValue::SparseVector(_) => 4,
846                MetadataValue::BoolArray(_) => 5,
847                MetadataValue::IntArray(_) => 6,
848                MetadataValue::FloatArray(_) => 7,
849                MetadataValue::StringArray(_) => 8,
850            }
851        }
852
853        // Chain type ordering with value ordering
854        type_order(self).cmp(&type_order(other)).then_with(|| {
855            match (self, other) {
856                (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
857                (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
858                (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
859                (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
860                (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
861                    left.cmp(right)
862                }
863                (MetadataValue::BoolArray(left), MetadataValue::BoolArray(right)) => {
864                    left.cmp(right)
865                }
866                (MetadataValue::IntArray(left), MetadataValue::IntArray(right)) => left.cmp(right),
867                (MetadataValue::FloatArray(left), MetadataValue::FloatArray(right)) => {
868                    // Compare element by element using total_cmp for f64
869                    for (l, r) in left.iter().zip(right.iter()) {
870                        match l.total_cmp(r) {
871                            Ordering::Equal => continue,
872                            other => return other,
873                        }
874                    }
875                    left.len().cmp(&right.len())
876                }
877                (MetadataValue::StringArray(left), MetadataValue::StringArray(right)) => {
878                    left.cmp(right)
879                }
880                _ => Ordering::Equal, // Different types, but type_order already handled this
881            }
882        })
883    }
884}
885
886impl PartialOrd for MetadataValue {
887    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
888        Some(self.cmp(other))
889    }
890}
891
892impl TryFrom<&MetadataValue> for bool {
893    type Error = MetadataValueConversionError;
894
895    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
896        match value {
897            MetadataValue::Bool(value) => Ok(*value),
898            _ => Err(MetadataValueConversionError::InvalidValue),
899        }
900    }
901}
902
903impl TryFrom<&MetadataValue> for i64 {
904    type Error = MetadataValueConversionError;
905
906    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
907        match value {
908            MetadataValue::Int(value) => Ok(*value),
909            _ => Err(MetadataValueConversionError::InvalidValue),
910        }
911    }
912}
913
914impl TryFrom<&MetadataValue> for f64 {
915    type Error = MetadataValueConversionError;
916
917    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
918        match value {
919            MetadataValue::Float(value) => Ok(*value),
920            _ => Err(MetadataValueConversionError::InvalidValue),
921        }
922    }
923}
924
925impl TryFrom<&MetadataValue> for String {
926    type Error = MetadataValueConversionError;
927
928    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
929        match value {
930            MetadataValue::Str(value) => Ok(value.clone()),
931            _ => Err(MetadataValueConversionError::InvalidValue),
932        }
933    }
934}
935
936impl From<MetadataValue> for UpdateMetadataValue {
937    fn from(value: MetadataValue) -> Self {
938        match value {
939            MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
940            MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
941            MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
942            MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
943            MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
944            MetadataValue::BoolArray(v) => UpdateMetadataValue::BoolArray(v),
945            MetadataValue::IntArray(v) => UpdateMetadataValue::IntArray(v),
946            MetadataValue::FloatArray(v) => UpdateMetadataValue::FloatArray(v),
947            MetadataValue::StringArray(v) => UpdateMetadataValue::StringArray(v),
948        }
949    }
950}
951
952impl From<MetadataValue> for Value {
953    fn from(value: MetadataValue) -> Self {
954        match value {
955            MetadataValue::Bool(val) => Self::Bool(val),
956            MetadataValue::Int(val) => Self::Number(
957                Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
958            ),
959            MetadataValue::Float(val) => Self::Number(
960                Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
961            ),
962            MetadataValue::Str(val) => Self::String(val),
963            MetadataValue::SparseVector(val) => {
964                let mut map = serde_json::Map::new();
965                map.insert(
966                    "indices".to_string(),
967                    Value::Array(
968                        val.indices
969                            .iter()
970                            .map(|&i| Value::Number(i.into()))
971                            .collect(),
972                    ),
973                );
974                map.insert(
975                    "values".to_string(),
976                    Value::Array(
977                        val.values
978                            .iter()
979                            .map(|&v| {
980                                Value::Number(
981                                    Number::from_f64(v as f64)
982                                        .expect("Float number should not be NaN or infinite"),
983                                )
984                            })
985                            .collect(),
986                    ),
987                );
988                Self::Object(map)
989            }
990            MetadataValue::BoolArray(vals) => {
991                Self::Array(vals.into_iter().map(Value::Bool).collect())
992            }
993            MetadataValue::IntArray(vals) => Self::Array(
994                vals.into_iter()
995                    .map(|v| {
996                        Value::Number(
997                            Number::from_i128(v as i128)
998                                .expect("i64 should be representable in JSON"),
999                        )
1000                    })
1001                    .collect(),
1002            ),
1003            MetadataValue::FloatArray(vals) => Self::Array(
1004                vals.into_iter()
1005                    .map(|v| {
1006                        Value::Number(
1007                            Number::from_f64(v)
1008                                .expect("Inf and NaN should not be present in MetadataValue"),
1009                        )
1010                    })
1011                    .collect(),
1012            ),
1013            MetadataValue::StringArray(vals) => {
1014                Self::Array(vals.into_iter().map(Value::String).collect())
1015            }
1016        }
1017    }
1018}
1019
1020#[derive(Error, Debug)]
1021pub enum MetadataValueConversionError {
1022    #[error("Invalid metadata value, valid values are: Int, Float, Str")]
1023    InvalidValue,
1024    #[error("Metadata key cannot start with '#' or '$': {0}")]
1025    InvalidKey(String),
1026    #[error("Sparse vector indices, values, and tokens (when present) must have the same length")]
1027    SparseVectorLengthMismatch,
1028    #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
1029    SparseVectorIndicesNotSorted,
1030}
1031
1032impl ChromaError for MetadataValueConversionError {
1033    fn code(&self) -> ErrorCodes {
1034        match self {
1035            MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
1036            MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
1037            MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
1038            MetadataValueConversionError::SparseVectorIndicesNotSorted => {
1039                ErrorCodes::InvalidArgument
1040            }
1041        }
1042    }
1043}
1044
1045impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
1046    type Error = MetadataValueConversionError;
1047
1048    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
1049        match &value.value {
1050            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
1051                Ok(MetadataValue::Bool(*value))
1052            }
1053            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
1054                Ok(MetadataValue::Int(*value))
1055            }
1056            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
1057                Ok(MetadataValue::Float(*value))
1058            }
1059            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
1060                Ok(MetadataValue::Str(value.clone()))
1061            }
1062            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
1063                let sparse = value
1064                    .clone()
1065                    .try_into()
1066                    .map_err(|_| MetadataValueConversionError::SparseVectorLengthMismatch)?;
1067                Ok(MetadataValue::SparseVector(sparse))
1068            }
1069            Some(chroma_proto::update_metadata_value::Value::BoolListValue(value)) => {
1070                Ok(MetadataValue::BoolArray(value.values.clone()))
1071            }
1072            Some(chroma_proto::update_metadata_value::Value::IntListValue(value)) => {
1073                Ok(MetadataValue::IntArray(value.values.clone()))
1074            }
1075            Some(chroma_proto::update_metadata_value::Value::DoubleListValue(value)) => {
1076                Ok(MetadataValue::FloatArray(value.values.clone()))
1077            }
1078            Some(chroma_proto::update_metadata_value::Value::StringListValue(value)) => {
1079                Ok(MetadataValue::StringArray(value.values.clone()))
1080            }
1081            _ => Err(MetadataValueConversionError::InvalidValue),
1082        }
1083    }
1084}
1085
1086impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
1087    fn from(value: MetadataValue) -> Self {
1088        match value {
1089            MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
1090                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
1091            },
1092            MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
1093                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
1094                    value,
1095                )),
1096            },
1097            MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
1098                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1099                    value,
1100                )),
1101            },
1102            MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
1103                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
1104            },
1105            MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
1106                value: Some(
1107                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1108                        sparse_vec.into(),
1109                    ),
1110                ),
1111            },
1112            MetadataValue::BoolArray(values) => chroma_proto::UpdateMetadataValue {
1113                value: Some(chroma_proto::update_metadata_value::Value::BoolListValue(
1114                    chroma_proto::BoolListValue { values },
1115                )),
1116            },
1117            MetadataValue::IntArray(values) => chroma_proto::UpdateMetadataValue {
1118                value: Some(chroma_proto::update_metadata_value::Value::IntListValue(
1119                    chroma_proto::IntListValue { values },
1120                )),
1121            },
1122            MetadataValue::FloatArray(values) => chroma_proto::UpdateMetadataValue {
1123                value: Some(chroma_proto::update_metadata_value::Value::DoubleListValue(
1124                    chroma_proto::DoubleListValue { values },
1125                )),
1126            },
1127            MetadataValue::StringArray(values) => chroma_proto::UpdateMetadataValue {
1128                value: Some(chroma_proto::update_metadata_value::Value::StringListValue(
1129                    chroma_proto::StringListValue { values },
1130                )),
1131            },
1132        }
1133    }
1134}
1135
1136/*
1137===========================================
1138UpdateMetadata
1139===========================================
1140*/
1141pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
1142
1143/**
1144 * Check if two metadata are close to equal. Ignores small differences in float values.
1145 */
1146pub fn are_update_metadatas_close_to_equal(
1147    metadata1: &UpdateMetadata,
1148    metadata2: &UpdateMetadata,
1149) -> bool {
1150    assert_eq!(metadata1.len(), metadata2.len());
1151
1152    for (key, value) in metadata1.iter() {
1153        if !metadata2.contains_key(key) {
1154            return false;
1155        }
1156        let other_value = metadata2.get(key).unwrap();
1157
1158        if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
1159            (value, other_value)
1160        {
1161            if (value - other_value).abs() > 1e-6 {
1162                return false;
1163            }
1164        } else if value != other_value {
1165            return false;
1166        }
1167    }
1168
1169    true
1170}
1171
1172pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
1173    assert_eq!(metadata1.len(), metadata2.len());
1174
1175    for (key, value) in metadata1.iter() {
1176        if !metadata2.contains_key(key) {
1177            return false;
1178        }
1179        let other_value = metadata2.get(key).unwrap();
1180
1181        if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
1182            (value, other_value)
1183        {
1184            if (value - other_value).abs() > 1e-6 {
1185                return false;
1186            }
1187        } else if value != other_value {
1188            return false;
1189        }
1190    }
1191
1192    true
1193}
1194
1195impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
1196    type Error = UpdateMetadataValueConversionError;
1197
1198    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
1199        let mut metadata = UpdateMetadata::with_capacity(proto_metadata.metadata.len());
1200        for (key, value) in proto_metadata.metadata.into_iter() {
1201            let value = match (&value).try_into() {
1202                Ok(value) => value,
1203                Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
1204            };
1205            metadata.insert(key, value);
1206        }
1207        Ok(metadata)
1208    }
1209}
1210
1211impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
1212    fn from(metadata: UpdateMetadata) -> Self {
1213        let mut proto_metadata = chroma_proto::UpdateMetadata {
1214            metadata: HashMap::with_capacity(metadata.len()),
1215        };
1216        for (key, value) in metadata.into_iter() {
1217            let proto_value = value.into();
1218            proto_metadata.metadata.insert(key, proto_value);
1219        }
1220        proto_metadata
1221    }
1222}
1223
1224/*
1225===========================================
1226Metadata
1227===========================================
1228*/
1229
1230pub type Metadata = HashMap<String, MetadataValue>;
1231pub type DeletedMetadata = HashSet<String>;
1232
1233pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
1234    metadata
1235        .iter()
1236        .map(|(k, v)| {
1237            k.len()
1238                + match v {
1239                    MetadataValue::Bool(b) => size_of_val(b),
1240                    MetadataValue::Int(i) => size_of_val(i),
1241                    MetadataValue::Float(f) => size_of_val(f),
1242                    MetadataValue::Str(s) => s.len(),
1243                    MetadataValue::SparseVector(v) => {
1244                        size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
1245                    }
1246                    MetadataValue::BoolArray(arr) => size_of_val(&arr[..]),
1247                    MetadataValue::IntArray(arr) => size_of_val(&arr[..]),
1248                    MetadataValue::FloatArray(arr) => size_of_val(&arr[..]),
1249                    MetadataValue::StringArray(arr) => arr.iter().map(|s| s.len()).sum::<usize>(),
1250                }
1251        })
1252        .sum()
1253}
1254
1255pub fn get_metadata_value_as<'a, T>(
1256    metadata: &'a Metadata,
1257    key: &str,
1258) -> Result<T, Box<MetadataValueConversionError>>
1259where
1260    T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
1261{
1262    let res = match metadata.get(key) {
1263        Some(value) => T::try_from(value),
1264        None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
1265    };
1266    match res {
1267        Ok(value) => Ok(value),
1268        Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
1269    }
1270}
1271
1272impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
1273    type Error = MetadataValueConversionError;
1274
1275    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
1276        let mut metadata = Metadata::new();
1277        for (key, value) in proto_metadata.metadata.iter() {
1278            let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
1279            if maybe_value.is_err() {
1280                return Err(MetadataValueConversionError::InvalidValue);
1281            }
1282            let value = maybe_value.unwrap();
1283            metadata.insert(key.clone(), value);
1284        }
1285        Ok(metadata)
1286    }
1287}
1288
1289impl From<Metadata> for chroma_proto::UpdateMetadata {
1290    fn from(metadata: Metadata) -> Self {
1291        let mut metadata = metadata;
1292        let mut proto_metadata = chroma_proto::UpdateMetadata {
1293            metadata: HashMap::new(),
1294        };
1295        for (key, value) in metadata.drain() {
1296            let proto_value = value.into();
1297            proto_metadata.metadata.insert(key.clone(), proto_value);
1298        }
1299        proto_metadata
1300    }
1301}
1302
1303#[derive(Debug, Default)]
1304pub struct MetadataDelta<'referred_data> {
1305    pub metadata_to_update: HashMap<
1306        &'referred_data str,
1307        (&'referred_data MetadataValue, &'referred_data MetadataValue),
1308    >,
1309    pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1310    pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1311}
1312
1313impl MetadataDelta<'_> {
1314    pub fn new() -> Self {
1315        Self::default()
1316    }
1317}
1318
1319/*
1320===========================================
1321Metadata queries
1322===========================================
1323*/
1324
1325#[derive(Clone, Debug, Error, PartialEq)]
1326pub enum WhereConversionError {
1327    #[error("Error: {0}")]
1328    Cause(String),
1329    #[error("{0} -> {1}")]
1330    Trace(String, Box<Self>),
1331}
1332
1333impl WhereConversionError {
1334    pub fn cause(msg: impl ToString) -> Self {
1335        Self::Cause(msg.to_string())
1336    }
1337
1338    pub fn trace(self, context: impl ToString) -> Self {
1339        Self::Trace(context.to_string(), Box::new(self))
1340    }
1341}
1342
1343/// This `Where` enum serves as an unified representation for the `where` and `where_document` clauses.
1344/// Although this is not unified in the API level due to legacy design choices, in the future we will be
1345/// unifying them together, and the structure of the unified AST should be identical to the one here.
1346/// Currently both `where` and `where_document` clauses will be translated into `Where`, and if both are
1347/// present we simply create a conjunction of both clauses as the actual filter. This is consistent with
1348/// the semantics we used to have when the `where` and `where_document` clauses are treated seperately.
1349// TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level.
1350#[derive(Clone, Debug, PartialEq)]
1351#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1352pub enum Where {
1353    Composite(CompositeExpression),
1354    Document(DocumentExpression),
1355    Metadata(MetadataExpression),
1356}
1357
1358impl std::fmt::Display for Where {
1359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360        match self {
1361            Where::Composite(composite) => {
1362                let fragment = composite
1363                    .children
1364                    .iter()
1365                    .map(|child| format!("{}", child))
1366                    .collect::<Vec<_>>()
1367                    .join(match composite.operator {
1368                        BooleanOperator::And => " & ",
1369                        BooleanOperator::Or => " | ",
1370                    });
1371                write!(f, "({})", fragment)
1372            }
1373            Where::Metadata(expr) => write!(f, "{}", expr),
1374            Where::Document(expr) => write!(f, "{}", expr),
1375        }
1376    }
1377}
1378
1379impl serde::Serialize for Where {
1380    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1381    where
1382        S: Serializer,
1383    {
1384        match self {
1385            Where::Composite(composite) => {
1386                let mut map = serializer.serialize_map(Some(1))?;
1387                let op_key = match composite.operator {
1388                    BooleanOperator::And => "$and",
1389                    BooleanOperator::Or => "$or",
1390                };
1391                map.serialize_entry(op_key, &composite.children)?;
1392                map.end()
1393            }
1394            Where::Document(doc) => {
1395                let mut outer_map = serializer.serialize_map(Some(1))?;
1396                let mut inner_map = serde_json::Map::new();
1397                let op_key = match doc.operator {
1398                    DocumentOperator::Contains => "$contains",
1399                    DocumentOperator::NotContains => "$not_contains",
1400                    DocumentOperator::Regex => "$regex",
1401                    DocumentOperator::NotRegex => "$not_regex",
1402                };
1403                inner_map.insert(
1404                    op_key.to_string(),
1405                    serde_json::Value::String(doc.pattern.clone()),
1406                );
1407                outer_map.serialize_entry("#document", &inner_map)?;
1408                outer_map.end()
1409            }
1410            Where::Metadata(meta) => {
1411                let mut outer_map = serializer.serialize_map(Some(1))?;
1412                let mut inner_map = serde_json::Map::new();
1413
1414                match &meta.comparison {
1415                    MetadataComparison::Primitive(op, value) => {
1416                        let op_key = match op {
1417                            PrimitiveOperator::Equal => "$eq",
1418                            PrimitiveOperator::NotEqual => "$ne",
1419                            PrimitiveOperator::GreaterThan => "$gt",
1420                            PrimitiveOperator::GreaterThanOrEqual => "$gte",
1421                            PrimitiveOperator::LessThan => "$lt",
1422                            PrimitiveOperator::LessThanOrEqual => "$lte",
1423                        };
1424                        let value_json =
1425                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1426                        inner_map.insert(op_key.to_string(), value_json);
1427                    }
1428                    MetadataComparison::Set(op, set_value) => {
1429                        let op_key = match op {
1430                            SetOperator::In => "$in",
1431                            SetOperator::NotIn => "$nin",
1432                        };
1433                        let values_json = match set_value {
1434                            MetadataSetValue::Bool(v) => serde_json::to_value(v),
1435                            MetadataSetValue::Int(v) => serde_json::to_value(v),
1436                            MetadataSetValue::Float(v) => serde_json::to_value(v),
1437                            MetadataSetValue::Str(v) => serde_json::to_value(v),
1438                        }
1439                        .map_err(serde::ser::Error::custom)?;
1440                        inner_map.insert(op_key.to_string(), values_json);
1441                    }
1442                    MetadataComparison::ArrayContains(op, value) => {
1443                        let op_key = match op {
1444                            ContainsOperator::Contains => "$contains",
1445                            ContainsOperator::NotContains => "$not_contains",
1446                        };
1447                        let value_json =
1448                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1449                        inner_map.insert(op_key.to_string(), value_json);
1450                    }
1451                }
1452
1453                outer_map.serialize_entry(&meta.key, &inner_map)?;
1454                outer_map.end()
1455            }
1456        }
1457    }
1458}
1459
1460impl From<bool> for Where {
1461    fn from(value: bool) -> Self {
1462        if value {
1463            Where::conjunction(vec![])
1464        } else {
1465            Where::disjunction(vec![])
1466        }
1467    }
1468}
1469
1470impl Where {
1471    pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1472        // If children.len() == 0, we will return a conjunction that is always true.
1473        // If children.len() == 1, we will return the single child.
1474        // Otherwise, we will return a conjunction of the children.
1475
1476        let mut children: Vec<_> = children
1477            .into_iter()
1478            .flat_map(|expr| {
1479                if let Where::Composite(CompositeExpression {
1480                    operator: BooleanOperator::And,
1481                    children,
1482                }) = expr
1483                {
1484                    return children;
1485                }
1486                vec![expr]
1487            })
1488            .dedup()
1489            .collect();
1490
1491        if children.len() == 1 {
1492            return children.pop().expect("just checked len is 1");
1493        }
1494
1495        Self::Composite(CompositeExpression {
1496            operator: BooleanOperator::And,
1497            children,
1498        })
1499    }
1500    pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1501        // If children.len() == 0, we will return a disjunction that is always false.
1502        // If children.len() == 1, we will return the single child.
1503        // Otherwise, we will return a disjunction of the children.
1504
1505        let mut children: Vec<_> = children
1506            .into_iter()
1507            .flat_map(|expr| {
1508                if let Where::Composite(CompositeExpression {
1509                    operator: BooleanOperator::Or,
1510                    children,
1511                }) = expr
1512                {
1513                    return children;
1514                }
1515                vec![expr]
1516            })
1517            .dedup()
1518            .collect();
1519
1520        if children.len() == 1 {
1521            return children.pop().expect("just checked len is 1");
1522        }
1523
1524        Self::Composite(CompositeExpression {
1525            operator: BooleanOperator::Or,
1526            children,
1527        })
1528    }
1529
1530    pub fn fts_query_length(&self) -> u64 {
1531        match self {
1532            Where::Composite(composite_expression) => composite_expression
1533                .children
1534                .iter()
1535                .map(Where::fts_query_length)
1536                .sum(),
1537            // The query length is defined to be the number of trigram tokens
1538            Where::Document(document_expression) => {
1539                document_expression.pattern.len().max(3) as u64 - 2
1540            }
1541            Where::Metadata(_) => 0,
1542        }
1543    }
1544
1545    pub fn metadata_predicate_count(&self) -> u64 {
1546        match self {
1547            Where::Composite(composite_expression) => composite_expression
1548                .children
1549                .iter()
1550                .map(Where::metadata_predicate_count)
1551                .sum(),
1552            Where::Document(_) => 0,
1553            Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1554                MetadataComparison::Primitive(_, _) => 1,
1555                MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1556                    MetadataSetValue::Bool(items) => items.len() as u64,
1557                    MetadataSetValue::Int(items) => items.len() as u64,
1558                    MetadataSetValue::Float(items) => items.len() as u64,
1559                    MetadataSetValue::Str(items) => items.len() as u64,
1560                },
1561                MetadataComparison::ArrayContains(_, _) => 1,
1562            },
1563        }
1564    }
1565}
1566
1567impl BitAnd for Where {
1568    type Output = Where;
1569
1570    fn bitand(self, rhs: Self) -> Self::Output {
1571        Self::conjunction([self, rhs])
1572    }
1573}
1574
1575impl BitOr for Where {
1576    type Output = Where;
1577
1578    fn bitor(self, rhs: Self) -> Self::Output {
1579        Self::disjunction([self, rhs])
1580    }
1581}
1582
1583impl TryFrom<chroma_proto::Where> for Where {
1584    type Error = WhereConversionError;
1585
1586    fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1587        let where_inner = proto_where
1588            .r#where
1589            .ok_or(WhereConversionError::cause("Invalid Where"))?;
1590        Ok(match where_inner {
1591            chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1592                Self::Metadata(direct_comparison.try_into()?)
1593            }
1594            chroma_proto::r#where::Where::Children(where_children) => {
1595                Self::Composite(where_children.try_into()?)
1596            }
1597            chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1598                Self::Document(direct_where_document.into())
1599            }
1600        })
1601    }
1602}
1603
1604impl TryFrom<Where> for chroma_proto::Where {
1605    type Error = WhereConversionError;
1606
1607    fn try_from(value: Where) -> Result<Self, Self::Error> {
1608        let proto_where = match value {
1609            Where::Composite(composite_expression) => {
1610                chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1611            }
1612            Where::Document(document_expression) => {
1613                chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1614            }
1615            Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1616                chroma_proto::DirectComparison::try_from(metadata_expression)
1617                    .map_err(|err| err.trace("MetadataExpression"))?,
1618            ),
1619        };
1620        Ok(Self {
1621            r#where: Some(proto_where),
1622        })
1623    }
1624}
1625
1626#[derive(Clone, Debug, PartialEq)]
1627#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1628pub struct CompositeExpression {
1629    pub operator: BooleanOperator,
1630    pub children: Vec<Where>,
1631}
1632
1633impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1634    type Error = WhereConversionError;
1635
1636    fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1637        let operator = proto_children.operator().into();
1638        let children = proto_children
1639            .children
1640            .into_iter()
1641            .map(Where::try_from)
1642            .collect::<Result<Vec<_>, _>>()
1643            .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1644        Ok(Self { operator, children })
1645    }
1646}
1647
1648impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1649    type Error = WhereConversionError;
1650
1651    fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1652        Ok(Self {
1653            operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1654            children: value
1655                .children
1656                .into_iter()
1657                .map(chroma_proto::Where::try_from)
1658                .collect::<Result<_, _>>()?,
1659        })
1660    }
1661}
1662
1663#[derive(Clone, Debug, PartialEq)]
1664#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1665pub enum BooleanOperator {
1666    And,
1667    Or,
1668}
1669
1670impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1671    fn from(value: chroma_proto::BooleanOperator) -> Self {
1672        match value {
1673            chroma_proto::BooleanOperator::And => Self::And,
1674            chroma_proto::BooleanOperator::Or => Self::Or,
1675        }
1676    }
1677}
1678
1679impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1680    fn from(value: BooleanOperator) -> Self {
1681        match value {
1682            BooleanOperator::And => Self::And,
1683            BooleanOperator::Or => Self::Or,
1684        }
1685    }
1686}
1687
1688#[derive(Clone, Debug, PartialEq)]
1689#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1690pub struct DocumentExpression {
1691    pub operator: DocumentOperator,
1692    pub pattern: String,
1693}
1694
1695impl std::fmt::Display for DocumentExpression {
1696    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1697        let op_str = match self.operator {
1698            DocumentOperator::Contains => "CONTAINS",
1699            DocumentOperator::NotContains => "NOT CONTAINS",
1700            DocumentOperator::Regex => "REGEX",
1701            DocumentOperator::NotRegex => "NOT REGEX",
1702        };
1703        write!(f, "#document {} \"{}\"", op_str, self.pattern)
1704    }
1705}
1706
1707impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1708    fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1709        Self {
1710            operator: value.operator().into(),
1711            pattern: value.pattern,
1712        }
1713    }
1714}
1715
1716impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1717    fn from(value: DocumentExpression) -> Self {
1718        Self {
1719            pattern: value.pattern,
1720            operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1721        }
1722    }
1723}
1724
1725#[derive(Clone, Debug, PartialEq)]
1726#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1727pub enum DocumentOperator {
1728    Contains,
1729    NotContains,
1730    Regex,
1731    NotRegex,
1732}
1733impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1734    fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1735        match value {
1736            chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1737            chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1738            chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1739            chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1740        }
1741    }
1742}
1743
1744impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1745    fn from(value: DocumentOperator) -> Self {
1746        match value {
1747            DocumentOperator::Contains => Self::Contains,
1748            DocumentOperator::NotContains => Self::NotContains,
1749            DocumentOperator::Regex => Self::Regex,
1750            DocumentOperator::NotRegex => Self::NotRegex,
1751        }
1752    }
1753}
1754
1755#[derive(Clone, Debug, PartialEq)]
1756#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1757pub struct MetadataExpression {
1758    pub key: String,
1759    pub comparison: MetadataComparison,
1760}
1761
1762impl std::fmt::Display for MetadataExpression {
1763    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1764        match &self.comparison {
1765            MetadataComparison::Primitive(op, value) => {
1766                write!(f, "{} {} {}", self.key, op, value)
1767            }
1768            MetadataComparison::Set(op, set_value) => {
1769                write!(f, "{} {} {}", self.key, op, set_value)
1770            }
1771            MetadataComparison::ArrayContains(op, value) => {
1772                write!(f, "{} {} {}", self.key, op, value)
1773            }
1774        }
1775    }
1776}
1777
1778/// Helper to convert a `GenericComparator` and a `MetadataValue` into either a
1779/// `MetadataComparison::Primitive` (for EQ/NE) or `MetadataComparison::Contains`
1780/// (for CONTAINS/NOT_CONTAINS).
1781fn generic_comparator_to_metadata_comparison(
1782    comparator: chroma_proto::GenericComparator,
1783    value: MetadataValue,
1784) -> MetadataComparison {
1785    match comparator {
1786        chroma_proto::GenericComparator::Eq | chroma_proto::GenericComparator::Ne => {
1787            // SAFETY: We just matched Eq | Ne, so try_into() a
1788            // PrimitiveOperator will always succeed.
1789            MetadataComparison::Primitive(comparator.try_into().unwrap(), value)
1790        }
1791        chroma_proto::GenericComparator::ArrayContains => {
1792            MetadataComparison::ArrayContains(ContainsOperator::Contains, value)
1793        }
1794        chroma_proto::GenericComparator::ArrayNotContains => {
1795            MetadataComparison::ArrayContains(ContainsOperator::NotContains, value)
1796        }
1797    }
1798}
1799
1800impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1801    type Error = WhereConversionError;
1802
1803    fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1804        let proto_comparison = value
1805            .comparison
1806            .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1807        let comparison = match proto_comparison {
1808            chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1809                single_string_comparison,
1810            ) => generic_comparator_to_metadata_comparison(
1811                single_string_comparison.comparator(),
1812                MetadataValue::Str(single_string_comparison.value),
1813            ),
1814            chroma_proto::direct_comparison::Comparison::StringListOperand(
1815                string_list_comparison,
1816            ) => MetadataComparison::Set(
1817                string_list_comparison.list_operator().into(),
1818                MetadataSetValue::Str(string_list_comparison.values),
1819            ),
1820            chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1821                single_int_comparison,
1822            ) => {
1823                let comparator =
1824                    single_int_comparison
1825                        .comparator
1826                        .ok_or(WhereConversionError::cause(
1827                            "Invalid scalar integer operator",
1828                        ))?;
1829                let value = MetadataValue::Int(single_int_comparison.value);
1830                match comparator {
1831                    chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1832                        let generic = chroma_proto::GenericComparator::try_from(op)
1833                            .map_err(WhereConversionError::cause)?;
1834                        generic_comparator_to_metadata_comparison(generic, value)
1835                    }
1836                    chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1837                        MetadataComparison::Primitive(
1838                            chroma_proto::NumberComparator::try_from(op)
1839                                .map_err(WhereConversionError::cause)?
1840                                .into(),
1841                            value,
1842                        )
1843                    }
1844                }
1845            }
1846            chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1847                MetadataComparison::Set(
1848                    int_list_comparison.list_operator().into(),
1849                    MetadataSetValue::Int(int_list_comparison.values),
1850                )
1851            }
1852            chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1853                single_double_comparison,
1854            ) => {
1855                let comparator = single_double_comparison
1856                    .comparator
1857                    .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?;
1858                let value = MetadataValue::Float(single_double_comparison.value);
1859                match comparator {
1860                    chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1861                        let generic = chroma_proto::GenericComparator::try_from(op)
1862                            .map_err(WhereConversionError::cause)?;
1863                        generic_comparator_to_metadata_comparison(generic, value)
1864                    }
1865                    chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1866                        MetadataComparison::Primitive(
1867                            chroma_proto::NumberComparator::try_from(op)
1868                                .map_err(WhereConversionError::cause)?
1869                                .into(),
1870                            value,
1871                        )
1872                    }
1873                }
1874            }
1875            chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1876                double_list_comparison,
1877            ) => MetadataComparison::Set(
1878                double_list_comparison.list_operator().into(),
1879                MetadataSetValue::Float(double_list_comparison.values),
1880            ),
1881            chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1882                MetadataComparison::Set(
1883                    bool_list_comparison.list_operator().into(),
1884                    MetadataSetValue::Bool(bool_list_comparison.values),
1885                )
1886            }
1887            chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1888                single_bool_comparison,
1889            ) => generic_comparator_to_metadata_comparison(
1890                single_bool_comparison.comparator(),
1891                MetadataValue::Bool(single_bool_comparison.value),
1892            ),
1893        };
1894        Ok(Self {
1895            key: value.key,
1896            comparison,
1897        })
1898    }
1899}
1900
1901impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1902    type Error = WhereConversionError;
1903
1904    fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1905        let comparison = match value.comparison {
1906            MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1907                MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1908                MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1909                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1910                                numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1911                            }),
1912                MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1913                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1914                                numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1915                            }),
1916                MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1917                MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1918                MetadataValue::BoolArray(_) | MetadataValue::IntArray(_) | MetadataValue::FloatArray(_) | MetadataValue::StringArray(_) => {
1919                    return Err(WhereConversionError::Cause("Primitive comparison with array metadata values is not supported".to_string()))
1920                }
1921            },
1922            MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1923                MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1924                MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1925                MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1926                MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1927            },
1928            MetadataComparison::ArrayContains(contains_operator, metadata_value) => {
1929                let comparator = chroma_proto::GenericComparator::from(contains_operator) as i32;
1930                match metadata_value {
1931                    MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator }),
1932                    MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(comparator)) }),
1933                    MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(chroma_proto::single_double_comparison::Comparator::GenericComparator(comparator)) }),
1934                    MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator }),
1935                    MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Contains comparison with sparse vector is not supported".to_string())),
1936                    MetadataValue::BoolArray(_) | MetadataValue::IntArray(_) | MetadataValue::FloatArray(_) | MetadataValue::StringArray(_) => {
1937                        return Err(WhereConversionError::Cause("Contains comparison value must be a scalar, not an array".to_string()))
1938                    }
1939                }
1940            },
1941        };
1942        Ok(Self {
1943            key: value.key,
1944            comparison: Some(comparison),
1945        })
1946    }
1947}
1948
1949#[derive(Clone, Debug, PartialEq)]
1950#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1951pub enum MetadataComparison {
1952    Primitive(PrimitiveOperator, MetadataValue),
1953    Set(SetOperator, MetadataSetValue),
1954    /// Array contains: check if an array metadata field contains (or does not
1955    /// contain) a specific scalar value.
1956    ArrayContains(ContainsOperator, MetadataValue),
1957}
1958
1959impl std::fmt::Display for MetadataComparison {
1960    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1961        match self {
1962            MetadataComparison::Primitive(op, val) => {
1963                let type_name = match val {
1964                    MetadataValue::Bool(_) => "Bool",
1965                    MetadataValue::Int(_) => "Int",
1966                    MetadataValue::Float(_) => "Float",
1967                    MetadataValue::Str(_) => "Str",
1968                    MetadataValue::SparseVector(_) => "SparseVector",
1969                    MetadataValue::BoolArray(_) => "BoolArray",
1970                    MetadataValue::IntArray(_) => "IntArray",
1971                    MetadataValue::FloatArray(_) => "FloatArray",
1972                    MetadataValue::StringArray(_) => "StringArray",
1973                };
1974                write!(f, "Primitive({}, {})", op, type_name)
1975            }
1976            MetadataComparison::Set(op, val) => {
1977                let type_name = match val {
1978                    MetadataSetValue::Bool(_) => "Bool",
1979                    MetadataSetValue::Int(_) => "Int",
1980                    MetadataSetValue::Float(_) => "Float",
1981                    MetadataSetValue::Str(_) => "Str",
1982                };
1983                write!(f, "Set({}, {})", op, type_name)
1984            }
1985            MetadataComparison::ArrayContains(op, val) => {
1986                let type_name = match val {
1987                    MetadataValue::Bool(_) => "Bool",
1988                    MetadataValue::Int(_) => "Int",
1989                    MetadataValue::Float(_) => "Float",
1990                    MetadataValue::Str(_) => "Str",
1991                    MetadataValue::SparseVector(_) => "SparseVector",
1992                    MetadataValue::BoolArray(_) => "BoolArray",
1993                    MetadataValue::IntArray(_) => "IntArray",
1994                    MetadataValue::FloatArray(_) => "FloatArray",
1995                    MetadataValue::StringArray(_) => "StringArray",
1996                };
1997                write!(f, "ArrayContains({}, {})", op, type_name)
1998            }
1999        }
2000    }
2001}
2002
2003#[derive(Clone, Debug, PartialEq)]
2004#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2005pub enum PrimitiveOperator {
2006    Equal,
2007    NotEqual,
2008    GreaterThan,
2009    GreaterThanOrEqual,
2010    LessThan,
2011    LessThanOrEqual,
2012}
2013
2014impl std::fmt::Display for PrimitiveOperator {
2015    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2016        let op_str = match self {
2017            PrimitiveOperator::Equal => "=",
2018            PrimitiveOperator::NotEqual => "≠",
2019            PrimitiveOperator::GreaterThan => ">",
2020            PrimitiveOperator::GreaterThanOrEqual => "≥",
2021            PrimitiveOperator::LessThan => "<",
2022            PrimitiveOperator::LessThanOrEqual => "≤",
2023        };
2024        write!(f, "{}", op_str)
2025    }
2026}
2027
2028impl TryFrom<chroma_proto::GenericComparator> for PrimitiveOperator {
2029    type Error = WhereConversionError;
2030
2031    fn try_from(value: chroma_proto::GenericComparator) -> Result<Self, Self::Error> {
2032        match value {
2033            chroma_proto::GenericComparator::Eq => Ok(Self::Equal),
2034            chroma_proto::GenericComparator::Ne => Ok(Self::NotEqual),
2035            chroma_proto::GenericComparator::ArrayContains
2036            | chroma_proto::GenericComparator::ArrayNotContains => {
2037                Err(WhereConversionError::cause(
2038                    "ArrayContains/ArrayNotContains cannot be converted to PrimitiveOperator",
2039                ))
2040            }
2041        }
2042    }
2043}
2044
2045impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
2046    type Error = WhereConversionError;
2047
2048    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
2049        match value {
2050            PrimitiveOperator::Equal => Ok(Self::Eq),
2051            PrimitiveOperator::NotEqual => Ok(Self::Ne),
2052            op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
2053        }
2054    }
2055}
2056
2057impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
2058    fn from(value: chroma_proto::NumberComparator) -> Self {
2059        match value {
2060            chroma_proto::NumberComparator::Gt => Self::GreaterThan,
2061            chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
2062            chroma_proto::NumberComparator::Lt => Self::LessThan,
2063            chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
2064        }
2065    }
2066}
2067
2068impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
2069    type Error = WhereConversionError;
2070
2071    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
2072        match value {
2073            PrimitiveOperator::GreaterThan => Ok(Self::Gt),
2074            PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
2075            PrimitiveOperator::LessThan => Ok(Self::Lt),
2076            PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
2077            op => Err(WhereConversionError::cause(format!(
2078                "{op:?} ∉ [≤, <, >, ≥]"
2079            ))),
2080        }
2081    }
2082}
2083
2084#[derive(Clone, Debug, PartialEq, Eq)]
2085#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2086pub enum SetOperator {
2087    In,
2088    NotIn,
2089}
2090
2091impl std::fmt::Display for SetOperator {
2092    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2093        let op_str = match self {
2094            SetOperator::In => "∈",
2095            SetOperator::NotIn => "∉",
2096        };
2097        write!(f, "{}", op_str)
2098    }
2099}
2100
2101impl From<chroma_proto::ListOperator> for SetOperator {
2102    fn from(value: chroma_proto::ListOperator) -> Self {
2103        match value {
2104            chroma_proto::ListOperator::In => Self::In,
2105            chroma_proto::ListOperator::Nin => Self::NotIn,
2106        }
2107    }
2108}
2109
2110impl From<SetOperator> for chroma_proto::ListOperator {
2111    fn from(value: SetOperator) -> Self {
2112        match value {
2113            SetOperator::In => Self::In,
2114            SetOperator::NotIn => Self::Nin,
2115        }
2116    }
2117}
2118
2119#[derive(Clone, Debug, PartialEq, Eq)]
2120#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2121pub enum ContainsOperator {
2122    Contains,
2123    NotContains,
2124}
2125
2126impl std::fmt::Display for ContainsOperator {
2127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2128        let op_str = match self {
2129            ContainsOperator::Contains => "contains",
2130            ContainsOperator::NotContains => "not_contains",
2131        };
2132        write!(f, "{}", op_str)
2133    }
2134}
2135
2136impl From<ContainsOperator> for chroma_proto::GenericComparator {
2137    fn from(value: ContainsOperator) -> Self {
2138        match value {
2139            ContainsOperator::Contains => Self::ArrayContains,
2140            ContainsOperator::NotContains => Self::ArrayNotContains,
2141        }
2142    }
2143}
2144
2145#[derive(Clone, Debug, PartialEq)]
2146#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2147pub enum MetadataSetValue {
2148    Bool(Vec<bool>),
2149    Int(Vec<i64>),
2150    Float(Vec<f64>),
2151    Str(Vec<String>),
2152}
2153
2154impl std::fmt::Display for MetadataSetValue {
2155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2156        match self {
2157            MetadataSetValue::Bool(values) => {
2158                let values_str = values
2159                    .iter()
2160                    .map(|v| format!("\"{}\"", v))
2161                    .collect::<Vec<_>>()
2162                    .join(", ");
2163                write!(f, "[{}]", values_str)
2164            }
2165            MetadataSetValue::Int(values) => {
2166                let values_str = values
2167                    .iter()
2168                    .map(|v| v.to_string())
2169                    .collect::<Vec<_>>()
2170                    .join(", ");
2171                write!(f, "[{}]", values_str)
2172            }
2173            MetadataSetValue::Float(values) => {
2174                let values_str = values
2175                    .iter()
2176                    .map(|v| v.to_string())
2177                    .collect::<Vec<_>>()
2178                    .join(", ");
2179                write!(f, "[{}]", values_str)
2180            }
2181            MetadataSetValue::Str(values) => {
2182                let values_str = values
2183                    .iter()
2184                    .map(|v| format!("\"{}\"", v))
2185                    .collect::<Vec<_>>()
2186                    .join(", ");
2187                write!(f, "[{}]", values_str)
2188            }
2189        }
2190    }
2191}
2192
2193impl MetadataSetValue {
2194    pub fn value_type(&self) -> MetadataValueType {
2195        match self {
2196            MetadataSetValue::Bool(_) => MetadataValueType::Bool,
2197            MetadataSetValue::Int(_) => MetadataValueType::Int,
2198            MetadataSetValue::Float(_) => MetadataValueType::Float,
2199            MetadataSetValue::Str(_) => MetadataValueType::Str,
2200        }
2201    }
2202}
2203
2204impl From<Vec<bool>> for MetadataSetValue {
2205    fn from(values: Vec<bool>) -> Self {
2206        MetadataSetValue::Bool(values)
2207    }
2208}
2209
2210impl From<Vec<i64>> for MetadataSetValue {
2211    fn from(values: Vec<i64>) -> Self {
2212        MetadataSetValue::Int(values)
2213    }
2214}
2215
2216impl From<Vec<i32>> for MetadataSetValue {
2217    fn from(values: Vec<i32>) -> Self {
2218        MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
2219    }
2220}
2221
2222impl From<Vec<f64>> for MetadataSetValue {
2223    fn from(values: Vec<f64>) -> Self {
2224        MetadataSetValue::Float(values)
2225    }
2226}
2227
2228impl From<Vec<f32>> for MetadataSetValue {
2229    fn from(values: Vec<f32>) -> Self {
2230        MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
2231    }
2232}
2233
2234impl From<Vec<String>> for MetadataSetValue {
2235    fn from(values: Vec<String>) -> Self {
2236        MetadataSetValue::Str(values)
2237    }
2238}
2239
2240impl From<Vec<&str>> for MetadataSetValue {
2241    fn from(values: Vec<&str>) -> Self {
2242        MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
2243    }
2244}
2245
2246// TODO: Deprecate where_document
2247impl TryFrom<chroma_proto::WhereDocument> for Where {
2248    type Error = WhereConversionError;
2249
2250    fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
2251        match proto_document.r#where_document {
2252            Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
2253                let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
2254                    proto_comparison.operator,
2255                ) {
2256                    Ok(operator) => operator,
2257                    Err(_) => {
2258                        return Err(WhereConversionError::cause(
2259                            "[Deprecated] Invalid where document operator",
2260                        ))
2261                    }
2262                };
2263                let comparison = DocumentExpression {
2264                    pattern: proto_comparison.pattern,
2265                    operator: operator.into(),
2266                };
2267                Ok(Where::Document(comparison))
2268            }
2269            Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
2270                let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
2271                    proto_children.operator,
2272                ) {
2273                    Ok(operator) => operator,
2274                    Err(_) => {
2275                        return Err(WhereConversionError::cause(
2276                            "[Deprecated] Invalid boolean operator",
2277                        ))
2278                    }
2279                };
2280                let children = CompositeExpression {
2281                    children: proto_children
2282                        .children
2283                        .into_iter()
2284                        .map(|child| child.try_into())
2285                        .collect::<Result<_, _>>()?,
2286                    operator: operator.into(),
2287                };
2288                Ok(Where::Composite(children))
2289            }
2290            None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
2291        }
2292    }
2293}
2294
2295#[cfg(test)]
2296mod tests {
2297    use crate::operator::Key;
2298
2299    use super::*;
2300
2301    // This is needed for the tests that round trip to the python world.
2302    #[cfg(feature = "pyo3")]
2303    fn ensure_python_interpreter() {
2304        static PYTHON_INIT: std::sync::Once = std::sync::Once::new();
2305        PYTHON_INIT.call_once(|| {
2306            pyo3::prepare_freethreaded_python();
2307        });
2308    }
2309
2310    #[test]
2311    fn test_update_metadata_try_from() {
2312        let mut proto_metadata = chroma_proto::UpdateMetadata {
2313            metadata: HashMap::new(),
2314        };
2315        proto_metadata.metadata.insert(
2316            "foo".to_string(),
2317            chroma_proto::UpdateMetadataValue {
2318                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
2319            },
2320        );
2321        proto_metadata.metadata.insert(
2322            "bar".to_string(),
2323            chroma_proto::UpdateMetadataValue {
2324                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
2325            },
2326        );
2327        proto_metadata.metadata.insert(
2328            "baz".to_string(),
2329            chroma_proto::UpdateMetadataValue {
2330                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
2331                    "42".to_string(),
2332                )),
2333            },
2334        );
2335        // Add sparse vector test
2336        proto_metadata.metadata.insert(
2337            "sparse".to_string(),
2338            chroma_proto::UpdateMetadataValue {
2339                value: Some(
2340                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
2341                        chroma_proto::SparseVector {
2342                            indices: vec![0, 5, 10],
2343                            values: vec![0.1, 0.5, 0.9],
2344                            tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2345                        },
2346                    ),
2347                ),
2348            },
2349        );
2350        let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
2351        assert_eq!(converted_metadata.len(), 4);
2352        assert_eq!(
2353            converted_metadata.get("foo").unwrap(),
2354            &UpdateMetadataValue::Int(42)
2355        );
2356        assert_eq!(
2357            converted_metadata.get("bar").unwrap(),
2358            &UpdateMetadataValue::Float(42.0)
2359        );
2360        assert_eq!(
2361            converted_metadata.get("baz").unwrap(),
2362            &UpdateMetadataValue::Str("42".to_string())
2363        );
2364        assert_eq!(
2365            converted_metadata.get("sparse").unwrap(),
2366            &UpdateMetadataValue::SparseVector(
2367                SparseVector::new_with_tokens(
2368                    vec![0, 5, 10],
2369                    vec![0.1, 0.5, 0.9],
2370                    vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2371                )
2372                .unwrap()
2373            )
2374        );
2375    }
2376
2377    #[test]
2378    fn test_metadata_try_from() {
2379        let mut proto_metadata = chroma_proto::UpdateMetadata {
2380            metadata: HashMap::new(),
2381        };
2382        proto_metadata.metadata.insert(
2383            "foo".to_string(),
2384            chroma_proto::UpdateMetadataValue {
2385                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
2386            },
2387        );
2388        proto_metadata.metadata.insert(
2389            "bar".to_string(),
2390            chroma_proto::UpdateMetadataValue {
2391                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
2392            },
2393        );
2394        proto_metadata.metadata.insert(
2395            "baz".to_string(),
2396            chroma_proto::UpdateMetadataValue {
2397                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
2398                    "42".to_string(),
2399                )),
2400            },
2401        );
2402        // Add sparse vector test
2403        proto_metadata.metadata.insert(
2404            "sparse".to_string(),
2405            chroma_proto::UpdateMetadataValue {
2406                value: Some(
2407                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
2408                        chroma_proto::SparseVector {
2409                            indices: vec![1, 10, 100],
2410                            values: vec![0.2, 0.4, 0.6],
2411                            tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2412                        },
2413                    ),
2414                ),
2415            },
2416        );
2417        let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
2418        assert_eq!(converted_metadata.len(), 4);
2419        assert_eq!(
2420            converted_metadata.get("foo").unwrap(),
2421            &MetadataValue::Int(42)
2422        );
2423        assert_eq!(
2424            converted_metadata.get("bar").unwrap(),
2425            &MetadataValue::Float(42.0)
2426        );
2427        assert_eq!(
2428            converted_metadata.get("baz").unwrap(),
2429            &MetadataValue::Str("42".to_string())
2430        );
2431        assert_eq!(
2432            converted_metadata.get("sparse").unwrap(),
2433            &MetadataValue::SparseVector(
2434                SparseVector::new_with_tokens(
2435                    vec![1, 10, 100],
2436                    vec![0.2, 0.4, 0.6],
2437                    vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2438                )
2439                .unwrap()
2440            )
2441        );
2442    }
2443
2444    #[test]
2445    fn test_where_clause_simple_from() {
2446        let proto_where = chroma_proto::Where {
2447            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2448                chroma_proto::DirectComparison {
2449                    key: "foo".to_string(),
2450                    comparison: Some(
2451                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2452                            chroma_proto::SingleIntComparison {
2453                                value: 42,
2454                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2455                            },
2456                        ),
2457                    ),
2458                },
2459            )),
2460        };
2461        let where_clause: Where = proto_where.try_into().unwrap();
2462        match where_clause {
2463            Where::Metadata(comparison) => {
2464                assert_eq!(comparison.key, "foo");
2465                match comparison.comparison {
2466                    MetadataComparison::Primitive(_, value) => {
2467                        assert_eq!(value, MetadataValue::Int(42));
2468                    }
2469                    _ => panic!("Invalid comparison type"),
2470                }
2471            }
2472            _ => panic!("Invalid where type"),
2473        }
2474    }
2475
2476    #[test]
2477    fn test_where_clause_with_children() {
2478        let proto_where = chroma_proto::Where {
2479            r#where: Some(chroma_proto::r#where::Where::Children(
2480                chroma_proto::WhereChildren {
2481                    children: vec![
2482                        chroma_proto::Where {
2483                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2484                                chroma_proto::DirectComparison {
2485                                    key: "foo".to_string(),
2486                                    comparison: Some(
2487                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2488                                            chroma_proto::SingleIntComparison {
2489                                                value: 42,
2490                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2491                                            },
2492                                        ),
2493                                    ),
2494                                },
2495                            )),
2496                        },
2497                        chroma_proto::Where {
2498                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2499                                chroma_proto::DirectComparison {
2500                                    key: "bar".to_string(),
2501                                    comparison: Some(
2502                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2503                                            chroma_proto::SingleIntComparison {
2504                                                value: 42,
2505                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2506                                            },
2507                                        ),
2508                                    ),
2509                                },
2510                            )),
2511                        },
2512                    ],
2513                    operator: chroma_proto::BooleanOperator::And.into(),
2514                },
2515            )),
2516        };
2517        let where_clause: Where = proto_where.try_into().unwrap();
2518        match where_clause {
2519            Where::Composite(children) => {
2520                assert_eq!(children.children.len(), 2);
2521                assert_eq!(children.operator, BooleanOperator::And);
2522            }
2523            _ => panic!("Invalid where type"),
2524        }
2525    }
2526
2527    #[test]
2528    fn test_where_document_simple() {
2529        let proto_where = chroma_proto::WhereDocument {
2530            r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
2531                chroma_proto::DirectWhereDocument {
2532                    pattern: "foo".to_string(),
2533                    operator: chroma_proto::WhereDocumentOperator::Contains.into(),
2534                },
2535            )),
2536        };
2537        let where_document: Where = proto_where.try_into().unwrap();
2538        match where_document {
2539            Where::Document(comparison) => {
2540                assert_eq!(comparison.pattern, "foo");
2541                assert_eq!(comparison.operator, DocumentOperator::Contains);
2542            }
2543            _ => panic!("Invalid where document type"),
2544        }
2545    }
2546
2547    #[test]
2548    fn test_where_document_with_children() {
2549        let proto_where = chroma_proto::WhereDocument {
2550            r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
2551                chroma_proto::WhereDocumentChildren {
2552                    children: vec![
2553                        chroma_proto::WhereDocument {
2554                            r#where_document: Some(
2555                                chroma_proto::where_document::WhereDocument::Direct(
2556                                    chroma_proto::DirectWhereDocument {
2557                                        pattern: "foo".to_string(),
2558                                        operator: chroma_proto::WhereDocumentOperator::Contains
2559                                            .into(),
2560                                    },
2561                                ),
2562                            ),
2563                        },
2564                        chroma_proto::WhereDocument {
2565                            r#where_document: Some(
2566                                chroma_proto::where_document::WhereDocument::Direct(
2567                                    chroma_proto::DirectWhereDocument {
2568                                        pattern: "bar".to_string(),
2569                                        operator: chroma_proto::WhereDocumentOperator::Contains
2570                                            .into(),
2571                                    },
2572                                ),
2573                            ),
2574                        },
2575                    ],
2576                    operator: chroma_proto::BooleanOperator::And.into(),
2577                },
2578            )),
2579        };
2580        let where_document: Where = proto_where.try_into().unwrap();
2581        match where_document {
2582            Where::Composite(children) => {
2583                assert_eq!(children.children.len(), 2);
2584                assert_eq!(children.operator, BooleanOperator::And);
2585            }
2586            _ => panic!("Invalid where document type"),
2587        }
2588    }
2589
2590    #[test]
2591    fn test_sparse_vector_new() {
2592        let indices = vec![0, 5, 10];
2593        let values = vec![0.1, 0.5, 0.9];
2594        let sparse = SparseVector::new(indices.clone(), values.clone()).unwrap();
2595        assert_eq!(sparse.indices, indices);
2596        assert_eq!(sparse.values, values);
2597    }
2598
2599    #[test]
2600    fn test_sparse_vector_from_pairs() {
2601        let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2602        let sparse = SparseVector::from_pairs(pairs.clone());
2603        assert_eq!(sparse.indices, vec![0, 5, 10]);
2604        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2605    }
2606
2607    #[test]
2608    fn test_sparse_vector_from_triples() {
2609        let triples = vec![
2610            ("foo".to_string(), 0, 0.1),
2611            ("bar".to_string(), 5, 0.5),
2612            ("baz".to_string(), 10, 0.9),
2613        ];
2614        let sparse = SparseVector::from_triples(triples.clone());
2615        assert_eq!(sparse.indices, vec![0, 5, 10]);
2616        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2617    }
2618
2619    #[test]
2620    fn test_sparse_vector_iter() {
2621        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2622        let collected: Vec<(u32, f32)> = sparse.iter().collect();
2623        assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2624    }
2625
2626    #[test]
2627    fn test_sparse_vector_ordering() {
2628        let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2629        let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2630        let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]).unwrap();
2631        let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]).unwrap();
2632
2633        assert_eq!(sparse1, sparse2);
2634        assert!(sparse1 < sparse3);
2635        assert!(sparse1 < sparse4);
2636    }
2637
2638    #[test]
2639    fn test_sparse_vector_proto_conversion() {
2640        let tokens = vec![
2641            "token1".to_string(),
2642            "token2".to_string(),
2643            "token3".to_string(),
2644        ];
2645        let sparse =
2646            SparseVector::new_with_tokens(vec![1, 10, 100], vec![0.2, 0.4, 0.6], tokens.clone())
2647                .unwrap();
2648        let proto: chroma_proto::SparseVector = sparse.clone().into();
2649        assert_eq!(proto.indices, vec![1, 10, 100]);
2650        assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2651        assert_eq!(proto.tokens, tokens.clone());
2652
2653        let converted: SparseVector = proto.try_into().unwrap();
2654        assert_eq!(converted, sparse);
2655        assert_eq!(converted.tokens, Some(tokens));
2656    }
2657
2658    #[test]
2659    fn test_sparse_vector_proto_conversion_empty_tokens() {
2660        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2661        let proto: chroma_proto::SparseVector = sparse.clone().into();
2662        assert_eq!(proto.indices, vec![0, 5, 10]);
2663        assert_eq!(proto.values, vec![0.1, 0.5, 0.9]);
2664        assert_eq!(proto.tokens, Vec::<String>::new());
2665
2666        let converted: SparseVector = proto.try_into().unwrap();
2667        assert_eq!(converted, sparse);
2668        assert_eq!(converted.tokens, None);
2669    }
2670
2671    #[test]
2672    fn test_sparse_vector_logical_size() {
2673        let metadata = Metadata::from([(
2674            "sparse".to_string(),
2675            MetadataValue::SparseVector(
2676                SparseVector::new(vec![0, 1, 2, 3, 4], vec![0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(),
2677            ),
2678        )]);
2679
2680        let size = logical_size_of_metadata(&metadata);
2681        // Size should include the key string length and the sparse vector data
2682        // "sparse" = 6 bytes + 5 * 4 bytes (u32 indices) + 5 * 4 bytes (f32 values) = 46 bytes
2683        assert_eq!(size, 46);
2684    }
2685
2686    #[test]
2687    fn test_sparse_vector_validation() {
2688        // Valid sparse vector
2689        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
2690        assert!(sparse.validate().is_ok());
2691
2692        // Length mismatch
2693        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2694        assert!(sparse.is_err());
2695        let result = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3])
2696            .unwrap()
2697            .validate();
2698        assert!(result.is_ok());
2699
2700        // Tokens length mismatch with indices/values
2701        let sparse = SparseVector::new_with_tokens(
2702            vec![1, 2, 3],
2703            vec![0.1, 0.2, 0.3],
2704            vec!["a".to_string(), "b".to_string()],
2705        );
2706        assert!(sparse.is_err());
2707
2708        // Unsorted indices (descending order)
2709        let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap();
2710        let result = sparse.validate();
2711        assert!(result.is_err());
2712        assert!(matches!(
2713            result.unwrap_err(),
2714            MetadataValueConversionError::SparseVectorIndicesNotSorted
2715        ));
2716
2717        // Duplicate indices (not strictly ascending)
2718        let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]).unwrap();
2719        let result = sparse.validate();
2720        assert!(result.is_err());
2721        assert!(matches!(
2722            result.unwrap_err(),
2723            MetadataValueConversionError::SparseVectorIndicesNotSorted
2724        ));
2725
2726        // Descending at one point
2727        let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]).unwrap();
2728        let result = sparse.validate();
2729        assert!(result.is_err());
2730        assert!(matches!(
2731            result.unwrap_err(),
2732            MetadataValueConversionError::SparseVectorIndicesNotSorted
2733        ));
2734    }
2735
2736    #[test]
2737    fn test_sparse_vector_deserialize_old_format() {
2738        // Old format without #type field (backward compatibility)
2739        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2740        let sv: SparseVector = serde_json::from_str(json).unwrap();
2741        assert_eq!(sv.indices, vec![0, 1, 2]);
2742        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2743    }
2744
2745    #[test]
2746    fn test_sparse_vector_deserialize_new_format() {
2747        // New format with #type field
2748        let json =
2749            "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2750        let sv: SparseVector = serde_json::from_str(json).unwrap();
2751        assert_eq!(sv.indices, vec![0, 1, 2]);
2752        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2753    }
2754
2755    #[test]
2756    fn test_sparse_vector_deserialize_new_format_field_order() {
2757        // New format with different field order (should still work)
2758        let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2759        let sv: SparseVector = serde_json::from_str(json).unwrap();
2760        assert_eq!(sv.indices, vec![5, 10]);
2761        assert_eq!(sv.values, vec![0.5, 1.0]);
2762    }
2763
2764    #[test]
2765    fn test_sparse_vector_deserialize_wrong_type_tag() {
2766        // Wrong #type field value should fail
2767        let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2768        let result: Result<SparseVector, _> = serde_json::from_str(json);
2769        assert!(result.is_err());
2770        let err_msg = result.unwrap_err().to_string();
2771        assert!(err_msg.contains("sparse_vector"));
2772    }
2773
2774    #[test]
2775    fn test_sparse_vector_serialize_always_has_type() {
2776        // Serialization should always include #type field
2777        let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]).unwrap();
2778        let json = serde_json::to_value(&sv).unwrap();
2779
2780        assert_eq!(json["#type"], "sparse_vector");
2781        assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2782        assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2783    }
2784
2785    #[test]
2786    fn test_sparse_vector_roundtrip_with_type() {
2787        // Test that serialize -> deserialize preserves the data
2788        let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]).unwrap();
2789        let json = serde_json::to_string(&original).unwrap();
2790
2791        // Verify the serialized JSON contains #type
2792        assert!(json.contains("\"#type\":\"sparse_vector\""));
2793
2794        let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2795        assert_eq!(original, deserialized);
2796    }
2797
2798    #[test]
2799    fn test_sparse_vector_in_metadata_old_format() {
2800        // Test that old format works when sparse vector is in metadata
2801        let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2802        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2803
2804        let sparse_value = &map["sparse"];
2805        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2806        assert_eq!(sv.indices, vec![0, 1]);
2807        assert_eq!(sv.values, vec![1.0, 2.0]);
2808    }
2809
2810    #[test]
2811    fn test_sparse_vector_in_metadata_new_format() {
2812        // Test that new format works when sparse vector is in metadata
2813        let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2814        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2815
2816        let sparse_value = &map["sparse"];
2817        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2818        assert_eq!(sv.indices, vec![0, 1]);
2819        assert_eq!(sv.values, vec![1.0, 2.0]);
2820    }
2821
2822    #[test]
2823    fn test_sparse_vector_tokens_roundtrip_old_to_new() {
2824        // Old format without tokens field should deserialize with tokens=None
2825        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2826        let sv: SparseVector = serde_json::from_str(json).unwrap();
2827        assert_eq!(sv.indices, vec![0, 1, 2]);
2828        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2829        assert_eq!(sv.tokens, None);
2830
2831        // Serialize and verify it includes #type but no tokens field when None
2832        let serialized = serde_json::to_value(&sv).unwrap();
2833        assert_eq!(serialized["#type"], "sparse_vector");
2834        assert_eq!(serialized["indices"], serde_json::json!([0, 1, 2]));
2835        assert_eq!(serialized["values"], serde_json::json!([1.0, 2.0, 3.0]));
2836        assert_eq!(serialized["tokens"], serde_json::Value::Null);
2837    }
2838
2839    #[test]
2840    fn test_sparse_vector_tokens_roundtrip_new_to_new() {
2841        // New format with tokens field
2842        let sv_with_tokens = SparseVector::new_with_tokens(
2843            vec![0, 1, 2],
2844            vec![1.0, 2.0, 3.0],
2845            vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2846        )
2847        .unwrap();
2848
2849        // Serialize
2850        let serialized = serde_json::to_string(&sv_with_tokens).unwrap();
2851        assert!(serialized.contains("\"#type\":\"sparse_vector\""));
2852        assert!(serialized.contains("\"tokens\""));
2853
2854        // Deserialize and verify tokens are preserved
2855        let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2856        assert_eq!(deserialized.indices, vec![0, 1, 2]);
2857        assert_eq!(deserialized.values, vec![1.0, 2.0, 3.0]);
2858        assert_eq!(
2859            deserialized.tokens,
2860            Some(vec![
2861                "foo".to_string(),
2862                "bar".to_string(),
2863                "baz".to_string()
2864            ])
2865        );
2866    }
2867
2868    #[test]
2869    fn test_sparse_vector_tokens_deserialize_with_tokens_field() {
2870        // Test deserializing JSON that explicitly includes tokens field
2871        let json = r##"{"#type": "sparse_vector", "indices": [5, 10], "values": [0.5, 1.0], "tokens": ["token1", "token2"]}"##;
2872        let sv: SparseVector = serde_json::from_str(json).unwrap();
2873        assert_eq!(sv.indices, vec![5, 10]);
2874        assert_eq!(sv.values, vec![0.5, 1.0]);
2875        assert_eq!(
2876            sv.tokens,
2877            Some(vec!["token1".to_string(), "token2".to_string()])
2878        );
2879    }
2880
2881    #[test]
2882    fn test_sparse_vector_tokens_backward_compatibility() {
2883        // Verify old format (no tokens, no #type) deserializes correctly
2884        let old_json = r#"{"indices": [1, 2], "values": [0.1, 0.2]}"#;
2885        let old_sv: SparseVector = serde_json::from_str(old_json).unwrap();
2886
2887        // Verify new format (with #type, with tokens) deserializes correctly
2888        let new_json = r##"{"#type": "sparse_vector", "indices": [1, 2], "values": [0.1, 0.2], "tokens": ["a", "b"]}"##;
2889        let new_sv: SparseVector = serde_json::from_str(new_json).unwrap();
2890
2891        // Both should have same indices and values
2892        assert_eq!(old_sv.indices, new_sv.indices);
2893        assert_eq!(old_sv.values, new_sv.values);
2894
2895        // Old should have None tokens, new should have Some tokens
2896        assert_eq!(old_sv.tokens, None);
2897        assert_eq!(new_sv.tokens, Some(vec!["a".to_string(), "b".to_string()]));
2898    }
2899
2900    #[test]
2901    fn test_sparse_vector_from_triples_preserves_tokens() {
2902        let triples = vec![
2903            ("apple".to_string(), 10, 0.5),
2904            ("banana".to_string(), 20, 0.7),
2905            ("cherry".to_string(), 30, 0.9),
2906        ];
2907        let sv = SparseVector::from_triples(triples.clone());
2908
2909        assert_eq!(sv.indices, vec![10, 20, 30]);
2910        assert_eq!(sv.values, vec![0.5, 0.7, 0.9]);
2911        assert_eq!(
2912            sv.tokens,
2913            Some(vec![
2914                "apple".to_string(),
2915                "banana".to_string(),
2916                "cherry".to_string()
2917            ])
2918        );
2919
2920        // Roundtrip through serialization
2921        let serialized = serde_json::to_string(&sv).unwrap();
2922        let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2923
2924        assert_eq!(deserialized.indices, sv.indices);
2925        assert_eq!(deserialized.values, sv.values);
2926        assert_eq!(deserialized.tokens, sv.tokens);
2927    }
2928
2929    #[cfg(feature = "pyo3")]
2930    #[test]
2931    fn test_sparse_vector_pyo3_roundtrip_with_tokens() {
2932        ensure_python_interpreter();
2933
2934        pyo3::Python::with_gil(|py| {
2935            use pyo3::types::PyDict;
2936            use pyo3::IntoPyObject;
2937
2938            let dict_in = PyDict::new(py);
2939            dict_in.set_item("indices", vec![0u32, 1, 2]).unwrap();
2940            dict_in
2941                .set_item("values", vec![0.1f32, 0.2f32, 0.3f32])
2942                .unwrap();
2943            dict_in
2944                .set_item("tokens", vec!["foo", "bar", "baz"])
2945                .unwrap();
2946
2947            let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2948            assert_eq!(sparse.indices, vec![0, 1, 2]);
2949            assert_eq!(sparse.values, vec![0.1, 0.2, 0.3]);
2950            assert_eq!(
2951                sparse.tokens,
2952                Some(vec![
2953                    "foo".to_string(),
2954                    "bar".to_string(),
2955                    "baz".to_string()
2956                ])
2957            );
2958
2959            let py_obj = sparse.clone().into_pyobject(py).unwrap();
2960            let dict_out = py_obj.downcast::<PyDict>().unwrap();
2961            let tokens_obj = dict_out.get_item("tokens").unwrap();
2962            let tokens: Vec<String> = tokens_obj
2963                .expect("expected tokens key in Python dict")
2964                .extract()
2965                .unwrap();
2966            assert_eq!(
2967                tokens,
2968                vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
2969            );
2970        });
2971    }
2972
2973    #[cfg(feature = "pyo3")]
2974    #[test]
2975    fn test_sparse_vector_pyo3_roundtrip_without_tokens() {
2976        ensure_python_interpreter();
2977
2978        pyo3::Python::with_gil(|py| {
2979            use pyo3::types::PyDict;
2980            use pyo3::IntoPyObject;
2981
2982            let dict_in = PyDict::new(py);
2983            dict_in.set_item("indices", vec![5u32]).unwrap();
2984            dict_in.set_item("values", vec![1.5f32]).unwrap();
2985
2986            let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2987            assert_eq!(sparse.indices, vec![5]);
2988            assert_eq!(sparse.values, vec![1.5]);
2989            assert!(sparse.tokens.is_none());
2990
2991            let py_obj = sparse.into_pyobject(py).unwrap();
2992            let dict_out = py_obj.downcast::<PyDict>().unwrap();
2993            let tokens_obj = dict_out.get_item("tokens").unwrap();
2994            let tokens_value = tokens_obj.expect("expected tokens key in Python dict");
2995            assert!(
2996                tokens_value.is_none(),
2997                "expected tokens value in Python dict to be None"
2998            );
2999        });
3000    }
3001
3002    #[test]
3003    fn test_simplifies_identities() {
3004        let all: Where = true.into();
3005        assert_eq!(all.clone() & all.clone(), true.into());
3006        assert_eq!(all.clone() | all.clone(), true.into());
3007
3008        let foo = Key::field("foo").eq("bar");
3009        assert_eq!(foo.clone() & all.clone(), foo.clone());
3010        assert_eq!(all.clone() & foo.clone(), foo.clone());
3011
3012        let none: Where = false.into();
3013        assert_eq!(foo.clone() | none.clone(), foo.clone());
3014        assert_eq!(none | foo.clone(), foo);
3015    }
3016
3017    #[test]
3018    fn test_flattens() {
3019        let foo = Key::field("foo").eq("bar");
3020        let baz = Key::field("baz").eq("quux");
3021
3022        let and_nested = foo.clone() & (baz.clone() & foo.clone());
3023        assert_eq!(
3024            and_nested,
3025            Where::Composite(CompositeExpression {
3026                operator: BooleanOperator::And,
3027                children: vec![foo.clone(), baz.clone(), foo.clone()]
3028            })
3029        );
3030
3031        let or_nested = foo.clone() | (baz.clone() | foo.clone());
3032        assert_eq!(
3033            or_nested,
3034            Where::Composite(CompositeExpression {
3035                operator: BooleanOperator::Or,
3036                children: vec![foo.clone(), baz.clone(), foo.clone()]
3037            })
3038        );
3039    }
3040}