chroma_types/
metadata.rs

1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7    cmp::Ordering,
8    collections::{HashMap, HashSet},
9    mem::size_of_val,
10    ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::{PyAnyMethods, PyDictMethods};
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24    #[serde(rename = "#type")]
25    type_tag: Option<String>,
26    indices: Vec<u32>,
27    values: Vec<f32>,
28    tokens: Option<Vec<String>>,
29}
30
31/// Represents a sparse vector using parallel arrays for indices and values.
32///
33/// On deserialization: accepts both old format `{"indices": [...], "values": [...]}`
34/// and new format `{"#type": "sparse_vector", "indices": [...], "values": [...]}`.
35///
36/// On serialization: always includes `#type` field with value `"sparse_vector"`.
37#[derive(Clone, Debug, PartialEq)]
38#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
39pub struct SparseVector {
40    /// Dimension indices
41    pub indices: Vec<u32>,
42    /// Values corresponding to each index
43    pub values: Vec<f32>,
44    /// Tokens corresponding to each index
45    pub tokens: Option<Vec<String>>,
46}
47
48// Custom deserializer: accept both old and new formats
49impl<'de> Deserialize<'de> for SparseVector {
50    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51    where
52        D: Deserializer<'de>,
53    {
54        let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
55
56        // If #type is present, validate it
57        if let Some(type_tag) = &helper.type_tag {
58            if type_tag != "sparse_vector" {
59                return Err(serde::de::Error::custom(format!(
60                    "Expected #type='sparse_vector', got '{}'",
61                    type_tag
62                )));
63            }
64        }
65
66        Ok(SparseVector {
67            indices: helper.indices,
68            values: helper.values,
69            tokens: helper.tokens,
70        })
71    }
72}
73
74// Custom serializer: always include #type field
75impl Serialize for SparseVector {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        let helper = SparseVectorSerdeHelper {
81            type_tag: Some("sparse_vector".to_string()),
82            indices: self.indices.clone(),
83            values: self.values.clone(),
84            tokens: self.tokens.clone(),
85        };
86        helper.serialize(serializer)
87    }
88}
89
90/// Length mismatch between indices, values, and tokens in a sparse vector.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct SparseVectorLengthMismatch;
93
94impl std::fmt::Display for SparseVectorLengthMismatch {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(
97            f,
98            "Sparse vector indices, values, and tokens (when present) must have the same length"
99        )
100    }
101}
102
103impl std::error::Error for SparseVectorLengthMismatch {}
104
105impl ChromaError for SparseVectorLengthMismatch {
106    fn code(&self) -> ErrorCodes {
107        ErrorCodes::InvalidArgument
108    }
109}
110
111impl SparseVector {
112    /// Create a new sparse vector from parallel arrays.
113    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseVectorLengthMismatch> {
114        if indices.len() != values.len() {
115            return Err(SparseVectorLengthMismatch);
116        }
117        Ok(Self {
118            indices,
119            values,
120            tokens: None,
121        })
122    }
123
124    /// Create a new sparse vector from parallel arrays.
125    pub fn new_with_tokens(
126        indices: Vec<u32>,
127        values: Vec<f32>,
128        tokens: Vec<String>,
129    ) -> Result<Self, SparseVectorLengthMismatch> {
130        if indices.len() != values.len() {
131            return Err(SparseVectorLengthMismatch);
132        }
133        if tokens.len() != indices.len() {
134            return Err(SparseVectorLengthMismatch);
135        }
136        Ok(Self {
137            indices,
138            values,
139            tokens: Some(tokens),
140        })
141    }
142
143    /// Create a sparse vector from an iterator of (index, value) pairs.
144    pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
145        let mut indices = vec![];
146        let mut values = vec![];
147        for (index, value) in pairs {
148            indices.push(index);
149            values.push(value);
150        }
151        let tokens = None;
152        Self {
153            indices,
154            values,
155            tokens,
156        }
157    }
158
159    /// Create a sparse vector from an iterator of (string, index, value) pairs.
160    pub fn from_triples(triples: impl IntoIterator<Item = (String, u32, f32)>) -> Self {
161        let mut tokens = vec![];
162        let mut indices = vec![];
163        let mut values = vec![];
164        for (token, index, value) in triples {
165            tokens.push(token);
166            indices.push(index);
167            values.push(value);
168        }
169        let tokens = Some(tokens);
170        Self {
171            indices,
172            values,
173            tokens,
174        }
175    }
176
177    /// Iterate over (index, value) pairs.
178    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
179        self.indices
180            .iter()
181            .copied()
182            .zip(self.values.iter().copied())
183    }
184
185    /// Validate the sparse vector
186    pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
187        // Check that indices and values have the same length
188        if self.indices.len() != self.values.len() {
189            return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
190        }
191
192        // Check that tokens (if present) align with indices
193        if let Some(tokens) = self.tokens.as_ref() {
194            if tokens.len() != self.indices.len() {
195                return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
196            }
197        }
198
199        // Check that indices are sorted in strictly ascending order (no duplicates)
200        for i in 1..self.indices.len() {
201            if self.indices[i] <= self.indices[i - 1] {
202                return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
203            }
204        }
205
206        Ok(())
207    }
208}
209
210impl Eq for SparseVector {}
211
212impl Ord for SparseVector {
213    fn cmp(&self, other: &Self) -> Ordering {
214        self.indices.cmp(&other.indices).then_with(|| {
215            for (a, b) in self.values.iter().zip(other.values.iter()) {
216                match a.total_cmp(b) {
217                    Ordering::Equal => continue,
218                    other => return other,
219                }
220            }
221            self.values.len().cmp(&other.values.len())
222        })
223    }
224}
225
226impl PartialOrd for SparseVector {
227    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
228        Some(self.cmp(other))
229    }
230}
231
232impl TryFrom<chroma_proto::SparseVector> for SparseVector {
233    type Error = SparseVectorLengthMismatch;
234
235    fn try_from(proto: chroma_proto::SparseVector) -> Result<Self, Self::Error> {
236        if proto.tokens.is_empty() {
237            SparseVector::new(proto.indices, proto.values)
238        } else {
239            SparseVector::new_with_tokens(proto.indices, proto.values, proto.tokens)
240        }
241    }
242}
243
244impl From<SparseVector> for chroma_proto::SparseVector {
245    fn from(sparse: SparseVector) -> Self {
246        chroma_proto::SparseVector {
247            indices: sparse.indices,
248            values: sparse.values,
249            tokens: sparse.tokens.unwrap_or_default(),
250        }
251    }
252}
253
254/// Convert SparseVector to sprs::CsVec for efficient sparse operations
255impl From<&SparseVector> for CsVec<f32> {
256    fn from(sparse: &SparseVector) -> Self {
257        let (indices, values) = sparse
258            .iter()
259            .map(|(index, value)| (index as usize, value))
260            .unzip();
261        CsVec::new(u32::MAX as usize, indices, values)
262    }
263}
264
265impl From<SparseVector> for CsVec<f32> {
266    fn from(sparse: SparseVector) -> Self {
267        (&sparse).into()
268    }
269}
270
271#[cfg(feature = "pyo3")]
272impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
273    type Target = pyo3::PyAny;
274    type Output = pyo3::Bound<'py, Self::Target>;
275    type Error = pyo3::PyErr;
276
277    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
278        use pyo3::types::PyDict;
279
280        let dict = PyDict::new(py);
281        dict.set_item("indices", self.indices)?;
282        dict.set_item("values", self.values)?;
283        dict.set_item("tokens", self.tokens)?;
284        Ok(dict.into_any())
285    }
286}
287
288#[cfg(feature = "pyo3")]
289impl<'py> pyo3::FromPyObject<'py> for SparseVector {
290    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
291        use pyo3::types::PyDict;
292
293        let dict = ob.downcast::<PyDict>()?;
294        let indices_obj = dict.get_item("indices")?;
295        if indices_obj.is_none() {
296            return Err(pyo3::exceptions::PyKeyError::new_err(
297                "missing 'indices' key",
298            ));
299        }
300        let indices: Vec<u32> = indices_obj.unwrap().extract()?;
301
302        let values_obj = dict.get_item("values")?;
303        if values_obj.is_none() {
304            return Err(pyo3::exceptions::PyKeyError::new_err(
305                "missing 'values' key",
306            ));
307        }
308        let values: Vec<f32> = values_obj.unwrap().extract()?;
309
310        let tokens_obj = dict.get_item("tokens")?;
311        let tokens = match tokens_obj {
312            Some(obj) if obj.is_none() => None,
313            Some(obj) => Some(obj.extract::<Vec<String>>()?),
314            None => None,
315        };
316
317        let result = match tokens {
318            Some(tokens) => SparseVector::new_with_tokens(indices, values, tokens),
319            None => SparseVector::new(indices, values),
320        };
321
322        result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
323    }
324}
325
326#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
328#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
329#[serde(untagged)]
330pub enum UpdateMetadataValue {
331    Bool(bool),
332    Int(i64),
333    #[cfg_attr(
334        feature = "testing",
335        proptest(
336            strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
337        )
338    )]
339    Float(f64),
340    Str(String),
341    #[cfg_attr(feature = "testing", proptest(skip))]
342    SparseVector(SparseVector),
343    None,
344}
345
346#[cfg(feature = "pyo3")]
347impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
348    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
349        if ob.is_none() {
350            Ok(UpdateMetadataValue::None)
351        } else if let Ok(value) = ob.extract::<bool>() {
352            Ok(UpdateMetadataValue::Bool(value))
353        } else if let Ok(value) = ob.extract::<i64>() {
354            Ok(UpdateMetadataValue::Int(value))
355        } else if let Ok(value) = ob.extract::<f64>() {
356            Ok(UpdateMetadataValue::Float(value))
357        } else if let Ok(value) = ob.extract::<String>() {
358            Ok(UpdateMetadataValue::Str(value))
359        } else if let Ok(value) = ob.extract::<SparseVector>() {
360            Ok(UpdateMetadataValue::SparseVector(value))
361        } else {
362            Err(pyo3::exceptions::PyTypeError::new_err(
363                "Cannot convert Python object to UpdateMetadataValue",
364            ))
365        }
366    }
367}
368
369impl From<bool> for UpdateMetadataValue {
370    fn from(b: bool) -> Self {
371        Self::Bool(b)
372    }
373}
374
375impl From<i64> for UpdateMetadataValue {
376    fn from(v: i64) -> Self {
377        Self::Int(v)
378    }
379}
380
381impl From<i32> for UpdateMetadataValue {
382    fn from(v: i32) -> Self {
383        Self::Int(v as i64)
384    }
385}
386
387impl From<f64> for UpdateMetadataValue {
388    fn from(v: f64) -> Self {
389        Self::Float(v)
390    }
391}
392
393impl From<f32> for UpdateMetadataValue {
394    fn from(v: f32) -> Self {
395        Self::Float(v as f64)
396    }
397}
398
399impl From<String> for UpdateMetadataValue {
400    fn from(v: String) -> Self {
401        Self::Str(v)
402    }
403}
404
405impl From<&str> for UpdateMetadataValue {
406    fn from(v: &str) -> Self {
407        Self::Str(v.to_string())
408    }
409}
410
411impl From<SparseVector> for UpdateMetadataValue {
412    fn from(v: SparseVector) -> Self {
413        Self::SparseVector(v)
414    }
415}
416
417#[derive(Error, Debug)]
418pub enum UpdateMetadataValueConversionError {
419    #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
420    InvalidValue,
421}
422
423impl ChromaError for UpdateMetadataValueConversionError {
424    fn code(&self) -> ErrorCodes {
425        match self {
426            UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
427        }
428    }
429}
430
431impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
432    type Error = UpdateMetadataValueConversionError;
433
434    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
435        match &value.value {
436            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
437                Ok(UpdateMetadataValue::Bool(*value))
438            }
439            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
440                Ok(UpdateMetadataValue::Int(*value))
441            }
442            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
443                Ok(UpdateMetadataValue::Float(*value))
444            }
445            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
446                Ok(UpdateMetadataValue::Str(value.clone()))
447            }
448            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
449                let sparse = value
450                    .clone()
451                    .try_into()
452                    .map_err(|_| UpdateMetadataValueConversionError::InvalidValue)?;
453                Ok(UpdateMetadataValue::SparseVector(sparse))
454            }
455            // Used to communicate that the user wants to delete this key.
456            None => Ok(UpdateMetadataValue::None),
457        }
458    }
459}
460
461impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
462    fn from(value: UpdateMetadataValue) -> Self {
463        match value {
464            UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
465                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
466            },
467            UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
468                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
469            },
470            UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
471                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
472                    value,
473                )),
474            },
475            UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
476                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
477                    value,
478                )),
479            },
480            UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
481                value: Some(
482                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
483                        sparse_vec.into(),
484                    ),
485                ),
486            },
487            UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
488        }
489    }
490}
491
492impl TryFrom<&UpdateMetadataValue> for MetadataValue {
493    type Error = MetadataValueConversionError;
494
495    fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
496        match value {
497            UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
498            UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
499            UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
500            UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
501            UpdateMetadataValue::SparseVector(value) => {
502                Ok(MetadataValue::SparseVector(value.clone()))
503            }
504            UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
505        }
506    }
507}
508
509/*
510===========================================
511MetadataValue
512===========================================
513*/
514
515#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
516#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
517#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
518#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
519#[serde(untagged)]
520pub enum MetadataValue {
521    Bool(bool),
522    Int(i64),
523    #[cfg_attr(
524        feature = "testing",
525        proptest(
526            strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
527        )
528    )]
529    Float(f64),
530    Str(String),
531    #[cfg_attr(feature = "testing", proptest(skip))]
532    SparseVector(SparseVector),
533}
534
535impl std::fmt::Display for MetadataValue {
536    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537        match self {
538            MetadataValue::Bool(v) => write!(f, "{}", v),
539            MetadataValue::Int(v) => write!(f, "{}", v),
540            MetadataValue::Float(v) => write!(f, "{}", v),
541            MetadataValue::Str(v) => write!(f, "\"{}\"", v),
542            MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
543        }
544    }
545}
546
547impl Eq for MetadataValue {}
548
549#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
550pub enum MetadataValueType {
551    Bool,
552    Int,
553    Float,
554    Str,
555    SparseVector,
556}
557
558impl MetadataValue {
559    pub fn value_type(&self) -> MetadataValueType {
560        match self {
561            MetadataValue::Bool(_) => MetadataValueType::Bool,
562            MetadataValue::Int(_) => MetadataValueType::Int,
563            MetadataValue::Float(_) => MetadataValueType::Float,
564            MetadataValue::Str(_) => MetadataValueType::Str,
565            MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
566        }
567    }
568}
569
570impl From<&MetadataValue> for MetadataValueType {
571    fn from(value: &MetadataValue) -> Self {
572        value.value_type()
573    }
574}
575
576impl From<bool> for MetadataValue {
577    fn from(v: bool) -> Self {
578        MetadataValue::Bool(v)
579    }
580}
581
582impl From<i64> for MetadataValue {
583    fn from(v: i64) -> Self {
584        MetadataValue::Int(v)
585    }
586}
587
588impl From<i32> for MetadataValue {
589    fn from(v: i32) -> Self {
590        MetadataValue::Int(v as i64)
591    }
592}
593
594impl From<f64> for MetadataValue {
595    fn from(v: f64) -> Self {
596        MetadataValue::Float(v)
597    }
598}
599
600impl From<f32> for MetadataValue {
601    fn from(v: f32) -> Self {
602        MetadataValue::Float(v as f64)
603    }
604}
605
606impl From<String> for MetadataValue {
607    fn from(v: String) -> Self {
608        MetadataValue::Str(v)
609    }
610}
611
612impl From<&str> for MetadataValue {
613    fn from(v: &str) -> Self {
614        MetadataValue::Str(v.to_string())
615    }
616}
617
618impl From<SparseVector> for MetadataValue {
619    fn from(v: SparseVector) -> Self {
620        MetadataValue::SparseVector(v)
621    }
622}
623
624/// We need `Eq` and `Ord` since we want to use this as a key in `BTreeMap`
625///
626/// For cross-type comparisons, we define a consistent ordering based on variant position:
627/// Bool < Int < Float < Str < SparseVector
628#[allow(clippy::derive_ord_xor_partial_ord)]
629impl Ord for MetadataValue {
630    fn cmp(&self, other: &Self) -> Ordering {
631        // Define type ordering based on variant position
632        fn type_order(val: &MetadataValue) -> u8 {
633            match val {
634                MetadataValue::Bool(_) => 0,
635                MetadataValue::Int(_) => 1,
636                MetadataValue::Float(_) => 2,
637                MetadataValue::Str(_) => 3,
638                MetadataValue::SparseVector(_) => 4,
639            }
640        }
641
642        // Chain type ordering with value ordering
643        type_order(self).cmp(&type_order(other)).then_with(|| {
644            match (self, other) {
645                (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
646                (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
647                (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
648                (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
649                (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
650                    left.cmp(right)
651                }
652                _ => Ordering::Equal, // Different types, but type_order already handled this
653            }
654        })
655    }
656}
657
658impl PartialOrd for MetadataValue {
659    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
660        Some(self.cmp(other))
661    }
662}
663
664impl TryFrom<&MetadataValue> for bool {
665    type Error = MetadataValueConversionError;
666
667    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
668        match value {
669            MetadataValue::Bool(value) => Ok(*value),
670            _ => Err(MetadataValueConversionError::InvalidValue),
671        }
672    }
673}
674
675impl TryFrom<&MetadataValue> for i64 {
676    type Error = MetadataValueConversionError;
677
678    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
679        match value {
680            MetadataValue::Int(value) => Ok(*value),
681            _ => Err(MetadataValueConversionError::InvalidValue),
682        }
683    }
684}
685
686impl TryFrom<&MetadataValue> for f64 {
687    type Error = MetadataValueConversionError;
688
689    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
690        match value {
691            MetadataValue::Float(value) => Ok(*value),
692            _ => Err(MetadataValueConversionError::InvalidValue),
693        }
694    }
695}
696
697impl TryFrom<&MetadataValue> for String {
698    type Error = MetadataValueConversionError;
699
700    fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
701        match value {
702            MetadataValue::Str(value) => Ok(value.clone()),
703            _ => Err(MetadataValueConversionError::InvalidValue),
704        }
705    }
706}
707
708impl From<MetadataValue> for UpdateMetadataValue {
709    fn from(value: MetadataValue) -> Self {
710        match value {
711            MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
712            MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
713            MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
714            MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
715            MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
716        }
717    }
718}
719
720impl From<MetadataValue> for Value {
721    fn from(value: MetadataValue) -> Self {
722        match value {
723            MetadataValue::Bool(val) => Self::Bool(val),
724            MetadataValue::Int(val) => Self::Number(
725                Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
726            ),
727            MetadataValue::Float(val) => Self::Number(
728                Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
729            ),
730            MetadataValue::Str(val) => Self::String(val),
731            MetadataValue::SparseVector(val) => {
732                let mut map = serde_json::Map::new();
733                map.insert(
734                    "indices".to_string(),
735                    Value::Array(
736                        val.indices
737                            .iter()
738                            .map(|&i| Value::Number(i.into()))
739                            .collect(),
740                    ),
741                );
742                map.insert(
743                    "values".to_string(),
744                    Value::Array(
745                        val.values
746                            .iter()
747                            .map(|&v| {
748                                Value::Number(
749                                    Number::from_f64(v as f64)
750                                        .expect("Float number should not be NaN or infinite"),
751                                )
752                            })
753                            .collect(),
754                    ),
755                );
756                Self::Object(map)
757            }
758        }
759    }
760}
761
762#[derive(Error, Debug)]
763pub enum MetadataValueConversionError {
764    #[error("Invalid metadata value, valid values are: Int, Float, Str")]
765    InvalidValue,
766    #[error("Metadata key cannot start with '#' or '$': {0}")]
767    InvalidKey(String),
768    #[error("Sparse vector indices, values, and tokens (when present) must have the same length")]
769    SparseVectorLengthMismatch,
770    #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
771    SparseVectorIndicesNotSorted,
772}
773
774impl ChromaError for MetadataValueConversionError {
775    fn code(&self) -> ErrorCodes {
776        match self {
777            MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
778            MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
779            MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
780            MetadataValueConversionError::SparseVectorIndicesNotSorted => {
781                ErrorCodes::InvalidArgument
782            }
783        }
784    }
785}
786
787impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
788    type Error = MetadataValueConversionError;
789
790    fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
791        match &value.value {
792            Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
793                Ok(MetadataValue::Bool(*value))
794            }
795            Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
796                Ok(MetadataValue::Int(*value))
797            }
798            Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
799                Ok(MetadataValue::Float(*value))
800            }
801            Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
802                Ok(MetadataValue::Str(value.clone()))
803            }
804            Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
805                let sparse = value
806                    .clone()
807                    .try_into()
808                    .map_err(|_| MetadataValueConversionError::SparseVectorLengthMismatch)?;
809                Ok(MetadataValue::SparseVector(sparse))
810            }
811            _ => Err(MetadataValueConversionError::InvalidValue),
812        }
813    }
814}
815
816impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
817    fn from(value: MetadataValue) -> Self {
818        match value {
819            MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
820                value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
821            },
822            MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
823                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
824                    value,
825                )),
826            },
827            MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
828                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
829                    value,
830                )),
831            },
832            MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
833                value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
834            },
835            MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
836                value: Some(
837                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
838                        sparse_vec.into(),
839                    ),
840                ),
841            },
842        }
843    }
844}
845
846/*
847===========================================
848UpdateMetadata
849===========================================
850*/
851pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
852
853/**
854 * Check if two metadata are close to equal. Ignores small differences in float values.
855 */
856pub fn are_update_metadatas_close_to_equal(
857    metadata1: &UpdateMetadata,
858    metadata2: &UpdateMetadata,
859) -> bool {
860    assert_eq!(metadata1.len(), metadata2.len());
861
862    for (key, value) in metadata1.iter() {
863        if !metadata2.contains_key(key) {
864            return false;
865        }
866        let other_value = metadata2.get(key).unwrap();
867
868        if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
869            (value, other_value)
870        {
871            if (value - other_value).abs() > 1e-6 {
872                return false;
873            }
874        } else if value != other_value {
875            return false;
876        }
877    }
878
879    true
880}
881
882pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
883    assert_eq!(metadata1.len(), metadata2.len());
884
885    for (key, value) in metadata1.iter() {
886        if !metadata2.contains_key(key) {
887            return false;
888        }
889        let other_value = metadata2.get(key).unwrap();
890
891        if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
892            (value, other_value)
893        {
894            if (value - other_value).abs() > 1e-6 {
895                return false;
896            }
897        } else if value != other_value {
898            return false;
899        }
900    }
901
902    true
903}
904
905impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
906    type Error = UpdateMetadataValueConversionError;
907
908    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
909        let mut metadata = UpdateMetadata::new();
910        for (key, value) in proto_metadata.metadata.iter() {
911            let value = match value.try_into() {
912                Ok(value) => value,
913                Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
914            };
915            metadata.insert(key.clone(), value);
916        }
917        Ok(metadata)
918    }
919}
920
921impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
922    fn from(metadata: UpdateMetadata) -> Self {
923        let mut metadata = metadata;
924        let mut proto_metadata = chroma_proto::UpdateMetadata {
925            metadata: HashMap::new(),
926        };
927        for (key, value) in metadata.drain() {
928            let proto_value = value.into();
929            proto_metadata.metadata.insert(key.clone(), proto_value);
930        }
931        proto_metadata
932    }
933}
934
935/*
936===========================================
937Metadata
938===========================================
939*/
940
941pub type Metadata = HashMap<String, MetadataValue>;
942pub type DeletedMetadata = HashSet<String>;
943
944pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
945    metadata
946        .iter()
947        .map(|(k, v)| {
948            k.len()
949                + match v {
950                    MetadataValue::Bool(b) => size_of_val(b),
951                    MetadataValue::Int(i) => size_of_val(i),
952                    MetadataValue::Float(f) => size_of_val(f),
953                    MetadataValue::Str(s) => s.len(),
954                    MetadataValue::SparseVector(v) => {
955                        size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
956                    }
957                }
958        })
959        .sum()
960}
961
962pub fn get_metadata_value_as<'a, T>(
963    metadata: &'a Metadata,
964    key: &str,
965) -> Result<T, Box<MetadataValueConversionError>>
966where
967    T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
968{
969    let res = match metadata.get(key) {
970        Some(value) => T::try_from(value),
971        None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
972    };
973    match res {
974        Ok(value) => Ok(value),
975        Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
976    }
977}
978
979impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
980    type Error = MetadataValueConversionError;
981
982    fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
983        let mut metadata = Metadata::new();
984        for (key, value) in proto_metadata.metadata.iter() {
985            let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
986            if maybe_value.is_err() {
987                return Err(MetadataValueConversionError::InvalidValue);
988            }
989            let value = maybe_value.unwrap();
990            metadata.insert(key.clone(), value);
991        }
992        Ok(metadata)
993    }
994}
995
996impl From<Metadata> for chroma_proto::UpdateMetadata {
997    fn from(metadata: Metadata) -> Self {
998        let mut metadata = metadata;
999        let mut proto_metadata = chroma_proto::UpdateMetadata {
1000            metadata: HashMap::new(),
1001        };
1002        for (key, value) in metadata.drain() {
1003            let proto_value = value.into();
1004            proto_metadata.metadata.insert(key.clone(), proto_value);
1005        }
1006        proto_metadata
1007    }
1008}
1009
1010#[derive(Debug, Default)]
1011pub struct MetadataDelta<'referred_data> {
1012    pub metadata_to_update: HashMap<
1013        &'referred_data str,
1014        (&'referred_data MetadataValue, &'referred_data MetadataValue),
1015    >,
1016    pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1017    pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1018}
1019
1020impl MetadataDelta<'_> {
1021    pub fn new() -> Self {
1022        Self::default()
1023    }
1024}
1025
1026/*
1027===========================================
1028Metadata queries
1029===========================================
1030*/
1031
1032#[derive(Clone, Debug, Error, PartialEq)]
1033pub enum WhereConversionError {
1034    #[error("Error: {0}")]
1035    Cause(String),
1036    #[error("{0} -> {1}")]
1037    Trace(String, Box<Self>),
1038}
1039
1040impl WhereConversionError {
1041    pub fn cause(msg: impl ToString) -> Self {
1042        Self::Cause(msg.to_string())
1043    }
1044
1045    pub fn trace(self, context: impl ToString) -> Self {
1046        Self::Trace(context.to_string(), Box::new(self))
1047    }
1048}
1049
1050/// This `Where` enum serves as an unified representation for the `where` and `where_document` clauses.
1051/// Although this is not unified in the API level due to legacy design choices, in the future we will be
1052/// unifying them together, and the structure of the unified AST should be identical to the one here.
1053/// Currently both `where` and `where_document` clauses will be translated into `Where`, and if both are
1054/// present we simply create a conjunction of both clauses as the actual filter. This is consistent with
1055/// the semantics we used to have when the `where` and `where_document` clauses are treated seperately.
1056// TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level.
1057#[derive(Clone, Debug, PartialEq)]
1058#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1059pub enum Where {
1060    Composite(CompositeExpression),
1061    Document(DocumentExpression),
1062    Metadata(MetadataExpression),
1063}
1064
1065impl std::fmt::Display for Where {
1066    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1067        match self {
1068            Where::Composite(composite) => {
1069                let fragment = composite
1070                    .children
1071                    .iter()
1072                    .map(|child| format!("{}", child))
1073                    .collect::<Vec<_>>()
1074                    .join(match composite.operator {
1075                        BooleanOperator::And => " & ",
1076                        BooleanOperator::Or => " | ",
1077                    });
1078                write!(f, "({})", fragment)
1079            }
1080            Where::Metadata(expr) => write!(f, "{}", expr),
1081            Where::Document(expr) => write!(f, "{}", expr),
1082        }
1083    }
1084}
1085
1086impl serde::Serialize for Where {
1087    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1088    where
1089        S: Serializer,
1090    {
1091        match self {
1092            Where::Composite(composite) => {
1093                let mut map = serializer.serialize_map(Some(1))?;
1094                let op_key = match composite.operator {
1095                    BooleanOperator::And => "$and",
1096                    BooleanOperator::Or => "$or",
1097                };
1098                map.serialize_entry(op_key, &composite.children)?;
1099                map.end()
1100            }
1101            Where::Document(doc) => {
1102                let mut outer_map = serializer.serialize_map(Some(1))?;
1103                let mut inner_map = serde_json::Map::new();
1104                let op_key = match doc.operator {
1105                    DocumentOperator::Contains => "$contains",
1106                    DocumentOperator::NotContains => "$not_contains",
1107                    DocumentOperator::Regex => "$regex",
1108                    DocumentOperator::NotRegex => "$not_regex",
1109                };
1110                inner_map.insert(
1111                    op_key.to_string(),
1112                    serde_json::Value::String(doc.pattern.clone()),
1113                );
1114                outer_map.serialize_entry("#document", &inner_map)?;
1115                outer_map.end()
1116            }
1117            Where::Metadata(meta) => {
1118                let mut outer_map = serializer.serialize_map(Some(1))?;
1119                let mut inner_map = serde_json::Map::new();
1120
1121                match &meta.comparison {
1122                    MetadataComparison::Primitive(op, value) => {
1123                        let op_key = match op {
1124                            PrimitiveOperator::Equal => "$eq",
1125                            PrimitiveOperator::NotEqual => "$ne",
1126                            PrimitiveOperator::GreaterThan => "$gt",
1127                            PrimitiveOperator::GreaterThanOrEqual => "$gte",
1128                            PrimitiveOperator::LessThan => "$lt",
1129                            PrimitiveOperator::LessThanOrEqual => "$lte",
1130                        };
1131                        let value_json =
1132                            serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1133                        inner_map.insert(op_key.to_string(), value_json);
1134                    }
1135                    MetadataComparison::Set(op, set_value) => {
1136                        let op_key = match op {
1137                            SetOperator::In => "$in",
1138                            SetOperator::NotIn => "$nin",
1139                        };
1140                        let values_json = match set_value {
1141                            MetadataSetValue::Bool(v) => serde_json::to_value(v),
1142                            MetadataSetValue::Int(v) => serde_json::to_value(v),
1143                            MetadataSetValue::Float(v) => serde_json::to_value(v),
1144                            MetadataSetValue::Str(v) => serde_json::to_value(v),
1145                        }
1146                        .map_err(serde::ser::Error::custom)?;
1147                        inner_map.insert(op_key.to_string(), values_json);
1148                    }
1149                }
1150
1151                outer_map.serialize_entry(&meta.key, &inner_map)?;
1152                outer_map.end()
1153            }
1154        }
1155    }
1156}
1157
1158impl From<bool> for Where {
1159    fn from(value: bool) -> Self {
1160        if value {
1161            Where::conjunction(vec![])
1162        } else {
1163            Where::disjunction(vec![])
1164        }
1165    }
1166}
1167
1168impl Where {
1169    pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1170        // If children.len() == 0, we will return a conjunction that is always true.
1171        // If children.len() == 1, we will return the single child.
1172        // Otherwise, we will return a conjunction of the children.
1173
1174        let mut children: Vec<_> = children
1175            .into_iter()
1176            .flat_map(|expr| {
1177                if let Where::Composite(CompositeExpression {
1178                    operator: BooleanOperator::And,
1179                    children,
1180                }) = expr
1181                {
1182                    return children;
1183                }
1184                vec![expr]
1185            })
1186            .dedup()
1187            .collect();
1188
1189        if children.len() == 1 {
1190            return children.pop().expect("just checked len is 1");
1191        }
1192
1193        Self::Composite(CompositeExpression {
1194            operator: BooleanOperator::And,
1195            children,
1196        })
1197    }
1198    pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1199        // If children.len() == 0, we will return a disjunction that is always false.
1200        // If children.len() == 1, we will return the single child.
1201        // Otherwise, we will return a disjunction of the children.
1202
1203        let mut children: Vec<_> = children
1204            .into_iter()
1205            .flat_map(|expr| {
1206                if let Where::Composite(CompositeExpression {
1207                    operator: BooleanOperator::Or,
1208                    children,
1209                }) = expr
1210                {
1211                    return children;
1212                }
1213                vec![expr]
1214            })
1215            .dedup()
1216            .collect();
1217
1218        if children.len() == 1 {
1219            return children.pop().expect("just checked len is 1");
1220        }
1221
1222        Self::Composite(CompositeExpression {
1223            operator: BooleanOperator::Or,
1224            children,
1225        })
1226    }
1227
1228    pub fn fts_query_length(&self) -> u64 {
1229        match self {
1230            Where::Composite(composite_expression) => composite_expression
1231                .children
1232                .iter()
1233                .map(Where::fts_query_length)
1234                .sum(),
1235            // The query length is defined to be the number of trigram tokens
1236            Where::Document(document_expression) => {
1237                document_expression.pattern.len().max(3) as u64 - 2
1238            }
1239            Where::Metadata(_) => 0,
1240        }
1241    }
1242
1243    pub fn metadata_predicate_count(&self) -> u64 {
1244        match self {
1245            Where::Composite(composite_expression) => composite_expression
1246                .children
1247                .iter()
1248                .map(Where::metadata_predicate_count)
1249                .sum(),
1250            Where::Document(_) => 0,
1251            Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1252                MetadataComparison::Primitive(_, _) => 1,
1253                MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1254                    MetadataSetValue::Bool(items) => items.len() as u64,
1255                    MetadataSetValue::Int(items) => items.len() as u64,
1256                    MetadataSetValue::Float(items) => items.len() as u64,
1257                    MetadataSetValue::Str(items) => items.len() as u64,
1258                },
1259            },
1260        }
1261    }
1262}
1263
1264impl BitAnd for Where {
1265    type Output = Where;
1266
1267    fn bitand(self, rhs: Self) -> Self::Output {
1268        Self::conjunction([self, rhs])
1269    }
1270}
1271
1272impl BitOr for Where {
1273    type Output = Where;
1274
1275    fn bitor(self, rhs: Self) -> Self::Output {
1276        Self::disjunction([self, rhs])
1277    }
1278}
1279
1280impl TryFrom<chroma_proto::Where> for Where {
1281    type Error = WhereConversionError;
1282
1283    fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1284        let where_inner = proto_where
1285            .r#where
1286            .ok_or(WhereConversionError::cause("Invalid Where"))?;
1287        Ok(match where_inner {
1288            chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1289                Self::Metadata(direct_comparison.try_into()?)
1290            }
1291            chroma_proto::r#where::Where::Children(where_children) => {
1292                Self::Composite(where_children.try_into()?)
1293            }
1294            chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1295                Self::Document(direct_where_document.into())
1296            }
1297        })
1298    }
1299}
1300
1301impl TryFrom<Where> for chroma_proto::Where {
1302    type Error = WhereConversionError;
1303
1304    fn try_from(value: Where) -> Result<Self, Self::Error> {
1305        let proto_where = match value {
1306            Where::Composite(composite_expression) => {
1307                chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1308            }
1309            Where::Document(document_expression) => {
1310                chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1311            }
1312            Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1313                chroma_proto::DirectComparison::try_from(metadata_expression)
1314                    .map_err(|err| err.trace("MetadataExpression"))?,
1315            ),
1316        };
1317        Ok(Self {
1318            r#where: Some(proto_where),
1319        })
1320    }
1321}
1322
1323#[derive(Clone, Debug, PartialEq)]
1324#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1325pub struct CompositeExpression {
1326    pub operator: BooleanOperator,
1327    pub children: Vec<Where>,
1328}
1329
1330impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1331    type Error = WhereConversionError;
1332
1333    fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1334        let operator = proto_children.operator().into();
1335        let children = proto_children
1336            .children
1337            .into_iter()
1338            .map(Where::try_from)
1339            .collect::<Result<Vec<_>, _>>()
1340            .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1341        Ok(Self { operator, children })
1342    }
1343}
1344
1345impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1346    type Error = WhereConversionError;
1347
1348    fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1349        Ok(Self {
1350            operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1351            children: value
1352                .children
1353                .into_iter()
1354                .map(chroma_proto::Where::try_from)
1355                .collect::<Result<_, _>>()?,
1356        })
1357    }
1358}
1359
1360#[derive(Clone, Debug, PartialEq)]
1361#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1362pub enum BooleanOperator {
1363    And,
1364    Or,
1365}
1366
1367impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1368    fn from(value: chroma_proto::BooleanOperator) -> Self {
1369        match value {
1370            chroma_proto::BooleanOperator::And => Self::And,
1371            chroma_proto::BooleanOperator::Or => Self::Or,
1372        }
1373    }
1374}
1375
1376impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1377    fn from(value: BooleanOperator) -> Self {
1378        match value {
1379            BooleanOperator::And => Self::And,
1380            BooleanOperator::Or => Self::Or,
1381        }
1382    }
1383}
1384
1385#[derive(Clone, Debug, PartialEq)]
1386#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1387pub struct DocumentExpression {
1388    pub operator: DocumentOperator,
1389    pub pattern: String,
1390}
1391
1392impl std::fmt::Display for DocumentExpression {
1393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1394        let op_str = match self.operator {
1395            DocumentOperator::Contains => "CONTAINS",
1396            DocumentOperator::NotContains => "NOT CONTAINS",
1397            DocumentOperator::Regex => "REGEX",
1398            DocumentOperator::NotRegex => "NOT REGEX",
1399        };
1400        write!(f, "#document {} \"{}\"", op_str, self.pattern)
1401    }
1402}
1403
1404impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1405    fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1406        Self {
1407            operator: value.operator().into(),
1408            pattern: value.pattern,
1409        }
1410    }
1411}
1412
1413impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1414    fn from(value: DocumentExpression) -> Self {
1415        Self {
1416            pattern: value.pattern,
1417            operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1418        }
1419    }
1420}
1421
1422#[derive(Clone, Debug, PartialEq)]
1423#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1424pub enum DocumentOperator {
1425    Contains,
1426    NotContains,
1427    Regex,
1428    NotRegex,
1429}
1430impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1431    fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1432        match value {
1433            chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1434            chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1435            chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1436            chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1437        }
1438    }
1439}
1440
1441impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1442    fn from(value: DocumentOperator) -> Self {
1443        match value {
1444            DocumentOperator::Contains => Self::Contains,
1445            DocumentOperator::NotContains => Self::NotContains,
1446            DocumentOperator::Regex => Self::Regex,
1447            DocumentOperator::NotRegex => Self::NotRegex,
1448        }
1449    }
1450}
1451
1452#[derive(Clone, Debug, PartialEq)]
1453#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1454pub struct MetadataExpression {
1455    pub key: String,
1456    pub comparison: MetadataComparison,
1457}
1458
1459impl std::fmt::Display for MetadataExpression {
1460    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1461        match &self.comparison {
1462            MetadataComparison::Primitive(op, value) => {
1463                write!(f, "{} {} {}", self.key, op, value)
1464            }
1465            MetadataComparison::Set(op, set_value) => {
1466                write!(f, "{} {} {}", self.key, op, set_value)
1467            }
1468        }
1469    }
1470}
1471
1472impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1473    type Error = WhereConversionError;
1474
1475    fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1476        let proto_comparison = value
1477            .comparison
1478            .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1479        let comparison = match proto_comparison {
1480            chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1481                single_string_comparison,
1482            ) => MetadataComparison::Primitive(
1483                single_string_comparison.comparator().into(),
1484                MetadataValue::Str(single_string_comparison.value),
1485            ),
1486            chroma_proto::direct_comparison::Comparison::StringListOperand(
1487                string_list_comparison,
1488            ) => MetadataComparison::Set(
1489                string_list_comparison.list_operator().into(),
1490                MetadataSetValue::Str(string_list_comparison.values),
1491            ),
1492            chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1493                single_int_comparison,
1494            ) => MetadataComparison::Primitive(
1495                match single_int_comparison
1496                    .comparator
1497                    .ok_or(WhereConversionError::cause(
1498                        "Invalid scalar integer operator",
1499                    ))? {
1500                    chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1501                        chroma_proto::GenericComparator::try_from(op)
1502                            .map_err(WhereConversionError::cause)?
1503                            .into()
1504                    }
1505                    chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1506                        chroma_proto::NumberComparator::try_from(op)
1507                            .map_err(WhereConversionError::cause)?
1508                            .into()
1509                    }
1510                },
1511                MetadataValue::Int(single_int_comparison.value),
1512            ),
1513            chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1514                MetadataComparison::Set(
1515                    int_list_comparison.list_operator().into(),
1516                    MetadataSetValue::Int(int_list_comparison.values),
1517                )
1518            }
1519            chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1520                single_double_comparison,
1521            ) => MetadataComparison::Primitive(
1522                match single_double_comparison
1523                    .comparator
1524                    .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1525                {
1526                    chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1527                        chroma_proto::GenericComparator::try_from(op)
1528                            .map_err(WhereConversionError::cause)?
1529                            .into()
1530                    }
1531                    chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1532                        chroma_proto::NumberComparator::try_from(op)
1533                            .map_err(WhereConversionError::cause)?
1534                            .into()
1535                    }
1536                },
1537                MetadataValue::Float(single_double_comparison.value),
1538            ),
1539            chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1540                double_list_comparison,
1541            ) => MetadataComparison::Set(
1542                double_list_comparison.list_operator().into(),
1543                MetadataSetValue::Float(double_list_comparison.values),
1544            ),
1545            chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1546                MetadataComparison::Set(
1547                    bool_list_comparison.list_operator().into(),
1548                    MetadataSetValue::Bool(bool_list_comparison.values),
1549                )
1550            }
1551            chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1552                single_bool_comparison,
1553            ) => MetadataComparison::Primitive(
1554                single_bool_comparison.comparator().into(),
1555                MetadataValue::Bool(single_bool_comparison.value),
1556            ),
1557        };
1558        Ok(Self {
1559            key: value.key,
1560            comparison,
1561        })
1562    }
1563}
1564
1565impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1566    type Error = WhereConversionError;
1567
1568    fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1569        let comparison = match value.comparison {
1570            MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1571                MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1572                MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1573                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1574                                numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1575                            }),
1576                MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1577                                generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1578                                numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1579                            }),
1580                MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1581                MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1582            },
1583            MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1584                MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1585                MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1586                MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1587                MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1588            },
1589        };
1590        Ok(Self {
1591            key: value.key,
1592            comparison: Some(comparison),
1593        })
1594    }
1595}
1596
1597#[derive(Clone, Debug, PartialEq)]
1598#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1599pub enum MetadataComparison {
1600    Primitive(PrimitiveOperator, MetadataValue),
1601    Set(SetOperator, MetadataSetValue),
1602}
1603
1604#[derive(Clone, Debug, PartialEq)]
1605#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1606pub enum PrimitiveOperator {
1607    Equal,
1608    NotEqual,
1609    GreaterThan,
1610    GreaterThanOrEqual,
1611    LessThan,
1612    LessThanOrEqual,
1613}
1614
1615impl std::fmt::Display for PrimitiveOperator {
1616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1617        let op_str = match self {
1618            PrimitiveOperator::Equal => "=",
1619            PrimitiveOperator::NotEqual => "≠",
1620            PrimitiveOperator::GreaterThan => ">",
1621            PrimitiveOperator::GreaterThanOrEqual => "≥",
1622            PrimitiveOperator::LessThan => "<",
1623            PrimitiveOperator::LessThanOrEqual => "≤",
1624        };
1625        write!(f, "{}", op_str)
1626    }
1627}
1628
1629impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1630    fn from(value: chroma_proto::GenericComparator) -> Self {
1631        match value {
1632            chroma_proto::GenericComparator::Eq => Self::Equal,
1633            chroma_proto::GenericComparator::Ne => Self::NotEqual,
1634        }
1635    }
1636}
1637
1638impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1639    type Error = WhereConversionError;
1640
1641    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1642        match value {
1643            PrimitiveOperator::Equal => Ok(Self::Eq),
1644            PrimitiveOperator::NotEqual => Ok(Self::Ne),
1645            op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1646        }
1647    }
1648}
1649
1650impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1651    fn from(value: chroma_proto::NumberComparator) -> Self {
1652        match value {
1653            chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1654            chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1655            chroma_proto::NumberComparator::Lt => Self::LessThan,
1656            chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1657        }
1658    }
1659}
1660
1661impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1662    type Error = WhereConversionError;
1663
1664    fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1665        match value {
1666            PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1667            PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1668            PrimitiveOperator::LessThan => Ok(Self::Lt),
1669            PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1670            op => Err(WhereConversionError::cause(format!(
1671                "{op:?} ∉ [≤, <, >, ≥]"
1672            ))),
1673        }
1674    }
1675}
1676
1677#[derive(Clone, Debug, PartialEq, Eq)]
1678#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1679pub enum SetOperator {
1680    In,
1681    NotIn,
1682}
1683
1684impl std::fmt::Display for SetOperator {
1685    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1686        let op_str = match self {
1687            SetOperator::In => "∈",
1688            SetOperator::NotIn => "∉",
1689        };
1690        write!(f, "{}", op_str)
1691    }
1692}
1693
1694impl From<chroma_proto::ListOperator> for SetOperator {
1695    fn from(value: chroma_proto::ListOperator) -> Self {
1696        match value {
1697            chroma_proto::ListOperator::In => Self::In,
1698            chroma_proto::ListOperator::Nin => Self::NotIn,
1699        }
1700    }
1701}
1702
1703impl From<SetOperator> for chroma_proto::ListOperator {
1704    fn from(value: SetOperator) -> Self {
1705        match value {
1706            SetOperator::In => Self::In,
1707            SetOperator::NotIn => Self::Nin,
1708        }
1709    }
1710}
1711
1712#[derive(Clone, Debug, PartialEq)]
1713#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1714pub enum MetadataSetValue {
1715    Bool(Vec<bool>),
1716    Int(Vec<i64>),
1717    Float(Vec<f64>),
1718    Str(Vec<String>),
1719}
1720
1721impl std::fmt::Display for MetadataSetValue {
1722    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1723        match self {
1724            MetadataSetValue::Bool(values) => {
1725                let values_str = values
1726                    .iter()
1727                    .map(|v| format!("\"{}\"", v))
1728                    .collect::<Vec<_>>()
1729                    .join(", ");
1730                write!(f, "[{}]", values_str)
1731            }
1732            MetadataSetValue::Int(values) => {
1733                let values_str = values
1734                    .iter()
1735                    .map(|v| v.to_string())
1736                    .collect::<Vec<_>>()
1737                    .join(", ");
1738                write!(f, "[{}]", values_str)
1739            }
1740            MetadataSetValue::Float(values) => {
1741                let values_str = values
1742                    .iter()
1743                    .map(|v| v.to_string())
1744                    .collect::<Vec<_>>()
1745                    .join(", ");
1746                write!(f, "[{}]", values_str)
1747            }
1748            MetadataSetValue::Str(values) => {
1749                let values_str = values
1750                    .iter()
1751                    .map(|v| format!("\"{}\"", v))
1752                    .collect::<Vec<_>>()
1753                    .join(", ");
1754                write!(f, "[{}]", values_str)
1755            }
1756        }
1757    }
1758}
1759
1760impl MetadataSetValue {
1761    pub fn value_type(&self) -> MetadataValueType {
1762        match self {
1763            MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1764            MetadataSetValue::Int(_) => MetadataValueType::Int,
1765            MetadataSetValue::Float(_) => MetadataValueType::Float,
1766            MetadataSetValue::Str(_) => MetadataValueType::Str,
1767        }
1768    }
1769}
1770
1771impl From<Vec<bool>> for MetadataSetValue {
1772    fn from(values: Vec<bool>) -> Self {
1773        MetadataSetValue::Bool(values)
1774    }
1775}
1776
1777impl From<Vec<i64>> for MetadataSetValue {
1778    fn from(values: Vec<i64>) -> Self {
1779        MetadataSetValue::Int(values)
1780    }
1781}
1782
1783impl From<Vec<i32>> for MetadataSetValue {
1784    fn from(values: Vec<i32>) -> Self {
1785        MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1786    }
1787}
1788
1789impl From<Vec<f64>> for MetadataSetValue {
1790    fn from(values: Vec<f64>) -> Self {
1791        MetadataSetValue::Float(values)
1792    }
1793}
1794
1795impl From<Vec<f32>> for MetadataSetValue {
1796    fn from(values: Vec<f32>) -> Self {
1797        MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1798    }
1799}
1800
1801impl From<Vec<String>> for MetadataSetValue {
1802    fn from(values: Vec<String>) -> Self {
1803        MetadataSetValue::Str(values)
1804    }
1805}
1806
1807impl From<Vec<&str>> for MetadataSetValue {
1808    fn from(values: Vec<&str>) -> Self {
1809        MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1810    }
1811}
1812
1813// TODO: Deprecate where_document
1814impl TryFrom<chroma_proto::WhereDocument> for Where {
1815    type Error = WhereConversionError;
1816
1817    fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1818        match proto_document.r#where_document {
1819            Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1820                let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1821                    proto_comparison.operator,
1822                ) {
1823                    Ok(operator) => operator,
1824                    Err(_) => {
1825                        return Err(WhereConversionError::cause(
1826                            "[Deprecated] Invalid where document operator",
1827                        ))
1828                    }
1829                };
1830                let comparison = DocumentExpression {
1831                    pattern: proto_comparison.pattern,
1832                    operator: operator.into(),
1833                };
1834                Ok(Where::Document(comparison))
1835            }
1836            Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1837                let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1838                    proto_children.operator,
1839                ) {
1840                    Ok(operator) => operator,
1841                    Err(_) => {
1842                        return Err(WhereConversionError::cause(
1843                            "[Deprecated] Invalid boolean operator",
1844                        ))
1845                    }
1846                };
1847                let children = CompositeExpression {
1848                    children: proto_children
1849                        .children
1850                        .into_iter()
1851                        .map(|child| child.try_into())
1852                        .collect::<Result<_, _>>()?,
1853                    operator: operator.into(),
1854                };
1855                Ok(Where::Composite(children))
1856            }
1857            None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1858        }
1859    }
1860}
1861
1862#[cfg(test)]
1863mod tests {
1864    use crate::operator::Key;
1865
1866    use super::*;
1867
1868    // This is needed for the tests that round trip to the python world.
1869    #[cfg(feature = "pyo3")]
1870    fn ensure_python_interpreter() {
1871        static PYTHON_INIT: std::sync::Once = std::sync::Once::new();
1872        PYTHON_INIT.call_once(|| {
1873            pyo3::prepare_freethreaded_python();
1874        });
1875    }
1876
1877    #[test]
1878    fn test_update_metadata_try_from() {
1879        let mut proto_metadata = chroma_proto::UpdateMetadata {
1880            metadata: HashMap::new(),
1881        };
1882        proto_metadata.metadata.insert(
1883            "foo".to_string(),
1884            chroma_proto::UpdateMetadataValue {
1885                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1886            },
1887        );
1888        proto_metadata.metadata.insert(
1889            "bar".to_string(),
1890            chroma_proto::UpdateMetadataValue {
1891                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1892            },
1893        );
1894        proto_metadata.metadata.insert(
1895            "baz".to_string(),
1896            chroma_proto::UpdateMetadataValue {
1897                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1898                    "42".to_string(),
1899                )),
1900            },
1901        );
1902        // Add sparse vector test
1903        proto_metadata.metadata.insert(
1904            "sparse".to_string(),
1905            chroma_proto::UpdateMetadataValue {
1906                value: Some(
1907                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1908                        chroma_proto::SparseVector {
1909                            indices: vec![0, 5, 10],
1910                            values: vec![0.1, 0.5, 0.9],
1911                            tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
1912                        },
1913                    ),
1914                ),
1915            },
1916        );
1917        let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1918        assert_eq!(converted_metadata.len(), 4);
1919        assert_eq!(
1920            converted_metadata.get("foo").unwrap(),
1921            &UpdateMetadataValue::Int(42)
1922        );
1923        assert_eq!(
1924            converted_metadata.get("bar").unwrap(),
1925            &UpdateMetadataValue::Float(42.0)
1926        );
1927        assert_eq!(
1928            converted_metadata.get("baz").unwrap(),
1929            &UpdateMetadataValue::Str("42".to_string())
1930        );
1931        assert_eq!(
1932            converted_metadata.get("sparse").unwrap(),
1933            &UpdateMetadataValue::SparseVector(
1934                SparseVector::new_with_tokens(
1935                    vec![0, 5, 10],
1936                    vec![0.1, 0.5, 0.9],
1937                    vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
1938                )
1939                .unwrap()
1940            )
1941        );
1942    }
1943
1944    #[test]
1945    fn test_metadata_try_from() {
1946        let mut proto_metadata = chroma_proto::UpdateMetadata {
1947            metadata: HashMap::new(),
1948        };
1949        proto_metadata.metadata.insert(
1950            "foo".to_string(),
1951            chroma_proto::UpdateMetadataValue {
1952                value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1953            },
1954        );
1955        proto_metadata.metadata.insert(
1956            "bar".to_string(),
1957            chroma_proto::UpdateMetadataValue {
1958                value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1959            },
1960        );
1961        proto_metadata.metadata.insert(
1962            "baz".to_string(),
1963            chroma_proto::UpdateMetadataValue {
1964                value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1965                    "42".to_string(),
1966                )),
1967            },
1968        );
1969        // Add sparse vector test
1970        proto_metadata.metadata.insert(
1971            "sparse".to_string(),
1972            chroma_proto::UpdateMetadataValue {
1973                value: Some(
1974                    chroma_proto::update_metadata_value::Value::SparseVectorValue(
1975                        chroma_proto::SparseVector {
1976                            indices: vec![1, 10, 100],
1977                            values: vec![0.2, 0.4, 0.6],
1978                            tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
1979                        },
1980                    ),
1981                ),
1982            },
1983        );
1984        let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1985        assert_eq!(converted_metadata.len(), 4);
1986        assert_eq!(
1987            converted_metadata.get("foo").unwrap(),
1988            &MetadataValue::Int(42)
1989        );
1990        assert_eq!(
1991            converted_metadata.get("bar").unwrap(),
1992            &MetadataValue::Float(42.0)
1993        );
1994        assert_eq!(
1995            converted_metadata.get("baz").unwrap(),
1996            &MetadataValue::Str("42".to_string())
1997        );
1998        assert_eq!(
1999            converted_metadata.get("sparse").unwrap(),
2000            &MetadataValue::SparseVector(
2001                SparseVector::new_with_tokens(
2002                    vec![1, 10, 100],
2003                    vec![0.2, 0.4, 0.6],
2004                    vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2005                )
2006                .unwrap()
2007            )
2008        );
2009    }
2010
2011    #[test]
2012    fn test_where_clause_simple_from() {
2013        let proto_where = chroma_proto::Where {
2014            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2015                chroma_proto::DirectComparison {
2016                    key: "foo".to_string(),
2017                    comparison: Some(
2018                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2019                            chroma_proto::SingleIntComparison {
2020                                value: 42,
2021                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2022                            },
2023                        ),
2024                    ),
2025                },
2026            )),
2027        };
2028        let where_clause: Where = proto_where.try_into().unwrap();
2029        match where_clause {
2030            Where::Metadata(comparison) => {
2031                assert_eq!(comparison.key, "foo");
2032                match comparison.comparison {
2033                    MetadataComparison::Primitive(_, value) => {
2034                        assert_eq!(value, MetadataValue::Int(42));
2035                    }
2036                    _ => panic!("Invalid comparison type"),
2037                }
2038            }
2039            _ => panic!("Invalid where type"),
2040        }
2041    }
2042
2043    #[test]
2044    fn test_where_clause_with_children() {
2045        let proto_where = chroma_proto::Where {
2046            r#where: Some(chroma_proto::r#where::Where::Children(
2047                chroma_proto::WhereChildren {
2048                    children: vec![
2049                        chroma_proto::Where {
2050                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2051                                chroma_proto::DirectComparison {
2052                                    key: "foo".to_string(),
2053                                    comparison: Some(
2054                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2055                                            chroma_proto::SingleIntComparison {
2056                                                value: 42,
2057                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2058                                            },
2059                                        ),
2060                                    ),
2061                                },
2062                            )),
2063                        },
2064                        chroma_proto::Where {
2065                            r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2066                                chroma_proto::DirectComparison {
2067                                    key: "bar".to_string(),
2068                                    comparison: Some(
2069                                        chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2070                                            chroma_proto::SingleIntComparison {
2071                                                value: 42,
2072                                                comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2073                                            },
2074                                        ),
2075                                    ),
2076                                },
2077                            )),
2078                        },
2079                    ],
2080                    operator: chroma_proto::BooleanOperator::And.into(),
2081                },
2082            )),
2083        };
2084        let where_clause: Where = proto_where.try_into().unwrap();
2085        match where_clause {
2086            Where::Composite(children) => {
2087                assert_eq!(children.children.len(), 2);
2088                assert_eq!(children.operator, BooleanOperator::And);
2089            }
2090            _ => panic!("Invalid where type"),
2091        }
2092    }
2093
2094    #[test]
2095    fn test_where_document_simple() {
2096        let proto_where = chroma_proto::WhereDocument {
2097            r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
2098                chroma_proto::DirectWhereDocument {
2099                    pattern: "foo".to_string(),
2100                    operator: chroma_proto::WhereDocumentOperator::Contains.into(),
2101                },
2102            )),
2103        };
2104        let where_document: Where = proto_where.try_into().unwrap();
2105        match where_document {
2106            Where::Document(comparison) => {
2107                assert_eq!(comparison.pattern, "foo");
2108                assert_eq!(comparison.operator, DocumentOperator::Contains);
2109            }
2110            _ => panic!("Invalid where document type"),
2111        }
2112    }
2113
2114    #[test]
2115    fn test_where_document_with_children() {
2116        let proto_where = chroma_proto::WhereDocument {
2117            r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
2118                chroma_proto::WhereDocumentChildren {
2119                    children: vec![
2120                        chroma_proto::WhereDocument {
2121                            r#where_document: Some(
2122                                chroma_proto::where_document::WhereDocument::Direct(
2123                                    chroma_proto::DirectWhereDocument {
2124                                        pattern: "foo".to_string(),
2125                                        operator: chroma_proto::WhereDocumentOperator::Contains
2126                                            .into(),
2127                                    },
2128                                ),
2129                            ),
2130                        },
2131                        chroma_proto::WhereDocument {
2132                            r#where_document: Some(
2133                                chroma_proto::where_document::WhereDocument::Direct(
2134                                    chroma_proto::DirectWhereDocument {
2135                                        pattern: "bar".to_string(),
2136                                        operator: chroma_proto::WhereDocumentOperator::Contains
2137                                            .into(),
2138                                    },
2139                                ),
2140                            ),
2141                        },
2142                    ],
2143                    operator: chroma_proto::BooleanOperator::And.into(),
2144                },
2145            )),
2146        };
2147        let where_document: Where = proto_where.try_into().unwrap();
2148        match where_document {
2149            Where::Composite(children) => {
2150                assert_eq!(children.children.len(), 2);
2151                assert_eq!(children.operator, BooleanOperator::And);
2152            }
2153            _ => panic!("Invalid where document type"),
2154        }
2155    }
2156
2157    #[test]
2158    fn test_sparse_vector_new() {
2159        let indices = vec![0, 5, 10];
2160        let values = vec![0.1, 0.5, 0.9];
2161        let sparse = SparseVector::new(indices.clone(), values.clone()).unwrap();
2162        assert_eq!(sparse.indices, indices);
2163        assert_eq!(sparse.values, values);
2164    }
2165
2166    #[test]
2167    fn test_sparse_vector_from_pairs() {
2168        let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2169        let sparse = SparseVector::from_pairs(pairs.clone());
2170        assert_eq!(sparse.indices, vec![0, 5, 10]);
2171        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2172    }
2173
2174    #[test]
2175    fn test_sparse_vector_from_triples() {
2176        let triples = vec![
2177            ("foo".to_string(), 0, 0.1),
2178            ("bar".to_string(), 5, 0.5),
2179            ("baz".to_string(), 10, 0.9),
2180        ];
2181        let sparse = SparseVector::from_triples(triples.clone());
2182        assert_eq!(sparse.indices, vec![0, 5, 10]);
2183        assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2184    }
2185
2186    #[test]
2187    fn test_sparse_vector_iter() {
2188        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2189        let collected: Vec<(u32, f32)> = sparse.iter().collect();
2190        assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2191    }
2192
2193    #[test]
2194    fn test_sparse_vector_ordering() {
2195        let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2196        let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2197        let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]).unwrap();
2198        let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]).unwrap();
2199
2200        assert_eq!(sparse1, sparse2);
2201        assert!(sparse1 < sparse3);
2202        assert!(sparse1 < sparse4);
2203    }
2204
2205    #[test]
2206    fn test_sparse_vector_proto_conversion() {
2207        let tokens = vec![
2208            "token1".to_string(),
2209            "token2".to_string(),
2210            "token3".to_string(),
2211        ];
2212        let sparse =
2213            SparseVector::new_with_tokens(vec![1, 10, 100], vec![0.2, 0.4, 0.6], tokens.clone())
2214                .unwrap();
2215        let proto: chroma_proto::SparseVector = sparse.clone().into();
2216        assert_eq!(proto.indices, vec![1, 10, 100]);
2217        assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2218        assert_eq!(proto.tokens, tokens.clone());
2219
2220        let converted: SparseVector = proto.try_into().unwrap();
2221        assert_eq!(converted, sparse);
2222        assert_eq!(converted.tokens, Some(tokens));
2223    }
2224
2225    #[test]
2226    fn test_sparse_vector_proto_conversion_empty_tokens() {
2227        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2228        let proto: chroma_proto::SparseVector = sparse.clone().into();
2229        assert_eq!(proto.indices, vec![0, 5, 10]);
2230        assert_eq!(proto.values, vec![0.1, 0.5, 0.9]);
2231        assert_eq!(proto.tokens, Vec::<String>::new());
2232
2233        let converted: SparseVector = proto.try_into().unwrap();
2234        assert_eq!(converted, sparse);
2235        assert_eq!(converted.tokens, None);
2236    }
2237
2238    #[test]
2239    fn test_sparse_vector_logical_size() {
2240        let metadata = Metadata::from([(
2241            "sparse".to_string(),
2242            MetadataValue::SparseVector(
2243                SparseVector::new(vec![0, 1, 2, 3, 4], vec![0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(),
2244            ),
2245        )]);
2246
2247        let size = logical_size_of_metadata(&metadata);
2248        // Size should include the key string length and the sparse vector data
2249        // "sparse" = 6 bytes + 5 * 4 bytes (u32 indices) + 5 * 4 bytes (f32 values) = 46 bytes
2250        assert_eq!(size, 46);
2251    }
2252
2253    #[test]
2254    fn test_sparse_vector_validation() {
2255        // Valid sparse vector
2256        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
2257        assert!(sparse.validate().is_ok());
2258
2259        // Length mismatch
2260        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2261        assert!(sparse.is_err());
2262        let result = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3])
2263            .unwrap()
2264            .validate();
2265        assert!(result.is_ok());
2266
2267        // Tokens length mismatch with indices/values
2268        let sparse = SparseVector::new_with_tokens(
2269            vec![1, 2, 3],
2270            vec![0.1, 0.2, 0.3],
2271            vec!["a".to_string(), "b".to_string()],
2272        );
2273        assert!(sparse.is_err());
2274
2275        // Unsorted indices (descending order)
2276        let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap();
2277        let result = sparse.validate();
2278        assert!(result.is_err());
2279        assert!(matches!(
2280            result.unwrap_err(),
2281            MetadataValueConversionError::SparseVectorIndicesNotSorted
2282        ));
2283
2284        // Duplicate indices (not strictly ascending)
2285        let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]).unwrap();
2286        let result = sparse.validate();
2287        assert!(result.is_err());
2288        assert!(matches!(
2289            result.unwrap_err(),
2290            MetadataValueConversionError::SparseVectorIndicesNotSorted
2291        ));
2292
2293        // Descending at one point
2294        let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]).unwrap();
2295        let result = sparse.validate();
2296        assert!(result.is_err());
2297        assert!(matches!(
2298            result.unwrap_err(),
2299            MetadataValueConversionError::SparseVectorIndicesNotSorted
2300        ));
2301    }
2302
2303    #[test]
2304    fn test_sparse_vector_deserialize_old_format() {
2305        // Old format without #type field (backward compatibility)
2306        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2307        let sv: SparseVector = serde_json::from_str(json).unwrap();
2308        assert_eq!(sv.indices, vec![0, 1, 2]);
2309        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2310    }
2311
2312    #[test]
2313    fn test_sparse_vector_deserialize_new_format() {
2314        // New format with #type field
2315        let json =
2316            "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2317        let sv: SparseVector = serde_json::from_str(json).unwrap();
2318        assert_eq!(sv.indices, vec![0, 1, 2]);
2319        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2320    }
2321
2322    #[test]
2323    fn test_sparse_vector_deserialize_new_format_field_order() {
2324        // New format with different field order (should still work)
2325        let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2326        let sv: SparseVector = serde_json::from_str(json).unwrap();
2327        assert_eq!(sv.indices, vec![5, 10]);
2328        assert_eq!(sv.values, vec![0.5, 1.0]);
2329    }
2330
2331    #[test]
2332    fn test_sparse_vector_deserialize_wrong_type_tag() {
2333        // Wrong #type field value should fail
2334        let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2335        let result: Result<SparseVector, _> = serde_json::from_str(json);
2336        assert!(result.is_err());
2337        let err_msg = result.unwrap_err().to_string();
2338        assert!(err_msg.contains("sparse_vector"));
2339    }
2340
2341    #[test]
2342    fn test_sparse_vector_serialize_always_has_type() {
2343        // Serialization should always include #type field
2344        let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]).unwrap();
2345        let json = serde_json::to_value(&sv).unwrap();
2346
2347        assert_eq!(json["#type"], "sparse_vector");
2348        assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2349        assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2350    }
2351
2352    #[test]
2353    fn test_sparse_vector_roundtrip_with_type() {
2354        // Test that serialize -> deserialize preserves the data
2355        let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]).unwrap();
2356        let json = serde_json::to_string(&original).unwrap();
2357
2358        // Verify the serialized JSON contains #type
2359        assert!(json.contains("\"#type\":\"sparse_vector\""));
2360
2361        let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2362        assert_eq!(original, deserialized);
2363    }
2364
2365    #[test]
2366    fn test_sparse_vector_in_metadata_old_format() {
2367        // Test that old format works when sparse vector is in metadata
2368        let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2369        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2370
2371        let sparse_value = &map["sparse"];
2372        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2373        assert_eq!(sv.indices, vec![0, 1]);
2374        assert_eq!(sv.values, vec![1.0, 2.0]);
2375    }
2376
2377    #[test]
2378    fn test_sparse_vector_in_metadata_new_format() {
2379        // Test that new format works when sparse vector is in metadata
2380        let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2381        let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2382
2383        let sparse_value = &map["sparse"];
2384        let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2385        assert_eq!(sv.indices, vec![0, 1]);
2386        assert_eq!(sv.values, vec![1.0, 2.0]);
2387    }
2388
2389    #[test]
2390    fn test_sparse_vector_tokens_roundtrip_old_to_new() {
2391        // Old format without tokens field should deserialize with tokens=None
2392        let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2393        let sv: SparseVector = serde_json::from_str(json).unwrap();
2394        assert_eq!(sv.indices, vec![0, 1, 2]);
2395        assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2396        assert_eq!(sv.tokens, None);
2397
2398        // Serialize and verify it includes #type but no tokens field when None
2399        let serialized = serde_json::to_value(&sv).unwrap();
2400        assert_eq!(serialized["#type"], "sparse_vector");
2401        assert_eq!(serialized["indices"], serde_json::json!([0, 1, 2]));
2402        assert_eq!(serialized["values"], serde_json::json!([1.0, 2.0, 3.0]));
2403        assert_eq!(serialized["tokens"], serde_json::Value::Null);
2404    }
2405
2406    #[test]
2407    fn test_sparse_vector_tokens_roundtrip_new_to_new() {
2408        // New format with tokens field
2409        let sv_with_tokens = SparseVector::new_with_tokens(
2410            vec![0, 1, 2],
2411            vec![1.0, 2.0, 3.0],
2412            vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2413        )
2414        .unwrap();
2415
2416        // Serialize
2417        let serialized = serde_json::to_string(&sv_with_tokens).unwrap();
2418        assert!(serialized.contains("\"#type\":\"sparse_vector\""));
2419        assert!(serialized.contains("\"tokens\""));
2420
2421        // Deserialize and verify tokens are preserved
2422        let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2423        assert_eq!(deserialized.indices, vec![0, 1, 2]);
2424        assert_eq!(deserialized.values, vec![1.0, 2.0, 3.0]);
2425        assert_eq!(
2426            deserialized.tokens,
2427            Some(vec![
2428                "foo".to_string(),
2429                "bar".to_string(),
2430                "baz".to_string()
2431            ])
2432        );
2433    }
2434
2435    #[test]
2436    fn test_sparse_vector_tokens_deserialize_with_tokens_field() {
2437        // Test deserializing JSON that explicitly includes tokens field
2438        let json = r##"{"#type": "sparse_vector", "indices": [5, 10], "values": [0.5, 1.0], "tokens": ["token1", "token2"]}"##;
2439        let sv: SparseVector = serde_json::from_str(json).unwrap();
2440        assert_eq!(sv.indices, vec![5, 10]);
2441        assert_eq!(sv.values, vec![0.5, 1.0]);
2442        assert_eq!(
2443            sv.tokens,
2444            Some(vec!["token1".to_string(), "token2".to_string()])
2445        );
2446    }
2447
2448    #[test]
2449    fn test_sparse_vector_tokens_backward_compatibility() {
2450        // Verify old format (no tokens, no #type) deserializes correctly
2451        let old_json = r#"{"indices": [1, 2], "values": [0.1, 0.2]}"#;
2452        let old_sv: SparseVector = serde_json::from_str(old_json).unwrap();
2453
2454        // Verify new format (with #type, with tokens) deserializes correctly
2455        let new_json = r##"{"#type": "sparse_vector", "indices": [1, 2], "values": [0.1, 0.2], "tokens": ["a", "b"]}"##;
2456        let new_sv: SparseVector = serde_json::from_str(new_json).unwrap();
2457
2458        // Both should have same indices and values
2459        assert_eq!(old_sv.indices, new_sv.indices);
2460        assert_eq!(old_sv.values, new_sv.values);
2461
2462        // Old should have None tokens, new should have Some tokens
2463        assert_eq!(old_sv.tokens, None);
2464        assert_eq!(new_sv.tokens, Some(vec!["a".to_string(), "b".to_string()]));
2465    }
2466
2467    #[test]
2468    fn test_sparse_vector_from_triples_preserves_tokens() {
2469        let triples = vec![
2470            ("apple".to_string(), 10, 0.5),
2471            ("banana".to_string(), 20, 0.7),
2472            ("cherry".to_string(), 30, 0.9),
2473        ];
2474        let sv = SparseVector::from_triples(triples.clone());
2475
2476        assert_eq!(sv.indices, vec![10, 20, 30]);
2477        assert_eq!(sv.values, vec![0.5, 0.7, 0.9]);
2478        assert_eq!(
2479            sv.tokens,
2480            Some(vec![
2481                "apple".to_string(),
2482                "banana".to_string(),
2483                "cherry".to_string()
2484            ])
2485        );
2486
2487        // Roundtrip through serialization
2488        let serialized = serde_json::to_string(&sv).unwrap();
2489        let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2490
2491        assert_eq!(deserialized.indices, sv.indices);
2492        assert_eq!(deserialized.values, sv.values);
2493        assert_eq!(deserialized.tokens, sv.tokens);
2494    }
2495
2496    #[cfg(feature = "pyo3")]
2497    #[test]
2498    fn test_sparse_vector_pyo3_roundtrip_with_tokens() {
2499        ensure_python_interpreter();
2500
2501        pyo3::Python::with_gil(|py| {
2502            use pyo3::types::PyDict;
2503            use pyo3::IntoPyObject;
2504
2505            let dict_in = PyDict::new(py);
2506            dict_in.set_item("indices", vec![0u32, 1, 2]).unwrap();
2507            dict_in
2508                .set_item("values", vec![0.1f32, 0.2f32, 0.3f32])
2509                .unwrap();
2510            dict_in
2511                .set_item("tokens", vec!["foo", "bar", "baz"])
2512                .unwrap();
2513
2514            let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2515            assert_eq!(sparse.indices, vec![0, 1, 2]);
2516            assert_eq!(sparse.values, vec![0.1, 0.2, 0.3]);
2517            assert_eq!(
2518                sparse.tokens,
2519                Some(vec![
2520                    "foo".to_string(),
2521                    "bar".to_string(),
2522                    "baz".to_string()
2523                ])
2524            );
2525
2526            let py_obj = sparse.clone().into_pyobject(py).unwrap();
2527            let dict_out = py_obj.downcast::<PyDict>().unwrap();
2528            let tokens_obj = dict_out.get_item("tokens").unwrap();
2529            let tokens: Vec<String> = tokens_obj
2530                .expect("expected tokens key in Python dict")
2531                .extract()
2532                .unwrap();
2533            assert_eq!(
2534                tokens,
2535                vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
2536            );
2537        });
2538    }
2539
2540    #[cfg(feature = "pyo3")]
2541    #[test]
2542    fn test_sparse_vector_pyo3_roundtrip_without_tokens() {
2543        ensure_python_interpreter();
2544
2545        pyo3::Python::with_gil(|py| {
2546            use pyo3::types::PyDict;
2547            use pyo3::IntoPyObject;
2548
2549            let dict_in = PyDict::new(py);
2550            dict_in.set_item("indices", vec![5u32]).unwrap();
2551            dict_in.set_item("values", vec![1.5f32]).unwrap();
2552
2553            let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2554            assert_eq!(sparse.indices, vec![5]);
2555            assert_eq!(sparse.values, vec![1.5]);
2556            assert!(sparse.tokens.is_none());
2557
2558            let py_obj = sparse.into_pyobject(py).unwrap();
2559            let dict_out = py_obj.downcast::<PyDict>().unwrap();
2560            let tokens_obj = dict_out.get_item("tokens").unwrap();
2561            let tokens_value = tokens_obj.expect("expected tokens key in Python dict");
2562            assert!(
2563                tokens_value.is_none(),
2564                "expected tokens value in Python dict to be None"
2565            );
2566        });
2567    }
2568
2569    #[test]
2570    fn test_simplifies_identities() {
2571        let all: Where = true.into();
2572        assert_eq!(all.clone() & all.clone(), true.into());
2573        assert_eq!(all.clone() | all.clone(), true.into());
2574
2575        let foo = Key::field("foo").eq("bar");
2576        assert_eq!(foo.clone() & all.clone(), foo.clone());
2577        assert_eq!(all.clone() & foo.clone(), foo.clone());
2578
2579        let none: Where = false.into();
2580        assert_eq!(foo.clone() | none.clone(), foo.clone());
2581        assert_eq!(none | foo.clone(), foo);
2582    }
2583
2584    #[test]
2585    fn test_flattens() {
2586        let foo = Key::field("foo").eq("bar");
2587        let baz = Key::field("baz").eq("quux");
2588
2589        let and_nested = foo.clone() & (baz.clone() & foo.clone());
2590        assert_eq!(
2591            and_nested,
2592            Where::Composite(CompositeExpression {
2593                operator: BooleanOperator::And,
2594                children: vec![foo.clone(), baz.clone(), foo.clone()]
2595            })
2596        );
2597
2598        let or_nested = foo.clone() | (baz.clone() | foo.clone());
2599        assert_eq!(
2600            or_nested,
2601            Where::Composite(CompositeExpression {
2602                operator: BooleanOperator::Or,
2603                children: vec![foo.clone(), baz.clone(), foo.clone()]
2604            })
2605        );
2606    }
2607}