1use chroma_error::{ChromaError, ErrorCodes};
2use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
3use serde_json::{Number, Value};
4use sprs::CsVec;
5use std::{
6 cmp::Ordering,
7 collections::{HashMap, HashSet},
8 mem::size_of_val,
9 ops::{BitAnd, BitOr},
10};
11use thiserror::Error;
12
13use crate::chroma_proto;
14
15#[cfg(feature = "pyo3")]
16use pyo3::types::PyAnyMethods;
17
18#[cfg(feature = "testing")]
19use proptest::prelude::*;
20
21#[derive(Serialize, Deserialize)]
22struct SparseVectorSerdeHelper {
23 #[serde(rename = "#type")]
24 type_tag: Option<String>,
25 indices: Vec<u32>,
26 values: Vec<f32>,
27}
28
29#[derive(Clone, Debug, PartialEq)]
36#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
37pub struct SparseVector {
38 pub indices: Vec<u32>,
40 pub values: Vec<f32>,
42}
43
44impl<'de> Deserialize<'de> for SparseVector {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 let helper = SparseVectorSerdeHelper::deserialize(deserializer)?;
51
52 if let Some(type_tag) = &helper.type_tag {
54 if type_tag != "sparse_vector" {
55 return Err(serde::de::Error::custom(format!(
56 "Expected #type='sparse_vector', got '{}'",
57 type_tag
58 )));
59 }
60 }
61
62 Ok(SparseVector {
63 indices: helper.indices,
64 values: helper.values,
65 })
66 }
67}
68
69impl Serialize for SparseVector {
71 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72 where
73 S: Serializer,
74 {
75 let helper = SparseVectorSerdeHelper {
76 type_tag: Some("sparse_vector".to_string()),
77 indices: self.indices.clone(),
78 values: self.values.clone(),
79 };
80 helper.serialize(serializer)
81 }
82}
83
84impl SparseVector {
85 pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Self {
87 Self { indices, values }
88 }
89
90 pub fn from_pairs(pairs: impl IntoIterator<Item = (u32, f32)>) -> Self {
92 let (indices, values) = pairs.into_iter().unzip();
93 Self { indices, values }
94 }
95
96 pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
98 self.indices
99 .iter()
100 .copied()
101 .zip(self.values.iter().copied())
102 }
103
104 pub fn validate(&self) -> Result<(), MetadataValueConversionError> {
106 if self.indices.len() != self.values.len() {
108 return Err(MetadataValueConversionError::SparseVectorLengthMismatch);
109 }
110
111 for i in 1..self.indices.len() {
113 if self.indices[i] <= self.indices[i - 1] {
114 return Err(MetadataValueConversionError::SparseVectorIndicesNotSorted);
115 }
116 }
117
118 Ok(())
119 }
120}
121
122impl Eq for SparseVector {}
123
124impl Ord for SparseVector {
125 fn cmp(&self, other: &Self) -> Ordering {
126 self.indices.cmp(&other.indices).then_with(|| {
127 for (a, b) in self.values.iter().zip(other.values.iter()) {
128 match a.total_cmp(b) {
129 Ordering::Equal => continue,
130 other => return other,
131 }
132 }
133 self.values.len().cmp(&other.values.len())
134 })
135 }
136}
137
138impl PartialOrd for SparseVector {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144impl From<chroma_proto::SparseVector> for SparseVector {
145 fn from(proto: chroma_proto::SparseVector) -> Self {
146 SparseVector::new(proto.indices, proto.values)
147 }
148}
149
150impl From<SparseVector> for chroma_proto::SparseVector {
151 fn from(sparse: SparseVector) -> Self {
152 chroma_proto::SparseVector {
153 indices: sparse.indices,
154 values: sparse.values,
155 }
156 }
157}
158
159impl From<&SparseVector> for CsVec<f32> {
161 fn from(sparse: &SparseVector) -> Self {
162 let (indices, values) = sparse
163 .iter()
164 .map(|(index, value)| (index as usize, value))
165 .unzip();
166 CsVec::new(u32::MAX as usize, indices, values)
167 }
168}
169
170impl From<SparseVector> for CsVec<f32> {
171 fn from(sparse: SparseVector) -> Self {
172 (&sparse).into()
173 }
174}
175
176#[cfg(feature = "pyo3")]
177impl<'py> pyo3::IntoPyObject<'py> for SparseVector {
178 type Target = pyo3::PyAny;
179 type Output = pyo3::Bound<'py, Self::Target>;
180 type Error = pyo3::PyErr;
181
182 fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
183 use pyo3::types::PyDict;
184 let dict = PyDict::new(py);
185 dict.set_item("indices", self.indices)?;
186 dict.set_item("values", self.values)?;
187 Ok(dict.into_any())
188 }
189}
190
191#[cfg(feature = "pyo3")]
192impl<'py> pyo3::FromPyObject<'py> for SparseVector {
193 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
194 use pyo3::types::PyDict;
195
196 let dict = ob.downcast::<PyDict>()?;
197 let indices_obj = dict.get_item("indices")?;
198 let values_obj = dict.get_item("values")?;
199
200 let indices: Vec<u32> = indices_obj.extract()?;
201 let values: Vec<f32> = values_obj.extract()?;
202
203 Ok(SparseVector::new(indices, values))
204 }
205}
206
207#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
208#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
209#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
210#[serde(untagged)]
211pub enum UpdateMetadataValue {
212 Bool(bool),
213 Int(i64),
214 #[cfg_attr(
215 feature = "testing",
216 proptest(
217 strategy = "(-1e6..=1e6f32).prop_map(|v| UpdateMetadataValue::Float(v as f64)).boxed()"
218 )
219 )]
220 Float(f64),
221 Str(String),
222 #[cfg_attr(feature = "testing", proptest(skip))]
223 SparseVector(SparseVector),
224 None,
225}
226
227#[cfg(feature = "pyo3")]
228impl<'py> pyo3::FromPyObject<'py> for UpdateMetadataValue {
229 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
230 if ob.is_none() {
231 Ok(UpdateMetadataValue::None)
232 } else if let Ok(value) = ob.extract::<bool>() {
233 Ok(UpdateMetadataValue::Bool(value))
234 } else if let Ok(value) = ob.extract::<i64>() {
235 Ok(UpdateMetadataValue::Int(value))
236 } else if let Ok(value) = ob.extract::<f64>() {
237 Ok(UpdateMetadataValue::Float(value))
238 } else if let Ok(value) = ob.extract::<String>() {
239 Ok(UpdateMetadataValue::Str(value))
240 } else if let Ok(value) = ob.extract::<SparseVector>() {
241 Ok(UpdateMetadataValue::SparseVector(value))
242 } else {
243 Err(pyo3::exceptions::PyTypeError::new_err(
244 "Cannot convert Python object to UpdateMetadataValue",
245 ))
246 }
247 }
248}
249
250impl From<bool> for UpdateMetadataValue {
251 fn from(b: bool) -> Self {
252 Self::Bool(b)
253 }
254}
255
256impl From<i64> for UpdateMetadataValue {
257 fn from(v: i64) -> Self {
258 Self::Int(v)
259 }
260}
261
262impl From<i32> for UpdateMetadataValue {
263 fn from(v: i32) -> Self {
264 Self::Int(v as i64)
265 }
266}
267
268impl From<f64> for UpdateMetadataValue {
269 fn from(v: f64) -> Self {
270 Self::Float(v)
271 }
272}
273
274impl From<f32> for UpdateMetadataValue {
275 fn from(v: f32) -> Self {
276 Self::Float(v as f64)
277 }
278}
279
280impl From<String> for UpdateMetadataValue {
281 fn from(v: String) -> Self {
282 Self::Str(v)
283 }
284}
285
286impl From<&str> for UpdateMetadataValue {
287 fn from(v: &str) -> Self {
288 Self::Str(v.to_string())
289 }
290}
291
292impl From<SparseVector> for UpdateMetadataValue {
293 fn from(v: SparseVector) -> Self {
294 Self::SparseVector(v)
295 }
296}
297
298#[derive(Error, Debug)]
299pub enum UpdateMetadataValueConversionError {
300 #[error("Invalid metadata value, valid values are: Int, Float, Str, Bool, None")]
301 InvalidValue,
302}
303
304impl ChromaError for UpdateMetadataValueConversionError {
305 fn code(&self) -> ErrorCodes {
306 match self {
307 UpdateMetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
308 }
309 }
310}
311
312impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue {
313 type Error = UpdateMetadataValueConversionError;
314
315 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
316 match &value.value {
317 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
318 Ok(UpdateMetadataValue::Bool(*value))
319 }
320 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
321 Ok(UpdateMetadataValue::Int(*value))
322 }
323 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
324 Ok(UpdateMetadataValue::Float(*value))
325 }
326 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
327 Ok(UpdateMetadataValue::Str(value.clone()))
328 }
329 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
330 Ok(UpdateMetadataValue::SparseVector(value.clone().into()))
331 }
332 None => Ok(UpdateMetadataValue::None),
334 }
335 }
336}
337
338impl From<UpdateMetadataValue> for chroma_proto::UpdateMetadataValue {
339 fn from(value: UpdateMetadataValue) -> Self {
340 match value {
341 UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
342 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
343 },
344 UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
345 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
346 },
347 UpdateMetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
348 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
349 value,
350 )),
351 },
352 UpdateMetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
353 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
354 value,
355 )),
356 },
357 UpdateMetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
358 value: Some(
359 chroma_proto::update_metadata_value::Value::SparseVectorValue(
360 sparse_vec.into(),
361 ),
362 ),
363 },
364 UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None },
365 }
366 }
367}
368
369impl TryFrom<&UpdateMetadataValue> for MetadataValue {
370 type Error = MetadataValueConversionError;
371
372 fn try_from(value: &UpdateMetadataValue) -> Result<Self, Self::Error> {
373 match value {
374 UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)),
375 UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)),
376 UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)),
377 UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())),
378 UpdateMetadataValue::SparseVector(value) => {
379 Ok(MetadataValue::SparseVector(value.clone()))
380 }
381 UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue),
382 }
383 }
384}
385
386#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
393#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
394#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
395#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject, pyo3::IntoPyObject))]
396#[serde(untagged)]
397pub enum MetadataValue {
398 Bool(bool),
399 Int(i64),
400 #[cfg_attr(
401 feature = "testing",
402 proptest(
403 strategy = "(-1e6..=1e6f32).prop_map(|v| MetadataValue::Float(v as f64)).boxed()"
404 )
405 )]
406 Float(f64),
407 Str(String),
408 #[cfg_attr(feature = "testing", proptest(skip))]
409 SparseVector(SparseVector),
410}
411
412impl Eq for MetadataValue {}
413
414#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
415pub enum MetadataValueType {
416 Bool,
417 Int,
418 Float,
419 Str,
420 SparseVector,
421}
422
423impl MetadataValue {
424 pub fn value_type(&self) -> MetadataValueType {
425 match self {
426 MetadataValue::Bool(_) => MetadataValueType::Bool,
427 MetadataValue::Int(_) => MetadataValueType::Int,
428 MetadataValue::Float(_) => MetadataValueType::Float,
429 MetadataValue::Str(_) => MetadataValueType::Str,
430 MetadataValue::SparseVector(_) => MetadataValueType::SparseVector,
431 }
432 }
433}
434
435impl From<&MetadataValue> for MetadataValueType {
436 fn from(value: &MetadataValue) -> Self {
437 value.value_type()
438 }
439}
440
441impl From<bool> for MetadataValue {
442 fn from(v: bool) -> Self {
443 MetadataValue::Bool(v)
444 }
445}
446
447impl From<i64> for MetadataValue {
448 fn from(v: i64) -> Self {
449 MetadataValue::Int(v)
450 }
451}
452
453impl From<i32> for MetadataValue {
454 fn from(v: i32) -> Self {
455 MetadataValue::Int(v as i64)
456 }
457}
458
459impl From<f64> for MetadataValue {
460 fn from(v: f64) -> Self {
461 MetadataValue::Float(v)
462 }
463}
464
465impl From<f32> for MetadataValue {
466 fn from(v: f32) -> Self {
467 MetadataValue::Float(v as f64)
468 }
469}
470
471impl From<String> for MetadataValue {
472 fn from(v: String) -> Self {
473 MetadataValue::Str(v)
474 }
475}
476
477impl From<&str> for MetadataValue {
478 fn from(v: &str) -> Self {
479 MetadataValue::Str(v.to_string())
480 }
481}
482
483impl From<SparseVector> for MetadataValue {
484 fn from(v: SparseVector) -> Self {
485 MetadataValue::SparseVector(v)
486 }
487}
488
489#[allow(clippy::derive_ord_xor_partial_ord)]
494impl Ord for MetadataValue {
495 fn cmp(&self, other: &Self) -> Ordering {
496 fn type_order(val: &MetadataValue) -> u8 {
498 match val {
499 MetadataValue::Bool(_) => 0,
500 MetadataValue::Int(_) => 1,
501 MetadataValue::Float(_) => 2,
502 MetadataValue::Str(_) => 3,
503 MetadataValue::SparseVector(_) => 4,
504 }
505 }
506
507 type_order(self).cmp(&type_order(other)).then_with(|| {
509 match (self, other) {
510 (MetadataValue::Bool(left), MetadataValue::Bool(right)) => left.cmp(right),
511 (MetadataValue::Int(left), MetadataValue::Int(right)) => left.cmp(right),
512 (MetadataValue::Float(left), MetadataValue::Float(right)) => left.total_cmp(right),
513 (MetadataValue::Str(left), MetadataValue::Str(right)) => left.cmp(right),
514 (MetadataValue::SparseVector(left), MetadataValue::SparseVector(right)) => {
515 left.cmp(right)
516 }
517 _ => Ordering::Equal, }
519 })
520 }
521}
522
523impl PartialOrd for MetadataValue {
524 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
525 Some(self.cmp(other))
526 }
527}
528
529impl TryFrom<&MetadataValue> for bool {
530 type Error = MetadataValueConversionError;
531
532 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
533 match value {
534 MetadataValue::Bool(value) => Ok(*value),
535 _ => Err(MetadataValueConversionError::InvalidValue),
536 }
537 }
538}
539
540impl TryFrom<&MetadataValue> for i64 {
541 type Error = MetadataValueConversionError;
542
543 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
544 match value {
545 MetadataValue::Int(value) => Ok(*value),
546 _ => Err(MetadataValueConversionError::InvalidValue),
547 }
548 }
549}
550
551impl TryFrom<&MetadataValue> for f64 {
552 type Error = MetadataValueConversionError;
553
554 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
555 match value {
556 MetadataValue::Float(value) => Ok(*value),
557 _ => Err(MetadataValueConversionError::InvalidValue),
558 }
559 }
560}
561
562impl TryFrom<&MetadataValue> for String {
563 type Error = MetadataValueConversionError;
564
565 fn try_from(value: &MetadataValue) -> Result<Self, Self::Error> {
566 match value {
567 MetadataValue::Str(value) => Ok(value.clone()),
568 _ => Err(MetadataValueConversionError::InvalidValue),
569 }
570 }
571}
572
573impl From<MetadataValue> for UpdateMetadataValue {
574 fn from(value: MetadataValue) -> Self {
575 match value {
576 MetadataValue::Bool(v) => UpdateMetadataValue::Bool(v),
577 MetadataValue::Int(v) => UpdateMetadataValue::Int(v),
578 MetadataValue::Float(v) => UpdateMetadataValue::Float(v),
579 MetadataValue::Str(v) => UpdateMetadataValue::Str(v),
580 MetadataValue::SparseVector(v) => UpdateMetadataValue::SparseVector(v),
581 }
582 }
583}
584
585impl From<MetadataValue> for Value {
586 fn from(value: MetadataValue) -> Self {
587 match value {
588 MetadataValue::Bool(val) => Self::Bool(val),
589 MetadataValue::Int(val) => Self::Number(
590 Number::from_i128(val as i128).expect("i64 should be representable in JSON"),
591 ),
592 MetadataValue::Float(val) => Self::Number(
593 Number::from_f64(val).expect("Inf and NaN should not be present in MetadataValue"),
594 ),
595 MetadataValue::Str(val) => Self::String(val),
596 MetadataValue::SparseVector(val) => {
597 let mut map = serde_json::Map::new();
598 map.insert(
599 "indices".to_string(),
600 Value::Array(
601 val.indices
602 .iter()
603 .map(|&i| Value::Number(i.into()))
604 .collect(),
605 ),
606 );
607 map.insert(
608 "values".to_string(),
609 Value::Array(
610 val.values
611 .iter()
612 .map(|&v| {
613 Value::Number(
614 Number::from_f64(v as f64)
615 .expect("Float number should not be NaN or infinite"),
616 )
617 })
618 .collect(),
619 ),
620 );
621 Self::Object(map)
622 }
623 }
624 }
625}
626
627#[derive(Error, Debug)]
628pub enum MetadataValueConversionError {
629 #[error("Invalid metadata value, valid values are: Int, Float, Str")]
630 InvalidValue,
631 #[error("Metadata key cannot start with '#' or '$': {0}")]
632 InvalidKey(String),
633 #[error("Sparse vector indices and values must have the same length")]
634 SparseVectorLengthMismatch,
635 #[error("Sparse vector indices must be sorted in strictly ascending order (no duplicates)")]
636 SparseVectorIndicesNotSorted,
637}
638
639impl ChromaError for MetadataValueConversionError {
640 fn code(&self) -> ErrorCodes {
641 match self {
642 MetadataValueConversionError::InvalidValue => ErrorCodes::InvalidArgument,
643 MetadataValueConversionError::InvalidKey(_) => ErrorCodes::InvalidArgument,
644 MetadataValueConversionError::SparseVectorLengthMismatch => ErrorCodes::InvalidArgument,
645 MetadataValueConversionError::SparseVectorIndicesNotSorted => {
646 ErrorCodes::InvalidArgument
647 }
648 }
649 }
650}
651
652impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue {
653 type Error = MetadataValueConversionError;
654
655 fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result<Self, Self::Error> {
656 match &value.value {
657 Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => {
658 Ok(MetadataValue::Bool(*value))
659 }
660 Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => {
661 Ok(MetadataValue::Int(*value))
662 }
663 Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => {
664 Ok(MetadataValue::Float(*value))
665 }
666 Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => {
667 Ok(MetadataValue::Str(value.clone()))
668 }
669 Some(chroma_proto::update_metadata_value::Value::SparseVectorValue(value)) => {
670 Ok(MetadataValue::SparseVector(value.clone().into()))
671 }
672 _ => Err(MetadataValueConversionError::InvalidValue),
673 }
674 }
675}
676
677impl From<MetadataValue> for chroma_proto::UpdateMetadataValue {
678 fn from(value: MetadataValue) -> Self {
679 match value {
680 MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue {
681 value: Some(chroma_proto::update_metadata_value::Value::IntValue(value)),
682 },
683 MetadataValue::Float(value) => chroma_proto::UpdateMetadataValue {
684 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(
685 value,
686 )),
687 },
688 MetadataValue::Str(value) => chroma_proto::UpdateMetadataValue {
689 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
690 value,
691 )),
692 },
693 MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue {
694 value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)),
695 },
696 MetadataValue::SparseVector(sparse_vec) => chroma_proto::UpdateMetadataValue {
697 value: Some(
698 chroma_proto::update_metadata_value::Value::SparseVectorValue(
699 sparse_vec.into(),
700 ),
701 ),
702 },
703 }
704 }
705}
706
707pub type UpdateMetadata = HashMap<String, UpdateMetadataValue>;
713
714pub fn are_update_metadatas_close_to_equal(
718 metadata1: &UpdateMetadata,
719 metadata2: &UpdateMetadata,
720) -> bool {
721 assert_eq!(metadata1.len(), metadata2.len());
722
723 for (key, value) in metadata1.iter() {
724 if !metadata2.contains_key(key) {
725 return false;
726 }
727 let other_value = metadata2.get(key).unwrap();
728
729 if let (UpdateMetadataValue::Float(value), UpdateMetadataValue::Float(other_value)) =
730 (value, other_value)
731 {
732 if (value - other_value).abs() > 1e-6 {
733 return false;
734 }
735 } else if value != other_value {
736 return false;
737 }
738 }
739
740 true
741}
742
743pub fn are_metadatas_close_to_equal(metadata1: &Metadata, metadata2: &Metadata) -> bool {
744 assert_eq!(metadata1.len(), metadata2.len());
745
746 for (key, value) in metadata1.iter() {
747 if !metadata2.contains_key(key) {
748 return false;
749 }
750 let other_value = metadata2.get(key).unwrap();
751
752 if let (MetadataValue::Float(value), MetadataValue::Float(other_value)) =
753 (value, other_value)
754 {
755 if (value - other_value).abs() > 1e-6 {
756 return false;
757 }
758 } else if value != other_value {
759 return false;
760 }
761 }
762
763 true
764}
765
766impl TryFrom<chroma_proto::UpdateMetadata> for UpdateMetadata {
767 type Error = UpdateMetadataValueConversionError;
768
769 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
770 let mut metadata = UpdateMetadata::new();
771 for (key, value) in proto_metadata.metadata.iter() {
772 let value = match value.try_into() {
773 Ok(value) => value,
774 Err(_) => return Err(UpdateMetadataValueConversionError::InvalidValue),
775 };
776 metadata.insert(key.clone(), value);
777 }
778 Ok(metadata)
779 }
780}
781
782impl From<UpdateMetadata> for chroma_proto::UpdateMetadata {
783 fn from(metadata: UpdateMetadata) -> Self {
784 let mut metadata = metadata;
785 let mut proto_metadata = chroma_proto::UpdateMetadata {
786 metadata: HashMap::new(),
787 };
788 for (key, value) in metadata.drain() {
789 let proto_value = value.into();
790 proto_metadata.metadata.insert(key.clone(), proto_value);
791 }
792 proto_metadata
793 }
794}
795
796pub type Metadata = HashMap<String, MetadataValue>;
803pub type DeletedMetadata = HashSet<String>;
804
805pub fn logical_size_of_metadata(metadata: &Metadata) -> usize {
806 metadata
807 .iter()
808 .map(|(k, v)| {
809 k.len()
810 + match v {
811 MetadataValue::Bool(b) => size_of_val(b),
812 MetadataValue::Int(i) => size_of_val(i),
813 MetadataValue::Float(f) => size_of_val(f),
814 MetadataValue::Str(s) => s.len(),
815 MetadataValue::SparseVector(v) => {
816 size_of_val(&v.indices[..]) + size_of_val(&v.values[..])
817 }
818 }
819 })
820 .sum()
821}
822
823pub fn get_metadata_value_as<'a, T>(
824 metadata: &'a Metadata,
825 key: &str,
826) -> Result<T, Box<MetadataValueConversionError>>
827where
828 T: TryFrom<&'a MetadataValue, Error = MetadataValueConversionError>,
829{
830 let res = match metadata.get(key) {
831 Some(value) => T::try_from(value),
832 None => return Err(Box::new(MetadataValueConversionError::InvalidValue)),
833 };
834 match res {
835 Ok(value) => Ok(value),
836 Err(_) => Err(Box::new(MetadataValueConversionError::InvalidValue)),
837 }
838}
839
840impl TryFrom<chroma_proto::UpdateMetadata> for Metadata {
841 type Error = MetadataValueConversionError;
842
843 fn try_from(proto_metadata: chroma_proto::UpdateMetadata) -> Result<Self, Self::Error> {
844 let mut metadata = Metadata::new();
845 for (key, value) in proto_metadata.metadata.iter() {
846 let maybe_value: Result<MetadataValue, Self::Error> = value.try_into();
847 if maybe_value.is_err() {
848 return Err(MetadataValueConversionError::InvalidValue);
849 }
850 let value = maybe_value.unwrap();
851 metadata.insert(key.clone(), value);
852 }
853 Ok(metadata)
854 }
855}
856
857impl From<Metadata> for chroma_proto::UpdateMetadata {
858 fn from(metadata: Metadata) -> Self {
859 let mut metadata = metadata;
860 let mut proto_metadata = chroma_proto::UpdateMetadata {
861 metadata: HashMap::new(),
862 };
863 for (key, value) in metadata.drain() {
864 let proto_value = value.into();
865 proto_metadata.metadata.insert(key.clone(), proto_value);
866 }
867 proto_metadata
868 }
869}
870
871#[derive(Debug, Default)]
872pub struct MetadataDelta<'referred_data> {
873 pub metadata_to_update: HashMap<
874 &'referred_data str,
875 (&'referred_data MetadataValue, &'referred_data MetadataValue),
876 >,
877 pub metadata_to_delete: HashMap<&'referred_data str, &'referred_data MetadataValue>,
878 pub metadata_to_insert: HashMap<&'referred_data str, &'referred_data MetadataValue>,
879}
880
881impl MetadataDelta<'_> {
882 pub fn new() -> Self {
883 Self::default()
884 }
885}
886
887#[derive(Clone, Debug, Error, PartialEq)]
894pub enum WhereConversionError {
895 #[error("Error: {0}")]
896 Cause(String),
897 #[error("{0} -> {1}")]
898 Trace(String, Box<Self>),
899}
900
901impl WhereConversionError {
902 pub fn cause(msg: impl ToString) -> Self {
903 Self::Cause(msg.to_string())
904 }
905
906 pub fn trace(self, context: impl ToString) -> Self {
907 Self::Trace(context.to_string(), Box::new(self))
908 }
909}
910
911#[derive(Clone, Debug, PartialEq)]
919#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
920pub enum Where {
921 Composite(CompositeExpression),
922 Document(DocumentExpression),
923 Metadata(MetadataExpression),
924}
925
926impl serde::Serialize for Where {
927 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
928 where
929 S: Serializer,
930 {
931 match self {
932 Where::Composite(composite) => {
933 let mut map = serializer.serialize_map(Some(1))?;
934 let op_key = match composite.operator {
935 BooleanOperator::And => "$and",
936 BooleanOperator::Or => "$or",
937 };
938 map.serialize_entry(op_key, &composite.children)?;
939 map.end()
940 }
941 Where::Document(doc) => {
942 let mut outer_map = serializer.serialize_map(Some(1))?;
943 let mut inner_map = serde_json::Map::new();
944 let op_key = match doc.operator {
945 DocumentOperator::Contains => "$contains",
946 DocumentOperator::NotContains => "$not_contains",
947 DocumentOperator::Regex => "$regex",
948 DocumentOperator::NotRegex => "$not_regex",
949 };
950 inner_map.insert(
951 op_key.to_string(),
952 serde_json::Value::String(doc.pattern.clone()),
953 );
954 outer_map.serialize_entry("#document", &inner_map)?;
955 outer_map.end()
956 }
957 Where::Metadata(meta) => {
958 let mut outer_map = serializer.serialize_map(Some(1))?;
959 let mut inner_map = serde_json::Map::new();
960
961 match &meta.comparison {
962 MetadataComparison::Primitive(op, value) => {
963 let op_key = match op {
964 PrimitiveOperator::Equal => "$eq",
965 PrimitiveOperator::NotEqual => "$ne",
966 PrimitiveOperator::GreaterThan => "$gt",
967 PrimitiveOperator::GreaterThanOrEqual => "$gte",
968 PrimitiveOperator::LessThan => "$lt",
969 PrimitiveOperator::LessThanOrEqual => "$lte",
970 };
971 let value_json =
972 serde_json::to_value(value).map_err(serde::ser::Error::custom)?;
973 inner_map.insert(op_key.to_string(), value_json);
974 }
975 MetadataComparison::Set(op, set_value) => {
976 let op_key = match op {
977 SetOperator::In => "$in",
978 SetOperator::NotIn => "$nin",
979 };
980 let values_json = match set_value {
981 MetadataSetValue::Bool(v) => serde_json::to_value(v),
982 MetadataSetValue::Int(v) => serde_json::to_value(v),
983 MetadataSetValue::Float(v) => serde_json::to_value(v),
984 MetadataSetValue::Str(v) => serde_json::to_value(v),
985 }
986 .map_err(serde::ser::Error::custom)?;
987 inner_map.insert(op_key.to_string(), values_json);
988 }
989 }
990
991 outer_map.serialize_entry(&meta.key, &inner_map)?;
992 outer_map.end()
993 }
994 }
995 }
996}
997
998impl Where {
999 pub fn conjunction(children: Vec<Where>) -> Self {
1000 Self::Composite(CompositeExpression {
1001 operator: BooleanOperator::And,
1002 children,
1003 })
1004 }
1005 pub fn disjunction(children: Vec<Where>) -> Self {
1006 Self::Composite(CompositeExpression {
1007 operator: BooleanOperator::Or,
1008 children,
1009 })
1010 }
1011
1012 pub fn fts_query_length(&self) -> u64 {
1013 match self {
1014 Where::Composite(composite_expression) => composite_expression
1015 .children
1016 .iter()
1017 .map(Where::fts_query_length)
1018 .sum(),
1019 Where::Document(document_expression) => {
1021 document_expression.pattern.len().max(3) as u64 - 2
1022 }
1023 Where::Metadata(_) => 0,
1024 }
1025 }
1026
1027 pub fn metadata_predicate_count(&self) -> u64 {
1028 match self {
1029 Where::Composite(composite_expression) => composite_expression
1030 .children
1031 .iter()
1032 .map(Where::metadata_predicate_count)
1033 .sum(),
1034 Where::Document(_) => 0,
1035 Where::Metadata(metadata_expression) => match &metadata_expression.comparison {
1036 MetadataComparison::Primitive(_, _) => 1,
1037 MetadataComparison::Set(_, metadata_set_value) => match metadata_set_value {
1038 MetadataSetValue::Bool(items) => items.len() as u64,
1039 MetadataSetValue::Int(items) => items.len() as u64,
1040 MetadataSetValue::Float(items) => items.len() as u64,
1041 MetadataSetValue::Str(items) => items.len() as u64,
1042 },
1043 },
1044 }
1045 }
1046}
1047
1048impl BitAnd for Where {
1049 type Output = Where;
1050
1051 fn bitand(self, rhs: Self) -> Self::Output {
1052 match self {
1053 Where::Composite(CompositeExpression {
1054 operator: BooleanOperator::And,
1055 mut children,
1056 }) => match rhs {
1057 Where::Composite(CompositeExpression {
1058 operator: BooleanOperator::And,
1059 children: rhs_children,
1060 }) => {
1061 children.extend(rhs_children);
1062 Where::Composite(CompositeExpression {
1063 operator: BooleanOperator::And,
1064 children,
1065 })
1066 }
1067 _ => {
1068 children.push(rhs);
1069 Where::Composite(CompositeExpression {
1070 operator: BooleanOperator::And,
1071 children,
1072 })
1073 }
1074 },
1075 _ => match rhs {
1076 Where::Composite(CompositeExpression {
1077 operator: BooleanOperator::And,
1078 mut children,
1079 }) => {
1080 children.insert(0, self);
1081 Where::Composite(CompositeExpression {
1082 operator: BooleanOperator::And,
1083 children,
1084 })
1085 }
1086 _ => Where::conjunction(vec![self, rhs]),
1087 },
1088 }
1089 }
1090}
1091
1092impl BitOr for Where {
1093 type Output = Where;
1094
1095 fn bitor(self, rhs: Self) -> Self::Output {
1096 match self {
1097 Where::Composite(CompositeExpression {
1098 operator: BooleanOperator::Or,
1099 mut children,
1100 }) => match rhs {
1101 Where::Composite(CompositeExpression {
1102 operator: BooleanOperator::Or,
1103 children: rhs_children,
1104 }) => {
1105 children.extend(rhs_children);
1106 Where::Composite(CompositeExpression {
1107 operator: BooleanOperator::Or,
1108 children,
1109 })
1110 }
1111 _ => {
1112 children.push(rhs);
1113 Where::Composite(CompositeExpression {
1114 operator: BooleanOperator::Or,
1115 children,
1116 })
1117 }
1118 },
1119 _ => match rhs {
1120 Where::Composite(CompositeExpression {
1121 operator: BooleanOperator::Or,
1122 mut children,
1123 }) => {
1124 children.insert(0, self);
1125 Where::Composite(CompositeExpression {
1126 operator: BooleanOperator::Or,
1127 children,
1128 })
1129 }
1130 _ => Where::disjunction(vec![self, rhs]),
1131 },
1132 }
1133 }
1134}
1135
1136impl TryFrom<chroma_proto::Where> for Where {
1137 type Error = WhereConversionError;
1138
1139 fn try_from(proto_where: chroma_proto::Where) -> Result<Self, Self::Error> {
1140 let where_inner = proto_where
1141 .r#where
1142 .ok_or(WhereConversionError::cause("Invalid Where"))?;
1143 Ok(match where_inner {
1144 chroma_proto::r#where::Where::DirectComparison(direct_comparison) => {
1145 Self::Metadata(direct_comparison.try_into()?)
1146 }
1147 chroma_proto::r#where::Where::Children(where_children) => {
1148 Self::Composite(where_children.try_into()?)
1149 }
1150 chroma_proto::r#where::Where::DirectDocumentComparison(direct_where_document) => {
1151 Self::Document(direct_where_document.into())
1152 }
1153 })
1154 }
1155}
1156
1157impl TryFrom<Where> for chroma_proto::Where {
1158 type Error = WhereConversionError;
1159
1160 fn try_from(value: Where) -> Result<Self, Self::Error> {
1161 let proto_where = match value {
1162 Where::Composite(composite_expression) => {
1163 chroma_proto::r#where::Where::Children(composite_expression.try_into()?)
1164 }
1165 Where::Document(document_expression) => {
1166 chroma_proto::r#where::Where::DirectDocumentComparison(document_expression.into())
1167 }
1168 Where::Metadata(metadata_expression) => chroma_proto::r#where::Where::DirectComparison(
1169 chroma_proto::DirectComparison::try_from(metadata_expression)
1170 .map_err(|err| err.trace("MetadataExpression"))?,
1171 ),
1172 };
1173 Ok(Self {
1174 r#where: Some(proto_where),
1175 })
1176 }
1177}
1178
1179#[derive(Clone, Debug, PartialEq)]
1180#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1181pub struct CompositeExpression {
1182 pub operator: BooleanOperator,
1183 pub children: Vec<Where>,
1184}
1185
1186impl TryFrom<chroma_proto::WhereChildren> for CompositeExpression {
1187 type Error = WhereConversionError;
1188
1189 fn try_from(proto_children: chroma_proto::WhereChildren) -> Result<Self, Self::Error> {
1190 let operator = proto_children.operator().into();
1191 let children = proto_children
1192 .children
1193 .into_iter()
1194 .map(Where::try_from)
1195 .collect::<Result<Vec<_>, _>>()
1196 .map_err(|err| err.trace("Child Where of CompositeExpression"))?;
1197 Ok(Self { operator, children })
1198 }
1199}
1200
1201impl TryFrom<CompositeExpression> for chroma_proto::WhereChildren {
1202 type Error = WhereConversionError;
1203
1204 fn try_from(value: CompositeExpression) -> Result<Self, Self::Error> {
1205 Ok(Self {
1206 operator: chroma_proto::BooleanOperator::from(value.operator) as i32,
1207 children: value
1208 .children
1209 .into_iter()
1210 .map(chroma_proto::Where::try_from)
1211 .collect::<Result<_, _>>()?,
1212 })
1213 }
1214}
1215
1216#[derive(Clone, Debug, PartialEq)]
1217#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1218pub enum BooleanOperator {
1219 And,
1220 Or,
1221}
1222
1223impl From<chroma_proto::BooleanOperator> for BooleanOperator {
1224 fn from(value: chroma_proto::BooleanOperator) -> Self {
1225 match value {
1226 chroma_proto::BooleanOperator::And => Self::And,
1227 chroma_proto::BooleanOperator::Or => Self::Or,
1228 }
1229 }
1230}
1231
1232impl From<BooleanOperator> for chroma_proto::BooleanOperator {
1233 fn from(value: BooleanOperator) -> Self {
1234 match value {
1235 BooleanOperator::And => Self::And,
1236 BooleanOperator::Or => Self::Or,
1237 }
1238 }
1239}
1240
1241#[derive(Clone, Debug, PartialEq)]
1242#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1243pub struct DocumentExpression {
1244 pub operator: DocumentOperator,
1245 pub pattern: String,
1246}
1247
1248impl From<chroma_proto::DirectWhereDocument> for DocumentExpression {
1249 fn from(value: chroma_proto::DirectWhereDocument) -> Self {
1250 Self {
1251 operator: value.operator().into(),
1252 pattern: value.pattern,
1253 }
1254 }
1255}
1256
1257impl From<DocumentExpression> for chroma_proto::DirectWhereDocument {
1258 fn from(value: DocumentExpression) -> Self {
1259 Self {
1260 pattern: value.pattern,
1261 operator: chroma_proto::WhereDocumentOperator::from(value.operator) as i32,
1262 }
1263 }
1264}
1265
1266#[derive(Clone, Debug, PartialEq)]
1267#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1268pub enum DocumentOperator {
1269 Contains,
1270 NotContains,
1271 Regex,
1272 NotRegex,
1273}
1274impl From<chroma_proto::WhereDocumentOperator> for DocumentOperator {
1275 fn from(value: chroma_proto::WhereDocumentOperator) -> Self {
1276 match value {
1277 chroma_proto::WhereDocumentOperator::Contains => Self::Contains,
1278 chroma_proto::WhereDocumentOperator::NotContains => Self::NotContains,
1279 chroma_proto::WhereDocumentOperator::Regex => Self::Regex,
1280 chroma_proto::WhereDocumentOperator::NotRegex => Self::NotRegex,
1281 }
1282 }
1283}
1284
1285impl From<DocumentOperator> for chroma_proto::WhereDocumentOperator {
1286 fn from(value: DocumentOperator) -> Self {
1287 match value {
1288 DocumentOperator::Contains => Self::Contains,
1289 DocumentOperator::NotContains => Self::NotContains,
1290 DocumentOperator::Regex => Self::Regex,
1291 DocumentOperator::NotRegex => Self::NotRegex,
1292 }
1293 }
1294}
1295
1296#[derive(Clone, Debug, PartialEq)]
1297#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1298pub struct MetadataExpression {
1299 pub key: String,
1300 pub comparison: MetadataComparison,
1301}
1302
1303impl TryFrom<chroma_proto::DirectComparison> for MetadataExpression {
1304 type Error = WhereConversionError;
1305
1306 fn try_from(value: chroma_proto::DirectComparison) -> Result<Self, Self::Error> {
1307 let proto_comparison = value
1308 .comparison
1309 .ok_or(WhereConversionError::cause("Invalid MetadataExpression"))?;
1310 let comparison = match proto_comparison {
1311 chroma_proto::direct_comparison::Comparison::SingleStringOperand(
1312 single_string_comparison,
1313 ) => MetadataComparison::Primitive(
1314 single_string_comparison.comparator().into(),
1315 MetadataValue::Str(single_string_comparison.value),
1316 ),
1317 chroma_proto::direct_comparison::Comparison::StringListOperand(
1318 string_list_comparison,
1319 ) => MetadataComparison::Set(
1320 string_list_comparison.list_operator().into(),
1321 MetadataSetValue::Str(string_list_comparison.values),
1322 ),
1323 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1324 single_int_comparison,
1325 ) => MetadataComparison::Primitive(
1326 match single_int_comparison
1327 .comparator
1328 .ok_or(WhereConversionError::cause(
1329 "Invalid scalar integer operator",
1330 ))? {
1331 chroma_proto::single_int_comparison::Comparator::GenericComparator(op) => {
1332 chroma_proto::GenericComparator::try_from(op)
1333 .map_err(WhereConversionError::cause)?
1334 .into()
1335 }
1336 chroma_proto::single_int_comparison::Comparator::NumberComparator(op) => {
1337 chroma_proto::NumberComparator::try_from(op)
1338 .map_err(WhereConversionError::cause)?
1339 .into()
1340 }
1341 },
1342 MetadataValue::Int(single_int_comparison.value),
1343 ),
1344 chroma_proto::direct_comparison::Comparison::IntListOperand(int_list_comparison) => {
1345 MetadataComparison::Set(
1346 int_list_comparison.list_operator().into(),
1347 MetadataSetValue::Int(int_list_comparison.values),
1348 )
1349 }
1350 chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(
1351 single_double_comparison,
1352 ) => MetadataComparison::Primitive(
1353 match single_double_comparison
1354 .comparator
1355 .ok_or(WhereConversionError::cause("Invalid scalar float operator"))?
1356 {
1357 chroma_proto::single_double_comparison::Comparator::GenericComparator(op) => {
1358 chroma_proto::GenericComparator::try_from(op)
1359 .map_err(WhereConversionError::cause)?
1360 .into()
1361 }
1362 chroma_proto::single_double_comparison::Comparator::NumberComparator(op) => {
1363 chroma_proto::NumberComparator::try_from(op)
1364 .map_err(WhereConversionError::cause)?
1365 .into()
1366 }
1367 },
1368 MetadataValue::Float(single_double_comparison.value),
1369 ),
1370 chroma_proto::direct_comparison::Comparison::DoubleListOperand(
1371 double_list_comparison,
1372 ) => MetadataComparison::Set(
1373 double_list_comparison.list_operator().into(),
1374 MetadataSetValue::Float(double_list_comparison.values),
1375 ),
1376 chroma_proto::direct_comparison::Comparison::BoolListOperand(bool_list_comparison) => {
1377 MetadataComparison::Set(
1378 bool_list_comparison.list_operator().into(),
1379 MetadataSetValue::Bool(bool_list_comparison.values),
1380 )
1381 }
1382 chroma_proto::direct_comparison::Comparison::SingleBoolOperand(
1383 single_bool_comparison,
1384 ) => MetadataComparison::Primitive(
1385 single_bool_comparison.comparator().into(),
1386 MetadataValue::Bool(single_bool_comparison.value),
1387 ),
1388 };
1389 Ok(Self {
1390 key: value.key,
1391 comparison,
1392 })
1393 }
1394}
1395
1396impl TryFrom<MetadataExpression> for chroma_proto::DirectComparison {
1397 type Error = WhereConversionError;
1398
1399 fn try_from(value: MetadataExpression) -> Result<Self, Self::Error> {
1400 let comparison = match value.comparison {
1401 MetadataComparison::Primitive(primitive_operator, metadata_value) => match metadata_value {
1402 MetadataValue::Bool(value) => chroma_proto::direct_comparison::Comparison::SingleBoolOperand(chroma_proto::SingleBoolComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1403 MetadataValue::Int(value) => chroma_proto::direct_comparison::Comparison::SingleIntOperand(chroma_proto::SingleIntComparison { value, comparator: Some(match primitive_operator {
1404 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1405 numeric => chroma_proto::single_int_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1406 }),
1407 MetadataValue::Float(value) => chroma_proto::direct_comparison::Comparison::SingleDoubleOperand(chroma_proto::SingleDoubleComparison { value, comparator: Some(match primitive_operator {
1408 generic_operator @ PrimitiveOperator::Equal | generic_operator @ PrimitiveOperator::NotEqual => chroma_proto::single_double_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::try_from(generic_operator)? as i32),
1409 numeric => chroma_proto::single_double_comparison::Comparator::NumberComparator(chroma_proto::NumberComparator::try_from(numeric)? as i32) }),
1410 }),
1411 MetadataValue::Str(value) => chroma_proto::direct_comparison::Comparison::SingleStringOperand(chroma_proto::SingleStringComparison { value, comparator: chroma_proto::GenericComparator::try_from(primitive_operator)? as i32 }),
1412 MetadataValue::SparseVector(_) => return Err(WhereConversionError::Cause("Comparison with sparse vector is not supported".to_string())),
1413 },
1414 MetadataComparison::Set(set_operator, metadata_set_value) => match metadata_set_value {
1415 MetadataSetValue::Bool(vec) => chroma_proto::direct_comparison::Comparison::BoolListOperand(chroma_proto::BoolListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1416 MetadataSetValue::Int(vec) => chroma_proto::direct_comparison::Comparison::IntListOperand(chroma_proto::IntListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1417 MetadataSetValue::Float(vec) => chroma_proto::direct_comparison::Comparison::DoubleListOperand(chroma_proto::DoubleListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1418 MetadataSetValue::Str(vec) => chroma_proto::direct_comparison::Comparison::StringListOperand(chroma_proto::StringListComparison { values: vec, list_operator: chroma_proto::ListOperator::from(set_operator) as i32 }),
1419 },
1420 };
1421 Ok(Self {
1422 key: value.key,
1423 comparison: Some(comparison),
1424 })
1425 }
1426}
1427
1428#[derive(Clone, Debug, PartialEq)]
1429#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1430pub enum MetadataComparison {
1431 Primitive(PrimitiveOperator, MetadataValue),
1432 Set(SetOperator, MetadataSetValue),
1433}
1434
1435#[derive(Clone, Debug, PartialEq)]
1436#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1437pub enum PrimitiveOperator {
1438 Equal,
1439 NotEqual,
1440 GreaterThan,
1441 GreaterThanOrEqual,
1442 LessThan,
1443 LessThanOrEqual,
1444}
1445
1446impl From<chroma_proto::GenericComparator> for PrimitiveOperator {
1447 fn from(value: chroma_proto::GenericComparator) -> Self {
1448 match value {
1449 chroma_proto::GenericComparator::Eq => Self::Equal,
1450 chroma_proto::GenericComparator::Ne => Self::NotEqual,
1451 }
1452 }
1453}
1454
1455impl TryFrom<PrimitiveOperator> for chroma_proto::GenericComparator {
1456 type Error = WhereConversionError;
1457
1458 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1459 match value {
1460 PrimitiveOperator::Equal => Ok(Self::Eq),
1461 PrimitiveOperator::NotEqual => Ok(Self::Ne),
1462 op => Err(WhereConversionError::cause(format!("{op:?} ∉ [=, ≠]"))),
1463 }
1464 }
1465}
1466
1467impl From<chroma_proto::NumberComparator> for PrimitiveOperator {
1468 fn from(value: chroma_proto::NumberComparator) -> Self {
1469 match value {
1470 chroma_proto::NumberComparator::Gt => Self::GreaterThan,
1471 chroma_proto::NumberComparator::Gte => Self::GreaterThanOrEqual,
1472 chroma_proto::NumberComparator::Lt => Self::LessThan,
1473 chroma_proto::NumberComparator::Lte => Self::LessThanOrEqual,
1474 }
1475 }
1476}
1477
1478impl TryFrom<PrimitiveOperator> for chroma_proto::NumberComparator {
1479 type Error = WhereConversionError;
1480
1481 fn try_from(value: PrimitiveOperator) -> Result<Self, Self::Error> {
1482 match value {
1483 PrimitiveOperator::GreaterThan => Ok(Self::Gt),
1484 PrimitiveOperator::GreaterThanOrEqual => Ok(Self::Gte),
1485 PrimitiveOperator::LessThan => Ok(Self::Lt),
1486 PrimitiveOperator::LessThanOrEqual => Ok(Self::Lte),
1487 op => Err(WhereConversionError::cause(format!(
1488 "{op:?} ∉ [≤, <, >, ≥]"
1489 ))),
1490 }
1491 }
1492}
1493
1494#[derive(Clone, Debug, PartialEq)]
1495#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1496pub enum SetOperator {
1497 In,
1498 NotIn,
1499}
1500
1501impl From<chroma_proto::ListOperator> for SetOperator {
1502 fn from(value: chroma_proto::ListOperator) -> Self {
1503 match value {
1504 chroma_proto::ListOperator::In => Self::In,
1505 chroma_proto::ListOperator::Nin => Self::NotIn,
1506 }
1507 }
1508}
1509
1510impl From<SetOperator> for chroma_proto::ListOperator {
1511 fn from(value: SetOperator) -> Self {
1512 match value {
1513 SetOperator::In => Self::In,
1514 SetOperator::NotIn => Self::Nin,
1515 }
1516 }
1517}
1518
1519#[derive(Clone, Debug, PartialEq)]
1520#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
1521pub enum MetadataSetValue {
1522 Bool(Vec<bool>),
1523 Int(Vec<i64>),
1524 Float(Vec<f64>),
1525 Str(Vec<String>),
1526}
1527
1528impl MetadataSetValue {
1529 pub fn value_type(&self) -> MetadataValueType {
1530 match self {
1531 MetadataSetValue::Bool(_) => MetadataValueType::Bool,
1532 MetadataSetValue::Int(_) => MetadataValueType::Int,
1533 MetadataSetValue::Float(_) => MetadataValueType::Float,
1534 MetadataSetValue::Str(_) => MetadataValueType::Str,
1535 }
1536 }
1537}
1538
1539impl From<Vec<bool>> for MetadataSetValue {
1540 fn from(values: Vec<bool>) -> Self {
1541 MetadataSetValue::Bool(values)
1542 }
1543}
1544
1545impl From<Vec<i64>> for MetadataSetValue {
1546 fn from(values: Vec<i64>) -> Self {
1547 MetadataSetValue::Int(values)
1548 }
1549}
1550
1551impl From<Vec<i32>> for MetadataSetValue {
1552 fn from(values: Vec<i32>) -> Self {
1553 MetadataSetValue::Int(values.into_iter().map(|v| v as i64).collect())
1554 }
1555}
1556
1557impl From<Vec<f64>> for MetadataSetValue {
1558 fn from(values: Vec<f64>) -> Self {
1559 MetadataSetValue::Float(values)
1560 }
1561}
1562
1563impl From<Vec<f32>> for MetadataSetValue {
1564 fn from(values: Vec<f32>) -> Self {
1565 MetadataSetValue::Float(values.into_iter().map(|v| v as f64).collect())
1566 }
1567}
1568
1569impl From<Vec<String>> for MetadataSetValue {
1570 fn from(values: Vec<String>) -> Self {
1571 MetadataSetValue::Str(values)
1572 }
1573}
1574
1575impl From<Vec<&str>> for MetadataSetValue {
1576 fn from(values: Vec<&str>) -> Self {
1577 MetadataSetValue::Str(values.into_iter().map(|s| s.to_string()).collect())
1578 }
1579}
1580
1581impl TryFrom<chroma_proto::WhereDocument> for Where {
1583 type Error = WhereConversionError;
1584
1585 fn try_from(proto_document: chroma_proto::WhereDocument) -> Result<Self, Self::Error> {
1586 match proto_document.r#where_document {
1587 Some(chroma_proto::where_document::WhereDocument::Direct(proto_comparison)) => {
1588 let operator = match TryInto::<chroma_proto::WhereDocumentOperator>::try_into(
1589 proto_comparison.operator,
1590 ) {
1591 Ok(operator) => operator,
1592 Err(_) => {
1593 return Err(WhereConversionError::cause(
1594 "[Deprecated] Invalid where document operator",
1595 ))
1596 }
1597 };
1598 let comparison = DocumentExpression {
1599 pattern: proto_comparison.pattern,
1600 operator: operator.into(),
1601 };
1602 Ok(Where::Document(comparison))
1603 }
1604 Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => {
1605 let operator = match TryInto::<chroma_proto::BooleanOperator>::try_into(
1606 proto_children.operator,
1607 ) {
1608 Ok(operator) => operator,
1609 Err(_) => {
1610 return Err(WhereConversionError::cause(
1611 "[Deprecated] Invalid boolean operator",
1612 ))
1613 }
1614 };
1615 let children = CompositeExpression {
1616 children: proto_children
1617 .children
1618 .into_iter()
1619 .map(|child| child.try_into())
1620 .collect::<Result<_, _>>()?,
1621 operator: operator.into(),
1622 };
1623 Ok(Where::Composite(children))
1624 }
1625 None => Err(WhereConversionError::cause("[Deprecated] Invalid where")),
1626 }
1627 }
1628}
1629
1630#[cfg(test)]
1631mod tests {
1632 use super::*;
1633
1634 #[test]
1635 fn test_update_metadata_try_from() {
1636 let mut proto_metadata = chroma_proto::UpdateMetadata {
1637 metadata: HashMap::new(),
1638 };
1639 proto_metadata.metadata.insert(
1640 "foo".to_string(),
1641 chroma_proto::UpdateMetadataValue {
1642 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1643 },
1644 );
1645 proto_metadata.metadata.insert(
1646 "bar".to_string(),
1647 chroma_proto::UpdateMetadataValue {
1648 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1649 },
1650 );
1651 proto_metadata.metadata.insert(
1652 "baz".to_string(),
1653 chroma_proto::UpdateMetadataValue {
1654 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1655 "42".to_string(),
1656 )),
1657 },
1658 );
1659 proto_metadata.metadata.insert(
1661 "sparse".to_string(),
1662 chroma_proto::UpdateMetadataValue {
1663 value: Some(
1664 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1665 chroma_proto::SparseVector {
1666 indices: vec![0, 5, 10],
1667 values: vec![0.1, 0.5, 0.9],
1668 },
1669 ),
1670 ),
1671 },
1672 );
1673 let converted_metadata: UpdateMetadata = proto_metadata.try_into().unwrap();
1674 assert_eq!(converted_metadata.len(), 4);
1675 assert_eq!(
1676 converted_metadata.get("foo").unwrap(),
1677 &UpdateMetadataValue::Int(42)
1678 );
1679 assert_eq!(
1680 converted_metadata.get("bar").unwrap(),
1681 &UpdateMetadataValue::Float(42.0)
1682 );
1683 assert_eq!(
1684 converted_metadata.get("baz").unwrap(),
1685 &UpdateMetadataValue::Str("42".to_string())
1686 );
1687 assert_eq!(
1688 converted_metadata.get("sparse").unwrap(),
1689 &UpdateMetadataValue::SparseVector(SparseVector::new(
1690 vec![0, 5, 10],
1691 vec![0.1, 0.5, 0.9]
1692 ))
1693 );
1694 }
1695
1696 #[test]
1697 fn test_metadata_try_from() {
1698 let mut proto_metadata = chroma_proto::UpdateMetadata {
1699 metadata: HashMap::new(),
1700 };
1701 proto_metadata.metadata.insert(
1702 "foo".to_string(),
1703 chroma_proto::UpdateMetadataValue {
1704 value: Some(chroma_proto::update_metadata_value::Value::IntValue(42)),
1705 },
1706 );
1707 proto_metadata.metadata.insert(
1708 "bar".to_string(),
1709 chroma_proto::UpdateMetadataValue {
1710 value: Some(chroma_proto::update_metadata_value::Value::FloatValue(42.0)),
1711 },
1712 );
1713 proto_metadata.metadata.insert(
1714 "baz".to_string(),
1715 chroma_proto::UpdateMetadataValue {
1716 value: Some(chroma_proto::update_metadata_value::Value::StringValue(
1717 "42".to_string(),
1718 )),
1719 },
1720 );
1721 proto_metadata.metadata.insert(
1723 "sparse".to_string(),
1724 chroma_proto::UpdateMetadataValue {
1725 value: Some(
1726 chroma_proto::update_metadata_value::Value::SparseVectorValue(
1727 chroma_proto::SparseVector {
1728 indices: vec![1, 10, 100],
1729 values: vec![0.2, 0.4, 0.6],
1730 },
1731 ),
1732 ),
1733 },
1734 );
1735 let converted_metadata: Metadata = proto_metadata.try_into().unwrap();
1736 assert_eq!(converted_metadata.len(), 4);
1737 assert_eq!(
1738 converted_metadata.get("foo").unwrap(),
1739 &MetadataValue::Int(42)
1740 );
1741 assert_eq!(
1742 converted_metadata.get("bar").unwrap(),
1743 &MetadataValue::Float(42.0)
1744 );
1745 assert_eq!(
1746 converted_metadata.get("baz").unwrap(),
1747 &MetadataValue::Str("42".to_string())
1748 );
1749 assert_eq!(
1750 converted_metadata.get("sparse").unwrap(),
1751 &MetadataValue::SparseVector(SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]))
1752 );
1753 }
1754
1755 #[test]
1756 fn test_where_clause_simple_from() {
1757 let proto_where = chroma_proto::Where {
1758 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1759 chroma_proto::DirectComparison {
1760 key: "foo".to_string(),
1761 comparison: Some(
1762 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1763 chroma_proto::SingleIntComparison {
1764 value: 42,
1765 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1766 },
1767 ),
1768 ),
1769 },
1770 )),
1771 };
1772 let where_clause: Where = proto_where.try_into().unwrap();
1773 match where_clause {
1774 Where::Metadata(comparison) => {
1775 assert_eq!(comparison.key, "foo");
1776 match comparison.comparison {
1777 MetadataComparison::Primitive(_, value) => {
1778 assert_eq!(value, MetadataValue::Int(42));
1779 }
1780 _ => panic!("Invalid comparison type"),
1781 }
1782 }
1783 _ => panic!("Invalid where type"),
1784 }
1785 }
1786
1787 #[test]
1788 fn test_where_clause_with_children() {
1789 let proto_where = chroma_proto::Where {
1790 r#where: Some(chroma_proto::r#where::Where::Children(
1791 chroma_proto::WhereChildren {
1792 children: vec![
1793 chroma_proto::Where {
1794 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1795 chroma_proto::DirectComparison {
1796 key: "foo".to_string(),
1797 comparison: Some(
1798 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1799 chroma_proto::SingleIntComparison {
1800 value: 42,
1801 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1802 },
1803 ),
1804 ),
1805 },
1806 )),
1807 },
1808 chroma_proto::Where {
1809 r#where: Some(chroma_proto::r#where::Where::DirectComparison(
1810 chroma_proto::DirectComparison {
1811 key: "bar".to_string(),
1812 comparison: Some(
1813 chroma_proto::direct_comparison::Comparison::SingleIntOperand(
1814 chroma_proto::SingleIntComparison {
1815 value: 42,
1816 comparator: Some(chroma_proto::single_int_comparison::Comparator::GenericComparator(chroma_proto::GenericComparator::Eq as i32)),
1817 },
1818 ),
1819 ),
1820 },
1821 )),
1822 },
1823 ],
1824 operator: chroma_proto::BooleanOperator::And.into(),
1825 },
1826 )),
1827 };
1828 let where_clause: Where = proto_where.try_into().unwrap();
1829 match where_clause {
1830 Where::Composite(children) => {
1831 assert_eq!(children.children.len(), 2);
1832 assert_eq!(children.operator, BooleanOperator::And);
1833 }
1834 _ => panic!("Invalid where type"),
1835 }
1836 }
1837
1838 #[test]
1839 fn test_where_document_simple() {
1840 let proto_where = chroma_proto::WhereDocument {
1841 r#where_document: Some(chroma_proto::where_document::WhereDocument::Direct(
1842 chroma_proto::DirectWhereDocument {
1843 pattern: "foo".to_string(),
1844 operator: chroma_proto::WhereDocumentOperator::Contains.into(),
1845 },
1846 )),
1847 };
1848 let where_document: Where = proto_where.try_into().unwrap();
1849 match where_document {
1850 Where::Document(comparison) => {
1851 assert_eq!(comparison.pattern, "foo");
1852 assert_eq!(comparison.operator, DocumentOperator::Contains);
1853 }
1854 _ => panic!("Invalid where document type"),
1855 }
1856 }
1857
1858 #[test]
1859 fn test_where_document_with_children() {
1860 let proto_where = chroma_proto::WhereDocument {
1861 r#where_document: Some(chroma_proto::where_document::WhereDocument::Children(
1862 chroma_proto::WhereDocumentChildren {
1863 children: vec![
1864 chroma_proto::WhereDocument {
1865 r#where_document: Some(
1866 chroma_proto::where_document::WhereDocument::Direct(
1867 chroma_proto::DirectWhereDocument {
1868 pattern: "foo".to_string(),
1869 operator: chroma_proto::WhereDocumentOperator::Contains
1870 .into(),
1871 },
1872 ),
1873 ),
1874 },
1875 chroma_proto::WhereDocument {
1876 r#where_document: Some(
1877 chroma_proto::where_document::WhereDocument::Direct(
1878 chroma_proto::DirectWhereDocument {
1879 pattern: "bar".to_string(),
1880 operator: chroma_proto::WhereDocumentOperator::Contains
1881 .into(),
1882 },
1883 ),
1884 ),
1885 },
1886 ],
1887 operator: chroma_proto::BooleanOperator::And.into(),
1888 },
1889 )),
1890 };
1891 let where_document: Where = proto_where.try_into().unwrap();
1892 match where_document {
1893 Where::Composite(children) => {
1894 assert_eq!(children.children.len(), 2);
1895 assert_eq!(children.operator, BooleanOperator::And);
1896 }
1897 _ => panic!("Invalid where document type"),
1898 }
1899 }
1900
1901 #[test]
1902 fn test_sparse_vector_new() {
1903 let indices = vec![0, 5, 10];
1904 let values = vec![0.1, 0.5, 0.9];
1905 let sparse = SparseVector::new(indices.clone(), values.clone());
1906 assert_eq!(sparse.indices, indices);
1907 assert_eq!(sparse.values, values);
1908 }
1909
1910 #[test]
1911 fn test_sparse_vector_from_pairs() {
1912 let pairs = vec![(0, 0.1), (5, 0.5), (10, 0.9)];
1913 let sparse = SparseVector::from_pairs(pairs.clone());
1914 assert_eq!(sparse.indices, vec![0, 5, 10]);
1915 assert_eq!(sparse.values, vec![0.1, 0.5, 0.9]);
1916 }
1917
1918 #[test]
1919 fn test_sparse_vector_iter() {
1920 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
1921 let collected: Vec<(u32, f32)> = sparse.iter().collect();
1922 assert_eq!(collected, vec![(0, 0.1), (5, 0.5), (10, 0.9)]);
1923 }
1924
1925 #[test]
1926 fn test_sparse_vector_ordering() {
1927 let sparse1 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1928 let sparse2 = SparseVector::new(vec![0, 5], vec![0.1, 0.5]);
1929 let sparse3 = SparseVector::new(vec![0, 6], vec![0.1, 0.5]);
1930 let sparse4 = SparseVector::new(vec![0, 5], vec![0.1, 0.6]);
1931
1932 assert_eq!(sparse1, sparse2);
1933 assert!(sparse1 < sparse3);
1934 assert!(sparse1 < sparse4);
1935 }
1936
1937 #[test]
1938 fn test_sparse_vector_proto_conversion() {
1939 let sparse = SparseVector::new(vec![1, 10, 100], vec![0.2, 0.4, 0.6]);
1940 let proto: chroma_proto::SparseVector = sparse.clone().into();
1941 assert_eq!(proto.indices, vec![1, 10, 100]);
1942 assert_eq!(proto.values, vec![0.2, 0.4, 0.6]);
1943
1944 let converted: SparseVector = proto.into();
1945 assert_eq!(converted, sparse);
1946 }
1947
1948 #[test]
1949 fn test_sparse_vector_logical_size() {
1950 let metadata = Metadata::from([(
1951 "sparse".to_string(),
1952 MetadataValue::SparseVector(SparseVector::new(
1953 vec![0, 1, 2, 3, 4],
1954 vec![0.1, 0.2, 0.3, 0.4, 0.5],
1955 )),
1956 )]);
1957
1958 let size = logical_size_of_metadata(&metadata);
1959 assert_eq!(size, 46);
1962 }
1963
1964 #[test]
1965 fn test_sparse_vector_validation() {
1966 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]);
1968 assert!(sparse.validate().is_ok());
1969
1970 let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2]);
1972 let result = sparse.validate();
1973 assert!(result.is_err());
1974 assert!(matches!(
1975 result.unwrap_err(),
1976 MetadataValueConversionError::SparseVectorLengthMismatch
1977 ));
1978
1979 let sparse = SparseVector::new(vec![3, 1, 2], vec![0.3, 0.1, 0.2]);
1981 let result = sparse.validate();
1982 assert!(result.is_err());
1983 assert!(matches!(
1984 result.unwrap_err(),
1985 MetadataValueConversionError::SparseVectorIndicesNotSorted
1986 ));
1987
1988 let sparse = SparseVector::new(vec![1, 2, 2, 3], vec![0.1, 0.2, 0.3, 0.4]);
1990 let result = sparse.validate();
1991 assert!(result.is_err());
1992 assert!(matches!(
1993 result.unwrap_err(),
1994 MetadataValueConversionError::SparseVectorIndicesNotSorted
1995 ));
1996
1997 let sparse = SparseVector::new(vec![1, 3, 2], vec![0.1, 0.3, 0.2]);
1999 let result = sparse.validate();
2000 assert!(result.is_err());
2001 assert!(matches!(
2002 result.unwrap_err(),
2003 MetadataValueConversionError::SparseVectorIndicesNotSorted
2004 ));
2005 }
2006
2007 #[test]
2008 fn test_sparse_vector_deserialize_old_format() {
2009 let json = r#"{"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]}"#;
2011 let sv: SparseVector = serde_json::from_str(json).unwrap();
2012 assert_eq!(sv.indices, vec![0, 1, 2]);
2013 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2014 }
2015
2016 #[test]
2017 fn test_sparse_vector_deserialize_new_format() {
2018 let json =
2020 "{\"#type\": \"sparse_vector\", \"indices\": [0, 1, 2], \"values\": [1.0, 2.0, 3.0]}";
2021 let sv: SparseVector = serde_json::from_str(json).unwrap();
2022 assert_eq!(sv.indices, vec![0, 1, 2]);
2023 assert_eq!(sv.values, vec![1.0, 2.0, 3.0]);
2024 }
2025
2026 #[test]
2027 fn test_sparse_vector_deserialize_new_format_field_order() {
2028 let json = "{\"indices\": [5, 10], \"#type\": \"sparse_vector\", \"values\": [0.5, 1.0]}";
2030 let sv: SparseVector = serde_json::from_str(json).unwrap();
2031 assert_eq!(sv.indices, vec![5, 10]);
2032 assert_eq!(sv.values, vec![0.5, 1.0]);
2033 }
2034
2035 #[test]
2036 fn test_sparse_vector_deserialize_wrong_type_tag() {
2037 let json = "{\"#type\": \"dense_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}";
2039 let result: Result<SparseVector, _> = serde_json::from_str(json);
2040 assert!(result.is_err());
2041 let err_msg = result.unwrap_err().to_string();
2042 assert!(err_msg.contains("sparse_vector"));
2043 }
2044
2045 #[test]
2046 fn test_sparse_vector_serialize_always_has_type() {
2047 let sv = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
2049 let json = serde_json::to_value(&sv).unwrap();
2050
2051 assert_eq!(json["#type"], "sparse_vector");
2052 assert_eq!(json["indices"], serde_json::json!([0, 1, 2]));
2053 assert_eq!(json["values"], serde_json::json!([1.0, 2.0, 3.0]));
2054 }
2055
2056 #[test]
2057 fn test_sparse_vector_roundtrip_with_type() {
2058 let original = SparseVector::new(vec![0, 5, 10, 15], vec![0.1, 0.5, 1.0, 1.5]);
2060 let json = serde_json::to_string(&original).unwrap();
2061
2062 assert!(json.contains("\"#type\":\"sparse_vector\""));
2064
2065 let deserialized: SparseVector = serde_json::from_str(&json).unwrap();
2066 assert_eq!(original, deserialized);
2067 }
2068
2069 #[test]
2070 fn test_sparse_vector_in_metadata_old_format() {
2071 let json = r#"{"key": "value", "sparse": {"indices": [0, 1], "values": [1.0, 2.0]}}"#;
2073 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2074
2075 let sparse_value = &map["sparse"];
2076 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2077 assert_eq!(sv.indices, vec![0, 1]);
2078 assert_eq!(sv.values, vec![1.0, 2.0]);
2079 }
2080
2081 #[test]
2082 fn test_sparse_vector_in_metadata_new_format() {
2083 let json = "{\"key\": \"value\", \"sparse\": {\"#type\": \"sparse_vector\", \"indices\": [0, 1], \"values\": [1.0, 2.0]}}";
2085 let map: HashMap<String, serde_json::Value> = serde_json::from_str(json).unwrap();
2086
2087 let sparse_value = &map["sparse"];
2088 let sv: SparseVector = serde_json::from_value(sparse_value.clone()).unwrap();
2089 assert_eq!(sv.indices, vec![0, 1]);
2090 assert_eq!(sv.values, vec![1.0, 2.0]);
2091 }
2092}