1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::Hash,
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15 MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16 Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
127pub struct Filter {
128 pub query_ids: Option<Vec<String>>,
129 pub where_clause: Option<Where>,
130}
131
132impl Serialize for Filter {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: Serializer,
136 {
137 match (&self.query_ids, &self.where_clause) {
141 (None, None) => {
142 let map = serializer.serialize_map(Some(0))?;
144 map.end()
145 }
146 (None, Some(where_clause)) => {
147 where_clause.serialize(serializer)
149 }
150 (Some(ids), None) => {
151 let id_where = Where::Metadata(MetadataExpression {
153 key: "#id".to_string(),
154 comparison: MetadataComparison::Set(
155 SetOperator::In,
156 MetadataSetValue::Str(ids.clone()),
157 ),
158 });
159 id_where.serialize(serializer)
160 }
161 (Some(ids), Some(where_clause)) => {
162 let id_where = Where::Metadata(MetadataExpression {
164 key: "#id".to_string(),
165 comparison: MetadataComparison::Set(
166 SetOperator::In,
167 MetadataSetValue::Str(ids.clone()),
168 ),
169 });
170 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
171 combined.serialize(serializer)
172 }
173 }
174 }
175}
176
177impl<'de> Deserialize<'de> for Filter {
178 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
179 where
180 D: Deserializer<'de>,
181 {
182 let where_json = Value::deserialize(deserializer)?;
184 let where_clause =
185 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
186 None
187 } else {
188 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
189 };
190
191 Ok(Filter {
192 query_ids: None, where_clause,
194 })
195 }
196}
197
198impl TryFrom<chroma_proto::FilterOperator> for Filter {
199 type Error = QueryConversionError;
200
201 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
202 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
203 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
204 let where_clause = match (where_metadata, where_document) {
205 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
206 (Some(w), None) | (None, Some(w)) => Some(w),
207 _ => None,
208 };
209
210 Ok(Self {
211 query_ids: value.ids.map(|uids| uids.ids),
212 where_clause,
213 })
214 }
215}
216
217impl TryFrom<Filter> for chroma_proto::FilterOperator {
218 type Error = QueryConversionError;
219
220 fn try_from(value: Filter) -> Result<Self, Self::Error> {
221 Ok(Self {
222 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
223 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
224 where_document: None,
225 })
226 }
227}
228
229#[derive(Clone, Debug)]
235pub struct Knn {
236 pub embedding: Vec<f32>,
237 pub fetch: u32,
238}
239
240impl From<KnnBatch> for Vec<Knn> {
241 fn from(value: KnnBatch) -> Self {
242 value
243 .embeddings
244 .into_iter()
245 .map(|embedding| Knn {
246 embedding,
247 fetch: value.fetch,
248 })
249 .collect()
250 }
251}
252
253#[derive(Clone, Debug)]
259pub struct KnnBatch {
260 pub embeddings: Vec<Vec<f32>>,
261 pub fetch: u32,
262}
263
264impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
265 type Error = QueryConversionError;
266
267 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
268 Ok(Self {
269 embeddings: value
270 .embeddings
271 .into_iter()
272 .map(|vec| vec.try_into().map(|(v, _)| v))
273 .collect::<Result<_, _>>()?,
274 fetch: value.fetch,
275 })
276 }
277}
278
279impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
280 type Error = QueryConversionError;
281
282 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
283 Ok(Self {
284 embeddings: value
285 .embeddings
286 .into_iter()
287 .map(|embedding| {
288 let dim = embedding.len();
289 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
290 })
291 .collect::<Result<_, _>>()?,
292 fetch: value.fetch,
293 })
294 }
295}
296
297#[derive(Clone, Debug, Default, Deserialize, Serialize)]
303pub struct Limit {
304 #[serde(default)]
305 pub offset: u32,
306 #[serde(default)]
307 pub limit: Option<u32>,
308}
309
310impl From<chroma_proto::LimitOperator> for Limit {
311 fn from(value: chroma_proto::LimitOperator) -> Self {
312 Self {
313 offset: value.offset,
314 limit: value.limit,
315 }
316 }
317}
318
319impl From<Limit> for chroma_proto::LimitOperator {
320 fn from(value: Limit) -> Self {
321 Self {
322 offset: value.offset,
323 limit: value.limit,
324 }
325 }
326}
327
328#[derive(Clone, Debug)]
330pub struct RecordMeasure {
331 pub offset_id: u32,
332 pub measure: f32,
333}
334
335impl PartialEq for RecordMeasure {
336 fn eq(&self, other: &Self) -> bool {
337 self.offset_id.eq(&other.offset_id)
338 }
339}
340
341impl Eq for RecordMeasure {}
342
343impl Ord for RecordMeasure {
344 fn cmp(&self, other: &Self) -> Ordering {
345 self.measure.total_cmp(&other.measure)
346 }
347}
348
349impl PartialOrd for RecordMeasure {
350 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
351 Some(self.cmp(other))
352 }
353}
354
355#[derive(Debug, Default)]
356pub struct KnnOutput {
357 pub distances: Vec<RecordMeasure>,
358}
359
360#[derive(Clone, Debug)]
370pub struct Merge {
371 pub k: u32,
372}
373
374impl Merge {
375 pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
376 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
377
378 let mut max_heap = batch_iters
379 .iter_mut()
380 .enumerate()
381 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
382 .collect::<BinaryHeap<_>>();
383
384 let mut fusion = Vec::with_capacity(self.k as usize);
385 while let Some((m, idx)) = max_heap.pop() {
386 if self.k <= fusion.len() as u32 {
387 break;
388 }
389 if let Some(next_m) = batch_iters[idx].next() {
390 max_heap.push((next_m, idx));
391 }
392 if fusion.last().is_some_and(|tail| tail == &m) {
393 continue;
394 }
395 fusion.push(m);
396 }
397 fusion
398 }
399}
400
401#[derive(Clone, Debug, Default)]
408pub struct Projection {
409 pub document: bool,
410 pub embedding: bool,
411 pub metadata: bool,
412}
413
414impl From<chroma_proto::ProjectionOperator> for Projection {
415 fn from(value: chroma_proto::ProjectionOperator) -> Self {
416 Self {
417 document: value.document,
418 embedding: value.embedding,
419 metadata: value.metadata,
420 }
421 }
422}
423
424impl From<Projection> for chroma_proto::ProjectionOperator {
425 fn from(value: Projection) -> Self {
426 Self {
427 document: value.document,
428 embedding: value.embedding,
429 metadata: value.metadata,
430 }
431 }
432}
433
434#[derive(Clone, Debug, PartialEq)]
435pub struct ProjectionRecord {
436 pub id: String,
437 pub document: Option<String>,
438 pub embedding: Option<Vec<f32>>,
439 pub metadata: Option<Metadata>,
440}
441
442impl ProjectionRecord {
443 pub fn size_bytes(&self) -> u64 {
444 (self.id.len()
445 + self
446 .document
447 .as_ref()
448 .map(|doc| doc.len())
449 .unwrap_or_default()
450 + self
451 .embedding
452 .as_ref()
453 .map(|emb| size_of_val(&emb[..]))
454 .unwrap_or_default()
455 + self
456 .metadata
457 .as_ref()
458 .map(logical_size_of_metadata)
459 .unwrap_or_default()) as u64
460 }
461}
462
463impl Eq for ProjectionRecord {}
464
465impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
466 type Error = QueryConversionError;
467
468 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
469 Ok(Self {
470 id: value.id,
471 document: value.document,
472 embedding: value
473 .embedding
474 .map(|vec| vec.try_into().map(|(v, _)| v))
475 .transpose()?,
476 metadata: value.metadata.map(TryInto::try_into).transpose()?,
477 })
478 }
479}
480
481impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
482 type Error = QueryConversionError;
483
484 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
485 Ok(Self {
486 id: value.id,
487 document: value.document,
488 embedding: value
489 .embedding
490 .map(|embedding| {
491 let embedding_dimension = embedding.len();
492 chroma_proto::Vector::try_from((
493 embedding,
494 ScalarEncoding::FLOAT32,
495 embedding_dimension,
496 ))
497 })
498 .transpose()?,
499 metadata: value.metadata.map(|metadata| metadata.into()),
500 })
501 }
502}
503
504#[derive(Clone, Debug, Eq, PartialEq)]
505pub struct ProjectionOutput {
506 pub records: Vec<ProjectionRecord>,
507}
508
509#[derive(Clone, Debug, Eq, PartialEq)]
510pub struct GetResult {
511 pub pulled_log_bytes: u64,
512 pub result: ProjectionOutput,
513}
514
515impl GetResult {
516 pub fn size_bytes(&self) -> u64 {
517 self.result
518 .records
519 .iter()
520 .map(ProjectionRecord::size_bytes)
521 .sum()
522 }
523}
524
525impl TryFrom<chroma_proto::GetResult> for GetResult {
526 type Error = QueryConversionError;
527
528 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
529 Ok(Self {
530 pulled_log_bytes: value.pulled_log_bytes,
531 result: ProjectionOutput {
532 records: value
533 .records
534 .into_iter()
535 .map(TryInto::try_into)
536 .collect::<Result<_, _>>()?,
537 },
538 })
539 }
540}
541
542impl TryFrom<GetResult> for chroma_proto::GetResult {
543 type Error = QueryConversionError;
544
545 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
546 Ok(Self {
547 pulled_log_bytes: value.pulled_log_bytes,
548 records: value
549 .result
550 .records
551 .into_iter()
552 .map(TryInto::try_into)
553 .collect::<Result<_, _>>()?,
554 })
555 }
556}
557
558#[derive(Clone, Debug)]
566pub struct KnnProjection {
567 pub projection: Projection,
568 pub distance: bool,
569}
570
571impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
572 type Error = QueryConversionError;
573
574 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
575 Ok(Self {
576 projection: value
577 .projection
578 .ok_or(QueryConversionError::field("projection"))?
579 .into(),
580 distance: value.distance,
581 })
582 }
583}
584
585impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
586 fn from(value: KnnProjection) -> Self {
587 Self {
588 projection: Some(value.projection.into()),
589 distance: value.distance,
590 }
591 }
592}
593
594#[derive(Clone, Debug)]
595pub struct KnnProjectionRecord {
596 pub record: ProjectionRecord,
597 pub distance: Option<f32>,
598}
599
600impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
601 type Error = QueryConversionError;
602
603 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
604 Ok(Self {
605 record: value
606 .record
607 .ok_or(QueryConversionError::field("record"))?
608 .try_into()?,
609 distance: value.distance,
610 })
611 }
612}
613
614impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
615 type Error = QueryConversionError;
616
617 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
618 Ok(Self {
619 record: Some(value.record.try_into()?),
620 distance: value.distance,
621 })
622 }
623}
624
625#[derive(Clone, Debug, Default)]
626pub struct KnnProjectionOutput {
627 pub records: Vec<KnnProjectionRecord>,
628}
629
630impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
631 type Error = QueryConversionError;
632
633 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
634 Ok(Self {
635 records: value
636 .records
637 .into_iter()
638 .map(TryInto::try_into)
639 .collect::<Result<_, _>>()?,
640 })
641 }
642}
643
644impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
645 type Error = QueryConversionError;
646
647 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
648 Ok(Self {
649 records: value
650 .records
651 .into_iter()
652 .map(TryInto::try_into)
653 .collect::<Result<_, _>>()?,
654 })
655 }
656}
657
658#[derive(Clone, Debug, Default)]
659pub struct KnnBatchResult {
660 pub pulled_log_bytes: u64,
661 pub results: Vec<KnnProjectionOutput>,
662}
663
664impl KnnBatchResult {
665 pub fn size_bytes(&self) -> u64 {
666 self.results
667 .iter()
668 .flat_map(|res| {
669 res.records
670 .iter()
671 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
672 })
673 .sum()
674 }
675}
676
677impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
678 type Error = QueryConversionError;
679
680 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
681 Ok(Self {
682 pulled_log_bytes: value.pulled_log_bytes,
683 results: value
684 .results
685 .into_iter()
686 .map(TryInto::try_into)
687 .collect::<Result<_, _>>()?,
688 })
689 }
690}
691
692impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
693 type Error = QueryConversionError;
694
695 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
696 Ok(Self {
697 pulled_log_bytes: value.pulled_log_bytes,
698 results: value
699 .results
700 .into_iter()
701 .map(TryInto::try_into)
702 .collect::<Result<_, _>>()?,
703 })
704 }
705}
706
707#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
708#[serde(untagged)]
709pub enum QueryVector {
710 Dense(Vec<f32>),
711 Sparse(SparseVector),
712}
713
714impl TryFrom<chroma_proto::QueryVector> for QueryVector {
715 type Error = QueryConversionError;
716
717 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
718 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
719 match vector {
720 chroma_proto::query_vector::Vector::Dense(dense) => {
721 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
722 }
723 chroma_proto::query_vector::Vector::Sparse(sparse) => {
724 Ok(QueryVector::Sparse(sparse.into()))
725 }
726 }
727 }
728}
729
730impl TryFrom<QueryVector> for chroma_proto::QueryVector {
731 type Error = QueryConversionError;
732
733 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
734 match value {
735 QueryVector::Dense(vec) => {
736 let dim = vec.len();
737 Ok(chroma_proto::QueryVector {
738 vector: Some(chroma_proto::query_vector::Vector::Dense(
739 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
740 )),
741 })
742 }
743 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
744 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
745 }),
746 }
747 }
748}
749
750impl From<Vec<f32>> for QueryVector {
751 fn from(vec: Vec<f32>) -> Self {
752 QueryVector::Dense(vec)
753 }
754}
755
756impl From<SparseVector> for QueryVector {
757 fn from(sparse: SparseVector) -> Self {
758 QueryVector::Sparse(sparse)
759 }
760}
761
762#[derive(Clone, Debug, PartialEq)]
763pub struct KnnQuery {
764 pub query: QueryVector,
765 pub key: Key,
766 pub limit: u32,
767}
768
769#[derive(Clone, Debug, Default, Deserialize, Serialize)]
770#[serde(transparent)]
771pub struct Rank {
772 pub expr: Option<RankExpr>,
773}
774
775impl Rank {
776 pub fn knn_queries(&self) -> Vec<KnnQuery> {
777 self.expr
778 .as_ref()
779 .map(RankExpr::knn_queries)
780 .unwrap_or_default()
781 }
782}
783
784impl TryFrom<chroma_proto::RankOperator> for Rank {
785 type Error = QueryConversionError;
786
787 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
788 Ok(Rank {
789 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
790 })
791 }
792}
793
794impl TryFrom<Rank> for chroma_proto::RankOperator {
795 type Error = QueryConversionError;
796
797 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
798 Ok(chroma_proto::RankOperator {
799 expr: rank.expr.map(TryInto::try_into).transpose()?,
800 })
801 }
802}
803
804#[derive(Clone, Debug, Deserialize, Serialize)]
805pub enum RankExpr {
806 #[serde(rename = "$abs")]
807 Absolute(Box<RankExpr>),
808 #[serde(rename = "$div")]
809 Division {
810 left: Box<RankExpr>,
811 right: Box<RankExpr>,
812 },
813 #[serde(rename = "$exp")]
814 Exponentiation(Box<RankExpr>),
815 #[serde(rename = "$knn")]
816 Knn {
817 query: QueryVector,
818 #[serde(default = "RankExpr::default_knn_key")]
819 key: Key,
820 #[serde(default = "RankExpr::default_knn_limit")]
821 limit: u32,
822 #[serde(default)]
823 default: Option<f32>,
824 #[serde(default)]
825 return_rank: bool,
826 },
827 #[serde(rename = "$log")]
828 Logarithm(Box<RankExpr>),
829 #[serde(rename = "$max")]
830 Maximum(Vec<RankExpr>),
831 #[serde(rename = "$min")]
832 Minimum(Vec<RankExpr>),
833 #[serde(rename = "$mul")]
834 Multiplication(Vec<RankExpr>),
835 #[serde(rename = "$sub")]
836 Subtraction {
837 left: Box<RankExpr>,
838 right: Box<RankExpr>,
839 },
840 #[serde(rename = "$sum")]
841 Summation(Vec<RankExpr>),
842 #[serde(rename = "$val")]
843 Value(f32),
844}
845
846impl RankExpr {
847 pub fn default_knn_key() -> Key {
848 Key::Embedding
849 }
850
851 pub fn default_knn_limit() -> u32 {
852 16
853 }
854
855 pub fn knn_queries(&self) -> Vec<KnnQuery> {
856 match self {
857 RankExpr::Absolute(expr)
858 | RankExpr::Exponentiation(expr)
859 | RankExpr::Logarithm(expr) => expr.knn_queries(),
860 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
861 .knn_queries()
862 .into_iter()
863 .chain(right.knn_queries())
864 .collect(),
865 RankExpr::Maximum(exprs)
866 | RankExpr::Minimum(exprs)
867 | RankExpr::Multiplication(exprs)
868 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
869 RankExpr::Value(_) => Vec::new(),
870 RankExpr::Knn {
871 query,
872 key,
873 limit,
874 default: _,
875 return_rank: _,
876 } => vec![KnnQuery {
877 query: query.clone(),
878 key: key.clone(),
879 limit: *limit,
880 }],
881 }
882 }
883
884 pub fn exp(self) -> Self {
886 RankExpr::Exponentiation(Box::new(self))
887 }
888
889 pub fn log(self) -> Self {
891 RankExpr::Logarithm(Box::new(self))
892 }
893
894 pub fn abs(self) -> Self {
896 RankExpr::Absolute(Box::new(self))
897 }
898
899 pub fn max(self, other: impl Into<RankExpr>) -> Self {
901 let other = other.into();
902
903 match self {
904 RankExpr::Maximum(mut exprs) => match other {
905 RankExpr::Maximum(other_exprs) => {
906 exprs.extend(other_exprs);
907 RankExpr::Maximum(exprs)
908 }
909 _ => {
910 exprs.push(other);
911 RankExpr::Maximum(exprs)
912 }
913 },
914 _ => match other {
915 RankExpr::Maximum(mut exprs) => {
916 exprs.insert(0, self);
917 RankExpr::Maximum(exprs)
918 }
919 _ => RankExpr::Maximum(vec![self, other]),
920 },
921 }
922 }
923
924 pub fn min(self, other: impl Into<RankExpr>) -> Self {
926 let other = other.into();
927
928 match self {
929 RankExpr::Minimum(mut exprs) => match other {
930 RankExpr::Minimum(other_exprs) => {
931 exprs.extend(other_exprs);
932 RankExpr::Minimum(exprs)
933 }
934 _ => {
935 exprs.push(other);
936 RankExpr::Minimum(exprs)
937 }
938 },
939 _ => match other {
940 RankExpr::Minimum(mut exprs) => {
941 exprs.insert(0, self);
942 RankExpr::Minimum(exprs)
943 }
944 _ => RankExpr::Minimum(vec![self, other]),
945 },
946 }
947 }
948}
949
950impl Add for RankExpr {
951 type Output = RankExpr;
952
953 fn add(self, rhs: Self) -> Self::Output {
954 match self {
955 RankExpr::Summation(mut exprs) => match rhs {
956 RankExpr::Summation(rhs_exprs) => {
957 exprs.extend(rhs_exprs);
958 RankExpr::Summation(exprs)
959 }
960 _ => {
961 exprs.push(rhs);
962 RankExpr::Summation(exprs)
963 }
964 },
965 _ => match rhs {
966 RankExpr::Summation(mut exprs) => {
967 exprs.insert(0, self);
968 RankExpr::Summation(exprs)
969 }
970 _ => RankExpr::Summation(vec![self, rhs]),
971 },
972 }
973 }
974}
975
976impl Add<f32> for RankExpr {
977 type Output = RankExpr;
978
979 fn add(self, rhs: f32) -> Self::Output {
980 self + RankExpr::Value(rhs)
981 }
982}
983
984impl Add<RankExpr> for f32 {
985 type Output = RankExpr;
986
987 fn add(self, rhs: RankExpr) -> Self::Output {
988 RankExpr::Value(self) + rhs
989 }
990}
991
992impl Sub for RankExpr {
993 type Output = RankExpr;
994
995 fn sub(self, rhs: Self) -> Self::Output {
996 RankExpr::Subtraction {
997 left: Box::new(self),
998 right: Box::new(rhs),
999 }
1000 }
1001}
1002
1003impl Sub<f32> for RankExpr {
1004 type Output = RankExpr;
1005
1006 fn sub(self, rhs: f32) -> Self::Output {
1007 self - RankExpr::Value(rhs)
1008 }
1009}
1010
1011impl Sub<RankExpr> for f32 {
1012 type Output = RankExpr;
1013
1014 fn sub(self, rhs: RankExpr) -> Self::Output {
1015 RankExpr::Value(self) - rhs
1016 }
1017}
1018
1019impl Mul for RankExpr {
1020 type Output = RankExpr;
1021
1022 fn mul(self, rhs: Self) -> Self::Output {
1023 match self {
1024 RankExpr::Multiplication(mut exprs) => match rhs {
1025 RankExpr::Multiplication(rhs_exprs) => {
1026 exprs.extend(rhs_exprs);
1027 RankExpr::Multiplication(exprs)
1028 }
1029 _ => {
1030 exprs.push(rhs);
1031 RankExpr::Multiplication(exprs)
1032 }
1033 },
1034 _ => match rhs {
1035 RankExpr::Multiplication(mut exprs) => {
1036 exprs.insert(0, self);
1037 RankExpr::Multiplication(exprs)
1038 }
1039 _ => RankExpr::Multiplication(vec![self, rhs]),
1040 },
1041 }
1042 }
1043}
1044
1045impl Mul<f32> for RankExpr {
1046 type Output = RankExpr;
1047
1048 fn mul(self, rhs: f32) -> Self::Output {
1049 self * RankExpr::Value(rhs)
1050 }
1051}
1052
1053impl Mul<RankExpr> for f32 {
1054 type Output = RankExpr;
1055
1056 fn mul(self, rhs: RankExpr) -> Self::Output {
1057 RankExpr::Value(self) * rhs
1058 }
1059}
1060
1061impl Div for RankExpr {
1062 type Output = RankExpr;
1063
1064 fn div(self, rhs: Self) -> Self::Output {
1065 RankExpr::Division {
1066 left: Box::new(self),
1067 right: Box::new(rhs),
1068 }
1069 }
1070}
1071
1072impl Div<f32> for RankExpr {
1073 type Output = RankExpr;
1074
1075 fn div(self, rhs: f32) -> Self::Output {
1076 self / RankExpr::Value(rhs)
1077 }
1078}
1079
1080impl Div<RankExpr> for f32 {
1081 type Output = RankExpr;
1082
1083 fn div(self, rhs: RankExpr) -> Self::Output {
1084 RankExpr::Value(self) / rhs
1085 }
1086}
1087
1088impl Neg for RankExpr {
1089 type Output = RankExpr;
1090
1091 fn neg(self) -> Self::Output {
1092 RankExpr::Value(-1.0) * self
1093 }
1094}
1095
1096impl From<f32> for RankExpr {
1097 fn from(v: f32) -> Self {
1098 RankExpr::Value(v)
1099 }
1100}
1101
1102impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1103 type Error = QueryConversionError;
1104
1105 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1106 match proto_expr.rank {
1107 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1108 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1109 }
1110 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1111 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1112 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1113 Ok(RankExpr::Division {
1114 left: Box::new(RankExpr::try_from(*left)?),
1115 right: Box::new(RankExpr::try_from(*right)?),
1116 })
1117 }
1118 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1119 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1120 ),
1121 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1122 let query = knn
1123 .query
1124 .ok_or(QueryConversionError::field("query"))?
1125 .try_into()?;
1126 Ok(RankExpr::Knn {
1127 query,
1128 key: Key::from(knn.key),
1129 limit: knn.limit,
1130 default: knn.default,
1131 return_rank: knn.return_rank,
1132 })
1133 }
1134 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1135 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1136 }
1137 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1138 let exprs = max
1139 .exprs
1140 .into_iter()
1141 .map(RankExpr::try_from)
1142 .collect::<Result<Vec<_>, _>>()?;
1143 Ok(RankExpr::Maximum(exprs))
1144 }
1145 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1146 let exprs = min
1147 .exprs
1148 .into_iter()
1149 .map(RankExpr::try_from)
1150 .collect::<Result<Vec<_>, _>>()?;
1151 Ok(RankExpr::Minimum(exprs))
1152 }
1153 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1154 let exprs = mul
1155 .exprs
1156 .into_iter()
1157 .map(RankExpr::try_from)
1158 .collect::<Result<Vec<_>, _>>()?;
1159 Ok(RankExpr::Multiplication(exprs))
1160 }
1161 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1162 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1163 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1164 Ok(RankExpr::Subtraction {
1165 left: Box::new(RankExpr::try_from(*left)?),
1166 right: Box::new(RankExpr::try_from(*right)?),
1167 })
1168 }
1169 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1170 let exprs = sum
1171 .exprs
1172 .into_iter()
1173 .map(RankExpr::try_from)
1174 .collect::<Result<Vec<_>, _>>()?;
1175 Ok(RankExpr::Summation(exprs))
1176 }
1177 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1178 None => Err(QueryConversionError::field("rank")),
1179 }
1180 }
1181}
1182
1183impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1184 type Error = QueryConversionError;
1185
1186 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1187 let proto_rank = match rank_expr {
1188 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1189 chroma_proto::RankExpr::try_from(*expr)?,
1190 )),
1191 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1192 Box::new(chroma_proto::rank_expr::RankPair {
1193 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1194 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1195 }),
1196 ),
1197 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1198 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1199 ),
1200 RankExpr::Knn {
1201 query,
1202 key,
1203 limit,
1204 default,
1205 return_rank,
1206 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1207 query: Some(query.try_into()?),
1208 key: key.to_string(),
1209 limit,
1210 default,
1211 return_rank,
1212 }),
1213 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1214 chroma_proto::RankExpr::try_from(*expr)?,
1215 )),
1216 RankExpr::Maximum(exprs) => {
1217 let proto_exprs = exprs
1218 .into_iter()
1219 .map(chroma_proto::RankExpr::try_from)
1220 .collect::<Result<Vec<_>, _>>()?;
1221 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1222 exprs: proto_exprs,
1223 })
1224 }
1225 RankExpr::Minimum(exprs) => {
1226 let proto_exprs = exprs
1227 .into_iter()
1228 .map(chroma_proto::RankExpr::try_from)
1229 .collect::<Result<Vec<_>, _>>()?;
1230 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1231 exprs: proto_exprs,
1232 })
1233 }
1234 RankExpr::Multiplication(exprs) => {
1235 let proto_exprs = exprs
1236 .into_iter()
1237 .map(chroma_proto::RankExpr::try_from)
1238 .collect::<Result<Vec<_>, _>>()?;
1239 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1240 exprs: proto_exprs,
1241 })
1242 }
1243 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1244 Box::new(chroma_proto::rank_expr::RankPair {
1245 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1246 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1247 }),
1248 ),
1249 RankExpr::Summation(exprs) => {
1250 let proto_exprs = exprs
1251 .into_iter()
1252 .map(chroma_proto::RankExpr::try_from)
1253 .collect::<Result<Vec<_>, _>>()?;
1254 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1255 exprs: proto_exprs,
1256 })
1257 }
1258 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1259 };
1260
1261 Ok(chroma_proto::RankExpr {
1262 rank: Some(proto_rank),
1263 })
1264 }
1265}
1266
1267#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1268#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1269pub enum Key {
1270 Document,
1272 Embedding,
1273 Metadata,
1274 Score,
1275 MetadataField(String),
1276}
1277
1278impl Serialize for Key {
1279 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1280 where
1281 S: serde::Serializer,
1282 {
1283 match self {
1284 Key::Document => serializer.serialize_str("#document"),
1285 Key::Embedding => serializer.serialize_str("#embedding"),
1286 Key::Metadata => serializer.serialize_str("#metadata"),
1287 Key::Score => serializer.serialize_str("#score"),
1288 Key::MetadataField(field) => serializer.serialize_str(field),
1289 }
1290 }
1291}
1292
1293impl<'de> Deserialize<'de> for Key {
1294 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1295 where
1296 D: Deserializer<'de>,
1297 {
1298 let s = String::deserialize(deserializer)?;
1299 Ok(Key::from(s))
1300 }
1301}
1302
1303impl fmt::Display for Key {
1304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1305 match self {
1306 Key::Document => write!(f, "#document"),
1307 Key::Embedding => write!(f, "#embedding"),
1308 Key::Metadata => write!(f, "#metadata"),
1309 Key::Score => write!(f, "#score"),
1310 Key::MetadataField(field) => write!(f, "{}", field),
1311 }
1312 }
1313}
1314
1315impl From<&str> for Key {
1316 fn from(s: &str) -> Self {
1317 match s {
1318 "#document" => Key::Document,
1319 "#embedding" => Key::Embedding,
1320 "#metadata" => Key::Metadata,
1321 "#score" => Key::Score,
1322 field => Key::MetadataField(field.to_string()),
1324 }
1325 }
1326}
1327
1328impl From<String> for Key {
1329 fn from(s: String) -> Self {
1330 Key::from(s.as_str())
1331 }
1332}
1333
1334impl Key {
1335 pub fn field(name: impl Into<String>) -> Self {
1337 Key::MetadataField(name.into())
1338 }
1339
1340 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1342 Where::Metadata(MetadataExpression {
1343 key: self.to_string(),
1344 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1345 })
1346 }
1347
1348 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1350 Where::Metadata(MetadataExpression {
1351 key: self.to_string(),
1352 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1353 })
1354 }
1355
1356 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1358 Where::Metadata(MetadataExpression {
1359 key: self.to_string(),
1360 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1361 })
1362 }
1363
1364 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1366 Where::Metadata(MetadataExpression {
1367 key: self.to_string(),
1368 comparison: MetadataComparison::Primitive(
1369 PrimitiveOperator::GreaterThanOrEqual,
1370 value.into(),
1371 ),
1372 })
1373 }
1374
1375 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1377 Where::Metadata(MetadataExpression {
1378 key: self.to_string(),
1379 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1380 })
1381 }
1382
1383 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1385 Where::Metadata(MetadataExpression {
1386 key: self.to_string(),
1387 comparison: MetadataComparison::Primitive(
1388 PrimitiveOperator::LessThanOrEqual,
1389 value.into(),
1390 ),
1391 })
1392 }
1393
1394 pub fn is_in<I, T>(self, values: I) -> Where
1397 where
1398 I: IntoIterator<Item = T>,
1399 Vec<T>: Into<MetadataSetValue>,
1400 {
1401 let vec: Vec<T> = values.into_iter().collect();
1402 Where::Metadata(MetadataExpression {
1403 key: self.to_string(),
1404 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1405 })
1406 }
1407
1408 pub fn not_in<I, T>(self, values: I) -> Where
1411 where
1412 I: IntoIterator<Item = T>,
1413 Vec<T>: Into<MetadataSetValue>,
1414 {
1415 let vec: Vec<T> = values.into_iter().collect();
1416 Where::Metadata(MetadataExpression {
1417 key: self.to_string(),
1418 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
1419 })
1420 }
1421
1422 pub fn contains<S: Into<String>>(self, text: S) -> Where {
1424 Where::Document(DocumentExpression {
1425 operator: DocumentOperator::Contains,
1426 pattern: text.into(),
1427 })
1428 }
1429
1430 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
1432 Where::Document(DocumentExpression {
1433 operator: DocumentOperator::NotContains,
1434 pattern: text.into(),
1435 })
1436 }
1437
1438 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
1440 Where::Document(DocumentExpression {
1441 operator: DocumentOperator::Regex,
1442 pattern: pattern.into(),
1443 })
1444 }
1445
1446 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
1448 Where::Document(DocumentExpression {
1449 operator: DocumentOperator::NotRegex,
1450 pattern: pattern.into(),
1451 })
1452 }
1453}
1454
1455#[derive(Clone, Debug, Default, Deserialize, Serialize)]
1456pub struct Select {
1457 #[serde(default)]
1458 pub keys: HashSet<Key>,
1459}
1460
1461impl TryFrom<chroma_proto::SelectOperator> for Select {
1462 type Error = QueryConversionError;
1463
1464 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
1465 let keys = value
1466 .keys
1467 .into_iter()
1468 .map(|key| {
1469 serde_json::from_value(serde_json::Value::String(key))
1471 .map_err(|_| QueryConversionError::field("keys"))
1472 })
1473 .collect::<Result<HashSet<_>, _>>()?;
1474
1475 Ok(Self { keys })
1476 }
1477}
1478
1479impl TryFrom<Select> for chroma_proto::SelectOperator {
1480 type Error = QueryConversionError;
1481
1482 fn try_from(value: Select) -> Result<Self, Self::Error> {
1483 let keys = value
1484 .keys
1485 .into_iter()
1486 .map(|key| {
1487 serde_json::to_value(&key)
1489 .ok()
1490 .and_then(|v| v.as_str().map(String::from))
1491 .ok_or(QueryConversionError::field("keys"))
1492 })
1493 .collect::<Result<Vec<_>, _>>()?;
1494
1495 Ok(Self { keys })
1496 }
1497}
1498
1499#[derive(Clone, Debug, Deserialize, Serialize)]
1500#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1501pub struct SearchRecord {
1502 pub id: String,
1503 pub document: Option<String>,
1504 pub embedding: Option<Vec<f32>>,
1505 pub metadata: Option<Metadata>,
1506 pub score: Option<f32>,
1507}
1508
1509impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
1510 type Error = QueryConversionError;
1511
1512 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
1513 Ok(Self {
1514 id: value.id,
1515 document: value.document,
1516 embedding: value
1517 .embedding
1518 .map(|vec| vec.try_into().map(|(v, _)| v))
1519 .transpose()?,
1520 metadata: value.metadata.map(TryInto::try_into).transpose()?,
1521 score: value.score,
1522 })
1523 }
1524}
1525
1526impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
1527 type Error = QueryConversionError;
1528
1529 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
1530 Ok(Self {
1531 id: value.id,
1532 document: value.document,
1533 embedding: value
1534 .embedding
1535 .map(|embedding| {
1536 let embedding_dimension = embedding.len();
1537 chroma_proto::Vector::try_from((
1538 embedding,
1539 ScalarEncoding::FLOAT32,
1540 embedding_dimension,
1541 ))
1542 })
1543 .transpose()?,
1544 metadata: value.metadata.map(Into::into),
1545 score: value.score,
1546 })
1547 }
1548}
1549
1550#[derive(Clone, Debug, Default)]
1551pub struct SearchPayloadResult {
1552 pub records: Vec<SearchRecord>,
1553}
1554
1555impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
1556 type Error = QueryConversionError;
1557
1558 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
1559 Ok(Self {
1560 records: value
1561 .records
1562 .into_iter()
1563 .map(TryInto::try_into)
1564 .collect::<Result<_, _>>()?,
1565 })
1566 }
1567}
1568
1569impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
1570 type Error = QueryConversionError;
1571
1572 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
1573 Ok(Self {
1574 records: value
1575 .records
1576 .into_iter()
1577 .map(TryInto::try_into)
1578 .collect::<Result<Vec<_>, _>>()?,
1579 })
1580 }
1581}
1582
1583#[derive(Clone, Debug)]
1584pub struct SearchResult {
1585 pub results: Vec<SearchPayloadResult>,
1586 pub pulled_log_bytes: u64,
1587}
1588
1589impl SearchResult {
1590 pub fn size_bytes(&self) -> u64 {
1591 self.results
1592 .iter()
1593 .flat_map(|result| {
1594 result.records.iter().map(|record| {
1595 (record.id.len()
1596 + record
1597 .document
1598 .as_ref()
1599 .map(|doc| doc.len())
1600 .unwrap_or_default()
1601 + record
1602 .embedding
1603 .as_ref()
1604 .map(|emb| size_of_val(&emb[..]))
1605 .unwrap_or_default()
1606 + record
1607 .metadata
1608 .as_ref()
1609 .map(logical_size_of_metadata)
1610 .unwrap_or_default()
1611 + record.score.as_ref().map(size_of_val).unwrap_or_default())
1612 as u64
1613 })
1614 })
1615 .sum()
1616 }
1617}
1618
1619impl TryFrom<chroma_proto::SearchResult> for SearchResult {
1620 type Error = QueryConversionError;
1621
1622 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
1623 Ok(Self {
1624 results: value
1625 .results
1626 .into_iter()
1627 .map(TryInto::try_into)
1628 .collect::<Result<_, _>>()?,
1629 pulled_log_bytes: value.pulled_log_bytes,
1630 })
1631 }
1632}
1633
1634impl TryFrom<SearchResult> for chroma_proto::SearchResult {
1635 type Error = QueryConversionError;
1636
1637 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
1638 Ok(Self {
1639 results: value
1640 .results
1641 .into_iter()
1642 .map(TryInto::try_into)
1643 .collect::<Result<Vec<_>, _>>()?,
1644 pulled_log_bytes: value.pulled_log_bytes,
1645 })
1646 }
1647}
1648
1649pub fn rrf(
1652 ranks: Vec<RankExpr>,
1653 k: Option<u32>,
1654 weights: Option<Vec<f32>>,
1655 normalize: bool,
1656) -> Result<RankExpr, QueryConversionError> {
1657 let k = k.unwrap_or(60);
1658
1659 if ranks.is_empty() {
1660 return Err(QueryConversionError::validation(
1661 "RRF requires at least one rank expression",
1662 ));
1663 }
1664
1665 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
1666
1667 if weights.len() != ranks.len() {
1668 return Err(QueryConversionError::validation(format!(
1669 "RRF weights length ({}) must match ranks length ({})",
1670 weights.len(),
1671 ranks.len()
1672 )));
1673 }
1674
1675 let weights = if normalize {
1676 let sum: f32 = weights.iter().sum();
1677 if sum == 0.0 {
1678 return Err(QueryConversionError::validation(
1679 "RRF weights sum to zero, cannot normalize",
1680 ));
1681 }
1682 weights.into_iter().map(|w| w / sum).collect()
1683 } else {
1684 weights
1685 };
1686
1687 let terms: Vec<RankExpr> = weights
1688 .into_iter()
1689 .zip(ranks)
1690 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
1691 .collect();
1692
1693 let sum = terms
1696 .into_iter()
1697 .reduce(|a, b| a + b)
1698 .unwrap_or(RankExpr::Value(0.0));
1699 Ok(-sum)
1700}
1701
1702#[cfg(test)]
1703mod tests {
1704 use super::*;
1705
1706 #[test]
1707 fn test_key_from_string() {
1708 assert_eq!(Key::from("#document"), Key::Document);
1710 assert_eq!(Key::from("#embedding"), Key::Embedding);
1711 assert_eq!(Key::from("#metadata"), Key::Metadata);
1712 assert_eq!(Key::from("#score"), Key::Score);
1713
1714 assert_eq!(
1716 Key::from("custom_field"),
1717 Key::MetadataField("custom_field".to_string())
1718 );
1719 assert_eq!(
1720 Key::from("author"),
1721 Key::MetadataField("author".to_string())
1722 );
1723
1724 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
1726 assert_eq!(
1727 Key::from("year".to_string()),
1728 Key::MetadataField("year".to_string())
1729 );
1730 }
1731
1732 #[test]
1733 fn test_query_vector_dense_proto_conversion() {
1734 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1735 let query_vector = QueryVector::Dense(dense_vec.clone());
1736
1737 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
1739
1740 let converted: QueryVector = proto.try_into().unwrap();
1742
1743 assert_eq!(converted, query_vector);
1744 if let QueryVector::Dense(v) = converted {
1745 assert_eq!(v, dense_vec);
1746 } else {
1747 panic!("Expected dense vector");
1748 }
1749 }
1750
1751 #[test]
1752 fn test_query_vector_sparse_proto_conversion() {
1753 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
1754 let query_vector = QueryVector::Sparse(sparse.clone());
1755
1756 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
1758
1759 let converted: QueryVector = proto.try_into().unwrap();
1761
1762 assert_eq!(converted, query_vector);
1763 if let QueryVector::Sparse(s) = converted {
1764 assert_eq!(s, sparse);
1765 } else {
1766 panic!("Expected sparse vector");
1767 }
1768 }
1769
1770 #[test]
1771 fn test_filter_json_deserialization() {
1772 let simple_where = r#"{"author": "John Doe"}"#;
1776 let filter: Filter = serde_json::from_str(simple_where).unwrap();
1777 assert_eq!(filter.query_ids, None);
1778 assert!(filter.where_clause.is_some());
1779
1780 let id_filter_json = serde_json::json!({
1782 "#id": {
1783 "$in": ["doc1", "doc2", "doc3"]
1784 }
1785 });
1786 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
1787 assert_eq!(filter.query_ids, None);
1788 assert!(filter.where_clause.is_some());
1789
1790 let complex_json = serde_json::json!({
1792 "$and": [
1793 {
1794 "#id": {
1795 "$in": ["doc1", "doc2", "doc3"]
1796 }
1797 },
1798 {
1799 "$or": [
1800 {
1801 "author": {
1802 "$eq": "John Doe"
1803 }
1804 },
1805 {
1806 "author": {
1807 "$eq": "Jane Smith"
1808 }
1809 }
1810 ]
1811 },
1812 {
1813 "year": {
1814 "$gte": 2020
1815 }
1816 },
1817 {
1818 "tags": {
1819 "$contains": "machine-learning"
1820 }
1821 }
1822 ]
1823 });
1824
1825 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
1826 assert_eq!(filter.query_ids, None);
1827 assert!(filter.where_clause.is_some());
1828
1829 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
1831 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
1832 assert_eq!(composite.children.len(), 4);
1833
1834 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
1836 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
1837 assert_eq!(or_composite.children.len(), 2);
1838 } else {
1839 panic!("Expected OR composite in second child");
1840 }
1841 } else {
1842 panic!("Expected AND composite where clause");
1843 }
1844
1845 let mixed_operators_json = serde_json::json!({
1847 "$and": [
1848 {
1849 "status": {
1850 "$ne": "deleted"
1851 }
1852 },
1853 {
1854 "score": {
1855 "$gt": 0.5
1856 }
1857 },
1858 {
1859 "score": {
1860 "$lt": 0.9
1861 }
1862 },
1863 {
1864 "priority": {
1865 "$lte": 10
1866 }
1867 }
1868 ]
1869 });
1870
1871 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
1872 assert_eq!(filter.query_ids, None);
1873 assert!(filter.where_clause.is_some());
1874
1875 let deeply_nested_json = serde_json::json!({
1877 "$or": [
1878 {
1879 "$and": [
1880 {
1881 "#id": {
1882 "$in": ["id1", "id2"]
1883 }
1884 },
1885 {
1886 "$or": [
1887 {
1888 "category": "tech"
1889 },
1890 {
1891 "category": "science"
1892 }
1893 ]
1894 }
1895 ]
1896 },
1897 {
1898 "$and": [
1899 {
1900 "author": "Admin"
1901 },
1902 {
1903 "published": true
1904 }
1905 ]
1906 }
1907 ]
1908 });
1909
1910 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
1911 assert_eq!(filter.query_ids, None);
1912 assert!(filter.where_clause.is_some());
1913
1914 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
1916 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
1917 assert_eq!(composite.children.len(), 2);
1918
1919 for child in &composite.children {
1921 if let crate::metadata::Where::Composite(and_composite) = child {
1922 assert_eq!(
1923 and_composite.operator,
1924 crate::metadata::BooleanOperator::And
1925 );
1926 } else {
1927 panic!("Expected AND composite in OR children");
1928 }
1929 }
1930 } else {
1931 panic!("Expected OR composite at top level");
1932 }
1933
1934 let single_id_json = serde_json::json!({
1936 "#id": {
1937 "$eq": "single-doc-id"
1938 }
1939 });
1940
1941 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
1942 assert_eq!(filter.query_ids, None);
1943 assert!(filter.where_clause.is_some());
1944
1945 let empty_json = serde_json::json!({});
1947 let filter: Filter = serde_json::from_value(empty_json).unwrap();
1948 assert_eq!(filter.query_ids, None);
1949 assert_eq!(filter.where_clause, None);
1951
1952 let advanced_json = serde_json::json!({
1954 "$and": [
1955 {
1956 "#id": {
1957 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
1958 }
1959 },
1960 {
1961 "tags": {
1962 "$not_contains": "deprecated"
1963 }
1964 },
1965 {
1966 "$or": [
1967 {
1968 "$and": [
1969 {
1970 "confidence": {
1971 "$gte": 0.8
1972 }
1973 },
1974 {
1975 "verified": true
1976 }
1977 ]
1978 },
1979 {
1980 "$and": [
1981 {
1982 "confidence": {
1983 "$gte": 0.6
1984 }
1985 },
1986 {
1987 "confidence": {
1988 "$lt": 0.8
1989 }
1990 },
1991 {
1992 "reviews": {
1993 "$gte": 5
1994 }
1995 }
1996 ]
1997 }
1998 ]
1999 }
2000 ]
2001 });
2002
2003 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2004 assert_eq!(filter.query_ids, None);
2005 assert!(filter.where_clause.is_some());
2006
2007 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2009 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2010 assert_eq!(composite.children.len(), 3);
2011 } else {
2012 panic!("Expected AND composite at top level");
2013 }
2014 }
2015
2016 #[test]
2017 fn test_limit_json_serialization() {
2018 let limit = Limit {
2019 offset: 10,
2020 limit: Some(20),
2021 };
2022
2023 let json = serde_json::to_string(&limit).unwrap();
2024 let deserialized: Limit = serde_json::from_str(&json).unwrap();
2025
2026 assert_eq!(deserialized.offset, limit.offset);
2027 assert_eq!(deserialized.limit, limit.limit);
2028 }
2029
2030 #[test]
2031 fn test_query_vector_json_serialization() {
2032 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2034 let json = serde_json::to_string(&dense).unwrap();
2035 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2036 assert_eq!(deserialized, dense);
2037
2038 let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]));
2040 let json = serde_json::to_string(&sparse).unwrap();
2041 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2042 assert_eq!(deserialized, sparse);
2043 }
2044
2045 #[test]
2046 fn test_select_key_json_serialization() {
2047 use std::collections::HashSet;
2048
2049 let doc_key = Key::Document;
2051 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2052
2053 let embed_key = Key::Embedding;
2054 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
2055
2056 let meta_key = Key::Metadata;
2057 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
2058
2059 let score_key = Key::Score;
2060 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
2061
2062 let custom_key = Key::MetadataField("custom_key".to_string());
2064 assert_eq!(
2065 serde_json::to_string(&custom_key).unwrap(),
2066 "\"custom_key\""
2067 );
2068
2069 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
2071 assert!(matches!(deserialized, Key::Document));
2072
2073 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
2074 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
2075
2076 let mut keys = HashSet::new();
2078 keys.insert(Key::Document);
2079 keys.insert(Key::Embedding);
2080 keys.insert(Key::MetadataField("author".to_string()));
2081
2082 let select = Select { keys };
2083 let json = serde_json::to_string(&select).unwrap();
2084 let deserialized: Select = serde_json::from_str(&json).unwrap();
2085
2086 assert_eq!(deserialized.keys.len(), 3);
2087 assert!(deserialized.keys.contains(&Key::Document));
2088 assert!(deserialized.keys.contains(&Key::Embedding));
2089 assert!(deserialized
2090 .keys
2091 .contains(&Key::MetadataField("author".to_string())));
2092 }
2093
2094 #[test]
2095 fn test_merge_basic_integers() {
2096 use std::cmp::Reverse;
2097
2098 let merge = Merge { k: 5 };
2099
2100 let input = vec![
2102 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
2103 vec![Reverse(2), Reverse(5), Reverse(8)],
2104 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
2105 ];
2106
2107 let result = merge.merge(input);
2108
2109 assert_eq!(result.len(), 5);
2111 assert_eq!(
2112 result,
2113 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
2114 );
2115 }
2116
2117 #[test]
2118 fn test_merge_u32_descending() {
2119 let merge = Merge { k: 6 };
2120
2121 let input = vec![
2123 vec![100u32, 75, 50, 25],
2124 vec![90, 60, 30],
2125 vec![95, 85, 70, 40, 10],
2126 ];
2127
2128 let result = merge.merge(input);
2129
2130 assert_eq!(result.len(), 6);
2132 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
2133 }
2134
2135 #[test]
2136 fn test_merge_i32_descending() {
2137 let merge = Merge { k: 5 };
2138
2139 let input = vec![
2141 vec![50i32, 10, -10, -50],
2142 vec![30, 0, -30],
2143 vec![40, 20, -20, -40],
2144 ];
2145
2146 let result = merge.merge(input);
2147
2148 assert_eq!(result.len(), 5);
2150 assert_eq!(result, vec![50, 40, 30, 20, 10]);
2151 }
2152
2153 #[test]
2154 fn test_merge_with_duplicates() {
2155 let merge = Merge { k: 10 };
2156
2157 let input = vec![
2159 vec![100u32, 80, 80, 60, 40],
2160 vec![90, 80, 50, 30],
2161 vec![100, 70, 60, 20],
2162 ];
2163
2164 let result = merge.merge(input);
2165
2166 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
2168 }
2169
2170 #[test]
2171 fn test_merge_empty_vectors() {
2172 let merge = Merge { k: 5 };
2173
2174 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
2176 let result = merge.merge(input);
2177 assert_eq!(result.len(), 0);
2178
2179 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
2181 let result = merge.merge(input);
2182 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
2183
2184 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
2186 let result = merge.merge(input);
2187 assert_eq!(result, vec![100, 50, 25]);
2188 }
2189
2190 #[test]
2191 fn test_merge_k_boundary_conditions() {
2192 let merge = Merge { k: 0 };
2194 let input = vec![vec![100u32, 50], vec![75, 25]];
2195 let result = merge.merge(input);
2196 assert_eq!(result.len(), 0);
2197
2198 let merge = Merge { k: 1 };
2200 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
2201 let result = merge.merge(input);
2202 assert_eq!(result, vec![1000]);
2203
2204 let merge = Merge { k: 100 };
2206 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
2207 let result = merge.merge(input);
2208 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
2209 }
2210
2211 #[test]
2212 fn test_merge_with_strings() {
2213 let merge = Merge { k: 4 };
2214
2215 let input = vec![
2217 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
2218 vec!["elephant".to_string(), "banana".to_string()],
2219 vec!["fish".to_string(), "cat".to_string()],
2220 ];
2221
2222 let result = merge.merge(input);
2223
2224 assert_eq!(
2226 result,
2227 vec![
2228 "zebra".to_string(),
2229 "fish".to_string(),
2230 "elephant".to_string(),
2231 "dog".to_string()
2232 ]
2233 );
2234 }
2235
2236 #[test]
2237 fn test_merge_with_custom_struct() {
2238 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
2239 struct Score {
2240 value: i32,
2241 id: String,
2242 }
2243
2244 let merge = Merge { k: 3 };
2245
2246 let input = vec![
2248 vec![
2249 Score {
2250 value: 100,
2251 id: "a".to_string(),
2252 },
2253 Score {
2254 value: 80,
2255 id: "b".to_string(),
2256 },
2257 Score {
2258 value: 60,
2259 id: "c".to_string(),
2260 },
2261 ],
2262 vec![
2263 Score {
2264 value: 90,
2265 id: "d".to_string(),
2266 },
2267 Score {
2268 value: 70,
2269 id: "e".to_string(),
2270 },
2271 ],
2272 vec![
2273 Score {
2274 value: 95,
2275 id: "f".to_string(),
2276 },
2277 Score {
2278 value: 85,
2279 id: "g".to_string(),
2280 },
2281 ],
2282 ];
2283
2284 let result = merge.merge(input);
2285
2286 assert_eq!(result.len(), 3);
2287 assert_eq!(
2288 result[0],
2289 Score {
2290 value: 100,
2291 id: "a".to_string()
2292 }
2293 );
2294 assert_eq!(
2295 result[1],
2296 Score {
2297 value: 95,
2298 id: "f".to_string()
2299 }
2300 );
2301 assert_eq!(
2302 result[2],
2303 Score {
2304 value: 90,
2305 id: "d".to_string()
2306 }
2307 );
2308 }
2309
2310 #[test]
2311 fn test_merge_preserves_order() {
2312 use std::cmp::Reverse;
2313
2314 let merge = Merge { k: 10 };
2315
2316 let input = vec![
2319 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
2320 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
2321 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
2322 ];
2323
2324 let result = merge.merge(input);
2325
2326 for i in 1..result.len() {
2329 assert!(
2330 result[i - 1] >= result[i],
2331 "Output should be in descending Reverse order"
2332 );
2333 assert!(
2334 result[i - 1].0 <= result[i].0,
2335 "Inner values should be in ascending order"
2336 );
2337 }
2338
2339 assert_eq!(
2341 result,
2342 vec![
2343 Reverse(1),
2344 Reverse(2),
2345 Reverse(3),
2346 Reverse(4),
2347 Reverse(5),
2348 Reverse(6),
2349 Reverse(7),
2350 Reverse(8),
2351 Reverse(9),
2352 Reverse(10)
2353 ]
2354 );
2355 }
2356
2357 #[test]
2358 fn test_merge_single_vector() {
2359 let merge = Merge { k: 3 };
2360
2361 let input = vec![vec![1000u64, 800, 600, 400, 200]];
2363
2364 let result = merge.merge(input);
2365
2366 assert_eq!(result, vec![1000, 800, 600]);
2367 }
2368
2369 #[test]
2370 fn test_merge_all_same_values() {
2371 let merge = Merge { k: 5 };
2372
2373 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
2375
2376 let result = merge.merge(input);
2377
2378 assert_eq!(result, vec![42]);
2380 }
2381
2382 #[test]
2383 fn test_merge_mixed_types_sizes() {
2384 let merge = Merge { k: 4 };
2386 let input = vec![
2387 vec![1000usize, 500, 100],
2388 vec![800, 300],
2389 vec![900, 600, 200],
2390 ];
2391 let result = merge.merge(input);
2392 assert_eq!(result, vec![1000, 900, 800, 600]);
2393
2394 let merge = Merge { k: 5 };
2396 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
2397 let result = merge.merge(input);
2398 assert_eq!(result, vec![15, 10, 5, 0, -5]);
2399 }
2400}