1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7 cmp::Ordering,
8 collections::{HashMap, HashSet},
9 mem::size_of_val,
10 ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::{PyAnyMethods, PyDictMethods};
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24 #[serde(rename = "#type")]
25 type_tag: Option<String>,
26 indices: Vec<u32>,
27 values: Vec<f32>,
28 tokens: Option<Vec<String>>,
29}
30
31#[derive(Clone, Debug, PartialEq)]
38#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
39pub struct SparseVector {
40 pub indices: Vec<u32>,
42 pub values: Vec<f32>,
44 pub tokens: Option<Vec<String>>,
46}
47
48impl<'de> Deserialize<'de> for SparseVector {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: Deserializer<'de>,
53 {
54 let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
55
56 if let Some(type_tag) = &helper.type_tag {
58 if type_tag != "sparse_vector" {
59 return Err(serde::de::Error::custom(format!(
60 "Expected #type='sparse_vector', got '{}'",
61 type_tag
62 )));
63 }
64 }
65
66 Ok(SparseVector {
67 indices: helper.indices,
68 values: helper.values,
69 tokens: helper.tokens,
70 })
71 }
72}
73
74impl Serialize for SparseVector {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 let helper = SparseVectorSerdeHelper {
81 type_tag: Some("sparse_vector".to_string()),
82 indices: self.indices.clone(),
83 values: self.values.clone(),
84 tokens: self.tokens.clone(),
85 };
86 helper.serialize(serializer)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct SparseVectorLengthMismatch;
93
94impl std::fmt::Display for SparseVectorLengthMismatch {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(
97 f,
98 "Sparse vector indices, values, and tokens (when present) must have the same length"
99 )
100 }
101}
102
103impl std::error::Error for SparseVectorLengthMismatch {}
104
105impl ChromaError for SparseVectorLengthMismatch {
106 fn code(&self) -> ErrorCodes {
107 ErrorCodes::InvalidArgument
108 }
109}
110
111impl SparseVector {
112 pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseVectorLengthMismatch> {
114 if indices.len() != values.len() {
115 return Err(SparseVectorLengthMismatch);
116 }
117 Ok(Self {
118 indices,
119 values,
120 tokens: None,
121 })
122 }
123
124 pub fn new_with_tokens(
126 indices: Vec<u32>,
127 values: Vec<f32>,
128 tokens: Vec<String>,
129 ) -> Result<Self, SparseVectorLengthMismatch> {
130 if indices.len() != values.len() {
131 return Err(SparseVectorLengthMismatch);
132 }
133 if tokens.len() != indices.len() {
134 return Err(SparseVectorLengthMismatch);
135 }
136 Ok(Self {
137 indices,
138 values,
139 tokens: Some(tokens),
140 })
141 }
142
143 pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
145 let mut indices = vec![];
146 let mut values = vec![];
147 for (index, value) in pairs {
148 indices.push(index);
149 values.push(value);
150 }
151 let tokens = None;
152 Self {
153 indices,
154 values,
155 tokens,
156 }
157 }
158
159 pub fn from_triples(triples: impl IntoIterator<Item = (String, u32, f32)>) -> Self {
161 let mut tokens = vec![];
162 let mut indices = vec![];
163 let mut values = vec![];
164 for (token, index, value) in triples {
165 tokens.push(token);
166 indices.push(index);
167 values.push(value);
168 }
169 let tokens = Some(tokens);
170 Self {
171 indices,
172 values,
173 tokens,
174 }
175 }
176
177 pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
179 self.indices
180 .iter()
181 .copied()
182 .zip(self.values.iter().copied())
183 }
184
185 pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
187 if self.indices.len() != self.values.len() {
189 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
190 }
191
192 if let Some(tokens) = self.tokens.as_ref() {
194 if tokens.len() != self.indices.len() {
195 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
196 }
197 }
198
199 for i in 1..self.indices.len() {
201 if self.indices[i] <= self.indices[i - 1] {
202 return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
203 }
204 }
205
206 Ok(())
207 }
208}
209
210impl Eq for SparseVector {}
211
212impl Ord for SparseVector {
213 fn cmp(&self, other: &Self) -> Ordering {
214 self.indices.cmp(&other.indices).then_with(|| {
215 for (a, b) in self.values.iter().zip(other.values.iter()) {
216 match a.total_cmp(b) {
217 Ordering::Equal => continue,
218 other => return other,
219 }
220 }
221 self.values.len().cmp(&other.values.len())
222 })
223 }
224}
225
226impl PartialOrd for SparseVector {
227 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
228 Some(self.cmp(other))
229 }
230}
231
232impl TryFrom<chroma_proto::SparseVector> for SparseVector {
233 type Error = SparseVectorLengthMismatch;
234
235 fn try_from(proto: chroma_proto::SparseVector) -> Result<Self, Self::Error> {
236 if proto.tokens.is_empty() {
237 SparseVector::new(proto.indices, proto.values)
238 } else {
239 SparseVector::new_with_tokens(proto.indices, proto.values, proto.tokens)
240 }
241 }
242}
243
244impl From<SparseVector> for chroma_proto::SparseVector {
245 fn from(sparse: SparseVector) -> Self {
246 chroma_proto::SparseVector {
247 indices: sparse.indices,
248 values: sparse.values,
249 tokens: sparse.tokens.unwrap_or_default(),
250 }
251 }
252}
253
254impl From<&SparseVector> for CsVec<f32> {
256 fn from(sparse: &SparseVector) -> Self {
257 let (indices, values) = sparse
258 .iter()
259 .map(|(index, value)| (index as usize, value))
260 .unzip();
261 CsVec::new(u32::MAX as usize, indices, values)
262 }
263}
264
265impl From<SparseVector> for CsVec<f32> {
266 fn from(sparse: SparseVector) -> Self {
267 (&sparse).into()
268 }
269}
270
271#[cfg(feature = "pyo3")]
272impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
273 type Target = pyo3::PyAny;
274 type Output = pyo3::Bound<'py, Self::Target>;
275 type Error = pyo3::PyErr;
276
277 fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
278 use pyo3::types::PyDict;
279
280 let dict = PyDict::new(py);
281 dict.set_item("indices", self.indices)?;
282 dict.set_item("values", self.values)?;
283 dict.set_item("tokens", self.tokens)?;
284 Ok(dict.into_any())
285 }
286}
287
288#[cfg(feature = "pyo3")]
289impl<'py> pyo3::FromPyObject<'py> for SparseVector {
290 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
291 use pyo3::types::PyDict;
292
293 let dict = ob.downcast::<PyDict>()?;
294 let indices_obj = dict.get_item("indices")?;
295 if indices_obj.is_none() {
296 return Err(pyo3::exceptions::PyKeyError::new_err(
297 "missing 'indices' key",
298 ));
299 }
300 let indices: Vec<u32> = indices_obj.unwrap().extract()?;
301
302 let values_obj = dict.get_item("values")?;
303 if values_obj.is_none() {
304 return Err(pyo3::exceptions::PyKeyError::new_err(
305 "missing 'values' key",
306 ));
307 }
308 let values: Vec<f32> = values_obj.unwrap().extract()?;
309
310 let tokens_obj = dict.get_item("tokens")?;
311 let tokens = match tokens_obj {
312 Some(obj) if obj.is_none() => None,
313 Some(obj) => Some(obj.extract::<Vec<String>>()?),
314 None => None,
315 };
316
317 let result = match tokens {
318 Some(tokens) => SparseVector::new_with_tokens(indices, values, tokens),
319 None => SparseVector::new(indices, values),
320 };
321
322 result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
323 }
324}
325
326#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
328#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
329#[serde(untagged)]
330pub enum UpdateMetadataValue {
331 Bool(bool),
332 Int(i64),
333 #[cfg_attr(
334 feature = "testing",
335 proptest(
336 strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
337 )
338 )]
339 Float(f64),
340 Str(String),
341 #[cfg_attr(feature = "testing", proptest(skip))]
342 SparseVector(SparseVector),
343 None,
344}
345
346#[cfg(feature = "pyo3")]
347impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
348 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
349 if ob.is_none() {
350 Ok(UpdateMetadataValue::None)
351 } else if let Ok(value) = ob.extract::<bool>() {
352 Ok(UpdateMetadataValue::Bool(value))
353 } else if let Ok(value) = ob.extract::<i64>() {
354 Ok(UpdateMetadataValue::Int(value))
355 } else if let Ok(value) = ob.extract::<f64>() {
356 Ok(UpdateMetadataValue::Float(value))
357 } else if let Ok(value) = ob.extract::<String>() {
358 Ok(UpdateMetadataValue::Str(value))
359 } else if let Ok(value) = ob.extract::<SparseVector>() {
360 Ok(UpdateMetadataValue::SparseVector(value))
361 } else {
362 Err(pyo3::exceptions::PyTypeError::new_err(
363 "Cannot convert Python object to UpdateMetadataValue",
364 ))
365 }
366 }
367}
368
369impl From<bool> for UpdateMetadataValue {
370 fn from(b: bool) -> Self {
371 Self::Bool(b)
372 }
373}
374
375impl From<i64> for UpdateMetadataValue {
376 fn from(v: i64) -> Self {
377 Self::Int(v)
378 }
379}
380
381impl From<i32> for UpdateMetadataValue {
382 fn from(v: i32) -> Self {
383 Self::Int(v as i64)
384 }
385}
386
387impl From<f64> for UpdateMetadataValue {
388 fn from(v: f64) -> Self {
389 Self::Float(v)
390 }
391}
392
393impl From<f32> for UpdateMetadataValue {
394 fn from(v: f32) -> Self {
395 Self::Float(v as f64)
396 }
397}
398
399impl From<String> for UpdateMetadataValue {
400 fn from(v: String) -> Self {
401 Self::Str(v)
402 }
403}
404
405impl From<&str> for UpdateMetadataValue {
406 fn from(v: &str) -> Self {
407 Self::Str(v.to_string())
408 }
409}
410
411impl From<SparseVector> for UpdateMetadataValue {
412 fn from(v: SparseVector) -> Self {
413 Self::SparseVector(v)
414 }
415}
416
417#[derive(Error, Debug)]
418pub enum UpdateMetadataValueConversionError {
419 #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
420 InvalidValue,
421}
422
423impl ChromaError for UpdateMetadataValueConversionError {
424 fn code(&self) -> ErrorCodes {
425 match self {
426 UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
427 }
428 }
429}
430
431impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
432 type Error = UpdateMetadataValueConversionError;
433
434 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
435 match &value.value {
436 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
437 Ok(UpdateMetadataValue::Bool(*value))
438 }
439 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
440 Ok(UpdateMetadataValue::Int(*value))
441 }
442 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
443 Ok(UpdateMetadataValue::Float(*value))
444 }
445 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
446 Ok(UpdateMetadataValue::Str(value.clone()))
447 }
448 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
449 let sparse = value
450 .clone()
451 .try_into()
452 .map_err(|_| UpdateMetadataValueConversionError::InvalidValue)?;
453 Ok(UpdateMetadataValue::SparseVector(sparse))
454 }
455 None => Ok(UpdateMetadataValue::None),
457 }
458 }
459}
460
461impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
462 fn from(value: UpdateMetadataValue) -> Self {
463 match value {
464 UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
465 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
466 },
467 UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
468 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
469 },
470 UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
471 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
472 value,
473 )),
474 },
475 UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
476 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
477 value,
478 )),
479 },
480 UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
481 value: Some(
482 chroma_proto::update_metadata_value::Value::SparseVectorValue(
483 sparse_vec.into(),
484 ),
485 ),
486 },
487 UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
488 }
489 }
490}
491
492impl TryFrom<&UpdateMetadataValue> for MetadataValue {
493 type Error = MetadataValueConversionError;
494
495 fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
496 match value {
497 UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
498 UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
499 UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
500 UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
501 UpdateMetadataValue::SparseVector(value) => {
502 Ok(MetadataValue::SparseVector(value.clone()))
503 }
504 UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
505 }
506 }
507}
508
509#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
516#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
517#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
518#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
519#[serde(untagged)]
520pub enum MetadataValue {
521 Bool(bool),
522 Int(i64),
523 #[cfg_attr(
524 feature = "testing",
525 proptest(
526 strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
527 )
528 )]
529 Float(f64),
530 Str(String),
531 #[cfg_attr(feature = "testing", proptest(skip))]
532 SparseVector(SparseVector),
533}
534
535impl std::fmt::Display for MetadataValue {
536 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 match self {
538 MetadataValue::Bool(v) => write!(f, "{}", v),
539 MetadataValue::Int(v) => write!(f, "{}", v),
540 MetadataValue::Float(v) => write!(f, "{}", v),
541 MetadataValue::Str(v) => write!(f, "\"{}\"", v),
542 MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
543 }
544 }
545}
546
547impl Eq for MetadataValue {}
548
549#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
550pub enum MetadataValueType {
551 Bool,
552 Int,
553 Float,
554 Str,
555 SparseVector,
556}
557
558impl MetadataValue {
559 pub fn value_type(&self) -> MetadataValueType {
560 match self {
561 MetadataValue::Bool(_) => MetadataValueType::Bool,
562 MetadataValue::Int(_) => MetadataValueType::Int,
563 MetadataValue::Float(_) => MetadataValueType::Float,
564 MetadataValue::Str(_) => MetadataValueType::Str,
565 MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
566 }
567 }
568}
569
570impl From<&MetadataValue> for MetadataValueType {
571 fn from(value: &MetadataValue) -> Self {
572 value.value_type()
573 }
574}
575
576impl From<bool> for MetadataValue {
577 fn from(v: bool) -> Self {
578 MetadataValue::Bool(v)
579 }
580}
581
582impl From<i64> for MetadataValue {
583 fn from(v: i64) -> Self {
584 MetadataValue::Int(v)
585 }
586}
587
588impl From<i32> for MetadataValue {
589 fn from(v: i32) -> Self {
590 MetadataValue::Int(v as i64)
591 }
592}
593
594impl From<f64> for MetadataValue {
595 fn from(v: f64) -> Self {
596 MetadataValue::Float(v)
597 }
598}
599
600impl From<f32> for MetadataValue {
601 fn from(v: f32) -> Self {
602 MetadataValue::Float(v as f64)
603 }
604}
605
606impl From<String> for MetadataValue {
607 fn from(v: String) -> Self {
608 MetadataValue::Str(v)
609 }
610}
611
612impl From<&str> for MetadataValue {
613 fn from(v: &str) -> Self {
614 MetadataValue::Str(v.to_string())
615 }
616}
617
618impl From<SparseVector> for MetadataValue {
619 fn from(v: SparseVector) -> Self {
620 MetadataValue::SparseVector(v)
621 }
622}
623
624#[allow(clippy::derive_ord_xor_partial_ord)]
629impl Ord for MetadataValue {
630 fn cmp(&self, other: &Self) -> Ordering {
631 fn type_order(val: &MetadataValue) -> u8 {
633 match val {
634 MetadataValue::Bool(_) => 0,
635 MetadataValue::Int(_) => 1,
636 MetadataValue::Float(_) => 2,
637 MetadataValue::Str(_) => 3,
638 MetadataValue::SparseVector(_) => 4,
639 }
640 }
641
642 type_order(self).cmp(&type_order(other)).then_with(|| {
644 match (self, other) {
645 (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
646 (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
647 (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
648 (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
649 (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
650 left.cmp(right)
651 }
652 _ => Ordering::Equal, }
654 })
655 }
656}
657
658impl PartialOrd for MetadataValue {
659 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
660 Some(self.cmp(other))
661 }
662}
663
664impl TryFrom<&MetadataValue> for bool {
665 type Error = MetadataValueConversionError;
666
667 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
668 match value {
669 MetadataValue::Bool(value) => Ok(*value),
670 _ => Err(MetadataValueConversionError::InvalidValue),
671 }
672 }
673}
674
675impl TryFrom<&MetadataValue> for i64 {
676 type Error = MetadataValueConversionError;
677
678 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
679 match value {
680 MetadataValue::Int(value) => Ok(*value),
681 _ => Err(MetadataValueConversionError::InvalidValue),
682 }
683 }
684}
685
686impl TryFrom<&MetadataValue> for f64 {
687 type Error = MetadataValueConversionError;
688
689 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
690 match value {
691 MetadataValue::Float(value) => Ok(*value),
692 _ => Err(MetadataValueConversionError::InvalidValue),
693 }
694 }
695}
696
697impl TryFrom<&MetadataValue> for String {
698 type Error = MetadataValueConversionError;
699
700 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
701 match value {
702 MetadataValue::Str(value) => Ok(value.clone()),
703 _ => Err(MetadataValueConversionError::InvalidValue),
704 }
705 }
706}
707
708impl From<MetadataValue> for UpdateMetadataValue {
709 fn from(value: MetadataValue) -> Self {
710 match value {
711 MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
712 MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
713 MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
714 MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
715 MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
716 }
717 }
718}
719
720impl From<MetadataValue> for Value {
721 fn from(value: MetadataValue) -> Self {
722 match value {
723 MetadataValue::Bool(val) => Self::Bool(val),
724 MetadataValue::Int(val) => Self::Number(
725 Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
726 ),
727 MetadataValue::Float(val) => Self::Number(
728 Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
729 ),
730 MetadataValue::Str(val) => Self::String(val),
731 MetadataValue::SparseVector(val) => {
732 let mut map = serde_json::Map::new();
733 map.insert(
734 "indices".to_string(),
735 Value::Array(
736 val.indices
737 .iter()
738 .map(|&i| Value::Number(i.into()))
739 .collect(),
740 ),
741 );
742 map.insert(
743 "values".to_string(),
744 Value::Array(
745 val.values
746 .iter()
747 .map(|&v| {
748 Value::Number(
749 Number::from_f64(v as f64)
750 .expect("Float number should not be NaN or infinite"),
751 )
752 })
753 .collect(),
754 ),
755 );
756 Self::Object(map)
757 }
758 }
759 }
760}
761
762#[derive(Error, Debug)]
763pub enum MetadataValueConversionError {
764 #[error("Invalid metadata value, valid values are: Int, Float, Str")]
765 InvalidValue,
766 #[error("Metadata key cannot start with '#' or '$': {0}")]
767 InvalidKey(String),
768 #[error("Sparse vector indices, values, and tokens (when present) must have the same length")]
769 SparseVectorLengthMismatch,
770 #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
771 SparseVectorIndicesNotSorted,
772}
773
774impl ChromaError for MetadataValueConversionError {
775 fn code(&self) -> ErrorCodes {
776 match self {
777 MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
778 MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
779 MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
780 MetadataValueConversionError::SparseVectorIndicesNotSorted => {
781 ErrorCodes::InvalidArgument
782 }
783 }
784 }
785}
786
787impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
788 type Error = MetadataValueConversionError;
789
790 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
791 match &value.value {
792 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
793 Ok(MetadataValue::Bool(*value))
794 }
795 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
796 Ok(MetadataValue::Int(*value))
797 }
798 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
799 Ok(MetadataValue::Float(*value))
800 }
801 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
802 Ok(MetadataValue::Str(value.clone()))
803 }
804 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
805 let sparse = value
806 .clone()
807 .try_into()
808 .map_err(|_| MetadataValueConversionError::SparseVectorLengthMismatch)?;
809 Ok(MetadataValue::SparseVector(sparse))
810 }
811 _ => Err(MetadataValueConversionError::InvalidValue),
812 }
813 }
814}
815
816impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
817 fn from(value: MetadataValue) -> Self {
818 match value {
819 MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
820 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
821 },
822 MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
823 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
824 value,
825 )),
826 },
827 MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
828 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
829 value,
830 )),
831 },
832 MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
833 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
834 },
835 MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
836 value: Some(
837 chroma_proto::update_metadata_value::Value::SparseVectorValue(
838 sparse_vec.into(),
839 ),
840 ),
841 },
842 }
843 }
844}
845
846pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
852
853pub fn are_update_metadatas_close_to_equal(
857 metadata1: &UpdateMetadata,
858 metadata2: &UpdateMetadata,
859) -> bool {
860 assert_eq!(metadata1.len(), metadata2.len());
861
862 for (key, value) in metadata1.iter() {
863 if !metadata2.contains_key(key) {
864 return false;
865 }
866 let other_value = metadata2.get(key).unwrap();
867
868 if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
869 (value, other_value)
870 {
871 if (value - other_value).abs() > 1e-6 {
872 return false;
873 }
874 } else if value != other_value {
875 return false;
876 }
877 }
878
879 true
880}
881
882pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
883 assert_eq!(metadata1.len(), metadata2.len());
884
885 for (key, value) in metadata1.iter() {
886 if !metadata2.contains_key(key) {
887 return false;
888 }
889 let other_value = metadata2.get(key).unwrap();
890
891 if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
892 (value, other_value)
893 {
894 if (value - other_value).abs() > 1e-6 {
895 return false;
896 }
897 } else if value != other_value {
898 return false;
899 }
900 }
901
902 true
903}
904
905impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
906 type Error = UpdateMetadataValueConversionError;
907
908 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
909 let mut metadata = UpdateMetadata::new();
910 for (key, value) in proto_metadata.metadata.iter() {
911 let value = match value.try_into() {
912 Ok(value) => value,
913 Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
914 };
915 metadata.insert(key.clone(), value);
916 }
917 Ok(metadata)
918 }
919}
920
921impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
922 fn from(metadata: UpdateMetadata) -> Self {
923 let mut metadata = metadata;
924 let mut proto_metadata = chroma_proto::UpdateMetadata {
925 metadata: HashMap::new(),
926 };
927 for (key, value) in metadata.drain() {
928 let proto_value = value.into();
929 proto_metadata.metadata.insert(key.clone(), proto_value);
930 }
931 proto_metadata
932 }
933}
934
935pub type Metadata = HashMap<String, MetadataValue>;
942pub type DeletedMetadata = HashSet<String>;
943
944pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
945 metadata
946 .iter()
947 .map(|(k, v)| {
948 k.len()
949 + match v {
950 MetadataValue::Bool(b) => size_of_val(b),
951 MetadataValue::Int(i) => size_of_val(i),
952 MetadataValue::Float(f) => size_of_val(f),
953 MetadataValue::Str(s) => s.len(),
954 MetadataValue::SparseVector(v) => {
955 size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
956 }
957 }
958 })
959 .sum()
960}
961
962pub fn get_metadata_value_as<'a, T>(
963 metadata: &'a Metadata,
964 key: &str,
965) -> Result<T, Box<MetadataValueConversionError>>
966where
967 T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
968{
969 let res = match metadata.get(key) {
970 Some(value) => T::try_from(value),
971 None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
972 };
973 match res {
974 Ok(value) => Ok(value),
975 Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
976 }
977}
978
979impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
980 type Error = MetadataValueConversionError;
981
982 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
983 let mut metadata = Metadata::new();
984 for (key, value) in proto_metadata.metadata.iter() {
985 let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
986 if maybe_value.is_err() {
987 return Err(MetadataValueConversionError::InvalidValue);
988 }
989 let value = maybe_value.unwrap();
990 metadata.insert(key.clone(), value);
991 }
992 Ok(metadata)
993 }
994}
995
996impl From<Metadata> for chroma_proto::UpdateMetadata {
997 fn from(metadata: Metadata) -> Self {
998 let mut metadata = metadata;
999 let mut proto_metadata = chroma_proto::UpdateMetadata {
1000 metadata: HashMap::new(),
1001 };
1002 for (key, value) in metadata.drain() {
1003 let proto_value = value.into();
1004 proto_metadata.metadata.insert(key.clone(), proto_value);
1005 }
1006 proto_metadata
1007 }
1008}
1009
1010#[derive(Debug, Default)]
1011pub struct MetadataDelta<'referred_data> {
1012 pub metadata_to_update: HashMap<
1013 &'referred_data str,
1014 (&'referred_data MetadataValue, &'referred_data MetadataValue),
1015 >,
1016 pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1017 pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1018}
1019
1020impl MetadataDelta<'_> {
1021 pub fn new() -> Self {
1022 Self::default()
1023 }
1024}
1025
1026#[derive(Clone, Debug, Error, PartialEq)]
1033pub enum WhereConversionError {
1034 #[error("Error: {0}")]
1035 Cause(String),
1036 #[error("{0} -> {1}")]
1037 Trace(String, Box<Self>),
1038}
1039
1040impl WhereConversionError {
1041 pub fn cause(msg: impl ToString) -> Self {
1042 Self::Cause(msg.to_string())
1043 }
1044
1045 pub fn trace(self, context: impl ToString) -> Self {
1046 Self::Trace(context.to_string(), Box::new(self))
1047 }
1048}
1049
1050#[derive(Clone, Debug, PartialEq)]
1058#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1059pub enum Where {
1060 Composite(CompositeExpression),
1061 Document(DocumentExpression),
1062 Metadata(MetadataExpression),
1063}
1064
1065impl std::fmt::Display for Where {
1066 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1067 match self {
1068 Where::Composite(composite) => {
1069 let fragment = composite
1070 .children
1071 .iter()
1072 .map(|child| format!("{}", child))
1073 .collect::<Vec<_>>()
1074 .join(match composite.operator {
1075 BooleanOperator::And => " & ",
1076 BooleanOperator::Or => " | ",
1077 });
1078 write!(f, "({})", fragment)
1079 }
1080 Where::Metadata(expr) => write!(f, "{}", expr),
1081 Where::Document(expr) => write!(f, "{}", expr),
1082 }
1083 }
1084}
1085
1086impl serde::Serialize for Where {
1087 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1088 where
1089 S: Serializer,
1090 {
1091 match self {
1092 Where::Composite(composite) => {
1093 let mut map = serializer.serialize_map(Some(1))?;
1094 let op_key = match composite.operator {
1095 BooleanOperator::And => "$and",
1096 BooleanOperator::Or => "$or",
1097 };
1098 map.serialize_entry(op_key, &composite.children)?;
1099 map.end()
1100 }
1101 Where::Document(doc) => {
1102 let mut outer_map = serializer.serialize_map(Some(1))?;
1103 let mut inner_map = serde_json::Map::new();
1104 let op_key = match doc.operator {
1105 DocumentOperator::Contains => "$contains",
1106 DocumentOperator::NotContains => "$not_contains",
1107 DocumentOperator::Regex => "$regex",
1108 DocumentOperator::NotRegex => "$not_regex",
1109 };
1110 inner_map.insert(
1111 op_key.to_string(),
1112 serde_json::Value::String(doc.pattern.clone()),
1113 );
1114 outer_map.serialize_entry("#document", &inner_map)?;
1115 outer_map.end()
1116 }
1117 Where::Metadata(meta) => {
1118 let mut outer_map = serializer.serialize_map(Some(1))?;
1119 let mut inner_map = serde_json::Map::new();
1120
1121 match &meta.comparison {
1122 MetadataComparison::Primitive(op, value) => {
1123 let op_key = match op {
1124 PrimitiveOperator::Equal => "$eq",
1125 PrimitiveOperator::NotEqual => "$ne",
1126 PrimitiveOperator::GreaterThan => "$gt",
1127 PrimitiveOperator::GreaterThanOrEqual => "$gte",
1128 PrimitiveOperator::LessThan => "$lt",
1129 PrimitiveOperator::LessThanOrEqual => "$lte",
1130 };
1131 let value_json =
1132 serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1133 inner_map.insert(op_key.to_string(), value_json);
1134 }
1135 MetadataComparison::Set(op, set_value) => {
1136 let op_key = match op {
1137 SetOperator::In => "$in",
1138 SetOperator::NotIn => "$nin",
1139 };
1140 let values_json = match set_value {
1141 MetadataSetValue::Bool(v) => serde_json::to_value(v),
1142 MetadataSetValue::Int(v) => serde_json::to_value(v),
1143 MetadataSetValue::Float(v) => serde_json::to_value(v),
1144 MetadataSetValue::Str(v) => serde_json::to_value(v),
1145 }
1146 .map_err(serde::ser::Error::custom)?;
1147 inner_map.insert(op_key.to_string(), values_json);
1148 }
1149 }
1150
1151 outer_map.serialize_entry(&meta.key, &inner_map)?;
1152 outer_map.end()
1153 }
1154 }
1155 }
1156}
1157
1158impl From<bool> for Where {
1159 fn from(value: bool) -> Self {
1160 if value {
1161 Where::conjunction(vec![])
1162 } else {
1163 Where::disjunction(vec![])
1164 }
1165 }
1166}
1167
1168impl Where {
1169 pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1170 let mut children: Vec<_> = children
1175 .into_iter()
1176 .flat_map(|expr| {
1177 if let Where::Composite(CompositeExpression {
1178 operator: BooleanOperator::And,
1179 children,
1180 }) = expr
1181 {
1182 return children;
1183 }
1184 vec![expr]
1185 })
1186 .dedup()
1187 .collect();
1188
1189 if children.len() == 1 {
1190 return children.pop().expect("just checked len is 1");
1191 }
1192
1193 Self::Composite(CompositeExpression {
1194 operator: BooleanOperator::And,
1195 children,
1196 })
1197 }
1198 pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1199 let mut children: Vec<_> = children
1204 .into_iter()
1205 .flat_map(|expr| {
1206 if let Where::Composite(CompositeExpression {
1207 operator: BooleanOperator::Or,
1208 children,
1209 }) = expr
1210 {
1211 return children;
1212 }
1213 vec![expr]
1214 })
1215 .dedup()
1216 .collect();
1217
1218 if children.len() == 1 {
1219 return children.pop().expect("just checked len is 1");
1220 }
1221
1222 Self::Composite(CompositeExpression {
1223 operator: BooleanOperator::Or,
1224 children,
1225 })
1226 }
1227
1228 pub fn fts_query_length(&self) -> u64 {
1229 match self {
1230 Where::Composite(composite_expression) => composite_expression
1231 .children
1232 .iter()
1233 .map(Where::fts_query_length)
1234 .sum(),
1235 Where::Document(document_expression) => {
1237 document_expression.pattern.len().max(3) as u64 - 2
1238 }
1239 Where::Metadata(_) => 0,
1240 }
1241 }
1242
1243 pub fn metadata_predicate_count(&self) -> u64 {
1244 match self {
1245 Where::Composite(composite_expression) => composite_expression
1246 .children
1247 .iter()
1248 .map(Where::metadata_predicate_count)
1249 .sum(),
1250 Where::Document(_) => 0,
1251 Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1252 MetadataComparison::Primitive(_, _) => 1,
1253 MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1254 MetadataSetValue::Bool(items) => items.len() as u64,
1255 MetadataSetValue::Int(items) => items.len() as u64,
1256 MetadataSetValue::Float(items) => items.len() as u64,
1257 MetadataSetValue::Str(items) => items.len() as u64,
1258 },
1259 },
1260 }
1261 }
1262}
1263
1264impl BitAnd for Where {
1265 type Output = Where;
1266
1267 fn bitand(self, rhs: Self) -> Self::Output {
1268 Self::conjunction([self, rhs])
1269 }
1270}
1271
1272impl BitOr for Where {
1273 type Output = Where;
1274
1275 fn bitor(self, rhs: Self) -> Self::Output {
1276 Self::disjunction([self, rhs])
1277 }
1278}
1279
1280impl TryFrom<chroma_proto::Where> for Where {
1281 type Error = WhereConversionError;
1282
1283 fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1284 let where_inner = proto_where
1285 .r#where
1286 .ok_or(WhereConversionError::cause("Invalid Where"))?;
1287 Ok(match where_inner {
1288 chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1289 Self::Metadata(direct_comparison.try_into()?)
1290 }
1291 chroma_proto::r#where::Where::Children(where_children) => {
1292 Self::Composite(where_children.try_into()?)
1293 }
1294 chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1295 Self::Document(direct_where_document.into())
1296 }
1297 })
1298 }
1299}
1300
1301impl TryFrom<Where> for chroma_proto::Where {
1302 type Error = WhereConversionError;
1303
1304 fn try_from(value: Where) -> Result<Self, Self::Error> {
1305 let proto_where = match value {
1306 Where::Composite(composite_expression) => {
1307 chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1308 }
1309 Where::Document(document_expression) => {
1310 chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1311 }
1312 Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1313 chroma_proto::DirectComparison::try_from(metadata_expression)
1314 .map_err(|err| err.trace("MetadataExpression"))?,
1315 ),
1316 };
1317 Ok(Self {
1318 r#where: Some(proto_where),
1319 })
1320 }
1321}
1322
1323#[derive(Clone, Debug, PartialEq)]
1324#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1325pub struct CompositeExpression {
1326 pub operator: BooleanOperator,
1327 pub children: Vec<Where>,
1328}
1329
1330impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1331 type Error = WhereConversionError;
1332
1333 fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1334 let operator = proto_children.operator().into();
1335 let children = proto_children
1336 .children
1337 .into_iter()
1338 .map(Where::try_from)
1339 .collect::<Result<Vec<_>, _>>()
1340 .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1341 Ok(Self { operator, children })
1342 }
1343}
1344
1345impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1346 type Error = WhereConversionError;
1347
1348 fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1349 Ok(Self {
1350 operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1351 children: value
1352 .children
1353 .into_iter()
1354 .map(chroma_proto::Where::try_from)
1355 .collect::<Result<_, _>>()?,
1356 })
1357 }
1358}
1359
1360#[derive(Clone, Debug, PartialEq)]
1361#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1362pub enum BooleanOperator {
1363 And,
1364 Or,
1365}
1366
1367impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1368 fn from(value: chroma_proto::BooleanOperator) -> Self {
1369 match value {
1370 chroma_proto::BooleanOperator::And => Self::And,
1371 chroma_proto::BooleanOperator::Or => Self::Or,
1372 }
1373 }
1374}
1375
1376impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1377 fn from(value: BooleanOperator) -> Self {
1378 match value {
1379 BooleanOperator::And => Self::And,
1380 BooleanOperator::Or => Self::Or,
1381 }
1382 }
1383}
1384
1385#[derive(Clone, Debug, PartialEq)]
1386#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1387pub struct DocumentExpression {
1388 pub operator: DocumentOperator,
1389 pub pattern: String,
1390}
1391
1392impl std::fmt::Display for DocumentExpression {
1393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1394 let op_str = match self.operator {
1395 DocumentOperator::Contains => "CONTAINS",
1396 DocumentOperator::NotContains => "NOT CONTAINS",
1397 DocumentOperator::Regex => "REGEX",
1398 DocumentOperator::NotRegex => "NOT REGEX",
1399 };
1400 write!(f, "#document {} \"{}\"", op_str, self.pattern)
1401 }
1402}
1403
1404impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1405 fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1406 Self {
1407 operator: value.operator().into(),
1408 pattern: value.pattern,
1409 }
1410 }
1411}
1412
1413impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1414 fn from(value: DocumentExpression) -> Self {
1415 Self {
1416 pattern: value.pattern,
1417 operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1418 }
1419 }
1420}
1421
1422#[derive(Clone, Debug, PartialEq)]
1423#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1424pub enum DocumentOperator {
1425 Contains,
1426 NotContains,
1427 Regex,
1428 NotRegex,
1429}
1430impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1431 fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1432 match value {
1433 chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1434 chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1435 chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1436 chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1437 }
1438 }
1439}
1440
1441impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1442 fn from(value: DocumentOperator) -> Self {
1443 match value {
1444 DocumentOperator::Contains => Self::Contains,
1445 DocumentOperator::NotContains => Self::NotContains,
1446 DocumentOperator::Regex => Self::Regex,
1447 DocumentOperator::NotRegex => Self::NotRegex,
1448 }
1449 }
1450}
1451
1452#[derive(Clone, Debug, PartialEq)]
1453#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1454pub struct MetadataExpression {
1455 pub key: String,
1456 pub comparison: MetadataComparison,
1457}
1458
1459impl std::fmt::Display for MetadataExpression {
1460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1461 match &self.comparison {
1462 MetadataComparison::Primitive(op, value) => {
1463 write!(f, "{} {} {}", self.key, op, value)
1464 }
1465 MetadataComparison::Set(op, set_value) => {
1466 write!(f, "{} {} {}", self.key, op, set_value)
1467 }
1468 }
1469 }
1470}
1471
1472impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1473 type Error = WhereConversionError;
1474
1475 fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1476 let proto_comparison = value
1477 .comparison
1478 .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1479 let comparison = match proto_comparison {
1480 chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1481 single_string_comparison,
1482 ) => MetadataComparison::Primitive(
1483 single_string_comparison.comparator().into(),
1484 MetadataValue::Str(single_string_comparison.value),
1485 ),
1486 chroma_proto::direct_comparison::Comparison::StringListOperand(
1487 string_list_comparison,
1488 ) => MetadataComparison::Set(
1489 string_list_comparison.list_operator().into(),
1490 MetadataSetValue::Str(string_list_comparison.values),
1491 ),
1492 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1493 single_int_comparison,
1494 ) => MetadataComparison::Primitive(
1495 match single_int_comparison
1496 .comparator
1497 .ok_or(WhereConversionError::cause(
1498 "Invalid scalar integer operator",
1499 ))? {
1500 chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1501 chroma_proto::GenericComparator::try_from(op)
1502 .map_err(WhereConversionError::cause)?
1503 .into()
1504 }
1505 chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1506 chroma_proto::NumberComparator::try_from(op)
1507 .map_err(WhereConversionError::cause)?
1508 .into()
1509 }
1510 },
1511 MetadataValue::Int(single_int_comparison.value),
1512 ),
1513 chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1514 MetadataComparison::Set(
1515 int_list_comparison.list_operator().into(),
1516 MetadataSetValue::Int(int_list_comparison.values),
1517 )
1518 }
1519 chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1520 single_double_comparison,
1521 ) => MetadataComparison::Primitive(
1522 match single_double_comparison
1523 .comparator
1524 .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1525 {
1526 chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1527 chroma_proto::GenericComparator::try_from(op)
1528 .map_err(WhereConversionError::cause)?
1529 .into()
1530 }
1531 chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1532 chroma_proto::NumberComparator::try_from(op)
1533 .map_err(WhereConversionError::cause)?
1534 .into()
1535 }
1536 },
1537 MetadataValue::Float(single_double_comparison.value),
1538 ),
1539 chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1540 double_list_comparison,
1541 ) => MetadataComparison::Set(
1542 double_list_comparison.list_operator().into(),
1543 MetadataSetValue::Float(double_list_comparison.values),
1544 ),
1545 chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1546 MetadataComparison::Set(
1547 bool_list_comparison.list_operator().into(),
1548 MetadataSetValue::Bool(bool_list_comparison.values),
1549 )
1550 }
1551 chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1552 single_bool_comparison,
1553 ) => MetadataComparison::Primitive(
1554 single_bool_comparison.comparator().into(),
1555 MetadataValue::Bool(single_bool_comparison.value),
1556 ),
1557 };
1558 Ok(Self {
1559 key: value.key,
1560 comparison,
1561 })
1562 }
1563}
1564
1565impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1566 type Error = WhereConversionError;
1567
1568 fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1569 let comparison = match value.comparison {
1570 MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1571 MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1572 MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1573 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1574 numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1575 }),
1576 MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1577 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1578 numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1579 }),
1580 MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1581 MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1582 },
1583 MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1584 MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1585 MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1586 MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1587 MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1588 },
1589 };
1590 Ok(Self {
1591 key: value.key,
1592 comparison: Some(comparison),
1593 })
1594 }
1595}
1596
1597#[derive(Clone, Debug, PartialEq)]
1598#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1599pub enum MetadataComparison {
1600 Primitive(PrimitiveOperator, MetadataValue),
1601 Set(SetOperator, MetadataSetValue),
1602}
1603
1604#[derive(Clone, Debug, PartialEq)]
1605#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1606pub enum PrimitiveOperator {
1607 Equal,
1608 NotEqual,
1609 GreaterThan,
1610 GreaterThanOrEqual,
1611 LessThan,
1612 LessThanOrEqual,
1613}
1614
1615impl std::fmt::Display for PrimitiveOperator {
1616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1617 let op_str = match self {
1618 PrimitiveOperator::Equal => "=",
1619 PrimitiveOperator::NotEqual => "≠",
1620 PrimitiveOperator::GreaterThan => ">",
1621 PrimitiveOperator::GreaterThanOrEqual => "≥",
1622 PrimitiveOperator::LessThan => "<",
1623 PrimitiveOperator::LessThanOrEqual => "≤",
1624 };
1625 write!(f, "{}", op_str)
1626 }
1627}
1628
1629impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1630 fn from(value: chroma_proto::GenericComparator) -> Self {
1631 match value {
1632 chroma_proto::GenericComparator::Eq => Self::Equal,
1633 chroma_proto::GenericComparator::Ne => Self::NotEqual,
1634 }
1635 }
1636}
1637
1638impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1639 type Error = WhereConversionError;
1640
1641 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1642 match value {
1643 PrimitiveOperator::Equal => Ok(Self::Eq),
1644 PrimitiveOperator::NotEqual => Ok(Self::Ne),
1645 op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1646 }
1647 }
1648}
1649
1650impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1651 fn from(value: chroma_proto::NumberComparator) -> Self {
1652 match value {
1653 chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1654 chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1655 chroma_proto::NumberComparator::Lt => Self::LessThan,
1656 chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1657 }
1658 }
1659}
1660
1661impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1662 type Error = WhereConversionError;
1663
1664 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1665 match value {
1666 PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1667 PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1668 PrimitiveOperator::LessThan => Ok(Self::Lt),
1669 PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1670 op => Err(WhereConversionError::cause(format!(
1671 "{op:?} ∉ [≤, <, >, ≥]"
1672 ))),
1673 }
1674 }
1675}
1676
1677#[derive(Clone, Debug, PartialEq, Eq)]
1678#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1679pub enum SetOperator {
1680 In,
1681 NotIn,
1682}
1683
1684impl std::fmt::Display for SetOperator {
1685 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1686 let op_str = match self {
1687 SetOperator::In => "∈",
1688 SetOperator::NotIn => "∉",
1689 };
1690 write!(f, "{}", op_str)
1691 }
1692}
1693
1694impl From<chroma_proto::ListOperator> for SetOperator {
1695 fn from(value: chroma_proto::ListOperator) -> Self {
1696 match value {
1697 chroma_proto::ListOperator::In => Self::In,
1698 chroma_proto::ListOperator::Nin => Self::NotIn,
1699 }
1700 }
1701}
1702
1703impl From<SetOperator> for chroma_proto::ListOperator {
1704 fn from(value: SetOperator) -> Self {
1705 match value {
1706 SetOperator::In => Self::In,
1707 SetOperator::NotIn => Self::Nin,
1708 }
1709 }
1710}
1711
1712#[derive(Clone, Debug, PartialEq)]
1713#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1714pub enum MetadataSetValue {
1715 Bool(Vec<bool>),
1716 Int(Vec<i64>),
1717 Float(Vec<f64>),
1718 Str(Vec<String>),
1719}
1720
1721impl std::fmt::Display for MetadataSetValue {
1722 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1723 match self {
1724 MetadataSetValue::Bool(values) => {
1725 let values_str = values
1726 .iter()
1727 .map(|v| format!("\"{}\"", v))
1728 .collect::<Vec<_>>()
1729 .join(", ");
1730 write!(f, "[{}]", values_str)
1731 }
1732 MetadataSetValue::Int(values) => {
1733 let values_str = values
1734 .iter()
1735 .map(|v| v.to_string())
1736 .collect::<Vec<_>>()
1737 .join(", ");
1738 write!(f, "[{}]", values_str)
1739 }
1740 MetadataSetValue::Float(values) => {
1741 let values_str = values
1742 .iter()
1743 .map(|v| v.to_string())
1744 .collect::<Vec<_>>()
1745 .join(", ");
1746 write!(f, "[{}]", values_str)
1747 }
1748 MetadataSetValue::Str(values) => {
1749 let values_str = values
1750 .iter()
1751 .map(|v| format!("\"{}\"", v))
1752 .collect::<Vec<_>>()
1753 .join(", ");
1754 write!(f, "[{}]", values_str)
1755 }
1756 }
1757 }
1758}
1759
1760impl MetadataSetValue {
1761 pub fn value_type(&self) -> MetadataValueType {
1762 match self {
1763 MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1764 MetadataSetValue::Int(_) => MetadataValueType::Int,
1765 MetadataSetValue::Float(_) => MetadataValueType::Float,
1766 MetadataSetValue::Str(_) => MetadataValueType::Str,
1767 }
1768 }
1769}
1770
1771impl From<Vec<bool>> for MetadataSetValue {
1772 fn from(values: Vec<bool>) -> Self {
1773 MetadataSetValue::Bool(values)
1774 }
1775}
1776
1777impl From<Vec<i64>> for MetadataSetValue {
1778 fn from(values: Vec<i64>) -> Self {
1779 MetadataSetValue::Int(values)
1780 }
1781}
1782
1783impl From<Vec<i32>> for MetadataSetValue {
1784 fn from(values: Vec<i32>) -> Self {
1785 MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1786 }
1787}
1788
1789impl From<Vec<f64>> for MetadataSetValue {
1790 fn from(values: Vec<f64>) -> Self {
1791 MetadataSetValue::Float(values)
1792 }
1793}
1794
1795impl From<Vec<f32>> for MetadataSetValue {
1796 fn from(values: Vec<f32>) -> Self {
1797 MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1798 }
1799}
1800
1801impl From<Vec<String>> for MetadataSetValue {
1802 fn from(values: Vec<String>) -> Self {
1803 MetadataSetValue::Str(values)
1804 }
1805}
1806
1807impl From<Vec<&str>> for MetadataSetValue {
1808 fn from(values: Vec<&str>) -> Self {
1809 MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1810 }
1811}
1812
1813impl TryFrom<chroma_proto::WhereDocument> for Where {
1815 type Error = WhereConversionError;
1816
1817 fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1818 match proto_document.r#where_document {
1819 Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1820 let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1821 proto_comparison.operator,
1822 ) {
1823 Ok(operator) => operator,
1824 Err(_) => {
1825 return Err(WhereConversionError::cause(
1826 "[Deprecated] Invalid where document operator",
1827 ))
1828 }
1829 };
1830 let comparison = DocumentExpression {
1831 pattern: proto_comparison.pattern,
1832 operator: operator.into(),
1833 };
1834 Ok(Where::Document(comparison))
1835 }
1836 Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1837 let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1838 proto_children.operator,
1839 ) {
1840 Ok(operator) => operator,
1841 Err(_) => {
1842 return Err(WhereConversionError::cause(
1843 "[Deprecated] Invalid boolean operator",
1844 ))
1845 }
1846 };
1847 let children = CompositeExpression {
1848 children: proto_children
1849 .children
1850 .into_iter()
1851 .map(|child| child.try_into())
1852 .collect::<Result<_, _>>()?,
1853 operator: operator.into(),
1854 };
1855 Ok(Where::Composite(children))
1856 }
1857 None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1858 }
1859 }
1860}
1861
1862#[cfg(test)]
1863mod tests {
1864 use crate::operator::Key;
1865
1866 use super::*;
1867
1868 #[cfg(feature = "pyo3")]
1870 fn ensure_python_interpreter() {
1871 static PYTHON_INIT: std::sync::Once = std::sync::Once::new();
1872 PYTHON_INIT.call_once(|| {
1873 pyo3::prepare_freethreaded_python();
1874 });
1875 }
1876
1877 #[test]
1878 fn test_update_metadata_try_from() {
1879 let mut proto_metadata = chroma_proto::UpdateMetadata {
1880 metadata: HashMap::new(),
1881 };
1882 proto_metadata.metadata.insert(
1883 "foo".to_string(),
1884 chroma_proto::UpdateMetadataValue {
1885 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1886 },
1887 );
1888 proto_metadata.metadata.insert(
1889 "bar".to_string(),
1890 chroma_proto::UpdateMetadataValue {
1891 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1892 },
1893 );
1894 proto_metadata.metadata.insert(
1895 "baz".to_string(),
1896 chroma_proto::UpdateMetadataValue {
1897 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1898 "42".to_string(),
1899 )),
1900 },
1901 );
1902 proto_metadata.metadata.insert(
1904 "sparse".to_string(),
1905 chroma_proto::UpdateMetadataValue {
1906 value: Some(
1907 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1908 chroma_proto::SparseVector {
1909 indices: vec![0, 5, 10],
1910 values: vec![0.1, 0.5, 0.9],
1911 tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
1912 },
1913 ),
1914 ),
1915 },
1916 );
1917 let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1918 assert_eq!(converted_metadata.len(), 4);
1919 assert_eq!(
1920 converted_metadata.get("foo").unwrap(),
1921 &UpdateMetadataValue::Int(42)
1922 );
1923 assert_eq!(
1924 converted_metadata.get("bar").unwrap(),
1925 &UpdateMetadataValue::Float(42.0)
1926 );
1927 assert_eq!(
1928 converted_metadata.get("baz").unwrap(),
1929 &UpdateMetadataValue::Str("42".to_string())
1930 );
1931 assert_eq!(
1932 converted_metadata.get("sparse").unwrap(),
1933 &UpdateMetadataValue::SparseVector(
1934 SparseVector::new_with_tokens(
1935 vec![0, 5, 10],
1936 vec![0.1, 0.5, 0.9],
1937 vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
1938 )
1939 .unwrap()
1940 )
1941 );
1942 }
1943
1944 #[test]
1945 fn test_metadata_try_from() {
1946 let mut proto_metadata = chroma_proto::UpdateMetadata {
1947 metadata: HashMap::new(),
1948 };
1949 proto_metadata.metadata.insert(
1950 "foo".to_string(),
1951 chroma_proto::UpdateMetadataValue {
1952 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1953 },
1954 );
1955 proto_metadata.metadata.insert(
1956 "bar".to_string(),
1957 chroma_proto::UpdateMetadataValue {
1958 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1959 },
1960 );
1961 proto_metadata.metadata.insert(
1962 "baz".to_string(),
1963 chroma_proto::UpdateMetadataValue {
1964 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1965 "42".to_string(),
1966 )),
1967 },
1968 );
1969 proto_metadata.metadata.insert(
1971 "sparse".to_string(),
1972 chroma_proto::UpdateMetadataValue {
1973 value: Some(
1974 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1975 chroma_proto::SparseVector {
1976 indices: vec![1, 10, 100],
1977 values: vec![0.2, 0.4, 0.6],
1978 tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
1979 },
1980 ),
1981 ),
1982 },
1983 );
1984 let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1985 assert_eq!(converted_metadata.len(), 4);
1986 assert_eq!(
1987 converted_metadata.get("foo").unwrap(),
1988 &MetadataValue::Int(42)
1989 );
1990 assert_eq!(
1991 converted_metadata.get("bar").unwrap(),
1992 &MetadataValue::Float(42.0)
1993 );
1994 assert_eq!(
1995 converted_metadata.get("baz").unwrap(),
1996 &MetadataValue::Str("42".to_string())
1997 );
1998 assert_eq!(
1999 converted_metadata.get("sparse").unwrap(),
2000 &MetadataValue::SparseVector(
2001 SparseVector::new_with_tokens(
2002 vec![1, 10, 100],
2003 vec![0.2, 0.4, 0.6],
2004 vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2005 )
2006 .unwrap()
2007 )
2008 );
2009 }
2010
2011 #[test]
2012 fn test_where_clause_simple_from() {
2013 let proto_where = chroma_proto::Where {
2014 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2015 chroma_proto::DirectComparison {
2016 key: "foo".to_string(),
2017 comparison: Some(
2018 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2019 chroma_proto::SingleIntComparison {
2020 value: 42,
2021 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2022 },
2023 ),
2024 ),
2025 },
2026 )),
2027 };
2028 let where_clause: Where = proto_where.try_into().unwrap();
2029 match where_clause {
2030 Where::Metadata(comparison) => {
2031 assert_eq!(comparison.key, "foo");
2032 match comparison.comparison {
2033 MetadataComparison::Primitive(_, value) => {
2034 assert_eq!(value, MetadataValue::Int(42));
2035 }
2036 _ => panic!("Invalid comparison type"),
2037 }
2038 }
2039 _ => panic!("Invalid where type"),
2040 }
2041 }
2042
2043 #[test]
2044 fn test_where_clause_with_children() {
2045 let proto_where = chroma_proto::Where {
2046 r#where: Some(chroma_proto::r#where::Where::Children(
2047 chroma_proto::WhereChildren {
2048 children: vec![
2049 chroma_proto::Where {
2050 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2051 chroma_proto::DirectComparison {
2052 key: "foo".to_string(),
2053 comparison: Some(
2054 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2055 chroma_proto::SingleIntComparison {
2056 value: 42,
2057 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2058 },
2059 ),
2060 ),
2061 },
2062 )),
2063 },
2064 chroma_proto::Where {
2065 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2066 chroma_proto::DirectComparison {
2067 key: "bar".to_string(),
2068 comparison: Some(
2069 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2070 chroma_proto::SingleIntComparison {
2071 value: 42,
2072 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2073 },
2074 ),
2075 ),
2076 },
2077 )),
2078 },
2079 ],
2080 operator: chroma_proto::BooleanOperator::And.into(),
2081 },
2082 )),
2083 };
2084 let where_clause: Where = proto_where.try_into().unwrap();
2085 match where_clause {
2086 Where::Composite(children) => {
2087 assert_eq!(children.children.len(), 2);
2088 assert_eq!(children.operator, BooleanOperator::And);
2089 }
2090 _ => panic!("Invalid where type"),
2091 }
2092 }
2093
2094 #[test]
2095 fn test_where_document_simple() {
2096 let proto_where = chroma_proto::WhereDocument {
2097 r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
2098 chroma_proto::DirectWhereDocument {
2099 pattern: "foo".to_string(),
2100 operator: chroma_proto::WhereDocumentOperator::Contains.into(),
2101 },
2102 )),
2103 };
2104 let where_document: Where = proto_where.try_into().unwrap();
2105 match where_document {
2106 Where::Document(comparison) => {
2107 assert_eq!(comparison.pattern, "foo");
2108 assert_eq!(comparison.operator, DocumentOperator::Contains);
2109 }
2110 _ => panic!("Invalid where document type"),
2111 }
2112 }
2113
2114 #[test]
2115 fn test_where_document_with_children() {
2116 let proto_where = chroma_proto::WhereDocument {
2117 r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
2118 chroma_proto::WhereDocumentChildren {
2119 children: vec![
2120 chroma_proto::WhereDocument {
2121 r#where_document: Some(
2122 chroma_proto::where_document::WhereDocument::Direct(
2123 chroma_proto::DirectWhereDocument {
2124 pattern: "foo".to_string(),
2125 operator: chroma_proto::WhereDocumentOperator::Contains
2126 .into(),
2127 },
2128 ),
2129 ),
2130 },
2131 chroma_proto::WhereDocument {
2132 r#where_document: Some(
2133 chroma_proto::where_document::WhereDocument::Direct(
2134 chroma_proto::DirectWhereDocument {
2135 pattern: "bar".to_string(),
2136 operator: chroma_proto::WhereDocumentOperator::Contains
2137 .into(),
2138 },
2139 ),
2140 ),
2141 },
2142 ],
2143 operator: chroma_proto::BooleanOperator::And.into(),
2144 },
2145 )),
2146 };
2147 let where_document: Where = proto_where.try_into().unwrap();
2148 match where_document {
2149 Where::Composite(children) => {
2150 assert_eq!(children.children.len(), 2);
2151 assert_eq!(children.operator, BooleanOperator::And);
2152 }
2153 _ => panic!("Invalid where document type"),
2154 }
2155 }
2156
2157 #[test]
2158 fn test_sparse_vector_new() {
2159 let indices = vec![0, 5, 10];
2160 let values = vec![0.1, 0.5, 0.9];
2161 let sparse = SparseVector::new(indices.clone(), values.clone()).unwrap();
2162 assert_eq!(sparse.indices, indices);
2163 assert_eq!(sparse.values, values);
2164 }
2165
2166 #[test]
2167 fn test_sparse_vector_from_pairs() {
2168 let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2169 let sparse = SparseVector::from_pairs(pairs.clone());
2170 assert_eq!(sparse.indices, vec![0, 5, 10]);
2171 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2172 }
2173
2174 #[test]
2175 fn test_sparse_vector_from_triples() {
2176 let triples = vec![
2177 ("foo".to_string(), 0, 0.1),
2178 ("bar".to_string(), 5, 0.5),
2179 ("baz".to_string(), 10, 0.9),
2180 ];
2181 let sparse = SparseVector::from_triples(triples.clone());
2182 assert_eq!(sparse.indices, vec![0, 5, 10]);
2183 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2184 }
2185
2186 #[test]
2187 fn test_sparse_vector_iter() {
2188 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2189 let collected: Vec<(u32, f32)> = sparse.iter().collect();
2190 assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2191 }
2192
2193 #[test]
2194 fn test_sparse_vector_ordering() {
2195 let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2196 let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2197 let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]).unwrap();
2198 let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]).unwrap();
2199
2200 assert_eq!(sparse1, sparse2);
2201 assert!(sparse1 < sparse3);
2202 assert!(sparse1 < sparse4);
2203 }
2204
2205 #[test]
2206 fn test_sparse_vector_proto_conversion() {
2207 let tokens = vec![
2208 "token1".to_string(),
2209 "token2".to_string(),
2210 "token3".to_string(),
2211 ];
2212 let sparse =
2213 SparseVector::new_with_tokens(vec![1, 10, 100], vec![0.2, 0.4, 0.6], tokens.clone())
2214 .unwrap();
2215 let proto: chroma_proto::SparseVector = sparse.clone().into();
2216 assert_eq!(proto.indices, vec![1, 10, 100]);
2217 assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2218 assert_eq!(proto.tokens, tokens.clone());
2219
2220 let converted: SparseVector = proto.try_into().unwrap();
2221 assert_eq!(converted, sparse);
2222 assert_eq!(converted.tokens, Some(tokens));
2223 }
2224
2225 #[test]
2226 fn test_sparse_vector_proto_conversion_empty_tokens() {
2227 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2228 let proto: chroma_proto::SparseVector = sparse.clone().into();
2229 assert_eq!(proto.indices, vec![0, 5, 10]);
2230 assert_eq!(proto.values, vec![0.1, 0.5, 0.9]);
2231 assert_eq!(proto.tokens, Vec::<String>::new());
2232
2233 let converted: SparseVector = proto.try_into().unwrap();
2234 assert_eq!(converted, sparse);
2235 assert_eq!(converted.tokens, None);
2236 }
2237
2238 #[test]
2239 fn test_sparse_vector_logical_size() {
2240 let metadata = Metadata::from([(
2241 "sparse".to_string(),
2242 MetadataValue::SparseVector(
2243 SparseVector::new(vec![0, 1, 2, 3, 4], vec![0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(),
2244 ),
2245 )]);
2246
2247 let size = logical_size_of_metadata(&metadata);
2248 assert_eq!(size, 46);
2251 }
2252
2253 #[test]
2254 fn test_sparse_vector_validation() {
2255 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
2257 assert!(sparse.validate().is_ok());
2258
2259 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2261 assert!(sparse.is_err());
2262 let result = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3])
2263 .unwrap()
2264 .validate();
2265 assert!(result.is_ok());
2266
2267 let sparse = SparseVector::new_with_tokens(
2269 vec![1, 2, 3],
2270 vec![0.1, 0.2, 0.3],
2271 vec!["a".to_string(), "b".to_string()],
2272 );
2273 assert!(sparse.is_err());
2274
2275 let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap();
2277 let result = sparse.validate();
2278 assert!(result.is_err());
2279 assert!(matches!(
2280 result.unwrap_err(),
2281 MetadataValueConversionError::SparseVectorIndicesNotSorted
2282 ));
2283
2284 let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]).unwrap();
2286 let result = sparse.validate();
2287 assert!(result.is_err());
2288 assert!(matches!(
2289 result.unwrap_err(),
2290 MetadataValueConversionError::SparseVectorIndicesNotSorted
2291 ));
2292
2293 let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]).unwrap();
2295 let result = sparse.validate();
2296 assert!(result.is_err());
2297 assert!(matches!(
2298 result.unwrap_err(),
2299 MetadataValueConversionError::SparseVectorIndicesNotSorted
2300 ));
2301 }
2302
2303 #[test]
2304 fn test_sparse_vector_deserialize_old_format() {
2305 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2307 let sv: SparseVector = serde_json::from_str(json).unwrap();
2308 assert_eq!(sv.indices, vec![0, 1, 2]);
2309 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2310 }
2311
2312 #[test]
2313 fn test_sparse_vector_deserialize_new_format() {
2314 let json =
2316 "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2317 let sv: SparseVector = serde_json::from_str(json).unwrap();
2318 assert_eq!(sv.indices, vec![0, 1, 2]);
2319 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2320 }
2321
2322 #[test]
2323 fn test_sparse_vector_deserialize_new_format_field_order() {
2324 let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2326 let sv: SparseVector = serde_json::from_str(json).unwrap();
2327 assert_eq!(sv.indices, vec![5, 10]);
2328 assert_eq!(sv.values, vec![0.5, 1.0]);
2329 }
2330
2331 #[test]
2332 fn test_sparse_vector_deserialize_wrong_type_tag() {
2333 let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2335 let result: Result<SparseVector, _> = serde_json::from_str(json);
2336 assert!(result.is_err());
2337 let err_msg = result.unwrap_err().to_string();
2338 assert!(err_msg.contains("sparse_vector"));
2339 }
2340
2341 #[test]
2342 fn test_sparse_vector_serialize_always_has_type() {
2343 let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]).unwrap();
2345 let json = serde_json::to_value(&sv).unwrap();
2346
2347 assert_eq!(json["#type"], "sparse_vector");
2348 assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2349 assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2350 }
2351
2352 #[test]
2353 fn test_sparse_vector_roundtrip_with_type() {
2354 let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]).unwrap();
2356 let json = serde_json::to_string(&original).unwrap();
2357
2358 assert!(json.contains("\"#type\":\"sparse_vector\""));
2360
2361 let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2362 assert_eq!(original, deserialized);
2363 }
2364
2365 #[test]
2366 fn test_sparse_vector_in_metadata_old_format() {
2367 let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2369 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2370
2371 let sparse_value = &map["sparse"];
2372 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2373 assert_eq!(sv.indices, vec![0, 1]);
2374 assert_eq!(sv.values, vec![1.0, 2.0]);
2375 }
2376
2377 #[test]
2378 fn test_sparse_vector_in_metadata_new_format() {
2379 let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2381 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2382
2383 let sparse_value = &map["sparse"];
2384 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2385 assert_eq!(sv.indices, vec![0, 1]);
2386 assert_eq!(sv.values, vec![1.0, 2.0]);
2387 }
2388
2389 #[test]
2390 fn test_sparse_vector_tokens_roundtrip_old_to_new() {
2391 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2393 let sv: SparseVector = serde_json::from_str(json).unwrap();
2394 assert_eq!(sv.indices, vec![0, 1, 2]);
2395 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2396 assert_eq!(sv.tokens, None);
2397
2398 let serialized = serde_json::to_value(&sv).unwrap();
2400 assert_eq!(serialized["#type"], "sparse_vector");
2401 assert_eq!(serialized["indices"], serde_json::json!([0, 1, 2]));
2402 assert_eq!(serialized["values"], serde_json::json!([1.0, 2.0, 3.0]));
2403 assert_eq!(serialized["tokens"], serde_json::Value::Null);
2404 }
2405
2406 #[test]
2407 fn test_sparse_vector_tokens_roundtrip_new_to_new() {
2408 let sv_with_tokens = SparseVector::new_with_tokens(
2410 vec![0, 1, 2],
2411 vec![1.0, 2.0, 3.0],
2412 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2413 )
2414 .unwrap();
2415
2416 let serialized = serde_json::to_string(&sv_with_tokens).unwrap();
2418 assert!(serialized.contains("\"#type\":\"sparse_vector\""));
2419 assert!(serialized.contains("\"tokens\""));
2420
2421 let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2423 assert_eq!(deserialized.indices, vec![0, 1, 2]);
2424 assert_eq!(deserialized.values, vec![1.0, 2.0, 3.0]);
2425 assert_eq!(
2426 deserialized.tokens,
2427 Some(vec![
2428 "foo".to_string(),
2429 "bar".to_string(),
2430 "baz".to_string()
2431 ])
2432 );
2433 }
2434
2435 #[test]
2436 fn test_sparse_vector_tokens_deserialize_with_tokens_field() {
2437 let json = r##"{"#type": "sparse_vector", "indices": [5, 10], "values": [0.5, 1.0], "tokens": ["token1", "token2"]}"##;
2439 let sv: SparseVector = serde_json::from_str(json).unwrap();
2440 assert_eq!(sv.indices, vec![5, 10]);
2441 assert_eq!(sv.values, vec![0.5, 1.0]);
2442 assert_eq!(
2443 sv.tokens,
2444 Some(vec!["token1".to_string(), "token2".to_string()])
2445 );
2446 }
2447
2448 #[test]
2449 fn test_sparse_vector_tokens_backward_compatibility() {
2450 let old_json = r#"{"indices": [1, 2], "values": [0.1, 0.2]}"#;
2452 let old_sv: SparseVector = serde_json::from_str(old_json).unwrap();
2453
2454 let new_json = r##"{"#type": "sparse_vector", "indices": [1, 2], "values": [0.1, 0.2], "tokens": ["a", "b"]}"##;
2456 let new_sv: SparseVector = serde_json::from_str(new_json).unwrap();
2457
2458 assert_eq!(old_sv.indices, new_sv.indices);
2460 assert_eq!(old_sv.values, new_sv.values);
2461
2462 assert_eq!(old_sv.tokens, None);
2464 assert_eq!(new_sv.tokens, Some(vec!["a".to_string(), "b".to_string()]));
2465 }
2466
2467 #[test]
2468 fn test_sparse_vector_from_triples_preserves_tokens() {
2469 let triples = vec![
2470 ("apple".to_string(), 10, 0.5),
2471 ("banana".to_string(), 20, 0.7),
2472 ("cherry".to_string(), 30, 0.9),
2473 ];
2474 let sv = SparseVector::from_triples(triples.clone());
2475
2476 assert_eq!(sv.indices, vec![10, 20, 30]);
2477 assert_eq!(sv.values, vec![0.5, 0.7, 0.9]);
2478 assert_eq!(
2479 sv.tokens,
2480 Some(vec![
2481 "apple".to_string(),
2482 "banana".to_string(),
2483 "cherry".to_string()
2484 ])
2485 );
2486
2487 let serialized = serde_json::to_string(&sv).unwrap();
2489 let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2490
2491 assert_eq!(deserialized.indices, sv.indices);
2492 assert_eq!(deserialized.values, sv.values);
2493 assert_eq!(deserialized.tokens, sv.tokens);
2494 }
2495
2496 #[cfg(feature = "pyo3")]
2497 #[test]
2498 fn test_sparse_vector_pyo3_roundtrip_with_tokens() {
2499 ensure_python_interpreter();
2500
2501 pyo3::Python::with_gil(|py| {
2502 use pyo3::types::PyDict;
2503 use pyo3::IntoPyObject;
2504
2505 let dict_in = PyDict::new(py);
2506 dict_in.set_item("indices", vec![0u32, 1, 2]).unwrap();
2507 dict_in
2508 .set_item("values", vec![0.1f32, 0.2f32, 0.3f32])
2509 .unwrap();
2510 dict_in
2511 .set_item("tokens", vec!["foo", "bar", "baz"])
2512 .unwrap();
2513
2514 let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2515 assert_eq!(sparse.indices, vec![0, 1, 2]);
2516 assert_eq!(sparse.values, vec![0.1, 0.2, 0.3]);
2517 assert_eq!(
2518 sparse.tokens,
2519 Some(vec![
2520 "foo".to_string(),
2521 "bar".to_string(),
2522 "baz".to_string()
2523 ])
2524 );
2525
2526 let py_obj = sparse.clone().into_pyobject(py).unwrap();
2527 let dict_out = py_obj.downcast::<PyDict>().unwrap();
2528 let tokens_obj = dict_out.get_item("tokens").unwrap();
2529 let tokens: Vec<String> = tokens_obj
2530 .expect("expected tokens key in Python dict")
2531 .extract()
2532 .unwrap();
2533 assert_eq!(
2534 tokens,
2535 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
2536 );
2537 });
2538 }
2539
2540 #[cfg(feature = "pyo3")]
2541 #[test]
2542 fn test_sparse_vector_pyo3_roundtrip_without_tokens() {
2543 ensure_python_interpreter();
2544
2545 pyo3::Python::with_gil(|py| {
2546 use pyo3::types::PyDict;
2547 use pyo3::IntoPyObject;
2548
2549 let dict_in = PyDict::new(py);
2550 dict_in.set_item("indices", vec![5u32]).unwrap();
2551 dict_in.set_item("values", vec![1.5f32]).unwrap();
2552
2553 let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2554 assert_eq!(sparse.indices, vec![5]);
2555 assert_eq!(sparse.values, vec![1.5]);
2556 assert!(sparse.tokens.is_none());
2557
2558 let py_obj = sparse.into_pyobject(py).unwrap();
2559 let dict_out = py_obj.downcast::<PyDict>().unwrap();
2560 let tokens_obj = dict_out.get_item("tokens").unwrap();
2561 let tokens_value = tokens_obj.expect("expected tokens key in Python dict");
2562 assert!(
2563 tokens_value.is_none(),
2564 "expected tokens value in Python dict to be None"
2565 );
2566 });
2567 }
2568
2569 #[test]
2570 fn test_simplifies_identities() {
2571 let all: Where = true.into();
2572 assert_eq!(all.clone() & all.clone(), true.into());
2573 assert_eq!(all.clone() | all.clone(), true.into());
2574
2575 let foo = Key::field("foo").eq("bar");
2576 assert_eq!(foo.clone() & all.clone(), foo.clone());
2577 assert_eq!(all.clone() & foo.clone(), foo.clone());
2578
2579 let none: Where = false.into();
2580 assert_eq!(foo.clone() | none.clone(), foo.clone());
2581 assert_eq!(none | foo.clone(), foo);
2582 }
2583
2584 #[test]
2585 fn test_flattens() {
2586 let foo = Key::field("foo").eq("bar");
2587 let baz = Key::field("baz").eq("quux");
2588
2589 let and_nested = foo.clone() & (baz.clone() & foo.clone());
2590 assert_eq!(
2591 and_nested,
2592 Where::Composite(CompositeExpression {
2593 operator: BooleanOperator::And,
2594 children: vec![foo.clone(), baz.clone(), foo.clone()]
2595 })
2596 );
2597
2598 let or_nested = foo.clone() | (baz.clone() | foo.clone());
2599 assert_eq!(
2600 or_nested,
2601 Where::Composite(CompositeExpression {
2602 operator: BooleanOperator::Or,
2603 children: vec![foo.clone(), baz.clone(), foo.clone()]
2604 })
2605 );
2606 }
2607}