1use chroma_error::{ChromaError, ErrorCodes};
2use itertools::Itertools;
3use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
4use serde_json::{Number, Value};
5use sprs::CsVec;
6use std::{
7 cmp::Ordering,
8 collections::{HashMap, HashSet},
9 mem::size_of_val,
10 ops::{BitAnd, BitOr},
11};
12use thiserror::Error;
13
14use crate::chroma_proto;
15
16#[cfg(feature = "pyo3")]
17use pyo3::types::{PyAnyMethods, PyDictMethods};
18
19#[cfg(feature = "testing")]
20use proptest::prelude::*;
21
22#[derive(Serialize, Deserialize)]
23struct SparseVectorSerdeHelper {
24 #[serde(rename = "#type")]
25 type_tag: Option<String>,
26 indices: Vec<u32>,
27 values: Vec<f32>,
28 tokens: Option<Vec<String>>,
29}
30
31#[derive(Clone, Debug, PartialEq)]
38#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
39pub struct SparseVector {
40 pub indices: Vec<u32>,
42 pub values: Vec<f32>,
44 pub tokens: Option<Vec<String>>,
46}
47
48impl<'de> Deserialize<'de> for SparseVector {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: Deserializer<'de>,
53 {
54 let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
55
56 if let Some(type_tag) = &helper.type_tag {
58 if type_tag != "sparse_vector" {
59 return Err(serde::de::Error::custom(format!(
60 "Expected #type='sparse_vector', got '{}'",
61 type_tag
62 )));
63 }
64 }
65
66 Ok(SparseVector {
67 indices: helper.indices,
68 values: helper.values,
69 tokens: helper.tokens,
70 })
71 }
72}
73
74impl Serialize for SparseVector {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 let helper = SparseVectorSerdeHelper {
81 type_tag: Some("sparse_vector".to_string()),
82 indices: self.indices.clone(),
83 values: self.values.clone(),
84 tokens: self.tokens.clone(),
85 };
86 helper.serialize(serializer)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct SparseVectorLengthMismatch;
93
94impl std::fmt::Display for SparseVectorLengthMismatch {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(
97 f,
98 "Sparse vector indices, values, and tokens (when present) must have the same length"
99 )
100 }
101}
102
103impl std::error::Error for SparseVectorLengthMismatch {}
104
105impl ChromaError for SparseVectorLengthMismatch {
106 fn code(&self) -> ErrorCodes {
107 ErrorCodes::InvalidArgument
108 }
109}
110
111impl SparseVector {
112 pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseVectorLengthMismatch> {
114 if indices.len() != values.len() {
115 return Err(SparseVectorLengthMismatch);
116 }
117 Ok(Self {
118 indices,
119 values,
120 tokens: None,
121 })
122 }
123
124 pub fn new_with_tokens(
126 indices: Vec<u32>,
127 values: Vec<f32>,
128 tokens: Vec<String>,
129 ) -> Result<Self, SparseVectorLengthMismatch> {
130 if indices.len() != values.len() {
131 return Err(SparseVectorLengthMismatch);
132 }
133 if tokens.len() != indices.len() {
134 return Err(SparseVectorLengthMismatch);
135 }
136 Ok(Self {
137 indices,
138 values,
139 tokens: Some(tokens),
140 })
141 }
142
143 pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
145 let mut indices = vec![];
146 let mut values = vec![];
147 for (index, value) in pairs {
148 indices.push(index);
149 values.push(value);
150 }
151 let tokens = None;
152 Self {
153 indices,
154 values,
155 tokens,
156 }
157 }
158
159 pub fn from_triples(triples: impl IntoIterator<Item = (String, u32, f32)>) -> Self {
161 let mut tokens = vec![];
162 let mut indices = vec![];
163 let mut values = vec![];
164 for (token, index, value) in triples {
165 tokens.push(token);
166 indices.push(index);
167 values.push(value);
168 }
169 let tokens = Some(tokens);
170 Self {
171 indices,
172 values,
173 tokens,
174 }
175 }
176
177 pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
179 self.indices
180 .iter()
181 .copied()
182 .zip(self.values.iter().copied())
183 }
184
185 pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
187 if self.indices.len() != self.values.len() {
189 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
190 }
191
192 if let Some(tokens) = self.tokens.as_ref() {
194 if tokens.len() != self.indices.len() {
195 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
196 }
197 }
198
199 for i in 1..self.indices.len() {
201 if self.indices[i] <= self.indices[i - 1] {
202 return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
203 }
204 }
205
206 Ok(())
207 }
208}
209
210impl Eq for SparseVector {}
211
212impl Ord for SparseVector {
213 fn cmp(&self, other: &Self) -> Ordering {
214 self.indices.cmp(&other.indices).then_with(|| {
215 for (a, b) in self.values.iter().zip(other.values.iter()) {
216 match a.total_cmp(b) {
217 Ordering::Equal => continue,
218 other => return other,
219 }
220 }
221 self.values.len().cmp(&other.values.len())
222 })
223 }
224}
225
226impl PartialOrd for SparseVector {
227 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
228 Some(self.cmp(other))
229 }
230}
231
232impl TryFrom<chroma_proto::SparseVector> for SparseVector {
233 type Error = SparseVectorLengthMismatch;
234
235 fn try_from(proto: chroma_proto::SparseVector) -> Result<Self, Self::Error> {
236 if proto.tokens.is_empty() {
237 SparseVector::new(proto.indices, proto.values)
238 } else {
239 SparseVector::new_with_tokens(proto.indices, proto.values, proto.tokens)
240 }
241 }
242}
243
244impl From<SparseVector> for chroma_proto::SparseVector {
245 fn from(sparse: SparseVector) -> Self {
246 chroma_proto::SparseVector {
247 indices: sparse.indices,
248 values: sparse.values,
249 tokens: sparse.tokens.unwrap_or_default(),
250 }
251 }
252}
253
254impl From<&SparseVector> for CsVec<f32> {
256 fn from(sparse: &SparseVector) -> Self {
257 let (indices, values) = sparse
258 .iter()
259 .map(|(index, value)| (index as usize, value))
260 .unzip();
261 CsVec::new(u32::MAX as usize, indices, values)
262 }
263}
264
265impl From<SparseVector> for CsVec<f32> {
266 fn from(sparse: SparseVector) -> Self {
267 (&sparse).into()
268 }
269}
270
271#[cfg(feature = "pyo3")]
272impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
273 type Target = pyo3::PyAny;
274 type Output = pyo3::Bound<'py, Self::Target>;
275 type Error = pyo3::PyErr;
276
277 fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
278 use pyo3::types::PyDict;
279
280 let dict = PyDict::new(py);
281 dict.set_item("indices", self.indices)?;
282 dict.set_item("values", self.values)?;
283 dict.set_item("tokens", self.tokens)?;
284 Ok(dict.into_any())
285 }
286}
287
288#[cfg(feature = "pyo3")]
289impl<'py> pyo3::FromPyObject<'py> for SparseVector {
290 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
291 use pyo3::types::PyDict;
292
293 let dict = ob.downcast::<PyDict>()?;
294 let indices_obj = dict.get_item("indices")?;
295 if indices_obj.is_none() {
296 return Err(pyo3::exceptions::PyKeyError::new_err(
297 "missing 'indices' key",
298 ));
299 }
300 let indices: Vec<u32> = indices_obj.unwrap().extract()?;
301
302 let values_obj = dict.get_item("values")?;
303 if values_obj.is_none() {
304 return Err(pyo3::exceptions::PyKeyError::new_err(
305 "missing 'values' key",
306 ));
307 }
308 let values: Vec<f32> = values_obj.unwrap().extract()?;
309
310 let tokens_obj = dict.get_item("tokens")?;
311 let tokens = match tokens_obj {
312 Some(obj) if obj.is_none() => None,
313 Some(obj) => Some(obj.extract::<Vec<String>>()?),
314 None => None,
315 };
316
317 let result = match tokens {
318 Some(tokens) => SparseVector::new_with_tokens(indices, values, tokens),
319 None => SparseVector::new(indices, values),
320 };
321
322 result.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
323 }
324}
325
326#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
327#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
328#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
329#[serde(untagged)]
330pub enum UpdateMetadataValue {
331 Bool(bool),
332 Int(i64),
333 #[cfg_attr(
334 feature = "testing",
335 proptest(
336 strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
337 )
338 )]
339 Float(f64),
340 Str(String),
341 #[cfg_attr(feature = "testing", proptest(skip))]
342 SparseVector(SparseVector),
343 #[cfg_attr(feature = "testing", proptest(skip))]
346 BoolArray(Vec<bool>),
347 #[cfg_attr(feature = "testing", proptest(skip))]
348 IntArray(Vec<i64>),
349 #[cfg_attr(feature = "testing", proptest(skip))]
350 FloatArray(Vec<f64>),
351 #[cfg_attr(feature = "testing", proptest(skip))]
352 StringArray(Vec<String>),
353 None,
354}
355
356#[cfg(feature = "pyo3")]
357impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
358 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
359 use pyo3::types::PyList;
360
361 if ob.is_none() {
362 Ok(UpdateMetadataValue::None)
363 } else if let Ok(value) = ob.extract::<bool>() {
364 Ok(UpdateMetadataValue::Bool(value))
365 } else if let Ok(value) = ob.extract::<i64>() {
366 Ok(UpdateMetadataValue::Int(value))
367 } else if let Ok(value) = ob.extract::<f64>() {
368 Ok(UpdateMetadataValue::Float(value))
369 } else if let Ok(value) = ob.extract::<String>() {
370 Ok(UpdateMetadataValue::Str(value))
371 } else if let Ok(value) = ob.extract::<SparseVector>() {
372 Ok(UpdateMetadataValue::SparseVector(value))
373 } else if let Ok(list) = ob.downcast::<PyList>() {
374 if list.is_empty()? {
376 return Err(pyo3::exceptions::PyValueError::new_err(
377 "Empty lists are not allowed as metadata values",
378 ));
379 }
380 if let Ok(arr) = list.extract::<Vec<bool>>() {
384 Ok(UpdateMetadataValue::BoolArray(arr))
385 } else if let Ok(arr) = list.extract::<Vec<i64>>() {
386 Ok(UpdateMetadataValue::IntArray(arr))
387 } else if let Ok(arr) = list.extract::<Vec<f64>>() {
388 Ok(UpdateMetadataValue::FloatArray(arr))
389 } else if let Ok(arr) = list.extract::<Vec<String>>() {
390 Ok(UpdateMetadataValue::StringArray(arr))
391 } else {
392 Err(pyo3::exceptions::PyTypeError::new_err(
393 "Cannot convert Python list to UpdateMetadataValue: mixed or unsupported element types",
394 ))
395 }
396 } else {
397 Err(pyo3::exceptions::PyTypeError::new_err(
398 "Cannot convert Python object to UpdateMetadataValue",
399 ))
400 }
401 }
402}
403
404impl From<bool> for UpdateMetadataValue {
405 fn from(b: bool) -> Self {
406 Self::Bool(b)
407 }
408}
409
410impl From<i64> for UpdateMetadataValue {
411 fn from(v: i64) -> Self {
412 Self::Int(v)
413 }
414}
415
416impl From<i32> for UpdateMetadataValue {
417 fn from(v: i32) -> Self {
418 Self::Int(v as i64)
419 }
420}
421
422impl From<f64> for UpdateMetadataValue {
423 fn from(v: f64) -> Self {
424 Self::Float(v)
425 }
426}
427
428impl From<f32> for UpdateMetadataValue {
429 fn from(v: f32) -> Self {
430 Self::Float(v as f64)
431 }
432}
433
434impl From<String> for UpdateMetadataValue {
435 fn from(v: String) -> Self {
436 Self::Str(v)
437 }
438}
439
440impl From<&str> for UpdateMetadataValue {
441 fn from(v: &str) -> Self {
442 Self::Str(v.to_string())
443 }
444}
445
446impl From<SparseVector> for UpdateMetadataValue {
447 fn from(v: SparseVector) -> Self {
448 Self::SparseVector(v)
449 }
450}
451
452impl From<Vec<bool>> for UpdateMetadataValue {
453 fn from(v: Vec<bool>) -> Self {
454 Self::BoolArray(v)
455 }
456}
457
458impl From<Vec<i64>> for UpdateMetadataValue {
459 fn from(v: Vec<i64>) -> Self {
460 Self::IntArray(v)
461 }
462}
463
464impl From<Vec<f64>> for UpdateMetadataValue {
465 fn from(v: Vec<f64>) -> Self {
466 Self::FloatArray(v)
467 }
468}
469
470impl From<Vec<String>> for UpdateMetadataValue {
471 fn from(v: Vec<String>) -> Self {
472 Self::StringArray(v)
473 }
474}
475
476#[derive(Error, Debug)]
477pub enum UpdateMetadataValueConversionError {
478 #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
479 InvalidValue,
480}
481
482impl ChromaError for UpdateMetadataValueConversionError {
483 fn code(&self) -> ErrorCodes {
484 match self {
485 UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
486 }
487 }
488}
489
490impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
491 type Error = UpdateMetadataValueConversionError;
492
493 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
494 match &value.value {
495 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
496 Ok(UpdateMetadataValue::Bool(*value))
497 }
498 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
499 Ok(UpdateMetadataValue::Int(*value))
500 }
501 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
502 Ok(UpdateMetadataValue::Float(*value))
503 }
504 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
505 Ok(UpdateMetadataValue::Str(value.clone()))
506 }
507 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
508 let sparse = value
509 .clone()
510 .try_into()
511 .map_err(|_| UpdateMetadataValueConversionError::InvalidValue)?;
512 Ok(UpdateMetadataValue::SparseVector(sparse))
513 }
514 Some(chroma_proto::update_metadata_value::Value::BoolListValue(value)) => {
515 Ok(UpdateMetadataValue::BoolArray(value.values.clone()))
516 }
517 Some(chroma_proto::update_metadata_value::Value::IntListValue(value)) => {
518 Ok(UpdateMetadataValue::IntArray(value.values.clone()))
519 }
520 Some(chroma_proto::update_metadata_value::Value::DoubleListValue(value)) => {
521 Ok(UpdateMetadataValue::FloatArray(value.values.clone()))
522 }
523 Some(chroma_proto::update_metadata_value::Value::StringListValue(value)) => {
524 Ok(UpdateMetadataValue::StringArray(value.values.clone()))
525 }
526 None => Ok(UpdateMetadataValue::None),
528 }
529 }
530}
531
532impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
533 fn from(value: UpdateMetadataValue) -> Self {
534 match value {
535 UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
536 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
537 },
538 UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
539 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
540 },
541 UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
542 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
543 value,
544 )),
545 },
546 UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
547 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
548 value,
549 )),
550 },
551 UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
552 value: Some(
553 chroma_proto::update_metadata_value::Value::SparseVectorValue(
554 sparse_vec.into(),
555 ),
556 ),
557 },
558 UpdateMetadataValue::BoolArray(values) => chroma_proto::UpdateMetadataValue {
559 value: Some(chroma_proto::update_metadata_value::Value::BoolListValue(
560 chroma_proto::BoolListValue { values },
561 )),
562 },
563 UpdateMetadataValue::IntArray(values) => chroma_proto::UpdateMetadataValue {
564 value: Some(chroma_proto::update_metadata_value::Value::IntListValue(
565 chroma_proto::IntListValue { values },
566 )),
567 },
568 UpdateMetadataValue::FloatArray(values) => chroma_proto::UpdateMetadataValue {
569 value: Some(chroma_proto::update_metadata_value::Value::DoubleListValue(
570 chroma_proto::DoubleListValue { values },
571 )),
572 },
573 UpdateMetadataValue::StringArray(values) => chroma_proto::UpdateMetadataValue {
574 value: Some(chroma_proto::update_metadata_value::Value::StringListValue(
575 chroma_proto::StringListValue { values },
576 )),
577 },
578 UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
579 }
580 }
581}
582
583impl TryFrom<&UpdateMetadataValue> for MetadataValue {
584 type Error = MetadataValueConversionError;
585
586 fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
587 match value {
588 UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
589 UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
590 UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
591 UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
592 UpdateMetadataValue::SparseVector(value) => {
593 Ok(MetadataValue::SparseVector(value.clone()))
594 }
595 UpdateMetadataValue::BoolArray(value) => Ok(MetadataValue::BoolArray(value.clone())),
596 UpdateMetadataValue::IntArray(value) => Ok(MetadataValue::IntArray(value.clone())),
597 UpdateMetadataValue::FloatArray(value) => Ok(MetadataValue::FloatArray(value.clone())),
598 UpdateMetadataValue::StringArray(value) => {
599 Ok(MetadataValue::StringArray(value.clone()))
600 }
601 UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
602 }
603 }
604}
605
606#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
613#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
614#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
615#[cfg_attr(feature = "pyo3", derive(pyo3::IntoPyObject))]
616#[serde(untagged)]
617pub enum MetadataValue {
618 Bool(bool),
619 Int(i64),
620 #[cfg_attr(
621 feature = "testing",
622 proptest(
623 strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
624 )
625 )]
626 Float(f64),
627 Str(String),
628 #[cfg_attr(feature = "testing", proptest(skip))]
629 SparseVector(SparseVector),
630 #[cfg_attr(feature = "testing", proptest(skip))]
633 BoolArray(Vec<bool>),
634 #[cfg_attr(feature = "testing", proptest(skip))]
635 IntArray(Vec<i64>),
636 #[cfg_attr(feature = "testing", proptest(skip))]
637 FloatArray(Vec<f64>),
638 #[cfg_attr(feature = "testing", proptest(skip))]
639 StringArray(Vec<String>),
640}
641
642#[cfg(feature = "pyo3")]
643impl<'py> pyo3::FromPyObject<'py> for MetadataValue {
644 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
645 use pyo3::types::PyList;
646
647 if let Ok(value) = ob.extract::<bool>() {
648 Ok(MetadataValue::Bool(value))
649 } else if let Ok(value) = ob.extract::<i64>() {
650 Ok(MetadataValue::Int(value))
651 } else if let Ok(value) = ob.extract::<f64>() {
652 Ok(MetadataValue::Float(value))
653 } else if let Ok(value) = ob.extract::<String>() {
654 Ok(MetadataValue::Str(value))
655 } else if let Ok(value) = ob.extract::<SparseVector>() {
656 Ok(MetadataValue::SparseVector(value))
657 } else if let Ok(list) = ob.downcast::<PyList>() {
658 if list.is_empty()? {
660 return Err(pyo3::exceptions::PyValueError::new_err(
661 "Empty lists are not allowed as metadata values",
662 ));
663 }
664 if let Ok(arr) = list.extract::<Vec<bool>>() {
668 Ok(MetadataValue::BoolArray(arr))
669 } else if let Ok(arr) = list.extract::<Vec<i64>>() {
670 Ok(MetadataValue::IntArray(arr))
671 } else if let Ok(arr) = list.extract::<Vec<f64>>() {
672 Ok(MetadataValue::FloatArray(arr))
673 } else if let Ok(arr) = list.extract::<Vec<String>>() {
674 Ok(MetadataValue::StringArray(arr))
675 } else {
676 Err(pyo3::exceptions::PyTypeError::new_err(
677 "Cannot convert Python list to MetadataValue: mixed or unsupported element types",
678 ))
679 }
680 } else {
681 Err(pyo3::exceptions::PyTypeError::new_err(
682 "Cannot convert Python object to MetadataValue",
683 ))
684 }
685 }
686}
687
688impl std::fmt::Display for MetadataValue {
689 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
690 match self {
691 MetadataValue::Bool(v) => write!(f, "{}", v),
692 MetadataValue::Int(v) => write!(f, "{}", v),
693 MetadataValue::Float(v) => write!(f, "{}", v),
694 MetadataValue::Str(v) => write!(f, "\"{}\"", v),
695 MetadataValue::SparseVector(v) => write!(f, "SparseVector(len={})", v.values.len()),
696 MetadataValue::BoolArray(v) => write!(f, "BoolArray(len={})", v.len()),
697 MetadataValue::IntArray(v) => write!(f, "IntArray(len={})", v.len()),
698 MetadataValue::FloatArray(v) => write!(f, "FloatArray(len={})", v.len()),
699 MetadataValue::StringArray(v) => write!(f, "StringArray(len={})", v.len()),
700 }
701 }
702}
703
704impl Eq for MetadataValue {}
705
706#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
707pub enum MetadataValueType {
708 Bool,
709 Int,
710 Float,
711 Str,
712 SparseVector,
713 BoolArray,
714 IntArray,
715 FloatArray,
716 StringArray,
717}
718
719impl MetadataValue {
720 pub fn value_type(&self) -> MetadataValueType {
721 match self {
722 MetadataValue::Bool(_) => MetadataValueType::Bool,
723 MetadataValue::Int(_) => MetadataValueType::Int,
724 MetadataValue::Float(_) => MetadataValueType::Float,
725 MetadataValue::Str(_) => MetadataValueType::Str,
726 MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
727 MetadataValue::BoolArray(_) => MetadataValueType::BoolArray,
728 MetadataValue::IntArray(_) => MetadataValueType::IntArray,
729 MetadataValue::FloatArray(_) => MetadataValueType::FloatArray,
730 MetadataValue::StringArray(_) => MetadataValueType::StringArray,
731 }
732 }
733}
734
735impl From<&MetadataValue> for MetadataValueType {
736 fn from(value: &MetadataValue) -> Self {
737 value.value_type()
738 }
739}
740
741impl From<bool> for MetadataValue {
742 fn from(v: bool) -> Self {
743 MetadataValue::Bool(v)
744 }
745}
746
747impl From<i64> for MetadataValue {
748 fn from(v: i64) -> Self {
749 MetadataValue::Int(v)
750 }
751}
752
753impl From<i32> for MetadataValue {
754 fn from(v: i32) -> Self {
755 MetadataValue::Int(v as i64)
756 }
757}
758
759impl From<f64> for MetadataValue {
760 fn from(v: f64) -> Self {
761 MetadataValue::Float(v)
762 }
763}
764
765impl From<f32> for MetadataValue {
766 fn from(v: f32) -> Self {
767 MetadataValue::Float(v as f64)
768 }
769}
770
771impl From<String> for MetadataValue {
772 fn from(v: String) -> Self {
773 MetadataValue::Str(v)
774 }
775}
776
777impl From<&str> for MetadataValue {
778 fn from(v: &str) -> Self {
779 MetadataValue::Str(v.to_string())
780 }
781}
782
783impl From<SparseVector> for MetadataValue {
784 fn from(v: SparseVector) -> Self {
785 MetadataValue::SparseVector(v)
786 }
787}
788
789impl From<Vec<bool>> for MetadataValue {
790 fn from(v: Vec<bool>) -> Self {
791 MetadataValue::BoolArray(v)
792 }
793}
794
795impl From<Vec<i64>> for MetadataValue {
796 fn from(v: Vec<i64>) -> Self {
797 MetadataValue::IntArray(v)
798 }
799}
800
801impl From<Vec<i32>> for MetadataValue {
802 fn from(v: Vec<i32>) -> Self {
803 MetadataValue::IntArray(v.into_iter().map(|x| x as i64).collect())
804 }
805}
806
807impl From<Vec<f64>> for MetadataValue {
808 fn from(v: Vec<f64>) -> Self {
809 MetadataValue::FloatArray(v)
810 }
811}
812
813impl From<Vec<f32>> for MetadataValue {
814 fn from(v: Vec<f32>) -> Self {
815 MetadataValue::FloatArray(v.into_iter().map(|x| x as f64).collect())
816 }
817}
818
819impl From<Vec<String>> for MetadataValue {
820 fn from(v: Vec<String>) -> Self {
821 MetadataValue::StringArray(v)
822 }
823}
824
825impl From<Vec<&str>> for MetadataValue {
826 fn from(v: Vec<&str>) -> Self {
827 MetadataValue::StringArray(v.into_iter().map(|s| s.to_string()).collect())
828 }
829}
830
831#[allow(clippy::derive_ord_xor_partial_ord)]
836impl Ord for MetadataValue {
837 fn cmp(&self, other: &Self) -> Ordering {
838 fn type_order(val: &MetadataValue) -> u8 {
840 match val {
841 MetadataValue::Bool(_) => 0,
842 MetadataValue::Int(_) => 1,
843 MetadataValue::Float(_) => 2,
844 MetadataValue::Str(_) => 3,
845 MetadataValue::SparseVector(_) => 4,
846 MetadataValue::BoolArray(_) => 5,
847 MetadataValue::IntArray(_) => 6,
848 MetadataValue::FloatArray(_) => 7,
849 MetadataValue::StringArray(_) => 8,
850 }
851 }
852
853 type_order(self).cmp(&type_order(other)).then_with(|| {
855 match (self, other) {
856 (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
857 (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
858 (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
859 (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
860 (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
861 left.cmp(right)
862 }
863 (MetadataValue::BoolArray(left), MetadataValue::BoolArray(right)) => {
864 left.cmp(right)
865 }
866 (MetadataValue::IntArray(left), MetadataValue::IntArray(right)) => left.cmp(right),
867 (MetadataValue::FloatArray(left), MetadataValue::FloatArray(right)) => {
868 for (l, r) in left.iter().zip(right.iter()) {
870 match l.total_cmp(r) {
871 Ordering::Equal => continue,
872 other => return other,
873 }
874 }
875 left.len().cmp(&right.len())
876 }
877 (MetadataValue::StringArray(left), MetadataValue::StringArray(right)) => {
878 left.cmp(right)
879 }
880 _ => Ordering::Equal, }
882 })
883 }
884}
885
886impl PartialOrd for MetadataValue {
887 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
888 Some(self.cmp(other))
889 }
890}
891
892impl TryFrom<&MetadataValue> for bool {
893 type Error = MetadataValueConversionError;
894
895 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
896 match value {
897 MetadataValue::Bool(value) => Ok(*value),
898 _ => Err(MetadataValueConversionError::InvalidValue),
899 }
900 }
901}
902
903impl TryFrom<&MetadataValue> for i64 {
904 type Error = MetadataValueConversionError;
905
906 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
907 match value {
908 MetadataValue::Int(value) => Ok(*value),
909 _ => Err(MetadataValueConversionError::InvalidValue),
910 }
911 }
912}
913
914impl TryFrom<&MetadataValue> for f64 {
915 type Error = MetadataValueConversionError;
916
917 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
918 match value {
919 MetadataValue::Float(value) => Ok(*value),
920 _ => Err(MetadataValueConversionError::InvalidValue),
921 }
922 }
923}
924
925impl TryFrom<&MetadataValue> for String {
926 type Error = MetadataValueConversionError;
927
928 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
929 match value {
930 MetadataValue::Str(value) => Ok(value.clone()),
931 _ => Err(MetadataValueConversionError::InvalidValue),
932 }
933 }
934}
935
936impl From<MetadataValue> for UpdateMetadataValue {
937 fn from(value: MetadataValue) -> Self {
938 match value {
939 MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
940 MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
941 MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
942 MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
943 MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
944 MetadataValue::BoolArray(v) => UpdateMetadataValue::BoolArray(v),
945 MetadataValue::IntArray(v) => UpdateMetadataValue::IntArray(v),
946 MetadataValue::FloatArray(v) => UpdateMetadataValue::FloatArray(v),
947 MetadataValue::StringArray(v) => UpdateMetadataValue::StringArray(v),
948 }
949 }
950}
951
952impl From<MetadataValue> for Value {
953 fn from(value: MetadataValue) -> Self {
954 match value {
955 MetadataValue::Bool(val) => Self::Bool(val),
956 MetadataValue::Int(val) => Self::Number(
957 Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
958 ),
959 MetadataValue::Float(val) => Self::Number(
960 Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
961 ),
962 MetadataValue::Str(val) => Self::String(val),
963 MetadataValue::SparseVector(val) => {
964 let mut map = serde_json::Map::new();
965 map.insert(
966 "indices".to_string(),
967 Value::Array(
968 val.indices
969 .iter()
970 .map(|&i| Value::Number(i.into()))
971 .collect(),
972 ),
973 );
974 map.insert(
975 "values".to_string(),
976 Value::Array(
977 val.values
978 .iter()
979 .map(|&v| {
980 Value::Number(
981 Number::from_f64(v as f64)
982 .expect("Float number should not be NaN or infinite"),
983 )
984 })
985 .collect(),
986 ),
987 );
988 Self::Object(map)
989 }
990 MetadataValue::BoolArray(vals) => {
991 Self::Array(vals.into_iter().map(Value::Bool).collect())
992 }
993 MetadataValue::IntArray(vals) => Self::Array(
994 vals.into_iter()
995 .map(|v| {
996 Value::Number(
997 Number::from_i128(v as i128)
998 .expect("i64 should be representable in JSON"),
999 )
1000 })
1001 .collect(),
1002 ),
1003 MetadataValue::FloatArray(vals) => Self::Array(
1004 vals.into_iter()
1005 .map(|v| {
1006 Value::Number(
1007 Number::from_f64(v)
1008 .expect("Inf and NaN should not be present in MetadataValue"),
1009 )
1010 })
1011 .collect(),
1012 ),
1013 MetadataValue::StringArray(vals) => {
1014 Self::Array(vals.into_iter().map(Value::String).collect())
1015 }
1016 }
1017 }
1018}
1019
1020#[derive(Error, Debug)]
1021pub enum MetadataValueConversionError {
1022 #[error("Invalid metadata value, valid values are: Int, Float, Str")]
1023 InvalidValue,
1024 #[error("Metadata key cannot start with '#' or '$': {0}")]
1025 InvalidKey(String),
1026 #[error("Sparse vector indices, values, and tokens (when present) must have the same length")]
1027 SparseVectorLengthMismatch,
1028 #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
1029 SparseVectorIndicesNotSorted,
1030}
1031
1032impl ChromaError for MetadataValueConversionError {
1033 fn code(&self) -> ErrorCodes {
1034 match self {
1035 MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
1036 MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
1037 MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
1038 MetadataValueConversionError::SparseVectorIndicesNotSorted => {
1039 ErrorCodes::InvalidArgument
1040 }
1041 }
1042 }
1043}
1044
1045impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
1046 type Error = MetadataValueConversionError;
1047
1048 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
1049 match &value.value {
1050 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
1051 Ok(MetadataValue::Bool(*value))
1052 }
1053 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
1054 Ok(MetadataValue::Int(*value))
1055 }
1056 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
1057 Ok(MetadataValue::Float(*value))
1058 }
1059 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
1060 Ok(MetadataValue::Str(value.clone()))
1061 }
1062 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
1063 let sparse = value
1064 .clone()
1065 .try_into()
1066 .map_err(|_| MetadataValueConversionError::SparseVectorLengthMismatch)?;
1067 Ok(MetadataValue::SparseVector(sparse))
1068 }
1069 Some(chroma_proto::update_metadata_value::Value::BoolListValue(value)) => {
1070 Ok(MetadataValue::BoolArray(value.values.clone()))
1071 }
1072 Some(chroma_proto::update_metadata_value::Value::IntListValue(value)) => {
1073 Ok(MetadataValue::IntArray(value.values.clone()))
1074 }
1075 Some(chroma_proto::update_metadata_value::Value::DoubleListValue(value)) => {
1076 Ok(MetadataValue::FloatArray(value.values.clone()))
1077 }
1078 Some(chroma_proto::update_metadata_value::Value::StringListValue(value)) => {
1079 Ok(MetadataValue::StringArray(value.values.clone()))
1080 }
1081 _ => Err(MetadataValueConversionError::InvalidValue),
1082 }
1083 }
1084}
1085
1086impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
1087 fn from(value: MetadataValue) -> Self {
1088 match value {
1089 MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
1090 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
1091 },
1092 MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
1093 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
1094 value,
1095 )),
1096 },
1097 MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
1098 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1099 value,
1100 )),
1101 },
1102 MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
1103 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
1104 },
1105 MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
1106 value: Some(
1107 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1108 sparse_vec.into(),
1109 ),
1110 ),
1111 },
1112 MetadataValue::BoolArray(values) => chroma_proto::UpdateMetadataValue {
1113 value: Some(chroma_proto::update_metadata_value::Value::BoolListValue(
1114 chroma_proto::BoolListValue { values },
1115 )),
1116 },
1117 MetadataValue::IntArray(values) => chroma_proto::UpdateMetadataValue {
1118 value: Some(chroma_proto::update_metadata_value::Value::IntListValue(
1119 chroma_proto::IntListValue { values },
1120 )),
1121 },
1122 MetadataValue::FloatArray(values) => chroma_proto::UpdateMetadataValue {
1123 value: Some(chroma_proto::update_metadata_value::Value::DoubleListValue(
1124 chroma_proto::DoubleListValue { values },
1125 )),
1126 },
1127 MetadataValue::StringArray(values) => chroma_proto::UpdateMetadataValue {
1128 value: Some(chroma_proto::update_metadata_value::Value::StringListValue(
1129 chroma_proto::StringListValue { values },
1130 )),
1131 },
1132 }
1133 }
1134}
1135
1136pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
1142
1143pub fn are_update_metadatas_close_to_equal(
1147 metadata1: &UpdateMetadata,
1148 metadata2: &UpdateMetadata,
1149) -> bool {
1150 assert_eq!(metadata1.len(), metadata2.len());
1151
1152 for (key, value) in metadata1.iter() {
1153 if !metadata2.contains_key(key) {
1154 return false;
1155 }
1156 let other_value = metadata2.get(key).unwrap();
1157
1158 if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
1159 (value, other_value)
1160 {
1161 if (value - other_value).abs() > 1e-6 {
1162 return false;
1163 }
1164 } else if value != other_value {
1165 return false;
1166 }
1167 }
1168
1169 true
1170}
1171
1172pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
1173 assert_eq!(metadata1.len(), metadata2.len());
1174
1175 for (key, value) in metadata1.iter() {
1176 if !metadata2.contains_key(key) {
1177 return false;
1178 }
1179 let other_value = metadata2.get(key).unwrap();
1180
1181 if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
1182 (value, other_value)
1183 {
1184 if (value - other_value).abs() > 1e-6 {
1185 return false;
1186 }
1187 } else if value != other_value {
1188 return false;
1189 }
1190 }
1191
1192 true
1193}
1194
1195impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
1196 type Error = UpdateMetadataValueConversionError;
1197
1198 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
1199 let mut metadata = UpdateMetadata::with_capacity(proto_metadata.metadata.len());
1200 for (key, value) in proto_metadata.metadata.into_iter() {
1201 let value = match (&value).try_into() {
1202 Ok(value) => value,
1203 Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
1204 };
1205 metadata.insert(key, value);
1206 }
1207 Ok(metadata)
1208 }
1209}
1210
1211impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
1212 fn from(metadata: UpdateMetadata) -> Self {
1213 let mut proto_metadata = chroma_proto::UpdateMetadata {
1214 metadata: HashMap::with_capacity(metadata.len()),
1215 };
1216 for (key, value) in metadata.into_iter() {
1217 let proto_value = value.into();
1218 proto_metadata.metadata.insert(key, proto_value);
1219 }
1220 proto_metadata
1221 }
1222}
1223
1224pub type Metadata = HashMap<String, MetadataValue>;
1231pub type DeletedMetadata = HashSet<String>;
1232
1233pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
1234 metadata
1235 .iter()
1236 .map(|(k, v)| {
1237 k.len()
1238 + match v {
1239 MetadataValue::Bool(b) => size_of_val(b),
1240 MetadataValue::Int(i) => size_of_val(i),
1241 MetadataValue::Float(f) => size_of_val(f),
1242 MetadataValue::Str(s) => s.len(),
1243 MetadataValue::SparseVector(v) => {
1244 size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
1245 }
1246 MetadataValue::BoolArray(arr) => size_of_val(&arr[..]),
1247 MetadataValue::IntArray(arr) => size_of_val(&arr[..]),
1248 MetadataValue::FloatArray(arr) => size_of_val(&arr[..]),
1249 MetadataValue::StringArray(arr) => arr.iter().map(|s| s.len()).sum::<usize>(),
1250 }
1251 })
1252 .sum()
1253}
1254
1255pub fn get_metadata_value_as<'a, T>(
1256 metadata: &'a Metadata,
1257 key: &str,
1258) -> Result<T, Box<MetadataValueConversionError>>
1259where
1260 T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
1261{
1262 let res = match metadata.get(key) {
1263 Some(value) => T::try_from(value),
1264 None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
1265 };
1266 match res {
1267 Ok(value) => Ok(value),
1268 Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
1269 }
1270}
1271
1272impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
1273 type Error = MetadataValueConversionError;
1274
1275 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
1276 let mut metadata = Metadata::new();
1277 for (key, value) in proto_metadata.metadata.iter() {
1278 let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
1279 if maybe_value.is_err() {
1280 return Err(MetadataValueConversionError::InvalidValue);
1281 }
1282 let value = maybe_value.unwrap();
1283 metadata.insert(key.clone(), value);
1284 }
1285 Ok(metadata)
1286 }
1287}
1288
1289impl From<Metadata> for chroma_proto::UpdateMetadata {
1290 fn from(metadata: Metadata) -> Self {
1291 let mut metadata = metadata;
1292 let mut proto_metadata = chroma_proto::UpdateMetadata {
1293 metadata: HashMap::new(),
1294 };
1295 for (key, value) in metadata.drain() {
1296 let proto_value = value.into();
1297 proto_metadata.metadata.insert(key.clone(), proto_value);
1298 }
1299 proto_metadata
1300 }
1301}
1302
1303#[derive(Debug, Default)]
1304pub struct MetadataDelta<'referred_data> {
1305 pub metadata_to_update: HashMap<
1306 &'referred_data str,
1307 (&'referred_data MetadataValue, &'referred_data MetadataValue),
1308 >,
1309 pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1310 pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
1311}
1312
1313impl MetadataDelta<'_> {
1314 pub fn new() -> Self {
1315 Self::default()
1316 }
1317}
1318
1319#[derive(Clone, Debug, Error, PartialEq)]
1326pub enum WhereConversionError {
1327 #[error("Error: {0}")]
1328 Cause(String),
1329 #[error("{0} -> {1}")]
1330 Trace(String, Box<Self>),
1331}
1332
1333impl WhereConversionError {
1334 pub fn cause(msg: impl ToString) -> Self {
1335 Self::Cause(msg.to_string())
1336 }
1337
1338 pub fn trace(self, context: impl ToString) -> Self {
1339 Self::Trace(context.to_string(), Box::new(self))
1340 }
1341}
1342
1343#[derive(Clone, Debug, PartialEq)]
1351#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1352pub enum Where {
1353 Composite(CompositeExpression),
1354 Document(DocumentExpression),
1355 Metadata(MetadataExpression),
1356}
1357
1358impl std::fmt::Display for Where {
1359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360 match self {
1361 Where::Composite(composite) => {
1362 let fragment = composite
1363 .children
1364 .iter()
1365 .map(|child| format!("{}", child))
1366 .collect::<Vec<_>>()
1367 .join(match composite.operator {
1368 BooleanOperator::And => " & ",
1369 BooleanOperator::Or => " | ",
1370 });
1371 write!(f, "({})", fragment)
1372 }
1373 Where::Metadata(expr) => write!(f, "{}", expr),
1374 Where::Document(expr) => write!(f, "{}", expr),
1375 }
1376 }
1377}
1378
1379impl serde::Serialize for Where {
1380 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1381 where
1382 S: Serializer,
1383 {
1384 match self {
1385 Where::Composite(composite) => {
1386 let mut map = serializer.serialize_map(Some(1))?;
1387 let op_key = match composite.operator {
1388 BooleanOperator::And => "$and",
1389 BooleanOperator::Or => "$or",
1390 };
1391 map.serialize_entry(op_key, &composite.children)?;
1392 map.end()
1393 }
1394 Where::Document(doc) => {
1395 let mut outer_map = serializer.serialize_map(Some(1))?;
1396 let mut inner_map = serde_json::Map::new();
1397 let op_key = match doc.operator {
1398 DocumentOperator::Contains => "$contains",
1399 DocumentOperator::NotContains => "$not_contains",
1400 DocumentOperator::Regex => "$regex",
1401 DocumentOperator::NotRegex => "$not_regex",
1402 };
1403 inner_map.insert(
1404 op_key.to_string(),
1405 serde_json::Value::String(doc.pattern.clone()),
1406 );
1407 outer_map.serialize_entry("#document", &inner_map)?;
1408 outer_map.end()
1409 }
1410 Where::Metadata(meta) => {
1411 let mut outer_map = serializer.serialize_map(Some(1))?;
1412 let mut inner_map = serde_json::Map::new();
1413
1414 match &meta.comparison {
1415 MetadataComparison::Primitive(op, value) => {
1416 let op_key = match op {
1417 PrimitiveOperator::Equal => "$eq",
1418 PrimitiveOperator::NotEqual => "$ne",
1419 PrimitiveOperator::GreaterThan => "$gt",
1420 PrimitiveOperator::GreaterThanOrEqual => "$gte",
1421 PrimitiveOperator::LessThan => "$lt",
1422 PrimitiveOperator::LessThanOrEqual => "$lte",
1423 };
1424 let value_json =
1425 serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1426 inner_map.insert(op_key.to_string(), value_json);
1427 }
1428 MetadataComparison::Set(op, set_value) => {
1429 let op_key = match op {
1430 SetOperator::In => "$in",
1431 SetOperator::NotIn => "$nin",
1432 };
1433 let values_json = match set_value {
1434 MetadataSetValue::Bool(v) => serde_json::to_value(v),
1435 MetadataSetValue::Int(v) => serde_json::to_value(v),
1436 MetadataSetValue::Float(v) => serde_json::to_value(v),
1437 MetadataSetValue::Str(v) => serde_json::to_value(v),
1438 }
1439 .map_err(serde::ser::Error::custom)?;
1440 inner_map.insert(op_key.to_string(), values_json);
1441 }
1442 MetadataComparison::ArrayContains(op, value) => {
1443 let op_key = match op {
1444 ContainsOperator::Contains => "$contains",
1445 ContainsOperator::NotContains => "$not_contains",
1446 };
1447 let value_json =
1448 serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
1449 inner_map.insert(op_key.to_string(), value_json);
1450 }
1451 }
1452
1453 outer_map.serialize_entry(&meta.key, &inner_map)?;
1454 outer_map.end()
1455 }
1456 }
1457 }
1458}
1459
1460impl From<bool> for Where {
1461 fn from(value: bool) -> Self {
1462 if value {
1463 Where::conjunction(vec![])
1464 } else {
1465 Where::disjunction(vec![])
1466 }
1467 }
1468}
1469
1470impl Where {
1471 pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1472 let mut children: Vec<_> = children
1477 .into_iter()
1478 .flat_map(|expr| {
1479 if let Where::Composite(CompositeExpression {
1480 operator: BooleanOperator::And,
1481 children,
1482 }) = expr
1483 {
1484 return children;
1485 }
1486 vec![expr]
1487 })
1488 .dedup()
1489 .collect();
1490
1491 if children.len() == 1 {
1492 return children.pop().expect("just checked len is 1");
1493 }
1494
1495 Self::Composite(CompositeExpression {
1496 operator: BooleanOperator::And,
1497 children,
1498 })
1499 }
1500 pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1501 let mut children: Vec<_> = children
1506 .into_iter()
1507 .flat_map(|expr| {
1508 if let Where::Composite(CompositeExpression {
1509 operator: BooleanOperator::Or,
1510 children,
1511 }) = expr
1512 {
1513 return children;
1514 }
1515 vec![expr]
1516 })
1517 .dedup()
1518 .collect();
1519
1520 if children.len() == 1 {
1521 return children.pop().expect("just checked len is 1");
1522 }
1523
1524 Self::Composite(CompositeExpression {
1525 operator: BooleanOperator::Or,
1526 children,
1527 })
1528 }
1529
1530 pub fn fts_query_length(&self) -> u64 {
1531 match self {
1532 Where::Composite(composite_expression) => composite_expression
1533 .children
1534 .iter()
1535 .map(Where::fts_query_length)
1536 .sum(),
1537 Where::Document(document_expression) => {
1539 document_expression.pattern.len().max(3) as u64 - 2
1540 }
1541 Where::Metadata(_) => 0,
1542 }
1543 }
1544
1545 pub fn metadata_predicate_count(&self) -> u64 {
1546 match self {
1547 Where::Composite(composite_expression) => composite_expression
1548 .children
1549 .iter()
1550 .map(Where::metadata_predicate_count)
1551 .sum(),
1552 Where::Document(_) => 0,
1553 Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1554 MetadataComparison::Primitive(_, _) => 1,
1555 MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1556 MetadataSetValue::Bool(items) => items.len() as u64,
1557 MetadataSetValue::Int(items) => items.len() as u64,
1558 MetadataSetValue::Float(items) => items.len() as u64,
1559 MetadataSetValue::Str(items) => items.len() as u64,
1560 },
1561 MetadataComparison::ArrayContains(_, _) => 1,
1562 },
1563 }
1564 }
1565}
1566
1567impl BitAnd for Where {
1568 type Output = Where;
1569
1570 fn bitand(self, rhs: Self) -> Self::Output {
1571 Self::conjunction([self, rhs])
1572 }
1573}
1574
1575impl BitOr for Where {
1576 type Output = Where;
1577
1578 fn bitor(self, rhs: Self) -> Self::Output {
1579 Self::disjunction([self, rhs])
1580 }
1581}
1582
1583impl TryFrom<chroma_proto::Where> for Where {
1584 type Error = WhereConversionError;
1585
1586 fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1587 let where_inner = proto_where
1588 .r#where
1589 .ok_or(WhereConversionError::cause("Invalid Where"))?;
1590 Ok(match where_inner {
1591 chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1592 Self::Metadata(direct_comparison.try_into()?)
1593 }
1594 chroma_proto::r#where::Where::Children(where_children) => {
1595 Self::Composite(where_children.try_into()?)
1596 }
1597 chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1598 Self::Document(direct_where_document.into())
1599 }
1600 })
1601 }
1602}
1603
1604impl TryFrom<Where> for chroma_proto::Where {
1605 type Error = WhereConversionError;
1606
1607 fn try_from(value: Where) -> Result<Self, Self::Error> {
1608 let proto_where = match value {
1609 Where::Composite(composite_expression) => {
1610 chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1611 }
1612 Where::Document(document_expression) => {
1613 chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1614 }
1615 Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1616 chroma_proto::DirectComparison::try_from(metadata_expression)
1617 .map_err(|err| err.trace("MetadataExpression"))?,
1618 ),
1619 };
1620 Ok(Self {
1621 r#where: Some(proto_where),
1622 })
1623 }
1624}
1625
1626#[derive(Clone, Debug, PartialEq)]
1627#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1628pub struct CompositeExpression {
1629 pub operator: BooleanOperator,
1630 pub children: Vec<Where>,
1631}
1632
1633impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1634 type Error = WhereConversionError;
1635
1636 fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1637 let operator = proto_children.operator().into();
1638 let children = proto_children
1639 .children
1640 .into_iter()
1641 .map(Where::try_from)
1642 .collect::<Result<Vec<_>, _>>()
1643 .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1644 Ok(Self { operator, children })
1645 }
1646}
1647
1648impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1649 type Error = WhereConversionError;
1650
1651 fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1652 Ok(Self {
1653 operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1654 children: value
1655 .children
1656 .into_iter()
1657 .map(chroma_proto::Where::try_from)
1658 .collect::<Result<_, _>>()?,
1659 })
1660 }
1661}
1662
1663#[derive(Clone, Debug, PartialEq)]
1664#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1665pub enum BooleanOperator {
1666 And,
1667 Or,
1668}
1669
1670impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1671 fn from(value: chroma_proto::BooleanOperator) -> Self {
1672 match value {
1673 chroma_proto::BooleanOperator::And => Self::And,
1674 chroma_proto::BooleanOperator::Or => Self::Or,
1675 }
1676 }
1677}
1678
1679impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1680 fn from(value: BooleanOperator) -> Self {
1681 match value {
1682 BooleanOperator::And => Self::And,
1683 BooleanOperator::Or => Self::Or,
1684 }
1685 }
1686}
1687
1688#[derive(Clone, Debug, PartialEq)]
1689#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1690pub struct DocumentExpression {
1691 pub operator: DocumentOperator,
1692 pub pattern: String,
1693}
1694
1695impl std::fmt::Display for DocumentExpression {
1696 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1697 let op_str = match self.operator {
1698 DocumentOperator::Contains => "CONTAINS",
1699 DocumentOperator::NotContains => "NOT CONTAINS",
1700 DocumentOperator::Regex => "REGEX",
1701 DocumentOperator::NotRegex => "NOT REGEX",
1702 };
1703 write!(f, "#document {} \"{}\"", op_str, self.pattern)
1704 }
1705}
1706
1707impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1708 fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1709 Self {
1710 operator: value.operator().into(),
1711 pattern: value.pattern,
1712 }
1713 }
1714}
1715
1716impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1717 fn from(value: DocumentExpression) -> Self {
1718 Self {
1719 pattern: value.pattern,
1720 operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1721 }
1722 }
1723}
1724
1725#[derive(Clone, Debug, PartialEq)]
1726#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1727pub enum DocumentOperator {
1728 Contains,
1729 NotContains,
1730 Regex,
1731 NotRegex,
1732}
1733impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1734 fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1735 match value {
1736 chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1737 chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1738 chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1739 chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1740 }
1741 }
1742}
1743
1744impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1745 fn from(value: DocumentOperator) -> Self {
1746 match value {
1747 DocumentOperator::Contains => Self::Contains,
1748 DocumentOperator::NotContains => Self::NotContains,
1749 DocumentOperator::Regex => Self::Regex,
1750 DocumentOperator::NotRegex => Self::NotRegex,
1751 }
1752 }
1753}
1754
1755#[derive(Clone, Debug, PartialEq)]
1756#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1757pub struct MetadataExpression {
1758 pub key: String,
1759 pub comparison: MetadataComparison,
1760}
1761
1762impl std::fmt::Display for MetadataExpression {
1763 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1764 match &self.comparison {
1765 MetadataComparison::Primitive(op, value) => {
1766 write!(f, "{} {} {}", self.key, op, value)
1767 }
1768 MetadataComparison::Set(op, set_value) => {
1769 write!(f, "{} {} {}", self.key, op, set_value)
1770 }
1771 MetadataComparison::ArrayContains(op, value) => {
1772 write!(f, "{} {} {}", self.key, op, value)
1773 }
1774 }
1775 }
1776}
1777
1778fn generic_comparator_to_metadata_comparison(
1782 comparator: chroma_proto::GenericComparator,
1783 value: MetadataValue,
1784) -> MetadataComparison {
1785 match comparator {
1786 chroma_proto::GenericComparator::Eq | chroma_proto::GenericComparator::Ne => {
1787 MetadataComparison::Primitive(comparator.try_into().unwrap(), value)
1790 }
1791 chroma_proto::GenericComparator::ArrayContains => {
1792 MetadataComparison::ArrayContains(ContainsOperator::Contains, value)
1793 }
1794 chroma_proto::GenericComparator::ArrayNotContains => {
1795 MetadataComparison::ArrayContains(ContainsOperator::NotContains, value)
1796 }
1797 }
1798}
1799
1800impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1801 type Error = WhereConversionError;
1802
1803 fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1804 let proto_comparison = value
1805 .comparison
1806 .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1807 let comparison = match proto_comparison {
1808 chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1809 single_string_comparison,
1810 ) => generic_comparator_to_metadata_comparison(
1811 single_string_comparison.comparator(),
1812 MetadataValue::Str(single_string_comparison.value),
1813 ),
1814 chroma_proto::direct_comparison::Comparison::StringListOperand(
1815 string_list_comparison,
1816 ) => MetadataComparison::Set(
1817 string_list_comparison.list_operator().into(),
1818 MetadataSetValue::Str(string_list_comparison.values),
1819 ),
1820 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1821 single_int_comparison,
1822 ) => {
1823 let comparator =
1824 single_int_comparison
1825 .comparator
1826 .ok_or(WhereConversionError::cause(
1827 "Invalid scalar integer operator",
1828 ))?;
1829 let value = MetadataValue::Int(single_int_comparison.value);
1830 match comparator {
1831 chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1832 let generic = chroma_proto::GenericComparator::try_from(op)
1833 .map_err(WhereConversionError::cause)?;
1834 generic_comparator_to_metadata_comparison(generic, value)
1835 }
1836 chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1837 MetadataComparison::Primitive(
1838 chroma_proto::NumberComparator::try_from(op)
1839 .map_err(WhereConversionError::cause)?
1840 .into(),
1841 value,
1842 )
1843 }
1844 }
1845 }
1846 chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1847 MetadataComparison::Set(
1848 int_list_comparison.list_operator().into(),
1849 MetadataSetValue::Int(int_list_comparison.values),
1850 )
1851 }
1852 chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1853 single_double_comparison,
1854 ) => {
1855 let comparator = single_double_comparison
1856 .comparator
1857 .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?;
1858 let value = MetadataValue::Float(single_double_comparison.value);
1859 match comparator {
1860 chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1861 let generic = chroma_proto::GenericComparator::try_from(op)
1862 .map_err(WhereConversionError::cause)?;
1863 generic_comparator_to_metadata_comparison(generic, value)
1864 }
1865 chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1866 MetadataComparison::Primitive(
1867 chroma_proto::NumberComparator::try_from(op)
1868 .map_err(WhereConversionError::cause)?
1869 .into(),
1870 value,
1871 )
1872 }
1873 }
1874 }
1875 chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1876 double_list_comparison,
1877 ) => MetadataComparison::Set(
1878 double_list_comparison.list_operator().into(),
1879 MetadataSetValue::Float(double_list_comparison.values),
1880 ),
1881 chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1882 MetadataComparison::Set(
1883 bool_list_comparison.list_operator().into(),
1884 MetadataSetValue::Bool(bool_list_comparison.values),
1885 )
1886 }
1887 chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1888 single_bool_comparison,
1889 ) => generic_comparator_to_metadata_comparison(
1890 single_bool_comparison.comparator(),
1891 MetadataValue::Bool(single_bool_comparison.value),
1892 ),
1893 };
1894 Ok(Self {
1895 key: value.key,
1896 comparison,
1897 })
1898 }
1899}
1900
1901impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1902 type Error = WhereConversionError;
1903
1904 fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1905 let comparison = match value.comparison {
1906 MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1907 MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1908 MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1909 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1910 numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1911 }),
1912 MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1913 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1914 numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1915 }),
1916 MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1917 MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1918 MetadataValue::BoolArray(_) | MetadataValue::IntArray(_) | MetadataValue::FloatArray(_) | MetadataValue::StringArray(_) => {
1919 return Err(WhereConversionError::Cause("Primitive comparison with array metadata values is not supported".to_string()))
1920 }
1921 },
1922 MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1923 MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1924 MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1925 MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1926 MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1927 },
1928 MetadataComparison::ArrayContains(contains_operator, metadata_value) => {
1929 let comparator = chroma_proto::GenericComparator::from(contains_operator) as i32;
1930 match metadata_value {
1931 MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator }),
1932 MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(comparator)) }),
1933 MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(chroma_proto::single_double_comparison::Comparator::GenericComparator(comparator)) }),
1934 MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator }),
1935 MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Contains comparison with sparse vector is not supported".to_string())),
1936 MetadataValue::BoolArray(_) | MetadataValue::IntArray(_) | MetadataValue::FloatArray(_) | MetadataValue::StringArray(_) => {
1937 return Err(WhereConversionError::Cause("Contains comparison value must be a scalar, not an array".to_string()))
1938 }
1939 }
1940 },
1941 };
1942 Ok(Self {
1943 key: value.key,
1944 comparison: Some(comparison),
1945 })
1946 }
1947}
1948
1949#[derive(Clone, Debug, PartialEq)]
1950#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1951pub enum MetadataComparison {
1952 Primitive(PrimitiveOperator, MetadataValue),
1953 Set(SetOperator, MetadataSetValue),
1954 ArrayContains(ContainsOperator, MetadataValue),
1957}
1958
1959impl std::fmt::Display for MetadataComparison {
1960 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1961 match self {
1962 MetadataComparison::Primitive(op, val) => {
1963 let type_name = match val {
1964 MetadataValue::Bool(_) => "Bool",
1965 MetadataValue::Int(_) => "Int",
1966 MetadataValue::Float(_) => "Float",
1967 MetadataValue::Str(_) => "Str",
1968 MetadataValue::SparseVector(_) => "SparseVector",
1969 MetadataValue::BoolArray(_) => "BoolArray",
1970 MetadataValue::IntArray(_) => "IntArray",
1971 MetadataValue::FloatArray(_) => "FloatArray",
1972 MetadataValue::StringArray(_) => "StringArray",
1973 };
1974 write!(f, "Primitive({}, {})", op, type_name)
1975 }
1976 MetadataComparison::Set(op, val) => {
1977 let type_name = match val {
1978 MetadataSetValue::Bool(_) => "Bool",
1979 MetadataSetValue::Int(_) => "Int",
1980 MetadataSetValue::Float(_) => "Float",
1981 MetadataSetValue::Str(_) => "Str",
1982 };
1983 write!(f, "Set({}, {})", op, type_name)
1984 }
1985 MetadataComparison::ArrayContains(op, val) => {
1986 let type_name = match val {
1987 MetadataValue::Bool(_) => "Bool",
1988 MetadataValue::Int(_) => "Int",
1989 MetadataValue::Float(_) => "Float",
1990 MetadataValue::Str(_) => "Str",
1991 MetadataValue::SparseVector(_) => "SparseVector",
1992 MetadataValue::BoolArray(_) => "BoolArray",
1993 MetadataValue::IntArray(_) => "IntArray",
1994 MetadataValue::FloatArray(_) => "FloatArray",
1995 MetadataValue::StringArray(_) => "StringArray",
1996 };
1997 write!(f, "ArrayContains({}, {})", op, type_name)
1998 }
1999 }
2000 }
2001}
2002
2003#[derive(Clone, Debug, PartialEq)]
2004#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2005pub enum PrimitiveOperator {
2006 Equal,
2007 NotEqual,
2008 GreaterThan,
2009 GreaterThanOrEqual,
2010 LessThan,
2011 LessThanOrEqual,
2012}
2013
2014impl std::fmt::Display for PrimitiveOperator {
2015 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2016 let op_str = match self {
2017 PrimitiveOperator::Equal => "=",
2018 PrimitiveOperator::NotEqual => "≠",
2019 PrimitiveOperator::GreaterThan => ">",
2020 PrimitiveOperator::GreaterThanOrEqual => "≥",
2021 PrimitiveOperator::LessThan => "<",
2022 PrimitiveOperator::LessThanOrEqual => "≤",
2023 };
2024 write!(f, "{}", op_str)
2025 }
2026}
2027
2028impl TryFrom<chroma_proto::GenericComparator> for PrimitiveOperator {
2029 type Error = WhereConversionError;
2030
2031 fn try_from(value: chroma_proto::GenericComparator) -> Result<Self, Self::Error> {
2032 match value {
2033 chroma_proto::GenericComparator::Eq => Ok(Self::Equal),
2034 chroma_proto::GenericComparator::Ne => Ok(Self::NotEqual),
2035 chroma_proto::GenericComparator::ArrayContains
2036 | chroma_proto::GenericComparator::ArrayNotContains => {
2037 Err(WhereConversionError::cause(
2038 "ArrayContains/ArrayNotContains cannot be converted to PrimitiveOperator",
2039 ))
2040 }
2041 }
2042 }
2043}
2044
2045impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
2046 type Error = WhereConversionError;
2047
2048 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
2049 match value {
2050 PrimitiveOperator::Equal => Ok(Self::Eq),
2051 PrimitiveOperator::NotEqual => Ok(Self::Ne),
2052 op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
2053 }
2054 }
2055}
2056
2057impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
2058 fn from(value: chroma_proto::NumberComparator) -> Self {
2059 match value {
2060 chroma_proto::NumberComparator::Gt => Self::GreaterThan,
2061 chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
2062 chroma_proto::NumberComparator::Lt => Self::LessThan,
2063 chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
2064 }
2065 }
2066}
2067
2068impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
2069 type Error = WhereConversionError;
2070
2071 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
2072 match value {
2073 PrimitiveOperator::GreaterThan => Ok(Self::Gt),
2074 PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
2075 PrimitiveOperator::LessThan => Ok(Self::Lt),
2076 PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
2077 op => Err(WhereConversionError::cause(format!(
2078 "{op:?} ∉ [≤, <, >, ≥]"
2079 ))),
2080 }
2081 }
2082}
2083
2084#[derive(Clone, Debug, PartialEq, Eq)]
2085#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2086pub enum SetOperator {
2087 In,
2088 NotIn,
2089}
2090
2091impl std::fmt::Display for SetOperator {
2092 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2093 let op_str = match self {
2094 SetOperator::In => "∈",
2095 SetOperator::NotIn => "∉",
2096 };
2097 write!(f, "{}", op_str)
2098 }
2099}
2100
2101impl From<chroma_proto::ListOperator> for SetOperator {
2102 fn from(value: chroma_proto::ListOperator) -> Self {
2103 match value {
2104 chroma_proto::ListOperator::In => Self::In,
2105 chroma_proto::ListOperator::Nin => Self::NotIn,
2106 }
2107 }
2108}
2109
2110impl From<SetOperator> for chroma_proto::ListOperator {
2111 fn from(value: SetOperator) -> Self {
2112 match value {
2113 SetOperator::In => Self::In,
2114 SetOperator::NotIn => Self::Nin,
2115 }
2116 }
2117}
2118
2119#[derive(Clone, Debug, PartialEq, Eq)]
2120#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2121pub enum ContainsOperator {
2122 Contains,
2123 NotContains,
2124}
2125
2126impl std::fmt::Display for ContainsOperator {
2127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2128 let op_str = match self {
2129 ContainsOperator::Contains => "contains",
2130 ContainsOperator::NotContains => "not_contains",
2131 };
2132 write!(f, "{}", op_str)
2133 }
2134}
2135
2136impl From<ContainsOperator> for chroma_proto::GenericComparator {
2137 fn from(value: ContainsOperator) -> Self {
2138 match value {
2139 ContainsOperator::Contains => Self::ArrayContains,
2140 ContainsOperator::NotContains => Self::ArrayNotContains,
2141 }
2142 }
2143}
2144
2145#[derive(Clone, Debug, PartialEq)]
2146#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
2147pub enum MetadataSetValue {
2148 Bool(Vec<bool>),
2149 Int(Vec<i64>),
2150 Float(Vec<f64>),
2151 Str(Vec<String>),
2152}
2153
2154impl std::fmt::Display for MetadataSetValue {
2155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2156 match self {
2157 MetadataSetValue::Bool(values) => {
2158 let values_str = values
2159 .iter()
2160 .map(|v| format!("\"{}\"", v))
2161 .collect::<Vec<_>>()
2162 .join(", ");
2163 write!(f, "[{}]", values_str)
2164 }
2165 MetadataSetValue::Int(values) => {
2166 let values_str = values
2167 .iter()
2168 .map(|v| v.to_string())
2169 .collect::<Vec<_>>()
2170 .join(", ");
2171 write!(f, "[{}]", values_str)
2172 }
2173 MetadataSetValue::Float(values) => {
2174 let values_str = values
2175 .iter()
2176 .map(|v| v.to_string())
2177 .collect::<Vec<_>>()
2178 .join(", ");
2179 write!(f, "[{}]", values_str)
2180 }
2181 MetadataSetValue::Str(values) => {
2182 let values_str = values
2183 .iter()
2184 .map(|v| format!("\"{}\"", v))
2185 .collect::<Vec<_>>()
2186 .join(", ");
2187 write!(f, "[{}]", values_str)
2188 }
2189 }
2190 }
2191}
2192
2193impl MetadataSetValue {
2194 pub fn value_type(&self) -> MetadataValueType {
2195 match self {
2196 MetadataSetValue::Bool(_) => MetadataValueType::Bool,
2197 MetadataSetValue::Int(_) => MetadataValueType::Int,
2198 MetadataSetValue::Float(_) => MetadataValueType::Float,
2199 MetadataSetValue::Str(_) => MetadataValueType::Str,
2200 }
2201 }
2202}
2203
2204impl From<Vec<bool>> for MetadataSetValue {
2205 fn from(values: Vec<bool>) -> Self {
2206 MetadataSetValue::Bool(values)
2207 }
2208}
2209
2210impl From<Vec<i64>> for MetadataSetValue {
2211 fn from(values: Vec<i64>) -> Self {
2212 MetadataSetValue::Int(values)
2213 }
2214}
2215
2216impl From<Vec<i32>> for MetadataSetValue {
2217 fn from(values: Vec<i32>) -> Self {
2218 MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
2219 }
2220}
2221
2222impl From<Vec<f64>> for MetadataSetValue {
2223 fn from(values: Vec<f64>) -> Self {
2224 MetadataSetValue::Float(values)
2225 }
2226}
2227
2228impl From<Vec<f32>> for MetadataSetValue {
2229 fn from(values: Vec<f32>) -> Self {
2230 MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
2231 }
2232}
2233
2234impl From<Vec<String>> for MetadataSetValue {
2235 fn from(values: Vec<String>) -> Self {
2236 MetadataSetValue::Str(values)
2237 }
2238}
2239
2240impl From<Vec<&str>> for MetadataSetValue {
2241 fn from(values: Vec<&str>) -> Self {
2242 MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
2243 }
2244}
2245
2246impl TryFrom<chroma_proto::WhereDocument> for Where {
2248 type Error = WhereConversionError;
2249
2250 fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
2251 match proto_document.r#where_document {
2252 Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
2253 let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
2254 proto_comparison.operator,
2255 ) {
2256 Ok(operator) => operator,
2257 Err(_) => {
2258 return Err(WhereConversionError::cause(
2259 "[Deprecated] Invalid where document operator",
2260 ))
2261 }
2262 };
2263 let comparison = DocumentExpression {
2264 pattern: proto_comparison.pattern,
2265 operator: operator.into(),
2266 };
2267 Ok(Where::Document(comparison))
2268 }
2269 Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
2270 let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
2271 proto_children.operator,
2272 ) {
2273 Ok(operator) => operator,
2274 Err(_) => {
2275 return Err(WhereConversionError::cause(
2276 "[Deprecated] Invalid boolean operator",
2277 ))
2278 }
2279 };
2280 let children = CompositeExpression {
2281 children: proto_children
2282 .children
2283 .into_iter()
2284 .map(|child| child.try_into())
2285 .collect::<Result<_, _>>()?,
2286 operator: operator.into(),
2287 };
2288 Ok(Where::Composite(children))
2289 }
2290 None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
2291 }
2292 }
2293}
2294
2295#[cfg(test)]
2296mod tests {
2297 use crate::operator::Key;
2298
2299 use super::*;
2300
2301 #[cfg(feature = "pyo3")]
2303 fn ensure_python_interpreter() {
2304 static PYTHON_INIT: std::sync::Once = std::sync::Once::new();
2305 PYTHON_INIT.call_once(|| {
2306 pyo3::prepare_freethreaded_python();
2307 });
2308 }
2309
2310 #[test]
2311 fn test_update_metadata_try_from() {
2312 let mut proto_metadata = chroma_proto::UpdateMetadata {
2313 metadata: HashMap::new(),
2314 };
2315 proto_metadata.metadata.insert(
2316 "foo".to_string(),
2317 chroma_proto::UpdateMetadataValue {
2318 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
2319 },
2320 );
2321 proto_metadata.metadata.insert(
2322 "bar".to_string(),
2323 chroma_proto::UpdateMetadataValue {
2324 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
2325 },
2326 );
2327 proto_metadata.metadata.insert(
2328 "baz".to_string(),
2329 chroma_proto::UpdateMetadataValue {
2330 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
2331 "42".to_string(),
2332 )),
2333 },
2334 );
2335 proto_metadata.metadata.insert(
2337 "sparse".to_string(),
2338 chroma_proto::UpdateMetadataValue {
2339 value: Some(
2340 chroma_proto::update_metadata_value::Value::SparseVectorValue(
2341 chroma_proto::SparseVector {
2342 indices: vec![0, 5, 10],
2343 values: vec![0.1, 0.5, 0.9],
2344 tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2345 },
2346 ),
2347 ),
2348 },
2349 );
2350 let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
2351 assert_eq!(converted_metadata.len(), 4);
2352 assert_eq!(
2353 converted_metadata.get("foo").unwrap(),
2354 &UpdateMetadataValue::Int(42)
2355 );
2356 assert_eq!(
2357 converted_metadata.get("bar").unwrap(),
2358 &UpdateMetadataValue::Float(42.0)
2359 );
2360 assert_eq!(
2361 converted_metadata.get("baz").unwrap(),
2362 &UpdateMetadataValue::Str("42".to_string())
2363 );
2364 assert_eq!(
2365 converted_metadata.get("sparse").unwrap(),
2366 &UpdateMetadataValue::SparseVector(
2367 SparseVector::new_with_tokens(
2368 vec![0, 5, 10],
2369 vec![0.1, 0.5, 0.9],
2370 vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2371 )
2372 .unwrap()
2373 )
2374 );
2375 }
2376
2377 #[test]
2378 fn test_metadata_try_from() {
2379 let mut proto_metadata = chroma_proto::UpdateMetadata {
2380 metadata: HashMap::new(),
2381 };
2382 proto_metadata.metadata.insert(
2383 "foo".to_string(),
2384 chroma_proto::UpdateMetadataValue {
2385 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
2386 },
2387 );
2388 proto_metadata.metadata.insert(
2389 "bar".to_string(),
2390 chroma_proto::UpdateMetadataValue {
2391 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
2392 },
2393 );
2394 proto_metadata.metadata.insert(
2395 "baz".to_string(),
2396 chroma_proto::UpdateMetadataValue {
2397 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
2398 "42".to_string(),
2399 )),
2400 },
2401 );
2402 proto_metadata.metadata.insert(
2404 "sparse".to_string(),
2405 chroma_proto::UpdateMetadataValue {
2406 value: Some(
2407 chroma_proto::update_metadata_value::Value::SparseVectorValue(
2408 chroma_proto::SparseVector {
2409 indices: vec![1, 10, 100],
2410 values: vec![0.2, 0.4, 0.6],
2411 tokens: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2412 },
2413 ),
2414 ),
2415 },
2416 );
2417 let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
2418 assert_eq!(converted_metadata.len(), 4);
2419 assert_eq!(
2420 converted_metadata.get("foo").unwrap(),
2421 &MetadataValue::Int(42)
2422 );
2423 assert_eq!(
2424 converted_metadata.get("bar").unwrap(),
2425 &MetadataValue::Float(42.0)
2426 );
2427 assert_eq!(
2428 converted_metadata.get("baz").unwrap(),
2429 &MetadataValue::Str("42".to_string())
2430 );
2431 assert_eq!(
2432 converted_metadata.get("sparse").unwrap(),
2433 &MetadataValue::SparseVector(
2434 SparseVector::new_with_tokens(
2435 vec![1, 10, 100],
2436 vec![0.2, 0.4, 0.6],
2437 vec!["foo".to_string(), "bar".to_string(), "baz".to_string(),],
2438 )
2439 .unwrap()
2440 )
2441 );
2442 }
2443
2444 #[test]
2445 fn test_where_clause_simple_from() {
2446 let proto_where = chroma_proto::Where {
2447 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2448 chroma_proto::DirectComparison {
2449 key: "foo".to_string(),
2450 comparison: Some(
2451 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2452 chroma_proto::SingleIntComparison {
2453 value: 42,
2454 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2455 },
2456 ),
2457 ),
2458 },
2459 )),
2460 };
2461 let where_clause: Where = proto_where.try_into().unwrap();
2462 match where_clause {
2463 Where::Metadata(comparison) => {
2464 assert_eq!(comparison.key, "foo");
2465 match comparison.comparison {
2466 MetadataComparison::Primitive(_, value) => {
2467 assert_eq!(value, MetadataValue::Int(42));
2468 }
2469 _ => panic!("Invalid comparison type"),
2470 }
2471 }
2472 _ => panic!("Invalid where type"),
2473 }
2474 }
2475
2476 #[test]
2477 fn test_where_clause_with_children() {
2478 let proto_where = chroma_proto::Where {
2479 r#where: Some(chroma_proto::r#where::Where::Children(
2480 chroma_proto::WhereChildren {
2481 children: vec![
2482 chroma_proto::Where {
2483 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2484 chroma_proto::DirectComparison {
2485 key: "foo".to_string(),
2486 comparison: Some(
2487 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2488 chroma_proto::SingleIntComparison {
2489 value: 42,
2490 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2491 },
2492 ),
2493 ),
2494 },
2495 )),
2496 },
2497 chroma_proto::Where {
2498 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
2499 chroma_proto::DirectComparison {
2500 key: "bar".to_string(),
2501 comparison: Some(
2502 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
2503 chroma_proto::SingleIntComparison {
2504 value: 42,
2505 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
2506 },
2507 ),
2508 ),
2509 },
2510 )),
2511 },
2512 ],
2513 operator: chroma_proto::BooleanOperator::And.into(),
2514 },
2515 )),
2516 };
2517 let where_clause: Where = proto_where.try_into().unwrap();
2518 match where_clause {
2519 Where::Composite(children) => {
2520 assert_eq!(children.children.len(), 2);
2521 assert_eq!(children.operator, BooleanOperator::And);
2522 }
2523 _ => panic!("Invalid where type"),
2524 }
2525 }
2526
2527 #[test]
2528 fn test_where_document_simple() {
2529 let proto_where = chroma_proto::WhereDocument {
2530 r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
2531 chroma_proto::DirectWhereDocument {
2532 pattern: "foo".to_string(),
2533 operator: chroma_proto::WhereDocumentOperator::Contains.into(),
2534 },
2535 )),
2536 };
2537 let where_document: Where = proto_where.try_into().unwrap();
2538 match where_document {
2539 Where::Document(comparison) => {
2540 assert_eq!(comparison.pattern, "foo");
2541 assert_eq!(comparison.operator, DocumentOperator::Contains);
2542 }
2543 _ => panic!("Invalid where document type"),
2544 }
2545 }
2546
2547 #[test]
2548 fn test_where_document_with_children() {
2549 let proto_where = chroma_proto::WhereDocument {
2550 r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
2551 chroma_proto::WhereDocumentChildren {
2552 children: vec![
2553 chroma_proto::WhereDocument {
2554 r#where_document: Some(
2555 chroma_proto::where_document::WhereDocument::Direct(
2556 chroma_proto::DirectWhereDocument {
2557 pattern: "foo".to_string(),
2558 operator: chroma_proto::WhereDocumentOperator::Contains
2559 .into(),
2560 },
2561 ),
2562 ),
2563 },
2564 chroma_proto::WhereDocument {
2565 r#where_document: Some(
2566 chroma_proto::where_document::WhereDocument::Direct(
2567 chroma_proto::DirectWhereDocument {
2568 pattern: "bar".to_string(),
2569 operator: chroma_proto::WhereDocumentOperator::Contains
2570 .into(),
2571 },
2572 ),
2573 ),
2574 },
2575 ],
2576 operator: chroma_proto::BooleanOperator::And.into(),
2577 },
2578 )),
2579 };
2580 let where_document: Where = proto_where.try_into().unwrap();
2581 match where_document {
2582 Where::Composite(children) => {
2583 assert_eq!(children.children.len(), 2);
2584 assert_eq!(children.operator, BooleanOperator::And);
2585 }
2586 _ => panic!("Invalid where document type"),
2587 }
2588 }
2589
2590 #[test]
2591 fn test_sparse_vector_new() {
2592 let indices = vec![0, 5, 10];
2593 let values = vec![0.1, 0.5, 0.9];
2594 let sparse = SparseVector::new(indices.clone(), values.clone()).unwrap();
2595 assert_eq!(sparse.indices, indices);
2596 assert_eq!(sparse.values, values);
2597 }
2598
2599 #[test]
2600 fn test_sparse_vector_from_pairs() {
2601 let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
2602 let sparse = SparseVector::from_pairs(pairs.clone());
2603 assert_eq!(sparse.indices, vec![0, 5, 10]);
2604 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2605 }
2606
2607 #[test]
2608 fn test_sparse_vector_from_triples() {
2609 let triples = vec![
2610 ("foo".to_string(), 0, 0.1),
2611 ("bar".to_string(), 5, 0.5),
2612 ("baz".to_string(), 10, 0.9),
2613 ];
2614 let sparse = SparseVector::from_triples(triples.clone());
2615 assert_eq!(sparse.indices, vec![0, 5, 10]);
2616 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
2617 }
2618
2619 #[test]
2620 fn test_sparse_vector_iter() {
2621 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2622 let collected: Vec<(u32, f32)> = sparse.iter().collect();
2623 assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
2624 }
2625
2626 #[test]
2627 fn test_sparse_vector_ordering() {
2628 let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2629 let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]).unwrap();
2630 let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]).unwrap();
2631 let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]).unwrap();
2632
2633 assert_eq!(sparse1, sparse2);
2634 assert!(sparse1 < sparse3);
2635 assert!(sparse1 < sparse4);
2636 }
2637
2638 #[test]
2639 fn test_sparse_vector_proto_conversion() {
2640 let tokens = vec![
2641 "token1".to_string(),
2642 "token2".to_string(),
2643 "token3".to_string(),
2644 ];
2645 let sparse =
2646 SparseVector::new_with_tokens(vec![1, 10, 100], vec![0.2, 0.4, 0.6], tokens.clone())
2647 .unwrap();
2648 let proto: chroma_proto::SparseVector = sparse.clone().into();
2649 assert_eq!(proto.indices, vec![1, 10, 100]);
2650 assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
2651 assert_eq!(proto.tokens, tokens.clone());
2652
2653 let converted: SparseVector = proto.try_into().unwrap();
2654 assert_eq!(converted, sparse);
2655 assert_eq!(converted.tokens, Some(tokens));
2656 }
2657
2658 #[test]
2659 fn test_sparse_vector_proto_conversion_empty_tokens() {
2660 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2661 let proto: chroma_proto::SparseVector = sparse.clone().into();
2662 assert_eq!(proto.indices, vec![0, 5, 10]);
2663 assert_eq!(proto.values, vec![0.1, 0.5, 0.9]);
2664 assert_eq!(proto.tokens, Vec::<String>::new());
2665
2666 let converted: SparseVector = proto.try_into().unwrap();
2667 assert_eq!(converted, sparse);
2668 assert_eq!(converted.tokens, None);
2669 }
2670
2671 #[test]
2672 fn test_sparse_vector_logical_size() {
2673 let metadata = Metadata::from([(
2674 "sparse".to_string(),
2675 MetadataValue::SparseVector(
2676 SparseVector::new(vec![0, 1, 2, 3, 4], vec![0.1, 0.2, 0.3, 0.4, 0.5]).unwrap(),
2677 ),
2678 )]);
2679
2680 let size = logical_size_of_metadata(&metadata);
2681 assert_eq!(size, 46);
2684 }
2685
2686 #[test]
2687 fn test_sparse_vector_validation() {
2688 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
2690 assert!(sparse.validate().is_ok());
2691
2692 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
2694 assert!(sparse.is_err());
2695 let result = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3])
2696 .unwrap()
2697 .validate();
2698 assert!(result.is_ok());
2699
2700 let sparse = SparseVector::new_with_tokens(
2702 vec![1, 2, 3],
2703 vec![0.1, 0.2, 0.3],
2704 vec!["a".to_string(), "b".to_string()],
2705 );
2706 assert!(sparse.is_err());
2707
2708 let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]).unwrap();
2710 let result = sparse.validate();
2711 assert!(result.is_err());
2712 assert!(matches!(
2713 result.unwrap_err(),
2714 MetadataValueConversionError::SparseVectorIndicesNotSorted
2715 ));
2716
2717 let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]).unwrap();
2719 let result = sparse.validate();
2720 assert!(result.is_err());
2721 assert!(matches!(
2722 result.unwrap_err(),
2723 MetadataValueConversionError::SparseVectorIndicesNotSorted
2724 ));
2725
2726 let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]).unwrap();
2728 let result = sparse.validate();
2729 assert!(result.is_err());
2730 assert!(matches!(
2731 result.unwrap_err(),
2732 MetadataValueConversionError::SparseVectorIndicesNotSorted
2733 ));
2734 }
2735
2736 #[test]
2737 fn test_sparse_vector_deserialize_old_format() {
2738 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2740 let sv: SparseVector = serde_json::from_str(json).unwrap();
2741 assert_eq!(sv.indices, vec![0, 1, 2]);
2742 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2743 }
2744
2745 #[test]
2746 fn test_sparse_vector_deserialize_new_format() {
2747 let json =
2749 "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2750 let sv: SparseVector = serde_json::from_str(json).unwrap();
2751 assert_eq!(sv.indices, vec![0, 1, 2]);
2752 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2753 }
2754
2755 #[test]
2756 fn test_sparse_vector_deserialize_new_format_field_order() {
2757 let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2759 let sv: SparseVector = serde_json::from_str(json).unwrap();
2760 assert_eq!(sv.indices, vec![5, 10]);
2761 assert_eq!(sv.values, vec![0.5, 1.0]);
2762 }
2763
2764 #[test]
2765 fn test_sparse_vector_deserialize_wrong_type_tag() {
2766 let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2768 let result: Result<SparseVector, _> = serde_json::from_str(json);
2769 assert!(result.is_err());
2770 let err_msg = result.unwrap_err().to_string();
2771 assert!(err_msg.contains("sparse_vector"));
2772 }
2773
2774 #[test]
2775 fn test_sparse_vector_serialize_always_has_type() {
2776 let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]).unwrap();
2778 let json = serde_json::to_value(&sv).unwrap();
2779
2780 assert_eq!(json["#type"], "sparse_vector");
2781 assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2782 assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2783 }
2784
2785 #[test]
2786 fn test_sparse_vector_roundtrip_with_type() {
2787 let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]).unwrap();
2789 let json = serde_json::to_string(&original).unwrap();
2790
2791 assert!(json.contains("\"#type\":\"sparse_vector\""));
2793
2794 let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2795 assert_eq!(original, deserialized);
2796 }
2797
2798 #[test]
2799 fn test_sparse_vector_in_metadata_old_format() {
2800 let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2802 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2803
2804 let sparse_value = &map["sparse"];
2805 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2806 assert_eq!(sv.indices, vec![0, 1]);
2807 assert_eq!(sv.values, vec![1.0, 2.0]);
2808 }
2809
2810 #[test]
2811 fn test_sparse_vector_in_metadata_new_format() {
2812 let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2814 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2815
2816 let sparse_value = &map["sparse"];
2817 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2818 assert_eq!(sv.indices, vec![0, 1]);
2819 assert_eq!(sv.values, vec![1.0, 2.0]);
2820 }
2821
2822 #[test]
2823 fn test_sparse_vector_tokens_roundtrip_old_to_new() {
2824 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2826 let sv: SparseVector = serde_json::from_str(json).unwrap();
2827 assert_eq!(sv.indices, vec![0, 1, 2]);
2828 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2829 assert_eq!(sv.tokens, None);
2830
2831 let serialized = serde_json::to_value(&sv).unwrap();
2833 assert_eq!(serialized["#type"], "sparse_vector");
2834 assert_eq!(serialized["indices"], serde_json::json!([0, 1, 2]));
2835 assert_eq!(serialized["values"], serde_json::json!([1.0, 2.0, 3.0]));
2836 assert_eq!(serialized["tokens"], serde_json::Value::Null);
2837 }
2838
2839 #[test]
2840 fn test_sparse_vector_tokens_roundtrip_new_to_new() {
2841 let sv_with_tokens = SparseVector::new_with_tokens(
2843 vec![0, 1, 2],
2844 vec![1.0, 2.0, 3.0],
2845 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
2846 )
2847 .unwrap();
2848
2849 let serialized = serde_json::to_string(&sv_with_tokens).unwrap();
2851 assert!(serialized.contains("\"#type\":\"sparse_vector\""));
2852 assert!(serialized.contains("\"tokens\""));
2853
2854 let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2856 assert_eq!(deserialized.indices, vec![0, 1, 2]);
2857 assert_eq!(deserialized.values, vec![1.0, 2.0, 3.0]);
2858 assert_eq!(
2859 deserialized.tokens,
2860 Some(vec![
2861 "foo".to_string(),
2862 "bar".to_string(),
2863 "baz".to_string()
2864 ])
2865 );
2866 }
2867
2868 #[test]
2869 fn test_sparse_vector_tokens_deserialize_with_tokens_field() {
2870 let json = r##"{"#type": "sparse_vector", "indices": [5, 10], "values": [0.5, 1.0], "tokens": ["token1", "token2"]}"##;
2872 let sv: SparseVector = serde_json::from_str(json).unwrap();
2873 assert_eq!(sv.indices, vec![5, 10]);
2874 assert_eq!(sv.values, vec![0.5, 1.0]);
2875 assert_eq!(
2876 sv.tokens,
2877 Some(vec!["token1".to_string(), "token2".to_string()])
2878 );
2879 }
2880
2881 #[test]
2882 fn test_sparse_vector_tokens_backward_compatibility() {
2883 let old_json = r#"{"indices": [1, 2], "values": [0.1, 0.2]}"#;
2885 let old_sv: SparseVector = serde_json::from_str(old_json).unwrap();
2886
2887 let new_json = r##"{"#type": "sparse_vector", "indices": [1, 2], "values": [0.1, 0.2], "tokens": ["a", "b"]}"##;
2889 let new_sv: SparseVector = serde_json::from_str(new_json).unwrap();
2890
2891 assert_eq!(old_sv.indices, new_sv.indices);
2893 assert_eq!(old_sv.values, new_sv.values);
2894
2895 assert_eq!(old_sv.tokens, None);
2897 assert_eq!(new_sv.tokens, Some(vec!["a".to_string(), "b".to_string()]));
2898 }
2899
2900 #[test]
2901 fn test_sparse_vector_from_triples_preserves_tokens() {
2902 let triples = vec![
2903 ("apple".to_string(), 10, 0.5),
2904 ("banana".to_string(), 20, 0.7),
2905 ("cherry".to_string(), 30, 0.9),
2906 ];
2907 let sv = SparseVector::from_triples(triples.clone());
2908
2909 assert_eq!(sv.indices, vec![10, 20, 30]);
2910 assert_eq!(sv.values, vec![0.5, 0.7, 0.9]);
2911 assert_eq!(
2912 sv.tokens,
2913 Some(vec![
2914 "apple".to_string(),
2915 "banana".to_string(),
2916 "cherry".to_string()
2917 ])
2918 );
2919
2920 let serialized = serde_json::to_string(&sv).unwrap();
2922 let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
2923
2924 assert_eq!(deserialized.indices, sv.indices);
2925 assert_eq!(deserialized.values, sv.values);
2926 assert_eq!(deserialized.tokens, sv.tokens);
2927 }
2928
2929 #[cfg(feature = "pyo3")]
2930 #[test]
2931 fn test_sparse_vector_pyo3_roundtrip_with_tokens() {
2932 ensure_python_interpreter();
2933
2934 pyo3::Python::with_gil(|py| {
2935 use pyo3::types::PyDict;
2936 use pyo3::IntoPyObject;
2937
2938 let dict_in = PyDict::new(py);
2939 dict_in.set_item("indices", vec![0u32, 1, 2]).unwrap();
2940 dict_in
2941 .set_item("values", vec![0.1f32, 0.2f32, 0.3f32])
2942 .unwrap();
2943 dict_in
2944 .set_item("tokens", vec!["foo", "bar", "baz"])
2945 .unwrap();
2946
2947 let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2948 assert_eq!(sparse.indices, vec![0, 1, 2]);
2949 assert_eq!(sparse.values, vec![0.1, 0.2, 0.3]);
2950 assert_eq!(
2951 sparse.tokens,
2952 Some(vec![
2953 "foo".to_string(),
2954 "bar".to_string(),
2955 "baz".to_string()
2956 ])
2957 );
2958
2959 let py_obj = sparse.clone().into_pyobject(py).unwrap();
2960 let dict_out = py_obj.downcast::<PyDict>().unwrap();
2961 let tokens_obj = dict_out.get_item("tokens").unwrap();
2962 let tokens: Vec<String> = tokens_obj
2963 .expect("expected tokens key in Python dict")
2964 .extract()
2965 .unwrap();
2966 assert_eq!(
2967 tokens,
2968 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
2969 );
2970 });
2971 }
2972
2973 #[cfg(feature = "pyo3")]
2974 #[test]
2975 fn test_sparse_vector_pyo3_roundtrip_without_tokens() {
2976 ensure_python_interpreter();
2977
2978 pyo3::Python::with_gil(|py| {
2979 use pyo3::types::PyDict;
2980 use pyo3::IntoPyObject;
2981
2982 let dict_in = PyDict::new(py);
2983 dict_in.set_item("indices", vec![5u32]).unwrap();
2984 dict_in.set_item("values", vec![1.5f32]).unwrap();
2985
2986 let sparse: SparseVector = dict_in.clone().into_any().extract().unwrap();
2987 assert_eq!(sparse.indices, vec![5]);
2988 assert_eq!(sparse.values, vec![1.5]);
2989 assert!(sparse.tokens.is_none());
2990
2991 let py_obj = sparse.into_pyobject(py).unwrap();
2992 let dict_out = py_obj.downcast::<PyDict>().unwrap();
2993 let tokens_obj = dict_out.get_item("tokens").unwrap();
2994 let tokens_value = tokens_obj.expect("expected tokens key in Python dict");
2995 assert!(
2996 tokens_value.is_none(),
2997 "expected tokens value in Python dict to be None"
2998 );
2999 });
3000 }
3001
3002 #[test]
3003 fn test_simplifies_identities() {
3004 let all: Where = true.into();
3005 assert_eq!(all.clone() & all.clone(), true.into());
3006 assert_eq!(all.clone() | all.clone(), true.into());
3007
3008 let foo = Key::field("foo").eq("bar");
3009 assert_eq!(foo.clone() & all.clone(), foo.clone());
3010 assert_eq!(all.clone() & foo.clone(), foo.clone());
3011
3012 let none: Where = false.into();
3013 assert_eq!(foo.clone() | none.clone(), foo.clone());
3014 assert_eq!(none | foo.clone(), foo);
3015 }
3016
3017 #[test]
3018 fn test_flattens() {
3019 let foo = Key::field("foo").eq("bar");
3020 let baz = Key::field("baz").eq("quux");
3021
3022 let and_nested = foo.clone() & (baz.clone() & foo.clone());
3023 assert_eq!(
3024 and_nested,
3025 Where::Composite(CompositeExpression {
3026 operator: BooleanOperator::And,
3027 children: vec![foo.clone(), baz.clone(), foo.clone()]
3028 })
3029 );
3030
3031 let or_nested = foo.clone() | (baz.clone() | foo.clone());
3032 assert_eq!(
3033 or_nested,
3034 Where::Composite(CompositeExpression {
3035 operator: BooleanOperator::Or,
3036 children: vec![foo.clone(), baz.clone(), foo.clone()]
3037 })
3038 );
3039 }
3040}