1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::Hash,
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 ContainsOperator, DocumentExpression, DocumentOperator, Metadata, MetadataComparison,
15 MetadataExpression, MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding,
16 SetOperator, SparseVector, Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
170pub struct Filter {
171 pub query_ids: Option<Vec<String>>,
172 pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 match (&self.query_ids, &self.where_clause) {
184 (None, None) => {
185 let map = serializer.serialize_map(Some(0))?;
187 map.end()
188 }
189 (None, Some(where_clause)) => {
190 where_clause.serialize(serializer)
192 }
193 (Some(ids), None) => {
194 let id_where = Where::Metadata(MetadataExpression {
196 key: "#id".to_string(),
197 comparison: MetadataComparison::Set(
198 SetOperator::In,
199 MetadataSetValue::Str(ids.clone()),
200 ),
201 });
202 id_where.serialize(serializer)
203 }
204 (Some(ids), Some(where_clause)) => {
205 let id_where = Where::Metadata(MetadataExpression {
207 key: "#id".to_string(),
208 comparison: MetadataComparison::Set(
209 SetOperator::In,
210 MetadataSetValue::Str(ids.clone()),
211 ),
212 });
213 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214 combined.serialize(serializer)
215 }
216 }
217 }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let where_json = Value::deserialize(deserializer)?;
227 let where_clause =
228 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229 None
230 } else {
231 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232 };
233
234 Ok(Filter {
235 query_ids: None, where_clause,
237 })
238 }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242 type Error = QueryConversionError;
243
244 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247 let where_clause = match (where_metadata, where_document) {
248 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249 (Some(w), None) | (None, Some(w)) => Some(w),
250 _ => None,
251 };
252
253 Ok(Self {
254 query_ids: value.ids.map(|uids| uids.ids),
255 where_clause,
256 })
257 }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261 type Error = QueryConversionError;
262
263 fn try_from(value: Filter) -> Result<Self, Self::Error> {
264 Ok(Self {
265 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267 where_document: None,
268 })
269 }
270}
271
272#[derive(Clone, Debug)]
278pub struct Knn {
279 pub embedding: Vec<f32>,
280 pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284 fn from(value: KnnBatch) -> Self {
285 value
286 .embeddings
287 .into_iter()
288 .map(|embedding| Knn {
289 embedding,
290 fetch: value.fetch,
291 })
292 .collect()
293 }
294}
295
296#[derive(Clone, Debug)]
302pub struct KnnBatch {
303 pub embeddings: Vec<Vec<f32>>,
304 pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308 type Error = QueryConversionError;
309
310 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311 Ok(Self {
312 embeddings: value
313 .embeddings
314 .into_iter()
315 .map(|vec| vec.try_into().map(|(v, _)| v))
316 .collect::<Result<_, _>>()?,
317 fetch: value.fetch,
318 })
319 }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323 type Error = QueryConversionError;
324
325 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326 Ok(Self {
327 embeddings: value
328 .embeddings
329 .into_iter()
330 .map(|embedding| {
331 let dim = embedding.len();
332 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333 })
334 .collect::<Result<_, _>>()?,
335 fetch: value.fetch,
336 })
337 }
338}
339
340#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374 #[serde(default)]
375 pub offset: u32,
376 #[serde(default)]
377 pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381 fn from(value: chroma_proto::LimitOperator) -> Self {
382 Self {
383 offset: value.offset,
384 limit: value.limit,
385 }
386 }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390 fn from(value: Limit) -> Self {
391 Self {
392 offset: value.offset,
393 limit: value.limit,
394 }
395 }
396}
397
398#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401 pub offset_id: u32,
402 pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406 fn eq(&self, other: &Self) -> bool {
407 self.offset_id.eq(&other.offset_id)
408 }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414 fn cmp(&self, other: &Self) -> Ordering {
415 self.measure
416 .total_cmp(&other.measure)
417 .then_with(|| self.offset_id.cmp(&other.offset_id))
418 }
419}
420
421impl PartialOrd for RecordMeasure {
422 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423 Some(self.cmp(other))
424 }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429 pub distances: Vec<RecordMeasure>,
430}
431
432#[derive(Clone, Debug)]
442pub struct Merge {
443 pub k: u32,
444}
445
446impl Merge {
447 pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450 let mut max_heap = batch_iters
451 .iter_mut()
452 .enumerate()
453 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454 .collect::<BinaryHeap<_>>();
455
456 let mut fusion = Vec::with_capacity(self.k as usize);
457 while let Some((m, idx)) = max_heap.pop() {
458 if self.k <= fusion.len() as u32 {
459 break;
460 }
461 if let Some(next_m) = batch_iters[idx].next() {
462 max_heap.push((next_m, idx));
463 }
464 if fusion.last().is_some_and(|tail| tail == &m) {
465 continue;
466 }
467 fusion.push(m);
468 }
469 fusion
470 }
471}
472
473#[derive(Clone, Debug, Default)]
480pub struct Projection {
481 pub document: bool,
482 pub embedding: bool,
483 pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487 fn from(value: chroma_proto::ProjectionOperator) -> Self {
488 Self {
489 document: value.document,
490 embedding: value.embedding,
491 metadata: value.metadata,
492 }
493 }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497 fn from(value: Projection) -> Self {
498 Self {
499 document: value.document,
500 embedding: value.embedding,
501 metadata: value.metadata,
502 }
503 }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508 pub id: String,
509 pub document: Option<String>,
510 pub embedding: Option<Vec<f32>>,
511 pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515 pub fn size_bytes(&self) -> u64 {
516 (self.id.len()
517 + self
518 .document
519 .as_ref()
520 .map(|doc| doc.len())
521 .unwrap_or_default()
522 + self
523 .embedding
524 .as_ref()
525 .map(|emb| size_of_val(&emb[..]))
526 .unwrap_or_default()
527 + self
528 .metadata
529 .as_ref()
530 .map(logical_size_of_metadata)
531 .unwrap_or_default()) as u64
532 }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538 type Error = QueryConversionError;
539
540 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541 Ok(Self {
542 id: value.id,
543 document: value.document,
544 embedding: value
545 .embedding
546 .map(|vec| vec.try_into().map(|(v, _)| v))
547 .transpose()?,
548 metadata: value.metadata.map(TryInto::try_into).transpose()?,
549 })
550 }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554 type Error = QueryConversionError;
555
556 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557 Ok(Self {
558 id: value.id,
559 document: value.document,
560 embedding: value
561 .embedding
562 .map(|embedding| {
563 let embedding_dimension = embedding.len();
564 chroma_proto::Vector::try_from((
565 embedding,
566 ScalarEncoding::FLOAT32,
567 embedding_dimension,
568 ))
569 })
570 .transpose()?,
571 metadata: value.metadata.map(|metadata| metadata.into()),
572 })
573 }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578 pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583 pub pulled_log_bytes: u64,
584 pub result: ProjectionOutput,
585}
586
587impl GetResult {
588 pub fn size_bytes(&self) -> u64 {
589 self.result
590 .records
591 .iter()
592 .map(ProjectionRecord::size_bytes)
593 .sum()
594 }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598 type Error = QueryConversionError;
599
600 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601 Ok(Self {
602 pulled_log_bytes: value.pulled_log_bytes,
603 result: ProjectionOutput {
604 records: value
605 .records
606 .into_iter()
607 .map(TryInto::try_into)
608 .collect::<Result<_, _>>()?,
609 },
610 })
611 }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615 type Error = QueryConversionError;
616
617 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618 Ok(Self {
619 pulled_log_bytes: value.pulled_log_bytes,
620 records: value
621 .result
622 .records
623 .into_iter()
624 .map(TryInto::try_into)
625 .collect::<Result<_, _>>()?,
626 })
627 }
628}
629
630#[derive(Clone, Debug)]
638pub struct KnnProjection {
639 pub projection: Projection,
640 pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644 type Error = QueryConversionError;
645
646 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647 Ok(Self {
648 projection: value
649 .projection
650 .ok_or(QueryConversionError::field("projection"))?
651 .into(),
652 distance: value.distance,
653 })
654 }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658 fn from(value: KnnProjection) -> Self {
659 Self {
660 projection: Some(value.projection.into()),
661 distance: value.distance,
662 }
663 }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668 pub record: ProjectionRecord,
669 pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673 type Error = QueryConversionError;
674
675 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676 Ok(Self {
677 record: value
678 .record
679 .ok_or(QueryConversionError::field("record"))?
680 .try_into()?,
681 distance: value.distance,
682 })
683 }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687 type Error = QueryConversionError;
688
689 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690 Ok(Self {
691 record: Some(value.record.try_into()?),
692 distance: value.distance,
693 })
694 }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699 pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703 type Error = QueryConversionError;
704
705 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706 Ok(Self {
707 records: value
708 .records
709 .into_iter()
710 .map(TryInto::try_into)
711 .collect::<Result<_, _>>()?,
712 })
713 }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717 type Error = QueryConversionError;
718
719 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720 Ok(Self {
721 records: value
722 .records
723 .into_iter()
724 .map(TryInto::try_into)
725 .collect::<Result<_, _>>()?,
726 })
727 }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732 pub pulled_log_bytes: u64,
733 pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737 pub fn size_bytes(&self) -> u64 {
738 self.results
739 .iter()
740 .flat_map(|res| {
741 res.records
742 .iter()
743 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744 })
745 .sum()
746 }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750 type Error = QueryConversionError;
751
752 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753 Ok(Self {
754 pulled_log_bytes: value.pulled_log_bytes,
755 results: value
756 .results
757 .into_iter()
758 .map(TryInto::try_into)
759 .collect::<Result<_, _>>()?,
760 })
761 }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765 type Error = QueryConversionError;
766
767 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768 Ok(Self {
769 pulled_log_bytes: value.pulled_log_bytes,
770 results: value
771 .results
772 .into_iter()
773 .map(TryInto::try_into)
774 .collect::<Result<_, _>>()?,
775 })
776 }
777}
778
779#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845 Dense(Vec<f32>),
846 Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850 type Error = QueryConversionError;
851
852 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854 match vector {
855 chroma_proto::query_vector::Vector::Dense(dense) => {
856 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857 }
858 chroma_proto::query_vector::Vector::Sparse(sparse) => {
859 Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860 QueryConversionError::validation("sparse vector length mismatch")
861 })?))
862 }
863 }
864 }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868 type Error = QueryConversionError;
869
870 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871 match value {
872 QueryVector::Dense(vec) => {
873 let dim = vec.len();
874 Ok(chroma_proto::QueryVector {
875 vector: Some(chroma_proto::query_vector::Vector::Dense(
876 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877 )),
878 })
879 }
880 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882 }),
883 }
884 }
885}
886
887impl From<Vec<f32>> for QueryVector {
888 fn from(vec: Vec<f32>) -> Self {
889 QueryVector::Dense(vec)
890 }
891}
892
893impl From<SparseVector> for QueryVector {
894 fn from(sparse: SparseVector) -> Self {
895 QueryVector::Sparse(sparse)
896 }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901 pub query: QueryVector,
902 pub key: Key,
903 pub limit: u32,
904}
905
906#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939 pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943 pub fn knn_queries(&self) -> Vec<KnnQuery> {
944 self.expr
945 .as_ref()
946 .map(RankExpr::knn_queries)
947 .unwrap_or_default()
948 }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952 type Error = QueryConversionError;
953
954 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955 Ok(Rank {
956 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957 })
958 }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962 type Error = QueryConversionError;
963
964 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965 Ok(chroma_proto::RankOperator {
966 expr: rank.expr.map(TryInto::try_into).transpose()?,
967 })
968 }
969}
970
971#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135 #[serde(rename = "$abs")]
1136 Absolute(Box<RankExpr>),
1137 #[serde(rename = "$div")]
1138 Division {
1139 left: Box<RankExpr>,
1140 right: Box<RankExpr>,
1141 },
1142 #[serde(rename = "$exp")]
1143 Exponentiation(Box<RankExpr>),
1144 #[serde(rename = "$knn")]
1145 Knn {
1146 query: QueryVector,
1147 #[serde(default = "RankExpr::default_knn_key")]
1148 key: Key,
1149 #[serde(default = "RankExpr::default_knn_limit")]
1150 limit: u32,
1151 #[serde(default)]
1152 default: Option<f32>,
1153 #[serde(default)]
1154 return_rank: bool,
1155 },
1156 #[serde(rename = "$log")]
1157 Logarithm(Box<RankExpr>),
1158 #[serde(rename = "$max")]
1159 Maximum(Vec<RankExpr>),
1160 #[serde(rename = "$min")]
1161 Minimum(Vec<RankExpr>),
1162 #[serde(rename = "$mul")]
1163 Multiplication(Vec<RankExpr>),
1164 #[serde(rename = "$sub")]
1165 Subtraction {
1166 left: Box<RankExpr>,
1167 right: Box<RankExpr>,
1168 },
1169 #[serde(rename = "$sum")]
1170 Summation(Vec<RankExpr>),
1171 #[serde(rename = "$val")]
1172 Value(f32),
1173}
1174
1175impl RankExpr {
1176 pub fn default_knn_key() -> Key {
1177 Key::Embedding
1178 }
1179
1180 pub fn default_knn_limit() -> u32 {
1181 16
1182 }
1183
1184 pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185 match self {
1186 RankExpr::Absolute(expr)
1187 | RankExpr::Exponentiation(expr)
1188 | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190 .knn_queries()
1191 .into_iter()
1192 .chain(right.knn_queries())
1193 .collect(),
1194 RankExpr::Maximum(exprs)
1195 | RankExpr::Minimum(exprs)
1196 | RankExpr::Multiplication(exprs)
1197 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198 RankExpr::Value(_) => Vec::new(),
1199 RankExpr::Knn {
1200 query,
1201 key,
1202 limit,
1203 default: _,
1204 return_rank: _,
1205 } => vec![KnnQuery {
1206 query: query.clone(),
1207 key: key.clone(),
1208 limit: *limit,
1209 }],
1210 }
1211 }
1212
1213 pub fn exp(self) -> Self {
1233 RankExpr::Exponentiation(Box::new(self))
1234 }
1235
1236 pub fn log(self) -> Self {
1257 RankExpr::Logarithm(Box::new(self))
1258 }
1259
1260 pub fn abs(self) -> Self {
1287 RankExpr::Absolute(Box::new(self))
1288 }
1289
1290 pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314 let other = other.into();
1315
1316 match self {
1317 RankExpr::Maximum(mut exprs) => match other {
1318 RankExpr::Maximum(other_exprs) => {
1319 exprs.extend(other_exprs);
1320 RankExpr::Maximum(exprs)
1321 }
1322 _ => {
1323 exprs.push(other);
1324 RankExpr::Maximum(exprs)
1325 }
1326 },
1327 _ => match other {
1328 RankExpr::Maximum(mut exprs) => {
1329 exprs.insert(0, self);
1330 RankExpr::Maximum(exprs)
1331 }
1332 _ => RankExpr::Maximum(vec![self, other]),
1333 },
1334 }
1335 }
1336
1337 pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361 let other = other.into();
1362
1363 match self {
1364 RankExpr::Minimum(mut exprs) => match other {
1365 RankExpr::Minimum(other_exprs) => {
1366 exprs.extend(other_exprs);
1367 RankExpr::Minimum(exprs)
1368 }
1369 _ => {
1370 exprs.push(other);
1371 RankExpr::Minimum(exprs)
1372 }
1373 },
1374 _ => match other {
1375 RankExpr::Minimum(mut exprs) => {
1376 exprs.insert(0, self);
1377 RankExpr::Minimum(exprs)
1378 }
1379 _ => RankExpr::Minimum(vec![self, other]),
1380 },
1381 }
1382 }
1383}
1384
1385impl Add for RankExpr {
1386 type Output = RankExpr;
1387
1388 fn add(self, rhs: Self) -> Self::Output {
1389 match self {
1390 RankExpr::Summation(mut exprs) => match rhs {
1391 RankExpr::Summation(rhs_exprs) => {
1392 exprs.extend(rhs_exprs);
1393 RankExpr::Summation(exprs)
1394 }
1395 _ => {
1396 exprs.push(rhs);
1397 RankExpr::Summation(exprs)
1398 }
1399 },
1400 _ => match rhs {
1401 RankExpr::Summation(mut exprs) => {
1402 exprs.insert(0, self);
1403 RankExpr::Summation(exprs)
1404 }
1405 _ => RankExpr::Summation(vec![self, rhs]),
1406 },
1407 }
1408 }
1409}
1410
1411impl Add<f32> for RankExpr {
1412 type Output = RankExpr;
1413
1414 fn add(self, rhs: f32) -> Self::Output {
1415 self + RankExpr::Value(rhs)
1416 }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420 type Output = RankExpr;
1421
1422 fn add(self, rhs: RankExpr) -> Self::Output {
1423 RankExpr::Value(self) + rhs
1424 }
1425}
1426
1427impl Sub for RankExpr {
1428 type Output = RankExpr;
1429
1430 fn sub(self, rhs: Self) -> Self::Output {
1431 RankExpr::Subtraction {
1432 left: Box::new(self),
1433 right: Box::new(rhs),
1434 }
1435 }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439 type Output = RankExpr;
1440
1441 fn sub(self, rhs: f32) -> Self::Output {
1442 self - RankExpr::Value(rhs)
1443 }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447 type Output = RankExpr;
1448
1449 fn sub(self, rhs: RankExpr) -> Self::Output {
1450 RankExpr::Value(self) - rhs
1451 }
1452}
1453
1454impl Mul for RankExpr {
1455 type Output = RankExpr;
1456
1457 fn mul(self, rhs: Self) -> Self::Output {
1458 match self {
1459 RankExpr::Multiplication(mut exprs) => match rhs {
1460 RankExpr::Multiplication(rhs_exprs) => {
1461 exprs.extend(rhs_exprs);
1462 RankExpr::Multiplication(exprs)
1463 }
1464 _ => {
1465 exprs.push(rhs);
1466 RankExpr::Multiplication(exprs)
1467 }
1468 },
1469 _ => match rhs {
1470 RankExpr::Multiplication(mut exprs) => {
1471 exprs.insert(0, self);
1472 RankExpr::Multiplication(exprs)
1473 }
1474 _ => RankExpr::Multiplication(vec![self, rhs]),
1475 },
1476 }
1477 }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481 type Output = RankExpr;
1482
1483 fn mul(self, rhs: f32) -> Self::Output {
1484 self * RankExpr::Value(rhs)
1485 }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489 type Output = RankExpr;
1490
1491 fn mul(self, rhs: RankExpr) -> Self::Output {
1492 RankExpr::Value(self) * rhs
1493 }
1494}
1495
1496impl Div for RankExpr {
1497 type Output = RankExpr;
1498
1499 fn div(self, rhs: Self) -> Self::Output {
1500 RankExpr::Division {
1501 left: Box::new(self),
1502 right: Box::new(rhs),
1503 }
1504 }
1505}
1506
1507impl Div<f32> for RankExpr {
1508 type Output = RankExpr;
1509
1510 fn div(self, rhs: f32) -> Self::Output {
1511 self / RankExpr::Value(rhs)
1512 }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516 type Output = RankExpr;
1517
1518 fn div(self, rhs: RankExpr) -> Self::Output {
1519 RankExpr::Value(self) / rhs
1520 }
1521}
1522
1523impl Neg for RankExpr {
1524 type Output = RankExpr;
1525
1526 fn neg(self) -> Self::Output {
1527 RankExpr::Value(-1.0) * self
1528 }
1529}
1530
1531impl From<f32> for RankExpr {
1532 fn from(v: f32) -> Self {
1533 RankExpr::Value(v)
1534 }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538 type Error = QueryConversionError;
1539
1540 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541 match proto_expr.rank {
1542 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544 }
1545 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548 Ok(RankExpr::Division {
1549 left: Box::new(RankExpr::try_from(*left)?),
1550 right: Box::new(RankExpr::try_from(*right)?),
1551 })
1552 }
1553 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555 ),
1556 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557 let query = knn
1558 .query
1559 .ok_or(QueryConversionError::field("query"))?
1560 .try_into()?;
1561 Ok(RankExpr::Knn {
1562 query,
1563 key: Key::from(knn.key),
1564 limit: knn.limit,
1565 default: knn.default,
1566 return_rank: knn.return_rank,
1567 })
1568 }
1569 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571 }
1572 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573 let exprs = max
1574 .exprs
1575 .into_iter()
1576 .map(RankExpr::try_from)
1577 .collect::<Result<Vec<_>, _>>()?;
1578 Ok(RankExpr::Maximum(exprs))
1579 }
1580 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581 let exprs = min
1582 .exprs
1583 .into_iter()
1584 .map(RankExpr::try_from)
1585 .collect::<Result<Vec<_>, _>>()?;
1586 Ok(RankExpr::Minimum(exprs))
1587 }
1588 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589 let exprs = mul
1590 .exprs
1591 .into_iter()
1592 .map(RankExpr::try_from)
1593 .collect::<Result<Vec<_>, _>>()?;
1594 Ok(RankExpr::Multiplication(exprs))
1595 }
1596 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599 Ok(RankExpr::Subtraction {
1600 left: Box::new(RankExpr::try_from(*left)?),
1601 right: Box::new(RankExpr::try_from(*right)?),
1602 })
1603 }
1604 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605 let exprs = sum
1606 .exprs
1607 .into_iter()
1608 .map(RankExpr::try_from)
1609 .collect::<Result<Vec<_>, _>>()?;
1610 Ok(RankExpr::Summation(exprs))
1611 }
1612 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613 None => Err(QueryConversionError::field("rank")),
1614 }
1615 }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619 type Error = QueryConversionError;
1620
1621 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622 let proto_rank = match rank_expr {
1623 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624 chroma_proto::RankExpr::try_from(*expr)?,
1625 )),
1626 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627 Box::new(chroma_proto::rank_expr::RankPair {
1628 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630 }),
1631 ),
1632 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634 ),
1635 RankExpr::Knn {
1636 query,
1637 key,
1638 limit,
1639 default,
1640 return_rank,
1641 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642 query: Some(query.try_into()?),
1643 key: key.to_string(),
1644 limit,
1645 default,
1646 return_rank,
1647 }),
1648 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649 chroma_proto::RankExpr::try_from(*expr)?,
1650 )),
1651 RankExpr::Maximum(exprs) => {
1652 let proto_exprs = exprs
1653 .into_iter()
1654 .map(chroma_proto::RankExpr::try_from)
1655 .collect::<Result<Vec<_>, _>>()?;
1656 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657 exprs: proto_exprs,
1658 })
1659 }
1660 RankExpr::Minimum(exprs) => {
1661 let proto_exprs = exprs
1662 .into_iter()
1663 .map(chroma_proto::RankExpr::try_from)
1664 .collect::<Result<Vec<_>, _>>()?;
1665 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666 exprs: proto_exprs,
1667 })
1668 }
1669 RankExpr::Multiplication(exprs) => {
1670 let proto_exprs = exprs
1671 .into_iter()
1672 .map(chroma_proto::RankExpr::try_from)
1673 .collect::<Result<Vec<_>, _>>()?;
1674 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675 exprs: proto_exprs,
1676 })
1677 }
1678 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679 Box::new(chroma_proto::rank_expr::RankPair {
1680 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682 }),
1683 ),
1684 RankExpr::Summation(exprs) => {
1685 let proto_exprs = exprs
1686 .into_iter()
1687 .map(chroma_proto::RankExpr::try_from)
1688 .collect::<Result<Vec<_>, _>>()?;
1689 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690 exprs: proto_exprs,
1691 })
1692 }
1693 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694 };
1695
1696 Ok(chroma_proto::RankExpr {
1697 rank: Some(proto_rank),
1698 })
1699 }
1700}
1701
1702#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769 Document,
1771 Embedding,
1772 Metadata,
1773 Score,
1774 MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779 where
1780 S: serde::Serializer,
1781 {
1782 match self {
1783 Key::Document => serializer.serialize_str("#document"),
1784 Key::Embedding => serializer.serialize_str("#embedding"),
1785 Key::Metadata => serializer.serialize_str("#metadata"),
1786 Key::Score => serializer.serialize_str("#score"),
1787 Key::MetadataField(field) => serializer.serialize_str(field),
1788 }
1789 }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794 where
1795 D: Deserializer<'de>,
1796 {
1797 let s = String::deserialize(deserializer)?;
1798 Ok(Key::from(s))
1799 }
1800}
1801
1802impl fmt::Display for Key {
1803 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804 match self {
1805 Key::Document => write!(f, "#document"),
1806 Key::Embedding => write!(f, "#embedding"),
1807 Key::Metadata => write!(f, "#metadata"),
1808 Key::Score => write!(f, "#score"),
1809 Key::MetadataField(field) => write!(f, "{}", field),
1810 }
1811 }
1812}
1813
1814impl From<&str> for Key {
1815 fn from(s: &str) -> Self {
1816 match s {
1817 "#document" => Key::Document,
1818 "#embedding" => Key::Embedding,
1819 "#metadata" => Key::Metadata,
1820 "#score" => Key::Score,
1821 field => Key::MetadataField(field.to_string()),
1823 }
1824 }
1825}
1826
1827impl From<String> for Key {
1828 fn from(s: String) -> Self {
1829 Key::from(s.as_str())
1830 }
1831}
1832
1833impl Key {
1834 pub fn field(name: impl Into<String>) -> Self {
1846 Key::MetadataField(name.into())
1847 }
1848
1849 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866 Where::Metadata(MetadataExpression {
1867 key: self.to_string(),
1868 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869 })
1870 }
1871
1872 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883 Where::Metadata(MetadataExpression {
1884 key: self.to_string(),
1885 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886 })
1887 }
1888
1889 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900 Where::Metadata(MetadataExpression {
1901 key: self.to_string(),
1902 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903 })
1904 }
1905
1906 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917 Where::Metadata(MetadataExpression {
1918 key: self.to_string(),
1919 comparison: MetadataComparison::Primitive(
1920 PrimitiveOperator::GreaterThanOrEqual,
1921 value.into(),
1922 ),
1923 })
1924 }
1925
1926 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937 Where::Metadata(MetadataExpression {
1938 key: self.to_string(),
1939 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940 })
1941 }
1942
1943 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954 Where::Metadata(MetadataExpression {
1955 key: self.to_string(),
1956 comparison: MetadataComparison::Primitive(
1957 PrimitiveOperator::LessThanOrEqual,
1958 value.into(),
1959 ),
1960 })
1961 }
1962
1963 pub fn is_in<I, T>(self, values: I) -> Where
1983 where
1984 I: IntoIterator<Item = T>,
1985 Vec<T>: Into<MetadataSetValue>,
1986 {
1987 let vec: Vec<T> = values.into_iter().collect();
1988 Where::Metadata(MetadataExpression {
1989 key: self.to_string(),
1990 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991 })
1992 }
1993
1994 pub fn not_in<I, T>(self, values: I) -> Where
2010 where
2011 I: IntoIterator<Item = T>,
2012 Vec<T>: Into<MetadataSetValue>,
2013 {
2014 let vec: Vec<T> = values.into_iter().collect();
2015 Where::Metadata(MetadataExpression {
2016 key: self.to_string(),
2017 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018 })
2019 }
2020
2021 pub fn contains<S: Into<String>>(self, text: S) -> Where {
2037 Where::Document(DocumentExpression {
2038 operator: DocumentOperator::Contains,
2039 pattern: text.into(),
2040 })
2041 }
2042
2043 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2059 Where::Document(DocumentExpression {
2060 operator: DocumentOperator::NotContains,
2061 pattern: text.into(),
2062 })
2063 }
2064
2065 pub fn contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2078 Where::Metadata(MetadataExpression {
2079 key: self.to_string(),
2080 comparison: MetadataComparison::ArrayContains(ContainsOperator::Contains, value.into()),
2081 })
2082 }
2083
2084 pub fn not_contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2096 Where::Metadata(MetadataExpression {
2097 key: self.to_string(),
2098 comparison: MetadataComparison::ArrayContains(
2099 ContainsOperator::NotContains,
2100 value.into(),
2101 ),
2102 })
2103 }
2104
2105 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2122 Where::Document(DocumentExpression {
2123 operator: DocumentOperator::Regex,
2124 pattern: pattern.into(),
2125 })
2126 }
2127
2128 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2144 Where::Document(DocumentExpression {
2145 operator: DocumentOperator::NotRegex,
2146 pattern: pattern.into(),
2147 })
2148 }
2149}
2150
2151#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2205pub struct Select {
2206 #[serde(default)]
2207 pub keys: HashSet<Key>,
2208}
2209
2210impl TryFrom<chroma_proto::SelectOperator> for Select {
2211 type Error = QueryConversionError;
2212
2213 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2214 let keys = value
2215 .keys
2216 .into_iter()
2217 .map(|key| {
2218 serde_json::from_value(serde_json::Value::String(key))
2220 .map_err(|_| QueryConversionError::field("keys"))
2221 })
2222 .collect::<Result<HashSet<_>, _>>()?;
2223
2224 Ok(Self { keys })
2225 }
2226}
2227
2228impl TryFrom<Select> for chroma_proto::SelectOperator {
2229 type Error = QueryConversionError;
2230
2231 fn try_from(value: Select) -> Result<Self, Self::Error> {
2232 let keys = value
2233 .keys
2234 .into_iter()
2235 .map(|key| {
2236 serde_json::to_value(&key)
2238 .ok()
2239 .and_then(|v| v.as_str().map(String::from))
2240 .ok_or(QueryConversionError::field("keys"))
2241 })
2242 .collect::<Result<Vec<_>, _>>()?;
2243
2244 Ok(Self { keys })
2245 }
2246}
2247
2248#[derive(Clone, Debug, Deserialize, Serialize)]
2287pub enum Aggregate {
2288 #[serde(rename = "$min_k")]
2290 MinK {
2291 keys: Vec<Key>,
2293 k: u32,
2295 },
2296 #[serde(rename = "$max_k")]
2298 MaxK {
2299 keys: Vec<Key>,
2301 k: u32,
2303 },
2304}
2305
2306#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2348pub struct GroupBy {
2349 #[serde(default)]
2351 pub keys: Vec<Key>,
2352 #[serde(default)]
2354 pub aggregate: Option<Aggregate>,
2355}
2356
2357impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2358 type Error = QueryConversionError;
2359
2360 fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2361 match value
2362 .aggregate
2363 .ok_or(QueryConversionError::field("aggregate"))?
2364 {
2365 chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2366 let keys = min_k.keys.into_iter().map(Key::from).collect();
2367 Ok(Aggregate::MinK { keys, k: min_k.k })
2368 }
2369 chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2370 let keys = max_k.keys.into_iter().map(Key::from).collect();
2371 Ok(Aggregate::MaxK { keys, k: max_k.k })
2372 }
2373 }
2374 }
2375}
2376
2377impl From<Aggregate> for chroma_proto::Aggregate {
2378 fn from(value: Aggregate) -> Self {
2379 let aggregate = match value {
2380 Aggregate::MinK { keys, k } => {
2381 chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2382 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2383 k,
2384 })
2385 }
2386 Aggregate::MaxK { keys, k } => {
2387 chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2388 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2389 k,
2390 })
2391 }
2392 };
2393
2394 chroma_proto::Aggregate {
2395 aggregate: Some(aggregate),
2396 }
2397 }
2398}
2399
2400impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2401 type Error = QueryConversionError;
2402
2403 fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2404 let keys = value.keys.into_iter().map(Key::from).collect();
2405 let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2406
2407 Ok(Self { keys, aggregate })
2408 }
2409}
2410
2411impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2412 type Error = QueryConversionError;
2413
2414 fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2415 let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2416 let aggregate = value.aggregate.map(Into::into);
2417
2418 Ok(Self { keys, aggregate })
2419 }
2420}
2421
2422#[derive(Clone, Debug, Deserialize, Serialize)]
2459#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2460pub struct SearchRecord {
2461 pub id: String,
2462 pub document: Option<String>,
2463 pub embedding: Option<Vec<f32>>,
2464 pub metadata: Option<Metadata>,
2465 pub score: Option<f32>,
2466}
2467
2468impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2469 type Error = QueryConversionError;
2470
2471 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2472 Ok(Self {
2473 id: value.id,
2474 document: value.document,
2475 embedding: value
2476 .embedding
2477 .map(|vec| vec.try_into().map(|(v, _)| v))
2478 .transpose()?,
2479 metadata: value.metadata.map(TryInto::try_into).transpose()?,
2480 score: value.score,
2481 })
2482 }
2483}
2484
2485impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2486 type Error = QueryConversionError;
2487
2488 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2489 Ok(Self {
2490 id: value.id,
2491 document: value.document,
2492 embedding: value
2493 .embedding
2494 .map(|embedding| {
2495 let embedding_dimension = embedding.len();
2496 chroma_proto::Vector::try_from((
2497 embedding,
2498 ScalarEncoding::FLOAT32,
2499 embedding_dimension,
2500 ))
2501 })
2502 .transpose()?,
2503 metadata: value.metadata.map(Into::into),
2504 score: value.score,
2505 })
2506 }
2507}
2508
2509#[derive(Clone, Debug, Default)]
2531pub struct SearchPayloadResult {
2532 pub records: Vec<SearchRecord>,
2533}
2534
2535impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2536 type Error = QueryConversionError;
2537
2538 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2539 Ok(Self {
2540 records: value
2541 .records
2542 .into_iter()
2543 .map(TryInto::try_into)
2544 .collect::<Result<_, _>>()?,
2545 })
2546 }
2547}
2548
2549impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2550 type Error = QueryConversionError;
2551
2552 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2553 Ok(Self {
2554 records: value
2555 .records
2556 .into_iter()
2557 .map(TryInto::try_into)
2558 .collect::<Result<Vec<_>, _>>()?,
2559 })
2560 }
2561}
2562
2563#[derive(Clone, Debug)]
2606pub struct SearchResult {
2607 pub results: Vec<SearchPayloadResult>,
2608 pub pulled_log_bytes: u64,
2609}
2610
2611impl SearchResult {
2612 pub fn size_bytes(&self) -> u64 {
2613 self.results
2614 .iter()
2615 .flat_map(|result| {
2616 result.records.iter().map(|record| {
2617 (record.id.len()
2618 + record
2619 .document
2620 .as_ref()
2621 .map(|doc| doc.len())
2622 .unwrap_or_default()
2623 + record
2624 .embedding
2625 .as_ref()
2626 .map(|emb| size_of_val(&emb[..]))
2627 .unwrap_or_default()
2628 + record
2629 .metadata
2630 .as_ref()
2631 .map(logical_size_of_metadata)
2632 .unwrap_or_default()
2633 + record.score.as_ref().map(size_of_val).unwrap_or_default())
2634 as u64
2635 })
2636 })
2637 .sum()
2638 }
2639}
2640
2641impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2642 type Error = QueryConversionError;
2643
2644 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2645 Ok(Self {
2646 results: value
2647 .results
2648 .into_iter()
2649 .map(TryInto::try_into)
2650 .collect::<Result<_, _>>()?,
2651 pulled_log_bytes: value.pulled_log_bytes,
2652 })
2653 }
2654}
2655
2656impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2657 type Error = QueryConversionError;
2658
2659 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2660 Ok(Self {
2661 results: value
2662 .results
2663 .into_iter()
2664 .map(TryInto::try_into)
2665 .collect::<Result<Vec<_>, _>>()?,
2666 pulled_log_bytes: value.pulled_log_bytes,
2667 })
2668 }
2669}
2670
2671pub fn rrf(
2816 ranks: Vec<RankExpr>,
2817 k: Option<u32>,
2818 weights: Option<Vec<f32>>,
2819 normalize: bool,
2820) -> Result<RankExpr, QueryConversionError> {
2821 let k = k.unwrap_or(60);
2822
2823 if ranks.is_empty() {
2824 return Err(QueryConversionError::validation(
2825 "RRF requires at least one rank expression",
2826 ));
2827 }
2828
2829 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2830
2831 if weights.len() != ranks.len() {
2832 return Err(QueryConversionError::validation(format!(
2833 "RRF weights length ({}) must match ranks length ({})",
2834 weights.len(),
2835 ranks.len()
2836 )));
2837 }
2838
2839 let weights = if normalize {
2840 let sum: f32 = weights.iter().sum();
2841 if sum == 0.0 {
2842 return Err(QueryConversionError::validation(
2843 "RRF weights sum to zero, cannot normalize",
2844 ));
2845 }
2846 weights.into_iter().map(|w| w / sum).collect()
2847 } else {
2848 weights
2849 };
2850
2851 let terms: Vec<RankExpr> = weights
2852 .into_iter()
2853 .zip(ranks)
2854 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2855 .collect();
2856
2857 let sum = terms
2860 .into_iter()
2861 .reduce(|a, b| a + b)
2862 .unwrap_or(RankExpr::Value(0.0));
2863 Ok(-sum)
2864}
2865
2866#[cfg(test)]
2867mod tests {
2868 use super::*;
2869
2870 #[test]
2871 fn test_key_from_string() {
2872 assert_eq!(Key::from("#document"), Key::Document);
2874 assert_eq!(Key::from("#embedding"), Key::Embedding);
2875 assert_eq!(Key::from("#metadata"), Key::Metadata);
2876 assert_eq!(Key::from("#score"), Key::Score);
2877
2878 assert_eq!(
2880 Key::from("custom_field"),
2881 Key::MetadataField("custom_field".to_string())
2882 );
2883 assert_eq!(
2884 Key::from("author"),
2885 Key::MetadataField("author".to_string())
2886 );
2887
2888 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2890 assert_eq!(
2891 Key::from("year".to_string()),
2892 Key::MetadataField("year".to_string())
2893 );
2894 }
2895
2896 #[test]
2897 fn test_query_vector_dense_proto_conversion() {
2898 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2899 let query_vector = QueryVector::Dense(dense_vec.clone());
2900
2901 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2903
2904 let converted: QueryVector = proto.try_into().unwrap();
2906
2907 assert_eq!(converted, query_vector);
2908 if let QueryVector::Dense(v) = converted {
2909 assert_eq!(v, dense_vec);
2910 } else {
2911 panic!("Expected dense vector");
2912 }
2913 }
2914
2915 #[test]
2916 fn test_query_vector_sparse_proto_conversion() {
2917 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2918 let query_vector = QueryVector::Sparse(sparse.clone());
2919
2920 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2922
2923 let converted: QueryVector = proto.try_into().unwrap();
2925
2926 assert_eq!(converted, query_vector);
2927 if let QueryVector::Sparse(s) = converted {
2928 assert_eq!(s, sparse);
2929 } else {
2930 panic!("Expected sparse vector");
2931 }
2932 }
2933
2934 #[test]
2935 fn test_filter_json_deserialization() {
2936 let simple_where = r#"{"author": "John Doe"}"#;
2940 let filter: Filter = serde_json::from_str(simple_where).unwrap();
2941 assert_eq!(filter.query_ids, None);
2942 assert!(filter.where_clause.is_some());
2943
2944 let id_filter_json = serde_json::json!({
2946 "#id": {
2947 "$in": ["doc1", "doc2", "doc3"]
2948 }
2949 });
2950 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2951 assert_eq!(filter.query_ids, None);
2952 assert!(filter.where_clause.is_some());
2953
2954 let complex_json = serde_json::json!({
2956 "$and": [
2957 {
2958 "#id": {
2959 "$in": ["doc1", "doc2", "doc3"]
2960 }
2961 },
2962 {
2963 "$or": [
2964 {
2965 "author": {
2966 "$eq": "John Doe"
2967 }
2968 },
2969 {
2970 "author": {
2971 "$eq": "Jane Smith"
2972 }
2973 }
2974 ]
2975 },
2976 {
2977 "year": {
2978 "$gte": 2020
2979 }
2980 },
2981 {
2982 "tags": {
2983 "$contains": "machine-learning"
2984 }
2985 }
2986 ]
2987 });
2988
2989 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2990 assert_eq!(filter.query_ids, None);
2991 assert!(filter.where_clause.is_some());
2992
2993 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2995 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2996 assert_eq!(composite.children.len(), 4);
2997
2998 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
3000 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
3001 assert_eq!(or_composite.children.len(), 2);
3002 } else {
3003 panic!("Expected OR composite in second child");
3004 }
3005 } else {
3006 panic!("Expected AND composite where clause");
3007 }
3008
3009 let mixed_operators_json = serde_json::json!({
3011 "$and": [
3012 {
3013 "status": {
3014 "$ne": "deleted"
3015 }
3016 },
3017 {
3018 "score": {
3019 "$gt": 0.5
3020 }
3021 },
3022 {
3023 "score": {
3024 "$lt": 0.9
3025 }
3026 },
3027 {
3028 "priority": {
3029 "$lte": 10
3030 }
3031 }
3032 ]
3033 });
3034
3035 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
3036 assert_eq!(filter.query_ids, None);
3037 assert!(filter.where_clause.is_some());
3038
3039 let deeply_nested_json = serde_json::json!({
3041 "$or": [
3042 {
3043 "$and": [
3044 {
3045 "#id": {
3046 "$in": ["id1", "id2"]
3047 }
3048 },
3049 {
3050 "$or": [
3051 {
3052 "category": "tech"
3053 },
3054 {
3055 "category": "science"
3056 }
3057 ]
3058 }
3059 ]
3060 },
3061 {
3062 "$and": [
3063 {
3064 "author": "Admin"
3065 },
3066 {
3067 "published": true
3068 }
3069 ]
3070 }
3071 ]
3072 });
3073
3074 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3075 assert_eq!(filter.query_ids, None);
3076 assert!(filter.where_clause.is_some());
3077
3078 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3080 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3081 assert_eq!(composite.children.len(), 2);
3082
3083 for child in &composite.children {
3085 if let crate::metadata::Where::Composite(and_composite) = child {
3086 assert_eq!(
3087 and_composite.operator,
3088 crate::metadata::BooleanOperator::And
3089 );
3090 } else {
3091 panic!("Expected AND composite in OR children");
3092 }
3093 }
3094 } else {
3095 panic!("Expected OR composite at top level");
3096 }
3097
3098 let single_id_json = serde_json::json!({
3100 "#id": {
3101 "$eq": "single-doc-id"
3102 }
3103 });
3104
3105 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3106 assert_eq!(filter.query_ids, None);
3107 assert!(filter.where_clause.is_some());
3108
3109 let empty_json = serde_json::json!({});
3111 let filter: Filter = serde_json::from_value(empty_json).unwrap();
3112 assert_eq!(filter.query_ids, None);
3113 assert_eq!(filter.where_clause, None);
3115
3116 let advanced_json = serde_json::json!({
3118 "$and": [
3119 {
3120 "#id": {
3121 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3122 }
3123 },
3124 {
3125 "tags": {
3126 "$not_contains": "deprecated"
3127 }
3128 },
3129 {
3130 "$or": [
3131 {
3132 "$and": [
3133 {
3134 "confidence": {
3135 "$gte": 0.8
3136 }
3137 },
3138 {
3139 "verified": true
3140 }
3141 ]
3142 },
3143 {
3144 "$and": [
3145 {
3146 "confidence": {
3147 "$gte": 0.6
3148 }
3149 },
3150 {
3151 "confidence": {
3152 "$lt": 0.8
3153 }
3154 },
3155 {
3156 "reviews": {
3157 "$gte": 5
3158 }
3159 }
3160 ]
3161 }
3162 ]
3163 }
3164 ]
3165 });
3166
3167 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3168 assert_eq!(filter.query_ids, None);
3169 assert!(filter.where_clause.is_some());
3170
3171 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3173 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3174 assert_eq!(composite.children.len(), 3);
3175 } else {
3176 panic!("Expected AND composite at top level");
3177 }
3178 }
3179
3180 #[test]
3181 fn test_limit_json_serialization() {
3182 let limit = Limit {
3183 offset: 10,
3184 limit: Some(20),
3185 };
3186
3187 let json = serde_json::to_string(&limit).unwrap();
3188 let deserialized: Limit = serde_json::from_str(&json).unwrap();
3189
3190 assert_eq!(deserialized.offset, limit.offset);
3191 assert_eq!(deserialized.limit, limit.limit);
3192 }
3193
3194 #[test]
3195 fn test_query_vector_json_serialization() {
3196 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3198 let json = serde_json::to_string(&dense).unwrap();
3199 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3200 assert_eq!(deserialized, dense);
3201
3202 let sparse =
3204 QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3205 let json = serde_json::to_string(&sparse).unwrap();
3206 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3207 assert_eq!(deserialized, sparse);
3208 }
3209
3210 #[test]
3211 fn test_select_key_json_serialization() {
3212 use std::collections::HashSet;
3213
3214 let doc_key = Key::Document;
3216 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3217
3218 let embed_key = Key::Embedding;
3219 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3220
3221 let meta_key = Key::Metadata;
3222 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3223
3224 let score_key = Key::Score;
3225 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3226
3227 let custom_key = Key::MetadataField("custom_key".to_string());
3229 assert_eq!(
3230 serde_json::to_string(&custom_key).unwrap(),
3231 "\"custom_key\""
3232 );
3233
3234 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3236 assert!(matches!(deserialized, Key::Document));
3237
3238 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3239 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3240
3241 let mut keys = HashSet::new();
3243 keys.insert(Key::Document);
3244 keys.insert(Key::Embedding);
3245 keys.insert(Key::MetadataField("author".to_string()));
3246
3247 let select = Select { keys };
3248 let json = serde_json::to_string(&select).unwrap();
3249 let deserialized: Select = serde_json::from_str(&json).unwrap();
3250
3251 assert_eq!(deserialized.keys.len(), 3);
3252 assert!(deserialized.keys.contains(&Key::Document));
3253 assert!(deserialized.keys.contains(&Key::Embedding));
3254 assert!(deserialized
3255 .keys
3256 .contains(&Key::MetadataField("author".to_string())));
3257 }
3258
3259 #[test]
3260 fn test_merge_basic_integers() {
3261 use std::cmp::Reverse;
3262
3263 let merge = Merge { k: 5 };
3264
3265 let input = vec![
3267 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3268 vec![Reverse(2), Reverse(5), Reverse(8)],
3269 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3270 ];
3271
3272 let result = merge.merge(input);
3273
3274 assert_eq!(result.len(), 5);
3276 assert_eq!(
3277 result,
3278 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3279 );
3280 }
3281
3282 #[test]
3283 fn test_merge_u32_descending() {
3284 let merge = Merge { k: 6 };
3285
3286 let input = vec![
3288 vec![100u32, 75, 50, 25],
3289 vec![90, 60, 30],
3290 vec![95, 85, 70, 40, 10],
3291 ];
3292
3293 let result = merge.merge(input);
3294
3295 assert_eq!(result.len(), 6);
3297 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3298 }
3299
3300 #[test]
3301 fn test_merge_i32_descending() {
3302 let merge = Merge { k: 5 };
3303
3304 let input = vec![
3306 vec![50i32, 10, -10, -50],
3307 vec![30, 0, -30],
3308 vec![40, 20, -20, -40],
3309 ];
3310
3311 let result = merge.merge(input);
3312
3313 assert_eq!(result.len(), 5);
3315 assert_eq!(result, vec![50, 40, 30, 20, 10]);
3316 }
3317
3318 #[test]
3319 fn test_merge_with_duplicates() {
3320 let merge = Merge { k: 10 };
3321
3322 let input = vec![
3324 vec![100u32, 80, 80, 60, 40],
3325 vec![90, 80, 50, 30],
3326 vec![100, 70, 60, 20],
3327 ];
3328
3329 let result = merge.merge(input);
3330
3331 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3333 }
3334
3335 #[test]
3336 fn test_merge_empty_vectors() {
3337 let merge = Merge { k: 5 };
3338
3339 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3341 let result = merge.merge(input);
3342 assert_eq!(result.len(), 0);
3343
3344 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3346 let result = merge.merge(input);
3347 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3348
3349 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3351 let result = merge.merge(input);
3352 assert_eq!(result, vec![100, 50, 25]);
3353 }
3354
3355 #[test]
3356 fn test_merge_k_boundary_conditions() {
3357 let merge = Merge { k: 0 };
3359 let input = vec![vec![100u32, 50], vec![75, 25]];
3360 let result = merge.merge(input);
3361 assert_eq!(result.len(), 0);
3362
3363 let merge = Merge { k: 1 };
3365 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3366 let result = merge.merge(input);
3367 assert_eq!(result, vec![1000]);
3368
3369 let merge = Merge { k: 100 };
3371 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3372 let result = merge.merge(input);
3373 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3374 }
3375
3376 #[test]
3377 fn test_merge_with_strings() {
3378 let merge = Merge { k: 4 };
3379
3380 let input = vec![
3382 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3383 vec!["elephant".to_string(), "banana".to_string()],
3384 vec!["fish".to_string(), "cat".to_string()],
3385 ];
3386
3387 let result = merge.merge(input);
3388
3389 assert_eq!(
3391 result,
3392 vec![
3393 "zebra".to_string(),
3394 "fish".to_string(),
3395 "elephant".to_string(),
3396 "dog".to_string()
3397 ]
3398 );
3399 }
3400
3401 #[test]
3402 fn test_merge_with_custom_struct() {
3403 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3404 struct Score {
3405 value: i32,
3406 id: String,
3407 }
3408
3409 let merge = Merge { k: 3 };
3410
3411 let input = vec![
3413 vec![
3414 Score {
3415 value: 100,
3416 id: "a".to_string(),
3417 },
3418 Score {
3419 value: 80,
3420 id: "b".to_string(),
3421 },
3422 Score {
3423 value: 60,
3424 id: "c".to_string(),
3425 },
3426 ],
3427 vec![
3428 Score {
3429 value: 90,
3430 id: "d".to_string(),
3431 },
3432 Score {
3433 value: 70,
3434 id: "e".to_string(),
3435 },
3436 ],
3437 vec![
3438 Score {
3439 value: 95,
3440 id: "f".to_string(),
3441 },
3442 Score {
3443 value: 85,
3444 id: "g".to_string(),
3445 },
3446 ],
3447 ];
3448
3449 let result = merge.merge(input);
3450
3451 assert_eq!(result.len(), 3);
3452 assert_eq!(
3453 result[0],
3454 Score {
3455 value: 100,
3456 id: "a".to_string()
3457 }
3458 );
3459 assert_eq!(
3460 result[1],
3461 Score {
3462 value: 95,
3463 id: "f".to_string()
3464 }
3465 );
3466 assert_eq!(
3467 result[2],
3468 Score {
3469 value: 90,
3470 id: "d".to_string()
3471 }
3472 );
3473 }
3474
3475 #[test]
3476 fn test_merge_preserves_order() {
3477 use std::cmp::Reverse;
3478
3479 let merge = Merge { k: 10 };
3480
3481 let input = vec![
3484 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3485 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3486 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3487 ];
3488
3489 let result = merge.merge(input);
3490
3491 for i in 1..result.len() {
3494 assert!(
3495 result[i - 1] >= result[i],
3496 "Output should be in descending Reverse order"
3497 );
3498 assert!(
3499 result[i - 1].0 <= result[i].0,
3500 "Inner values should be in ascending order"
3501 );
3502 }
3503
3504 assert_eq!(
3506 result,
3507 vec![
3508 Reverse(1),
3509 Reverse(2),
3510 Reverse(3),
3511 Reverse(4),
3512 Reverse(5),
3513 Reverse(6),
3514 Reverse(7),
3515 Reverse(8),
3516 Reverse(9),
3517 Reverse(10)
3518 ]
3519 );
3520 }
3521
3522 #[test]
3523 fn test_merge_single_vector() {
3524 let merge = Merge { k: 3 };
3525
3526 let input = vec![vec![1000u64, 800, 600, 400, 200]];
3528
3529 let result = merge.merge(input);
3530
3531 assert_eq!(result, vec![1000, 800, 600]);
3532 }
3533
3534 #[test]
3535 fn test_merge_all_same_values() {
3536 let merge = Merge { k: 5 };
3537
3538 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3540
3541 let result = merge.merge(input);
3542
3543 assert_eq!(result, vec![42]);
3545 }
3546
3547 #[test]
3548 fn test_merge_mixed_types_sizes() {
3549 let merge = Merge { k: 4 };
3551 let input = vec![
3552 vec![1000usize, 500, 100],
3553 vec![800, 300],
3554 vec![900, 600, 200],
3555 ];
3556 let result = merge.merge(input);
3557 assert_eq!(result, vec![1000, 900, 800, 600]);
3558
3559 let merge = Merge { k: 5 };
3561 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3562 let result = merge.merge(input);
3563 assert_eq!(result, vec![15, 10, 5, 0, -5]);
3564 }
3565
3566 #[test]
3567 fn test_aggregate_json_serialization() {
3568 let min_k = Aggregate::MinK {
3570 keys: vec![Key::Score, Key::field("date")],
3571 k: 3,
3572 };
3573 let json = serde_json::to_value(&min_k).unwrap();
3574 assert!(json.get("$min_k").is_some());
3575 assert_eq!(json["$min_k"]["k"], 3);
3576
3577 let min_k_json = serde_json::json!({
3579 "$min_k": {
3580 "keys": ["#score", "date"],
3581 "k": 5
3582 }
3583 });
3584 let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3585 match deserialized {
3586 Aggregate::MinK { keys, k } => {
3587 assert_eq!(k, 5);
3588 assert_eq!(keys.len(), 2);
3589 assert_eq!(keys[0], Key::Score);
3590 assert_eq!(keys[1], Key::field("date"));
3591 }
3592 _ => panic!("Expected MinK"),
3593 }
3594
3595 let max_k = Aggregate::MaxK {
3597 keys: vec![Key::field("timestamp")],
3598 k: 10,
3599 };
3600 let json = serde_json::to_value(&max_k).unwrap();
3601 assert!(json.get("$max_k").is_some());
3602 assert_eq!(json["$max_k"]["k"], 10);
3603
3604 let max_k_json = serde_json::json!({
3606 "$max_k": {
3607 "keys": ["timestamp"],
3608 "k": 2
3609 }
3610 });
3611 let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3612 match deserialized {
3613 Aggregate::MaxK { keys, k } => {
3614 assert_eq!(k, 2);
3615 assert_eq!(keys.len(), 1);
3616 assert_eq!(keys[0], Key::field("timestamp"));
3617 }
3618 _ => panic!("Expected MaxK"),
3619 }
3620 }
3621
3622 #[test]
3623 fn test_group_by_json_serialization() {
3624 let group_by = GroupBy {
3626 keys: vec![Key::field("category"), Key::field("author")],
3627 aggregate: Some(Aggregate::MinK {
3628 keys: vec![Key::Score],
3629 k: 3,
3630 }),
3631 };
3632
3633 let json = serde_json::to_value(&group_by).unwrap();
3634 assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3635 assert!(json["aggregate"]["$min_k"].is_object());
3636
3637 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3639 assert_eq!(deserialized.keys.len(), 2);
3640 assert_eq!(deserialized.keys[0], Key::field("category"));
3641 assert_eq!(deserialized.keys[1], Key::field("author"));
3642 assert!(deserialized.aggregate.is_some());
3643
3644 let empty_group_by = GroupBy::default();
3646 let json = serde_json::to_value(&empty_group_by).unwrap();
3647 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3648 assert!(deserialized.keys.is_empty());
3649 assert!(deserialized.aggregate.is_none());
3650
3651 let json = serde_json::json!({
3653 "keys": ["category"],
3654 "aggregate": {
3655 "$max_k": {
3656 "keys": ["#score", "priority"],
3657 "k": 5
3658 }
3659 }
3660 });
3661 let group_by: GroupBy = serde_json::from_value(json).unwrap();
3662 assert_eq!(group_by.keys.len(), 1);
3663 assert_eq!(group_by.keys[0], Key::field("category"));
3664 match group_by.aggregate {
3665 Some(Aggregate::MaxK { keys, k }) => {
3666 assert_eq!(k, 5);
3667 assert_eq!(keys.len(), 2);
3668 assert_eq!(keys[0], Key::Score);
3669 }
3670 _ => panic!("Expected MaxK aggregate"),
3671 }
3672 }
3673}