1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7 cmp::Ordering,
8 collections::{HashMap, HashSet},
9 mem::size_of_val,
10 ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::PyAnyMethods;
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24 #[serde(rename = "#type")]
25 type_tag: Option<String>,
26 indices: Vec<u32>,
27 values: Vec<f32>,
28}
29
30#[derive(Clone, Debug, PartialEq)]
37#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
38pub struct SparseVector {
39 pub indices: Vec<u32>,
41 pub values: Vec<f32>,
43}
44
45impl<'de> Deserialize<'de> for SparseVector {
47 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
48 where
49 D: Deserializer<'de>,
50 {
51 let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
52
53 if let Some(type_tag) = &helper.type_tag {
55 if type_tag != "sparse_vector" {
56 return Err(serde::de::Error::custom(format!(
57 "Expected #type='sparse_vector', got '{}'",
58 type_tag
59 )));
60 }
61 }
62
63 Ok(SparseVector {
64 indices: helper.indices,
65 values: helper.values,
66 })
67 }
68}
69
70impl Serialize for SparseVector {
72 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73 where
74 S: Serializer,
75 {
76 let helper = SparseVectorSerdeHelper {
77 type_tag: Some("sparse_vector".to_string()),
78 indices: self.indices.clone(),
79 values: self.values.clone(),
80 };
81 helper.serialize(serializer)
82 }
83}
84
85impl SparseVector {
86 pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Self {
88 Self { indices, values }
89 }
90
91 pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
93 let (indices, values) = pairs.into_iter().unzip();
94 Self { indices, values }
95 }
96
97 pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
99 self.indices
100 .iter()
101 .copied()
102 .zip(self.values.iter().copied())
103 }
104
105 pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
107 if self.indices.len() != self.values.len() {
109 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
110 }
111
112 for i in 1..self.indices.len() {
114 if self.indices[i] <= self.indices[i - 1] {
115 return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
116 }
117 }
118
119 Ok(())
120 }
121}
122
123impl Eq for SparseVector {}
124
125impl Ord for SparseVector {
126 fn cmp(&self, other: &Self) -> Ordering {
127 self.indices.cmp(&other.indices).then_with(|| {
128 for (a, b) in self.values.iter().zip(other.values.iter()) {
129 match a.total_cmp(b) {
130 Ordering::Equal => continue,
131 other => return other,
132 }
133 }
134 self.values.len().cmp(&other.values.len())
135 })
136 }
137}
138
139impl PartialOrd for SparseVector {
140 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141 Some(self.cmp(other))
142 }
143}
144
145impl From<chroma_proto::SparseVector> for SparseVector {
146 fn from(proto: chroma_proto::SparseVector) -> Self {
147 SparseVector::new(proto.indices, proto.values)
148 }
149}
150
151impl From<SparseVector> for chroma_proto::SparseVector {
152 fn from(sparse: SparseVector) -> Self {
153 chroma_proto::SparseVector {
154 indices: sparse.indices,
155 values: sparse.values,
156 }
157 }
158}
159
160impl From<&SparseVector> for CsVec<f32> {
162 fn from(sparse: &SparseVector) -> Self {
163 let (indices, values) = sparse
164 .iter()
165 .map(|(index, value)| (index as usize, value))
166 .unzip();
167 CsVec::new(u32::MAX as usize, indices, values)
168 }
169}
170
171impl From<SparseVector> for CsVec<f32> {
172 fn from(sparse: SparseVector) -> Self {
173 (&sparse).into()
174 }
175}
176
177#[cfg(feature = "pyo3")]
178impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
179 type Target = pyo3::PyAny;
180 type Output = pyo3::Bound<'py, Self::Target>;
181 type Error = pyo3::PyErr;
182
183 fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
184 use pyo3::types::PyDict;
185 let dict = PyDict::new(py);
186 dict.set_item("indices", self.indices)?;
187 dict.set_item("values", self.values)?;
188 Ok(dict.into_any())
189 }
190}
191
192#[cfg(feature = "pyo3")]
193impl<'py> pyo3::FromPyObject<'py> for SparseVector {
194 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
195 use pyo3::types::PyDict;
196
197 let dict = ob.downcast::<PyDict>()?;
198 let indices_obj = dict.get_item("indices")?;
199 let values_obj = dict.get_item("values")?;
200
201 let indices: Vec<u32> = indices_obj.extract()?;
202 let values: Vec<f32> = values_obj.extract()?;
203
204 Ok(SparseVector::new(indices, values))
205 }
206}
207
208#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
209#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
210#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
211#[serde(untagged)]
212pub enum UpdateMetadataValue {
213 Bool(bool),
214 Int(i64),
215 #[cfg_attr(
216 feature = "testing",
217 proptest(
218 strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
219 )
220 )]
221 Float(f64),
222 Str(String),
223 #[cfg_attr(feature = "testing", proptest(skip))]
224 SparseVector(SparseVector),
225 None,
226}
227
228#[cfg(feature = "pyo3")]
229impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
230 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
231 if ob.is_none() {
232 Ok(UpdateMetadataValue::None)
233 } else if let Ok(value) = ob.extract::<bool>() {
234 Ok(UpdateMetadataValue::Bool(value))
235 } else if let Ok(value) = ob.extract::<i64>() {
236 Ok(UpdateMetadataValue::Int(value))
237 } else if let Ok(value) = ob.extract::<f64>() {
238 Ok(UpdateMetadataValue::Float(value))
239 } else if let Ok(value) = ob.extract::<String>() {
240 Ok(UpdateMetadataValue::Str(value))
241 } else if let Ok(value) = ob.extract::<SparseVector>() {
242 Ok(UpdateMetadataValue::SparseVector(value))
243 } else {
244 Err(pyo3::exceptions::PyTypeError::new_err(
245 "Cannot convert Python object to UpdateMetadataValue",
246 ))
247 }
248 }
249}
250
251impl From<bool> for UpdateMetadataValue {
252 fn from(b: bool) -> Self {
253 Self::Bool(b)
254 }
255}
256
257impl From<i64> for UpdateMetadataValue {
258 fn from(v: i64) -> Self {
259 Self::Int(v)
260 }
261}
262
263impl From<i32> for UpdateMetadataValue {
264 fn from(v: i32) -> Self {
265 Self::Int(v as i64)
266 }
267}
268
269impl From<f64> for UpdateMetadataValue {
270 fn from(v: f64) -> Self {
271 Self::Float(v)
272 }
273}
274
275impl From<f32> for UpdateMetadataValue {
276 fn from(v: f32) -> Self {
277 Self::Float(v as f64)
278 }
279}
280
281impl From<String> for UpdateMetadataValue {
282 fn from(v: String) -> Self {
283 Self::Str(v)
284 }
285}
286
287impl From<&str> for UpdateMetadataValue {
288 fn from(v: &str) -> Self {
289 Self::Str(v.to_string())
290 }
291}
292
293impl From<SparseVector> for UpdateMetadataValue {
294 fn from(v: SparseVector) -> Self {
295 Self::SparseVector(v)
296 }
297}
298
299#[derive(Error, Debug)]
300pub enum UpdateMetadataValueConversionError {
301 #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
302 InvalidValue,
303}
304
305impl ChromaError for UpdateMetadataValueConversionError {
306 fn code(&self) -> ErrorCodes {
307 match self {
308 UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
309 }
310 }
311}
312
313impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
314 type Error = UpdateMetadataValueConversionError;
315
316 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
317 match &value.value {
318 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
319 Ok(UpdateMetadataValue::Bool(*value))
320 }
321 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
322 Ok(UpdateMetadataValue::Int(*value))
323 }
324 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
325 Ok(UpdateMetadataValue::Float(*value))
326 }
327 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
328 Ok(UpdateMetadataValue::Str(value.clone()))
329 }
330 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
331 Ok(UpdateMetadataValue::SparseVector(value.clone().into()))
332 }
333 None => Ok(UpdateMetadataValue::None),
335 }
336 }
337}
338
339impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
340 fn from(value: UpdateMetadataValue) -> Self {
341 match value {
342 UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
343 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
344 },
345 UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
346 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
347 },
348 UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
349 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
350 value,
351 )),
352 },
353 UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
354 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
355 value,
356 )),
357 },
358 UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
359 value: Some(
360 chroma_proto::update_metadata_value::Value::SparseVectorValue(
361 sparse_vec.into(),
362 ),
363 ),
364 },
365 UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
366 }
367 }
368}
369
370impl TryFrom<&UpdateMetadataValue> for MetadataValue {
371 type Error = MetadataValueConversionError;
372
373 fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
374 match value {
375 UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
376 UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
377 UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
378 UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
379 UpdateMetadataValue::SparseVector(value) => {
380 Ok(MetadataValue::SparseVector(value.clone()))
381 }
382 UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
383 }
384 }
385}
386
387#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
394#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
395#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
396#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
397#[serde(untagged)]
398pub enum MetadataValue {
399 Bool(bool),
400 Int(i64),
401 #[cfg_attr(
402 feature = "testing",
403 proptest(
404 strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
405 )
406 )]
407 Float(f64),
408 Str(String),
409 #[cfg_attr(feature = "testing", proptest(skip))]
410 SparseVector(SparseVector),
411}
412
413impl std::fmt::Display for MetadataValue {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 match self {
416 MetadataValue::Bool(v) => write!(f, "{}", v),
417 MetadataValue::Int(v) => write!(f, "{}", v),
418 MetadataValue::Float(v) => write!(f, "{}", v),
419 MetadataValue::Str(v) => write!(f, "\"{}\"", v),
420 MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
421 }
422 }
423}
424
425impl Eq for MetadataValue {}
426
427#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
428pub enum MetadataValueType {
429 Bool,
430 Int,
431 Float,
432 Str,
433 SparseVector,
434}
435
436impl MetadataValue {
437 pub fn value_type(&self) -> MetadataValueType {
438 match self {
439 MetadataValue::Bool(_) => MetadataValueType::Bool,
440 MetadataValue::Int(_) => MetadataValueType::Int,
441 MetadataValue::Float(_) => MetadataValueType::Float,
442 MetadataValue::Str(_) => MetadataValueType::Str,
443 MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
444 }
445 }
446}
447
448impl From<&MetadataValue> for MetadataValueType {
449 fn from(value: &MetadataValue) -> Self {
450 value.value_type()
451 }
452}
453
454impl From<bool> for MetadataValue {
455 fn from(v: bool) -> Self {
456 MetadataValue::Bool(v)
457 }
458}
459
460impl From<i64> for MetadataValue {
461 fn from(v: i64) -> Self {
462 MetadataValue::Int(v)
463 }
464}
465
466impl From<i32> for MetadataValue {
467 fn from(v: i32) -> Self {
468 MetadataValue::Int(v as i64)
469 }
470}
471
472impl From<f64> for MetadataValue {
473 fn from(v: f64) -> Self {
474 MetadataValue::Float(v)
475 }
476}
477
478impl From<f32> for MetadataValue {
479 fn from(v: f32) -> Self {
480 MetadataValue::Float(v as f64)
481 }
482}
483
484impl From<String> for MetadataValue {
485 fn from(v: String) -> Self {
486 MetadataValue::Str(v)
487 }
488}
489
490impl From<&str> for MetadataValue {
491 fn from(v: &str) -> Self {
492 MetadataValue::Str(v.to_string())
493 }
494}
495
496impl From<SparseVector> for MetadataValue {
497 fn from(v: SparseVector) -> Self {
498 MetadataValue::SparseVector(v)
499 }
500}
501
502#[allow(clippy::derive_ord_xor_partial_ord)]
507impl Ord for MetadataValue {
508 fn cmp(&self, other: &Self) -> Ordering {
509 fn type_order(val: &MetadataValue) -> u8 {
511 match val {
512 MetadataValue::Bool(_) => 0,
513 MetadataValue::Int(_) => 1,
514 MetadataValue::Float(_) => 2,
515 MetadataValue::Str(_) => 3,
516 MetadataValue::SparseVector(_) => 4,
517 }
518 }
519
520 type_order(self).cmp(&type_order(other)).then_with(|| {
522 match (self, other) {
523 (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
524 (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
525 (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
526 (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
527 (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
528 left.cmp(right)
529 }
530 _ => Ordering::Equal, }
532 })
533 }
534}
535
536impl PartialOrd for MetadataValue {
537 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
538 Some(self.cmp(other))
539 }
540}
541
542impl TryFrom<&MetadataValue> for bool {
543 type Error = MetadataValueConversionError;
544
545 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
546 match value {
547 MetadataValue::Bool(value) => Ok(*value),
548 _ => Err(MetadataValueConversionError::InvalidValue),
549 }
550 }
551}
552
553impl TryFrom<&MetadataValue> for i64 {
554 type Error = MetadataValueConversionError;
555
556 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
557 match value {
558 MetadataValue::Int(value) => Ok(*value),
559 _ => Err(MetadataValueConversionError::InvalidValue),
560 }
561 }
562}
563
564impl TryFrom<&MetadataValue> for f64 {
565 type Error = MetadataValueConversionError;
566
567 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
568 match value {
569 MetadataValue::Float(value) => Ok(*value),
570 _ => Err(MetadataValueConversionError::InvalidValue),
571 }
572 }
573}
574
575impl TryFrom<&MetadataValue> for String {
576 type Error = MetadataValueConversionError;
577
578 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
579 match value {
580 MetadataValue::Str(value) => Ok(value.clone()),
581 _ => Err(MetadataValueConversionError::InvalidValue),
582 }
583 }
584}
585
586impl From<MetadataValue> for UpdateMetadataValue {
587 fn from(value: MetadataValue) -> Self {
588 match value {
589 MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
590 MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
591 MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
592 MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
593 MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
594 }
595 }
596}
597
598impl From<MetadataValue> for Value {
599 fn from(value: MetadataValue) -> Self {
600 match value {
601 MetadataValue::Bool(val) => Self::Bool(val),
602 MetadataValue::Int(val) => Self::Number(
603 Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
604 ),
605 MetadataValue::Float(val) => Self::Number(
606 Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
607 ),
608 MetadataValue::Str(val) => Self::String(val),
609 MetadataValue::SparseVector(val) => {
610 let mut map = serde_json::Map::new();
611 map.insert(
612 "indices".to_string(),
613 Value::Array(
614 val.indices
615 .iter()
616 .map(|&i| Value::Number(i.into()))
617 .collect(),
618 ),
619 );
620 map.insert(
621 "values".to_string(),
622 Value::Array(
623 val.values
624 .iter()
625 .map(|&v| {
626 Value::Number(
627 Number::from_f64(v as f64)
628 .expect("Float number should not be NaN or infinite"),
629 )
630 })
631 .collect(),
632 ),
633 );
634 Self::Object(map)
635 }
636 }
637 }
638}
639
640#[derive(Error, Debug)]
641pub enum MetadataValueConversionError {
642 #[error("Invalid metadata value, valid values are: Int, Float, Str")]
643 InvalidValue,
644 #[error("Metadata key cannot start with '#' or '$': {0}")]
645 InvalidKey(String),
646 #[error("Sparse vector indices and values must have the same length")]
647 SparseVectorLengthMismatch,
648 #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
649 SparseVectorIndicesNotSorted,
650}
651
652impl ChromaError for MetadataValueConversionError {
653 fn code(&self) -> ErrorCodes {
654 match self {
655 MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
656 MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
657 MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
658 MetadataValueConversionError::SparseVectorIndicesNotSorted => {
659 ErrorCodes::InvalidArgument
660 }
661 }
662 }
663}
664
665impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
666 type Error = MetadataValueConversionError;
667
668 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
669 match &value.value {
670 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
671 Ok(MetadataValue::Bool(*value))
672 }
673 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
674 Ok(MetadataValue::Int(*value))
675 }
676 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
677 Ok(MetadataValue::Float(*value))
678 }
679 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
680 Ok(MetadataValue::Str(value.clone()))
681 }
682 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
683 Ok(MetadataValue::SparseVector(value.clone().into()))
684 }
685 _ => Err(MetadataValueConversionError::InvalidValue),
686 }
687 }
688}
689
690impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
691 fn from(value: MetadataValue) -> Self {
692 match value {
693 MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
694 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
695 },
696 MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
697 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
698 value,
699 )),
700 },
701 MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
702 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
703 value,
704 )),
705 },
706 MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
707 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
708 },
709 MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
710 value: Some(
711 chroma_proto::update_metadata_value::Value::SparseVectorValue(
712 sparse_vec.into(),
713 ),
714 ),
715 },
716 }
717 }
718}
719
720pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
726
727pub fn are_update_metadatas_close_to_equal(
731 metadata1: &UpdateMetadata,
732 metadata2: &UpdateMetadata,
733) -> bool {
734 assert_eq!(metadata1.len(), metadata2.len());
735
736 for (key, value) in metadata1.iter() {
737 if !metadata2.contains_key(key) {
738 return false;
739 }
740 let other_value = metadata2.get(key).unwrap();
741
742 if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
743 (value, other_value)
744 {
745 if (value - other_value).abs() > 1e-6 {
746 return false;
747 }
748 } else if value != other_value {
749 return false;
750 }
751 }
752
753 true
754}
755
756pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
757 assert_eq!(metadata1.len(), metadata2.len());
758
759 for (key, value) in metadata1.iter() {
760 if !metadata2.contains_key(key) {
761 return false;
762 }
763 let other_value = metadata2.get(key).unwrap();
764
765 if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
766 (value, other_value)
767 {
768 if (value - other_value).abs() > 1e-6 {
769 return false;
770 }
771 } else if value != other_value {
772 return false;
773 }
774 }
775
776 true
777}
778
779impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
780 type Error = UpdateMetadataValueConversionError;
781
782 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
783 let mut metadata = UpdateMetadata::new();
784 for (key, value) in proto_metadata.metadata.iter() {
785 let value = match value.try_into() {
786 Ok(value) => value,
787 Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
788 };
789 metadata.insert(key.clone(), value);
790 }
791 Ok(metadata)
792 }
793}
794
795impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
796 fn from(metadata: UpdateMetadata) -> Self {
797 let mut metadata = metadata;
798 let mut proto_metadata = chroma_proto::UpdateMetadata {
799 metadata: HashMap::new(),
800 };
801 for (key, value) in metadata.drain() {
802 let proto_value = value.into();
803 proto_metadata.metadata.insert(key.clone(), proto_value);
804 }
805 proto_metadata
806 }
807}
808
809pub type Metadata = HashMap<String, MetadataValue>;
816pub type DeletedMetadata = HashSet<String>;
817
818pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
819 metadata
820 .iter()
821 .map(|(k, v)| {
822 k.len()
823 + match v {
824 MetadataValue::Bool(b) => size_of_val(b),
825 MetadataValue::Int(i) => size_of_val(i),
826 MetadataValue::Float(f) => size_of_val(f),
827 MetadataValue::Str(s) => s.len(),
828 MetadataValue::SparseVector(v) => {
829 size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
830 }
831 }
832 })
833 .sum()
834}
835
836pub fn get_metadata_value_as<'a, T>(
837 metadata: &'a Metadata,
838 key: &str,
839) -> Result<T, Box<MetadataValueConversionError>>
840where
841 T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
842{
843 let res = match metadata.get(key) {
844 Some(value) => T::try_from(value),
845 None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
846 };
847 match res {
848 Ok(value) => Ok(value),
849 Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
850 }
851}
852
853impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
854 type Error = MetadataValueConversionError;
855
856 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
857 let mut metadata = Metadata::new();
858 for (key, value) in proto_metadata.metadata.iter() {
859 let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
860 if maybe_value.is_err() {
861 return Err(MetadataValueConversionError::InvalidValue);
862 }
863 let value = maybe_value.unwrap();
864 metadata.insert(key.clone(), value);
865 }
866 Ok(metadata)
867 }
868}
869
870impl From<Metadata> for chroma_proto::UpdateMetadata {
871 fn from(metadata: Metadata) -> Self {
872 let mut metadata = metadata;
873 let mut proto_metadata = chroma_proto::UpdateMetadata {
874 metadata: HashMap::new(),
875 };
876 for (key, value) in metadata.drain() {
877 let proto_value = value.into();
878 proto_metadata.metadata.insert(key.clone(), proto_value);
879 }
880 proto_metadata
881 }
882}
883
884#[derive(Debug, Default)]
885pub struct MetadataDelta<'referred_data> {
886 pub metadata_to_update: HashMap<
887 &'referred_data str,
888 (&'referred_data MetadataValue, &'referred_data MetadataValue),
889 >,
890 pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
891 pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
892}
893
894impl MetadataDelta<'_> {
895 pub fn new() -> Self {
896 Self::default()
897 }
898}
899
900#[derive(Clone, Debug, Error, PartialEq)]
907pub enum WhereConversionError {
908 #[error("Error: {0}")]
909 Cause(String),
910 #[error("{0} -> {1}")]
911 Trace(String, Box<Self>),
912}
913
914impl WhereConversionError {
915 pub fn cause(msg: impl ToString) -> Self {
916 Self::Cause(msg.to_string())
917 }
918
919 pub fn trace(self, context: impl ToString) -> Self {
920 Self::Trace(context.to_string(), Box::new(self))
921 }
922}
923
924#[derive(Clone, Debug, PartialEq)]
932#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
933pub enum Where {
934 Composite(CompositeExpression),
935 Document(DocumentExpression),
936 Metadata(MetadataExpression),
937}
938
939impl std::fmt::Display for Where {
940 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941 match self {
942 Where::Composite(composite) => {
943 let fragment = composite
944 .children
945 .iter()
946 .map(|child| format!("{}", child))
947 .collect::<Vec<_>>()
948 .join(match composite.operator {
949 BooleanOperator::And => " & ",
950 BooleanOperator::Or => " | ",
951 });
952 write!(f, "({})", fragment)
953 }
954 Where::Metadata(expr) => write!(f, "{}", expr),
955 Where::Document(expr) => write!(f, "{}", expr),
956 }
957 }
958}
959
960impl serde::Serialize for Where {
961 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
962 where
963 S: Serializer,
964 {
965 match self {
966 Where::Composite(composite) => {
967 let mut map = serializer.serialize_map(Some(1))?;
968 let op_key = match composite.operator {
969 BooleanOperator::And => "$and",
970 BooleanOperator::Or => "$or",
971 };
972 map.serialize_entry(op_key, &composite.children)?;
973 map.end()
974 }
975 Where::Document(doc) => {
976 let mut outer_map = serializer.serialize_map(Some(1))?;
977 let mut inner_map = serde_json::Map::new();
978 let op_key = match doc.operator {
979 DocumentOperator::Contains => "$contains",
980 DocumentOperator::NotContains => "$not_contains",
981 DocumentOperator::Regex => "$regex",
982 DocumentOperator::NotRegex => "$not_regex",
983 };
984 inner_map.insert(
985 op_key.to_string(),
986 serde_json::Value::String(doc.pattern.clone()),
987 );
988 outer_map.serialize_entry("#document", &inner_map)?;
989 outer_map.end()
990 }
991 Where::Metadata(meta) => {
992 let mut outer_map = serializer.serialize_map(Some(1))?;
993 let mut inner_map = serde_json::Map::new();
994
995 match &meta.comparison {
996 MetadataComparison::Primitive(op, value) => {
997 let op_key = match op {
998 PrimitiveOperator::Equal => "$eq",
999 PrimitiveOperator::NotEqual => "$ne",
1000 PrimitiveOperator::GreaterThan => "$gt",
1001 PrimitiveOperator::GreaterThanOrEqual => "$gte",
1002 PrimitiveOperator::LessThan => "$lt",
1003 PrimitiveOperator::LessThanOrEqual => "$lte",
1004 };
1005 let value_json =
1006 serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1007 inner_map.insert(op_key.to_string(), value_json);
1008 }
1009 MetadataComparison::Set(op, set_value) => {
1010 let op_key = match op {
1011 SetOperator::In => "$in",
1012 SetOperator::NotIn => "$nin",
1013 };
1014 let values_json = match set_value {
1015 MetadataSetValue::Bool(v) => serde_json::to_value(v),
1016 MetadataSetValue::Int(v) => serde_json::to_value(v),
1017 MetadataSetValue::Float(v) => serde_json::to_value(v),
1018 MetadataSetValue::Str(v) => serde_json::to_value(v),
1019 }
1020 .map_err(serde::ser::Error::custom)?;
1021 inner_map.insert(op_key.to_string(), values_json);
1022 }
1023 }
1024
1025 outer_map.serialize_entry(&meta.key, &inner_map)?;
1026 outer_map.end()
1027 }
1028 }
1029 }
1030}
1031
1032impl From<bool> for Where {
1033 fn from(value: bool) -> Self {
1034 if value {
1035 Where::conjunction(vec![])
1036 } else {
1037 Where::disjunction(vec![])
1038 }
1039 }
1040}
1041
1042impl Where {
1043 pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1044 let mut children: Vec<_> = children
1049 .into_iter()
1050 .flat_map(|expr| {
1051 if let Where::Composite(CompositeExpression {
1052 operator: BooleanOperator::And,
1053 children,
1054 }) = expr
1055 {
1056 return children;
1057 }
1058 vec![expr]
1059 })
1060 .dedup()
1061 .collect();
1062
1063 if children.len() == 1 {
1064 return children.pop().expect("just checked len is 1");
1065 }
1066
1067 Self::Composite(CompositeExpression {
1068 operator: BooleanOperator::And,
1069 children,
1070 })
1071 }
1072 pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1073 let mut children: Vec<_> = children
1078 .into_iter()
1079 .flat_map(|expr| {
1080 if let Where::Composite(CompositeExpression {
1081 operator: BooleanOperator::Or,
1082 children,
1083 }) = expr
1084 {
1085 return children;
1086 }
1087 vec![expr]
1088 })
1089 .dedup()
1090 .collect();
1091
1092 if children.len() == 1 {
1093 return children.pop().expect("just checked len is 1");
1094 }
1095
1096 Self::Composite(CompositeExpression {
1097 operator: BooleanOperator::Or,
1098 children,
1099 })
1100 }
1101
1102 pub fn fts_query_length(&self) -> u64 {
1103 match self {
1104 Where::Composite(composite_expression) => composite_expression
1105 .children
1106 .iter()
1107 .map(Where::fts_query_length)
1108 .sum(),
1109 Where::Document(document_expression) => {
1111 document_expression.pattern.len().max(3) as u64 - 2
1112 }
1113 Where::Metadata(_) => 0,
1114 }
1115 }
1116
1117 pub fn metadata_predicate_count(&self) -> u64 {
1118 match self {
1119 Where::Composite(composite_expression) => composite_expression
1120 .children
1121 .iter()
1122 .map(Where::metadata_predicate_count)
1123 .sum(),
1124 Where::Document(_) => 0,
1125 Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1126 MetadataComparison::Primitive(_, _) => 1,
1127 MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1128 MetadataSetValue::Bool(items) => items.len() as u64,
1129 MetadataSetValue::Int(items) => items.len() as u64,
1130 MetadataSetValue::Float(items) => items.len() as u64,
1131 MetadataSetValue::Str(items) => items.len() as u64,
1132 },
1133 },
1134 }
1135 }
1136}
1137
1138impl BitAnd for Where {
1139 type Output = Where;
1140
1141 fn bitand(self, rhs: Self) -> Self::Output {
1142 Self::conjunction([self, rhs])
1143 }
1144}
1145
1146impl BitOr for Where {
1147 type Output = Where;
1148
1149 fn bitor(self, rhs: Self) -> Self::Output {
1150 Self::disjunction([self, rhs])
1151 }
1152}
1153
1154impl TryFrom<chroma_proto::Where> for Where {
1155 type Error = WhereConversionError;
1156
1157 fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1158 let where_inner = proto_where
1159 .r#where
1160 .ok_or(WhereConversionError::cause("Invalid Where"))?;
1161 Ok(match where_inner {
1162 chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1163 Self::Metadata(direct_comparison.try_into()?)
1164 }
1165 chroma_proto::r#where::Where::Children(where_children) => {
1166 Self::Composite(where_children.try_into()?)
1167 }
1168 chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1169 Self::Document(direct_where_document.into())
1170 }
1171 })
1172 }
1173}
1174
1175impl TryFrom<Where> for chroma_proto::Where {
1176 type Error = WhereConversionError;
1177
1178 fn try_from(value: Where) -> Result<Self, Self::Error> {
1179 let proto_where = match value {
1180 Where::Composite(composite_expression) => {
1181 chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1182 }
1183 Where::Document(document_expression) => {
1184 chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1185 }
1186 Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1187 chroma_proto::DirectComparison::try_from(metadata_expression)
1188 .map_err(|err| err.trace("MetadataExpression"))?,
1189 ),
1190 };
1191 Ok(Self {
1192 r#where: Some(proto_where),
1193 })
1194 }
1195}
1196
1197#[derive(Clone, Debug, PartialEq)]
1198#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1199pub struct CompositeExpression {
1200 pub operator: BooleanOperator,
1201 pub children: Vec<Where>,
1202}
1203
1204impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1205 type Error = WhereConversionError;
1206
1207 fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1208 let operator = proto_children.operator().into();
1209 let children = proto_children
1210 .children
1211 .into_iter()
1212 .map(Where::try_from)
1213 .collect::<Result<Vec<_>, _>>()
1214 .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1215 Ok(Self { operator, children })
1216 }
1217}
1218
1219impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1220 type Error = WhereConversionError;
1221
1222 fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1223 Ok(Self {
1224 operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1225 children: value
1226 .children
1227 .into_iter()
1228 .map(chroma_proto::Where::try_from)
1229 .collect::<Result<_, _>>()?,
1230 })
1231 }
1232}
1233
1234#[derive(Clone, Debug, PartialEq)]
1235#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1236pub enum BooleanOperator {
1237 And,
1238 Or,
1239}
1240
1241impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1242 fn from(value: chroma_proto::BooleanOperator) -> Self {
1243 match value {
1244 chroma_proto::BooleanOperator::And => Self::And,
1245 chroma_proto::BooleanOperator::Or => Self::Or,
1246 }
1247 }
1248}
1249
1250impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1251 fn from(value: BooleanOperator) -> Self {
1252 match value {
1253 BooleanOperator::And => Self::And,
1254 BooleanOperator::Or => Self::Or,
1255 }
1256 }
1257}
1258
1259#[derive(Clone, Debug, PartialEq)]
1260#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1261pub struct DocumentExpression {
1262 pub operator: DocumentOperator,
1263 pub pattern: String,
1264}
1265
1266impl std::fmt::Display for DocumentExpression {
1267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1268 let op_str = match self.operator {
1269 DocumentOperator::Contains => "CONTAINS",
1270 DocumentOperator::NotContains => "NOT CONTAINS",
1271 DocumentOperator::Regex => "REGEX",
1272 DocumentOperator::NotRegex => "NOT REGEX",
1273 };
1274 write!(f, "#document {} \"{}\"", op_str, self.pattern)
1275 }
1276}
1277
1278impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1279 fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1280 Self {
1281 operator: value.operator().into(),
1282 pattern: value.pattern,
1283 }
1284 }
1285}
1286
1287impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1288 fn from(value: DocumentExpression) -> Self {
1289 Self {
1290 pattern: value.pattern,
1291 operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1292 }
1293 }
1294}
1295
1296#[derive(Clone, Debug, PartialEq)]
1297#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1298pub enum DocumentOperator {
1299 Contains,
1300 NotContains,
1301 Regex,
1302 NotRegex,
1303}
1304impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1305 fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1306 match value {
1307 chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1308 chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1309 chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1310 chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1311 }
1312 }
1313}
1314
1315impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1316 fn from(value: DocumentOperator) -> Self {
1317 match value {
1318 DocumentOperator::Contains => Self::Contains,
1319 DocumentOperator::NotContains => Self::NotContains,
1320 DocumentOperator::Regex => Self::Regex,
1321 DocumentOperator::NotRegex => Self::NotRegex,
1322 }
1323 }
1324}
1325
1326#[derive(Clone, Debug, PartialEq)]
1327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1328pub struct MetadataExpression {
1329 pub key: String,
1330 pub comparison: MetadataComparison,
1331}
1332
1333impl std::fmt::Display for MetadataExpression {
1334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1335 match &self.comparison {
1336 MetadataComparison::Primitive(op, value) => {
1337 write!(f, "{} {} {}", self.key, op, value)
1338 }
1339 MetadataComparison::Set(op, set_value) => {
1340 write!(f, "{} {} {}", self.key, op, set_value)
1341 }
1342 }
1343 }
1344}
1345
1346impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1347 type Error = WhereConversionError;
1348
1349 fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1350 let proto_comparison = value
1351 .comparison
1352 .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1353 let comparison = match proto_comparison {
1354 chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1355 single_string_comparison,
1356 ) => MetadataComparison::Primitive(
1357 single_string_comparison.comparator().into(),
1358 MetadataValue::Str(single_string_comparison.value),
1359 ),
1360 chroma_proto::direct_comparison::Comparison::StringListOperand(
1361 string_list_comparison,
1362 ) => MetadataComparison::Set(
1363 string_list_comparison.list_operator().into(),
1364 MetadataSetValue::Str(string_list_comparison.values),
1365 ),
1366 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1367 single_int_comparison,
1368 ) => MetadataComparison::Primitive(
1369 match single_int_comparison
1370 .comparator
1371 .ok_or(WhereConversionError::cause(
1372 "Invalid scalar integer operator",
1373 ))? {
1374 chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1375 chroma_proto::GenericComparator::try_from(op)
1376 .map_err(WhereConversionError::cause)?
1377 .into()
1378 }
1379 chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1380 chroma_proto::NumberComparator::try_from(op)
1381 .map_err(WhereConversionError::cause)?
1382 .into()
1383 }
1384 },
1385 MetadataValue::Int(single_int_comparison.value),
1386 ),
1387 chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1388 MetadataComparison::Set(
1389 int_list_comparison.list_operator().into(),
1390 MetadataSetValue::Int(int_list_comparison.values),
1391 )
1392 }
1393 chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1394 single_double_comparison,
1395 ) => MetadataComparison::Primitive(
1396 match single_double_comparison
1397 .comparator
1398 .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1399 {
1400 chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1401 chroma_proto::GenericComparator::try_from(op)
1402 .map_err(WhereConversionError::cause)?
1403 .into()
1404 }
1405 chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1406 chroma_proto::NumberComparator::try_from(op)
1407 .map_err(WhereConversionError::cause)?
1408 .into()
1409 }
1410 },
1411 MetadataValue::Float(single_double_comparison.value),
1412 ),
1413 chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1414 double_list_comparison,
1415 ) => MetadataComparison::Set(
1416 double_list_comparison.list_operator().into(),
1417 MetadataSetValue::Float(double_list_comparison.values),
1418 ),
1419 chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1420 MetadataComparison::Set(
1421 bool_list_comparison.list_operator().into(),
1422 MetadataSetValue::Bool(bool_list_comparison.values),
1423 )
1424 }
1425 chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1426 single_bool_comparison,
1427 ) => MetadataComparison::Primitive(
1428 single_bool_comparison.comparator().into(),
1429 MetadataValue::Bool(single_bool_comparison.value),
1430 ),
1431 };
1432 Ok(Self {
1433 key: value.key,
1434 comparison,
1435 })
1436 }
1437}
1438
1439impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1440 type Error = WhereConversionError;
1441
1442 fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1443 let comparison = match value.comparison {
1444 MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1445 MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1446 MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1447 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1448 numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1449 }),
1450 MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1451 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1452 numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1453 }),
1454 MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1455 MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1456 },
1457 MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1458 MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1459 MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1460 MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1461 MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1462 },
1463 };
1464 Ok(Self {
1465 key: value.key,
1466 comparison: Some(comparison),
1467 })
1468 }
1469}
1470
1471#[derive(Clone, Debug, PartialEq)]
1472#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1473pub enum MetadataComparison {
1474 Primitive(PrimitiveOperator, MetadataValue),
1475 Set(SetOperator, MetadataSetValue),
1476}
1477
1478#[derive(Clone, Debug, PartialEq)]
1479#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1480pub enum PrimitiveOperator {
1481 Equal,
1482 NotEqual,
1483 GreaterThan,
1484 GreaterThanOrEqual,
1485 LessThan,
1486 LessThanOrEqual,
1487}
1488
1489impl std::fmt::Display for PrimitiveOperator {
1490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1491 let op_str = match self {
1492 PrimitiveOperator::Equal => "=",
1493 PrimitiveOperator::NotEqual => "≠",
1494 PrimitiveOperator::GreaterThan => ">",
1495 PrimitiveOperator::GreaterThanOrEqual => "≥",
1496 PrimitiveOperator::LessThan => "<",
1497 PrimitiveOperator::LessThanOrEqual => "≤",
1498 };
1499 write!(f, "{}", op_str)
1500 }
1501}
1502
1503impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1504 fn from(value: chroma_proto::GenericComparator) -> Self {
1505 match value {
1506 chroma_proto::GenericComparator::Eq => Self::Equal,
1507 chroma_proto::GenericComparator::Ne => Self::NotEqual,
1508 }
1509 }
1510}
1511
1512impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1513 type Error = WhereConversionError;
1514
1515 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1516 match value {
1517 PrimitiveOperator::Equal => Ok(Self::Eq),
1518 PrimitiveOperator::NotEqual => Ok(Self::Ne),
1519 op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1520 }
1521 }
1522}
1523
1524impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1525 fn from(value: chroma_proto::NumberComparator) -> Self {
1526 match value {
1527 chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1528 chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1529 chroma_proto::NumberComparator::Lt => Self::LessThan,
1530 chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1531 }
1532 }
1533}
1534
1535impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1536 type Error = WhereConversionError;
1537
1538 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1539 match value {
1540 PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1541 PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1542 PrimitiveOperator::LessThan => Ok(Self::Lt),
1543 PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1544 op => Err(WhereConversionError::cause(format!(
1545 "{op:?} ∉ [≤, <, >, ≥]"
1546 ))),
1547 }
1548 }
1549}
1550
1551#[derive(Clone, Debug, PartialEq, Eq)]
1552#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1553pub enum SetOperator {
1554 In,
1555 NotIn,
1556}
1557
1558impl std::fmt::Display for SetOperator {
1559 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1560 let op_str = match self {
1561 SetOperator::In => "∈",
1562 SetOperator::NotIn => "∉",
1563 };
1564 write!(f, "{}", op_str)
1565 }
1566}
1567
1568impl From<chroma_proto::ListOperator> for SetOperator {
1569 fn from(value: chroma_proto::ListOperator) -> Self {
1570 match value {
1571 chroma_proto::ListOperator::In => Self::In,
1572 chroma_proto::ListOperator::Nin => Self::NotIn,
1573 }
1574 }
1575}
1576
1577impl From<SetOperator> for chroma_proto::ListOperator {
1578 fn from(value: SetOperator) -> Self {
1579 match value {
1580 SetOperator::In => Self::In,
1581 SetOperator::NotIn => Self::Nin,
1582 }
1583 }
1584}
1585
1586#[derive(Clone, Debug, PartialEq)]
1587#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1588pub enum MetadataSetValue {
1589 Bool(Vec<bool>),
1590 Int(Vec<i64>),
1591 Float(Vec<f64>),
1592 Str(Vec<String>),
1593}
1594
1595impl std::fmt::Display for MetadataSetValue {
1596 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1597 match self {
1598 MetadataSetValue::Bool(values) => {
1599 let values_str = values
1600 .iter()
1601 .map(|v| format!("\"{}\"", v))
1602 .collect::<Vec<_>>()
1603 .join(", ");
1604 write!(f, "[{}]", values_str)
1605 }
1606 MetadataSetValue::Int(values) => {
1607 let values_str = values
1608 .iter()
1609 .map(|v| v.to_string())
1610 .collect::<Vec<_>>()
1611 .join(", ");
1612 write!(f, "[{}]", values_str)
1613 }
1614 MetadataSetValue::Float(values) => {
1615 let values_str = values
1616 .iter()
1617 .map(|v| v.to_string())
1618 .collect::<Vec<_>>()
1619 .join(", ");
1620 write!(f, "[{}]", values_str)
1621 }
1622 MetadataSetValue::Str(values) => {
1623 let values_str = values
1624 .iter()
1625 .map(|v| format!("\"{}\"", v))
1626 .collect::<Vec<_>>()
1627 .join(", ");
1628 write!(f, "[{}]", values_str)
1629 }
1630 }
1631 }
1632}
1633
1634impl MetadataSetValue {
1635 pub fn value_type(&self) -> MetadataValueType {
1636 match self {
1637 MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1638 MetadataSetValue::Int(_) => MetadataValueType::Int,
1639 MetadataSetValue::Float(_) => MetadataValueType::Float,
1640 MetadataSetValue::Str(_) => MetadataValueType::Str,
1641 }
1642 }
1643}
1644
1645impl From<Vec<bool>> for MetadataSetValue {
1646 fn from(values: Vec<bool>) -> Self {
1647 MetadataSetValue::Bool(values)
1648 }
1649}
1650
1651impl From<Vec<i64>> for MetadataSetValue {
1652 fn from(values: Vec<i64>) -> Self {
1653 MetadataSetValue::Int(values)
1654 }
1655}
1656
1657impl From<Vec<i32>> for MetadataSetValue {
1658 fn from(values: Vec<i32>) -> Self {
1659 MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1660 }
1661}
1662
1663impl From<Vec<f64>> for MetadataSetValue {
1664 fn from(values: Vec<f64>) -> Self {
1665 MetadataSetValue::Float(values)
1666 }
1667}
1668
1669impl From<Vec<f32>> for MetadataSetValue {
1670 fn from(values: Vec<f32>) -> Self {
1671 MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1672 }
1673}
1674
1675impl From<Vec<String>> for MetadataSetValue {
1676 fn from(values: Vec<String>) -> Self {
1677 MetadataSetValue::Str(values)
1678 }
1679}
1680
1681impl From<Vec<&str>> for MetadataSetValue {
1682 fn from(values: Vec<&str>) -> Self {
1683 MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1684 }
1685}
1686
1687impl TryFrom<chroma_proto::WhereDocument> for Where {
1689 type Error = WhereConversionError;
1690
1691 fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1692 match proto_document.r#where_document {
1693 Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1694 let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1695 proto_comparison.operator,
1696 ) {
1697 Ok(operator) => operator,
1698 Err(_) => {
1699 return Err(WhereConversionError::cause(
1700 "[Deprecated] Invalid where document operator",
1701 ))
1702 }
1703 };
1704 let comparison = DocumentExpression {
1705 pattern: proto_comparison.pattern,
1706 operator: operator.into(),
1707 };
1708 Ok(Where::Document(comparison))
1709 }
1710 Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1711 let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1712 proto_children.operator,
1713 ) {
1714 Ok(operator) => operator,
1715 Err(_) => {
1716 return Err(WhereConversionError::cause(
1717 "[Deprecated] Invalid boolean operator",
1718 ))
1719 }
1720 };
1721 let children = CompositeExpression {
1722 children: proto_children
1723 .children
1724 .into_iter()
1725 .map(|child| child.try_into())
1726 .collect::<Result<_, _>>()?,
1727 operator: operator.into(),
1728 };
1729 Ok(Where::Composite(children))
1730 }
1731 None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1732 }
1733 }
1734}
1735
1736#[cfg(test)]
1737mod tests {
1738 use crate::operator::Key;
1739
1740 use super::*;
1741
1742 #[test]
1743 fn test_update_metadata_try_from() {
1744 let mut proto_metadata = chroma_proto::UpdateMetadata {
1745 metadata: HashMap::new(),
1746 };
1747 proto_metadata.metadata.insert(
1748 "foo".to_string(),
1749 chroma_proto::UpdateMetadataValue {
1750 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1751 },
1752 );
1753 proto_metadata.metadata.insert(
1754 "bar".to_string(),
1755 chroma_proto::UpdateMetadataValue {
1756 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1757 },
1758 );
1759 proto_metadata.metadata.insert(
1760 "baz".to_string(),
1761 chroma_proto::UpdateMetadataValue {
1762 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1763 "42".to_string(),
1764 )),
1765 },
1766 );
1767 proto_metadata.metadata.insert(
1769 "sparse".to_string(),
1770 chroma_proto::UpdateMetadataValue {
1771 value: Some(
1772 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1773 chroma_proto::SparseVector {
1774 indices: vec![0, 5, 10],
1775 values: vec![0.1, 0.5, 0.9],
1776 },
1777 ),
1778 ),
1779 },
1780 );
1781 let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1782 assert_eq!(converted_metadata.len(), 4);
1783 assert_eq!(
1784 converted_metadata.get("foo").unwrap(),
1785 &UpdateMetadataValue::Int(42)
1786 );
1787 assert_eq!(
1788 converted_metadata.get("bar").unwrap(),
1789 &UpdateMetadataValue::Float(42.0)
1790 );
1791 assert_eq!(
1792 converted_metadata.get("baz").unwrap(),
1793 &UpdateMetadataValue::Str("42".to_string())
1794 );
1795 assert_eq!(
1796 converted_metadata.get("sparse").unwrap(),
1797 &UpdateMetadataValue::SparseVector(SparseVector::new(
1798 vec![0, 5, 10],
1799 vec![0.1, 0.5, 0.9]
1800 ))
1801 );
1802 }
1803
1804 #[test]
1805 fn test_metadata_try_from() {
1806 let mut proto_metadata = chroma_proto::UpdateMetadata {
1807 metadata: HashMap::new(),
1808 };
1809 proto_metadata.metadata.insert(
1810 "foo".to_string(),
1811 chroma_proto::UpdateMetadataValue {
1812 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1813 },
1814 );
1815 proto_metadata.metadata.insert(
1816 "bar".to_string(),
1817 chroma_proto::UpdateMetadataValue {
1818 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1819 },
1820 );
1821 proto_metadata.metadata.insert(
1822 "baz".to_string(),
1823 chroma_proto::UpdateMetadataValue {
1824 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1825 "42".to_string(),
1826 )),
1827 },
1828 );
1829 proto_metadata.metadata.insert(
1831 "sparse".to_string(),
1832 chroma_proto::UpdateMetadataValue {
1833 value: Some(
1834 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1835 chroma_proto::SparseVector {
1836 indices: vec![1, 10, 100],
1837 values: vec![0.2, 0.4, 0.6],
1838 },
1839 ),
1840 ),
1841 },
1842 );
1843 let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1844 assert_eq!(converted_metadata.len(), 4);
1845 assert_eq!(
1846 converted_metadata.get("foo").unwrap(),
1847 &MetadataValue::Int(42)
1848 );
1849 assert_eq!(
1850 converted_metadata.get("bar").unwrap(),
1851 &MetadataValue::Float(42.0)
1852 );
1853 assert_eq!(
1854 converted_metadata.get("baz").unwrap(),
1855 &MetadataValue::Str("42".to_string())
1856 );
1857 assert_eq!(
1858 converted_metadata.get("sparse").unwrap(),
1859 &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]))
1860 );
1861 }
1862
1863 #[test]
1864 fn test_where_clause_simple_from() {
1865 let proto_where = chroma_proto::Where {
1866 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1867 chroma_proto::DirectComparison {
1868 key: "foo".to_string(),
1869 comparison: Some(
1870 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1871 chroma_proto::SingleIntComparison {
1872 value: 42,
1873 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1874 },
1875 ),
1876 ),
1877 },
1878 )),
1879 };
1880 let where_clause: Where = proto_where.try_into().unwrap();
1881 match where_clause {
1882 Where::Metadata(comparison) => {
1883 assert_eq!(comparison.key, "foo");
1884 match comparison.comparison {
1885 MetadataComparison::Primitive(_, value) => {
1886 assert_eq!(value, MetadataValue::Int(42));
1887 }
1888 _ => panic!("Invalid comparison type"),
1889 }
1890 }
1891 _ => panic!("Invalid where type"),
1892 }
1893 }
1894
1895 #[test]
1896 fn test_where_clause_with_children() {
1897 let proto_where = chroma_proto::Where {
1898 r#where: Some(chroma_proto::r#where::Where::Children(
1899 chroma_proto::WhereChildren {
1900 children: vec![
1901 chroma_proto::Where {
1902 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1903 chroma_proto::DirectComparison {
1904 key: "foo".to_string(),
1905 comparison: Some(
1906 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1907 chroma_proto::SingleIntComparison {
1908 value: 42,
1909 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1910 },
1911 ),
1912 ),
1913 },
1914 )),
1915 },
1916 chroma_proto::Where {
1917 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1918 chroma_proto::DirectComparison {
1919 key: "bar".to_string(),
1920 comparison: Some(
1921 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1922 chroma_proto::SingleIntComparison {
1923 value: 42,
1924 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1925 },
1926 ),
1927 ),
1928 },
1929 )),
1930 },
1931 ],
1932 operator: chroma_proto::BooleanOperator::And.into(),
1933 },
1934 )),
1935 };
1936 let where_clause: Where = proto_where.try_into().unwrap();
1937 match where_clause {
1938 Where::Composite(children) => {
1939 assert_eq!(children.children.len(), 2);
1940 assert_eq!(children.operator, BooleanOperator::And);
1941 }
1942 _ => panic!("Invalid where type"),
1943 }
1944 }
1945
1946 #[test]
1947 fn test_where_document_simple() {
1948 let proto_where = chroma_proto::WhereDocument {
1949 r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
1950 chroma_proto::DirectWhereDocument {
1951 pattern: "foo".to_string(),
1952 operator: chroma_proto::WhereDocumentOperator::Contains.into(),
1953 },
1954 )),
1955 };
1956 let where_document: Where = proto_where.try_into().unwrap();
1957 match where_document {
1958 Where::Document(comparison) => {
1959 assert_eq!(comparison.pattern, "foo");
1960 assert_eq!(comparison.operator, DocumentOperator::Contains);
1961 }
1962 _ => panic!("Invalid where document type"),
1963 }
1964 }
1965
1966 #[test]
1967 fn test_where_document_with_children() {
1968 let proto_where = chroma_proto::WhereDocument {
1969 r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
1970 chroma_proto::WhereDocumentChildren {
1971 children: vec![
1972 chroma_proto::WhereDocument {
1973 r#where_document: Some(
1974 chroma_proto::where_document::WhereDocument::Direct(
1975 chroma_proto::DirectWhereDocument {
1976 pattern: "foo".to_string(),
1977 operator: chroma_proto::WhereDocumentOperator::Contains
1978 .into(),
1979 },
1980 ),
1981 ),
1982 },
1983 chroma_proto::WhereDocument {
1984 r#where_document: Some(
1985 chroma_proto::where_document::WhereDocument::Direct(
1986 chroma_proto::DirectWhereDocument {
1987 pattern: "bar".to_string(),
1988 operator: chroma_proto::WhereDocumentOperator::Contains
1989 .into(),
1990 },
1991 ),
1992 ),
1993 },
1994 ],
1995 operator: chroma_proto::BooleanOperator::And.into(),
1996 },
1997 )),
1998 };
1999 let where_document: Where = proto_where.try_into().unwrap();
2000 match where_document {
2001 Where::Composite(children) => {
2002 assert_eq!(children.children.len(), 2);
2003 assert_eq!(children.operator, BooleanOperator::And);
2004 }
2005 _ => panic!("Invalid where document type"),
2006 }
2007 }
2008
2009 #[test]
2010 fn test_sparse_vector_new() {
2011 let indices = vec![0, 5, 10];
2012 let values = vec![0.1, 0.5, 0.9];
2013 let sparse = SparseVector::new(indices.clone(), values.clone());
2014 assert_eq!(sparse.indices, indices);
2015 assert_eq!(sparse.values, values);
2016 }
2017
2018 #[test]
2019 fn test_sparse_vector_from_pairs() {
2020 let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2021 let sparse = SparseVector::from_pairs(pairs.clone());
2022 assert_eq!(sparse.indices, vec![0, 5, 10]);
2023 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2024 }
2025
2026 #[test]
2027 fn test_sparse_vector_iter() {
2028 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
2029 let collected: Vec<(u32, f32)> = sparse.iter().collect();
2030 assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2031 }
2032
2033 #[test]
2034 fn test_sparse_vector_ordering() {
2035 let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
2036 let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
2037 let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]);
2038 let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]);
2039
2040 assert_eq!(sparse1, sparse2);
2041 assert!(sparse1 < sparse3);
2042 assert!(sparse1 < sparse4);
2043 }
2044
2045 #[test]
2046 fn test_sparse_vector_proto_conversion() {
2047 let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]);
2048 let proto: chroma_proto::SparseVector = sparse.clone().into();
2049 assert_eq!(proto.indices, vec![1, 10, 100]);
2050 assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2051
2052 let converted: SparseVector = proto.into();
2053 assert_eq!(converted, sparse);
2054 }
2055
2056 #[test]
2057 fn test_sparse_vector_logical_size() {
2058 let metadata = Metadata::from([(
2059 "sparse".to_string(),
2060 MetadataValue::SparseVector(SparseVector::new(
2061 vec![0, 1, 2, 3, 4],
2062 vec![0.1, 0.2, 0.3, 0.4, 0.5],
2063 )),
2064 )]);
2065
2066 let size = logical_size_of_metadata(&metadata);
2067 assert_eq!(size, 46);
2070 }
2071
2072 #[test]
2073 fn test_sparse_vector_validation() {
2074 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]);
2076 assert!(sparse.validate().is_ok());
2077
2078 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2080 let result = sparse.validate();
2081 assert!(result.is_err());
2082 assert!(matches!(
2083 result.unwrap_err(),
2084 MetadataValueConversionError::SparseVectorLengthMismatch
2085 ));
2086
2087 let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]);
2089 let result = sparse.validate();
2090 assert!(result.is_err());
2091 assert!(matches!(
2092 result.unwrap_err(),
2093 MetadataValueConversionError::SparseVectorIndicesNotSorted
2094 ));
2095
2096 let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]);
2098 let result = sparse.validate();
2099 assert!(result.is_err());
2100 assert!(matches!(
2101 result.unwrap_err(),
2102 MetadataValueConversionError::SparseVectorIndicesNotSorted
2103 ));
2104
2105 let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]);
2107 let result = sparse.validate();
2108 assert!(result.is_err());
2109 assert!(matches!(
2110 result.unwrap_err(),
2111 MetadataValueConversionError::SparseVectorIndicesNotSorted
2112 ));
2113 }
2114
2115 #[test]
2116 fn test_sparse_vector_deserialize_old_format() {
2117 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2119 let sv: SparseVector = serde_json::from_str(json).unwrap();
2120 assert_eq!(sv.indices, vec![0, 1, 2]);
2121 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2122 }
2123
2124 #[test]
2125 fn test_sparse_vector_deserialize_new_format() {
2126 let json =
2128 "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2129 let sv: SparseVector = serde_json::from_str(json).unwrap();
2130 assert_eq!(sv.indices, vec![0, 1, 2]);
2131 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2132 }
2133
2134 #[test]
2135 fn test_sparse_vector_deserialize_new_format_field_order() {
2136 let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2138 let sv: SparseVector = serde_json::from_str(json).unwrap();
2139 assert_eq!(sv.indices, vec![5, 10]);
2140 assert_eq!(sv.values, vec![0.5, 1.0]);
2141 }
2142
2143 #[test]
2144 fn test_sparse_vector_deserialize_wrong_type_tag() {
2145 let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2147 let result: Result<SparseVector, _> = serde_json::from_str(json);
2148 assert!(result.is_err());
2149 let err_msg = result.unwrap_err().to_string();
2150 assert!(err_msg.contains("sparse_vector"));
2151 }
2152
2153 #[test]
2154 fn test_sparse_vector_serialize_always_has_type() {
2155 let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
2157 let json = serde_json::to_value(&sv).unwrap();
2158
2159 assert_eq!(json["#type"], "sparse_vector");
2160 assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2161 assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2162 }
2163
2164 #[test]
2165 fn test_sparse_vector_roundtrip_with_type() {
2166 let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]);
2168 let json = serde_json::to_string(&original).unwrap();
2169
2170 assert!(json.contains("\"#type\":\"sparse_vector\""));
2172
2173 let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2174 assert_eq!(original, deserialized);
2175 }
2176
2177 #[test]
2178 fn test_sparse_vector_in_metadata_old_format() {
2179 let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2181 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2182
2183 let sparse_value = &map["sparse"];
2184 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2185 assert_eq!(sv.indices, vec![0, 1]);
2186 assert_eq!(sv.values, vec![1.0, 2.0]);
2187 }
2188
2189 #[test]
2190 fn test_sparse_vector_in_metadata_new_format() {
2191 let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2193 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2194
2195 let sparse_value = &map["sparse"];
2196 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2197 assert_eq!(sv.indices, vec![0, 1]);
2198 assert_eq!(sv.values, vec![1.0, 2.0]);
2199 }
2200
2201 #[test]
2202 fn test_simplifies_identities() {
2203 let all: Where = true.into();
2204 assert_eq!(all.clone() & all.clone(), true.into());
2205 assert_eq!(all.clone() | all.clone(), true.into());
2206
2207 let foo = Key::field("foo").eq("bar");
2208 assert_eq!(foo.clone() & all.clone(), foo.clone());
2209 assert_eq!(all.clone() & foo.clone(), foo.clone());
2210
2211 let none: Where = false.into();
2212 assert_eq!(foo.clone() | none.clone(), foo.clone());
2213 assert_eq!(none | foo.clone(), foo);
2214 }
2215
2216 #[test]
2217 fn test_flattens() {
2218 let foo = Key::field("foo").eq("bar");
2219 let baz = Key::field("baz").eq("quux");
2220
2221 let and_nested = foo.clone() & (baz.clone() & foo.clone());
2222 assert_eq!(
2223 and_nested,
2224 Where::Composite(CompositeExpression {
2225 operator: BooleanOperator::And,
2226 children: vec![foo.clone(), baz.clone(), foo.clone()]
2227 })
2228 );
2229
2230 let or_nested = foo.clone() | (baz.clone() | foo.clone());
2231 assert_eq!(
2232 or_nested,
2233 Where::Composite(CompositeExpression {
2234 operator: BooleanOperator::Or,
2235 children: vec![foo.clone(), baz.clone(), foo.clone()]
2236 })
2237 );
2238 }
2239}