1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::Hash,
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15 MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16 Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
170pub struct Filter {
171 pub query_ids: Option<Vec<String>>,
172 pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 match (&self.query_ids, &self.where_clause) {
184 (None, None) => {
185 let map = serializer.serialize_map(Some(0))?;
187 map.end()
188 }
189 (None, Some(where_clause)) => {
190 where_clause.serialize(serializer)
192 }
193 (Some(ids), None) => {
194 let id_where = Where::Metadata(MetadataExpression {
196 key: "#id".to_string(),
197 comparison: MetadataComparison::Set(
198 SetOperator::In,
199 MetadataSetValue::Str(ids.clone()),
200 ),
201 });
202 id_where.serialize(serializer)
203 }
204 (Some(ids), Some(where_clause)) => {
205 let id_where = Where::Metadata(MetadataExpression {
207 key: "#id".to_string(),
208 comparison: MetadataComparison::Set(
209 SetOperator::In,
210 MetadataSetValue::Str(ids.clone()),
211 ),
212 });
213 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214 combined.serialize(serializer)
215 }
216 }
217 }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let where_json = Value::deserialize(deserializer)?;
227 let where_clause =
228 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229 None
230 } else {
231 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232 };
233
234 Ok(Filter {
235 query_ids: None, where_clause,
237 })
238 }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242 type Error = QueryConversionError;
243
244 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247 let where_clause = match (where_metadata, where_document) {
248 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249 (Some(w), None) | (None, Some(w)) => Some(w),
250 _ => None,
251 };
252
253 Ok(Self {
254 query_ids: value.ids.map(|uids| uids.ids),
255 where_clause,
256 })
257 }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261 type Error = QueryConversionError;
262
263 fn try_from(value: Filter) -> Result<Self, Self::Error> {
264 Ok(Self {
265 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267 where_document: None,
268 })
269 }
270}
271
272#[derive(Clone, Debug)]
278pub struct Knn {
279 pub embedding: Vec<f32>,
280 pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284 fn from(value: KnnBatch) -> Self {
285 value
286 .embeddings
287 .into_iter()
288 .map(|embedding| Knn {
289 embedding,
290 fetch: value.fetch,
291 })
292 .collect()
293 }
294}
295
296#[derive(Clone, Debug)]
302pub struct KnnBatch {
303 pub embeddings: Vec<Vec<f32>>,
304 pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308 type Error = QueryConversionError;
309
310 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311 Ok(Self {
312 embeddings: value
313 .embeddings
314 .into_iter()
315 .map(|vec| vec.try_into().map(|(v, _)| v))
316 .collect::<Result<_, _>>()?,
317 fetch: value.fetch,
318 })
319 }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323 type Error = QueryConversionError;
324
325 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326 Ok(Self {
327 embeddings: value
328 .embeddings
329 .into_iter()
330 .map(|embedding| {
331 let dim = embedding.len();
332 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333 })
334 .collect::<Result<_, _>>()?,
335 fetch: value.fetch,
336 })
337 }
338}
339
340#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374 #[serde(default)]
375 pub offset: u32,
376 #[serde(default)]
377 pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381 fn from(value: chroma_proto::LimitOperator) -> Self {
382 Self {
383 offset: value.offset,
384 limit: value.limit,
385 }
386 }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390 fn from(value: Limit) -> Self {
391 Self {
392 offset: value.offset,
393 limit: value.limit,
394 }
395 }
396}
397
398#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401 pub offset_id: u32,
402 pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406 fn eq(&self, other: &Self) -> bool {
407 self.offset_id.eq(&other.offset_id)
408 }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414 fn cmp(&self, other: &Self) -> Ordering {
415 self.measure
416 .total_cmp(&other.measure)
417 .then_with(|| self.offset_id.cmp(&other.offset_id))
418 }
419}
420
421impl PartialOrd for RecordMeasure {
422 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423 Some(self.cmp(other))
424 }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429 pub distances: Vec<RecordMeasure>,
430}
431
432#[derive(Clone, Debug)]
442pub struct Merge {
443 pub k: u32,
444}
445
446impl Merge {
447 pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450 let mut max_heap = batch_iters
451 .iter_mut()
452 .enumerate()
453 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454 .collect::<BinaryHeap<_>>();
455
456 let mut fusion = Vec::with_capacity(self.k as usize);
457 while let Some((m, idx)) = max_heap.pop() {
458 if self.k <= fusion.len() as u32 {
459 break;
460 }
461 if let Some(next_m) = batch_iters[idx].next() {
462 max_heap.push((next_m, idx));
463 }
464 if fusion.last().is_some_and(|tail| tail == &m) {
465 continue;
466 }
467 fusion.push(m);
468 }
469 fusion
470 }
471}
472
473#[derive(Clone, Debug, Default)]
480pub struct Projection {
481 pub document: bool,
482 pub embedding: bool,
483 pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487 fn from(value: chroma_proto::ProjectionOperator) -> Self {
488 Self {
489 document: value.document,
490 embedding: value.embedding,
491 metadata: value.metadata,
492 }
493 }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497 fn from(value: Projection) -> Self {
498 Self {
499 document: value.document,
500 embedding: value.embedding,
501 metadata: value.metadata,
502 }
503 }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508 pub id: String,
509 pub document: Option<String>,
510 pub embedding: Option<Vec<f32>>,
511 pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515 pub fn size_bytes(&self) -> u64 {
516 (self.id.len()
517 + self
518 .document
519 .as_ref()
520 .map(|doc| doc.len())
521 .unwrap_or_default()
522 + self
523 .embedding
524 .as_ref()
525 .map(|emb| size_of_val(&emb[..]))
526 .unwrap_or_default()
527 + self
528 .metadata
529 .as_ref()
530 .map(logical_size_of_metadata)
531 .unwrap_or_default()) as u64
532 }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538 type Error = QueryConversionError;
539
540 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541 Ok(Self {
542 id: value.id,
543 document: value.document,
544 embedding: value
545 .embedding
546 .map(|vec| vec.try_into().map(|(v, _)| v))
547 .transpose()?,
548 metadata: value.metadata.map(TryInto::try_into).transpose()?,
549 })
550 }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554 type Error = QueryConversionError;
555
556 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557 Ok(Self {
558 id: value.id,
559 document: value.document,
560 embedding: value
561 .embedding
562 .map(|embedding| {
563 let embedding_dimension = embedding.len();
564 chroma_proto::Vector::try_from((
565 embedding,
566 ScalarEncoding::FLOAT32,
567 embedding_dimension,
568 ))
569 })
570 .transpose()?,
571 metadata: value.metadata.map(|metadata| metadata.into()),
572 })
573 }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578 pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583 pub pulled_log_bytes: u64,
584 pub result: ProjectionOutput,
585}
586
587impl GetResult {
588 pub fn size_bytes(&self) -> u64 {
589 self.result
590 .records
591 .iter()
592 .map(ProjectionRecord::size_bytes)
593 .sum()
594 }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598 type Error = QueryConversionError;
599
600 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601 Ok(Self {
602 pulled_log_bytes: value.pulled_log_bytes,
603 result: ProjectionOutput {
604 records: value
605 .records
606 .into_iter()
607 .map(TryInto::try_into)
608 .collect::<Result<_, _>>()?,
609 },
610 })
611 }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615 type Error = QueryConversionError;
616
617 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618 Ok(Self {
619 pulled_log_bytes: value.pulled_log_bytes,
620 records: value
621 .result
622 .records
623 .into_iter()
624 .map(TryInto::try_into)
625 .collect::<Result<_, _>>()?,
626 })
627 }
628}
629
630#[derive(Clone, Debug)]
638pub struct KnnProjection {
639 pub projection: Projection,
640 pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644 type Error = QueryConversionError;
645
646 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647 Ok(Self {
648 projection: value
649 .projection
650 .ok_or(QueryConversionError::field("projection"))?
651 .into(),
652 distance: value.distance,
653 })
654 }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658 fn from(value: KnnProjection) -> Self {
659 Self {
660 projection: Some(value.projection.into()),
661 distance: value.distance,
662 }
663 }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668 pub record: ProjectionRecord,
669 pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673 type Error = QueryConversionError;
674
675 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676 Ok(Self {
677 record: value
678 .record
679 .ok_or(QueryConversionError::field("record"))?
680 .try_into()?,
681 distance: value.distance,
682 })
683 }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687 type Error = QueryConversionError;
688
689 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690 Ok(Self {
691 record: Some(value.record.try_into()?),
692 distance: value.distance,
693 })
694 }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699 pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703 type Error = QueryConversionError;
704
705 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706 Ok(Self {
707 records: value
708 .records
709 .into_iter()
710 .map(TryInto::try_into)
711 .collect::<Result<_, _>>()?,
712 })
713 }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717 type Error = QueryConversionError;
718
719 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720 Ok(Self {
721 records: value
722 .records
723 .into_iter()
724 .map(TryInto::try_into)
725 .collect::<Result<_, _>>()?,
726 })
727 }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732 pub pulled_log_bytes: u64,
733 pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737 pub fn size_bytes(&self) -> u64 {
738 self.results
739 .iter()
740 .flat_map(|res| {
741 res.records
742 .iter()
743 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744 })
745 .sum()
746 }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750 type Error = QueryConversionError;
751
752 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753 Ok(Self {
754 pulled_log_bytes: value.pulled_log_bytes,
755 results: value
756 .results
757 .into_iter()
758 .map(TryInto::try_into)
759 .collect::<Result<_, _>>()?,
760 })
761 }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765 type Error = QueryConversionError;
766
767 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768 Ok(Self {
769 pulled_log_bytes: value.pulled_log_bytes,
770 results: value
771 .results
772 .into_iter()
773 .map(TryInto::try_into)
774 .collect::<Result<_, _>>()?,
775 })
776 }
777}
778
779#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845 Dense(Vec<f32>),
846 Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850 type Error = QueryConversionError;
851
852 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854 match vector {
855 chroma_proto::query_vector::Vector::Dense(dense) => {
856 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857 }
858 chroma_proto::query_vector::Vector::Sparse(sparse) => {
859 Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860 QueryConversionError::validation("sparse vector length mismatch")
861 })?))
862 }
863 }
864 }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868 type Error = QueryConversionError;
869
870 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871 match value {
872 QueryVector::Dense(vec) => {
873 let dim = vec.len();
874 Ok(chroma_proto::QueryVector {
875 vector: Some(chroma_proto::query_vector::Vector::Dense(
876 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877 )),
878 })
879 }
880 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882 }),
883 }
884 }
885}
886
887impl From<Vec<f32>> for QueryVector {
888 fn from(vec: Vec<f32>) -> Self {
889 QueryVector::Dense(vec)
890 }
891}
892
893impl From<SparseVector> for QueryVector {
894 fn from(sparse: SparseVector) -> Self {
895 QueryVector::Sparse(sparse)
896 }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901 pub query: QueryVector,
902 pub key: Key,
903 pub limit: u32,
904}
905
906#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939 pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943 pub fn knn_queries(&self) -> Vec<KnnQuery> {
944 self.expr
945 .as_ref()
946 .map(RankExpr::knn_queries)
947 .unwrap_or_default()
948 }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952 type Error = QueryConversionError;
953
954 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955 Ok(Rank {
956 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957 })
958 }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962 type Error = QueryConversionError;
963
964 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965 Ok(chroma_proto::RankOperator {
966 expr: rank.expr.map(TryInto::try_into).transpose()?,
967 })
968 }
969}
970
971#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135 #[serde(rename = "$abs")]
1136 Absolute(Box<RankExpr>),
1137 #[serde(rename = "$div")]
1138 Division {
1139 left: Box<RankExpr>,
1140 right: Box<RankExpr>,
1141 },
1142 #[serde(rename = "$exp")]
1143 Exponentiation(Box<RankExpr>),
1144 #[serde(rename = "$knn")]
1145 Knn {
1146 query: QueryVector,
1147 #[serde(default = "RankExpr::default_knn_key")]
1148 key: Key,
1149 #[serde(default = "RankExpr::default_knn_limit")]
1150 limit: u32,
1151 #[serde(default)]
1152 default: Option<f32>,
1153 #[serde(default)]
1154 return_rank: bool,
1155 },
1156 #[serde(rename = "$log")]
1157 Logarithm(Box<RankExpr>),
1158 #[serde(rename = "$max")]
1159 Maximum(Vec<RankExpr>),
1160 #[serde(rename = "$min")]
1161 Minimum(Vec<RankExpr>),
1162 #[serde(rename = "$mul")]
1163 Multiplication(Vec<RankExpr>),
1164 #[serde(rename = "$sub")]
1165 Subtraction {
1166 left: Box<RankExpr>,
1167 right: Box<RankExpr>,
1168 },
1169 #[serde(rename = "$sum")]
1170 Summation(Vec<RankExpr>),
1171 #[serde(rename = "$val")]
1172 Value(f32),
1173}
1174
1175impl RankExpr {
1176 pub fn default_knn_key() -> Key {
1177 Key::Embedding
1178 }
1179
1180 pub fn default_knn_limit() -> u32 {
1181 16
1182 }
1183
1184 pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185 match self {
1186 RankExpr::Absolute(expr)
1187 | RankExpr::Exponentiation(expr)
1188 | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190 .knn_queries()
1191 .into_iter()
1192 .chain(right.knn_queries())
1193 .collect(),
1194 RankExpr::Maximum(exprs)
1195 | RankExpr::Minimum(exprs)
1196 | RankExpr::Multiplication(exprs)
1197 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198 RankExpr::Value(_) => Vec::new(),
1199 RankExpr::Knn {
1200 query,
1201 key,
1202 limit,
1203 default: _,
1204 return_rank: _,
1205 } => vec![KnnQuery {
1206 query: query.clone(),
1207 key: key.clone(),
1208 limit: *limit,
1209 }],
1210 }
1211 }
1212
1213 pub fn exp(self) -> Self {
1233 RankExpr::Exponentiation(Box::new(self))
1234 }
1235
1236 pub fn log(self) -> Self {
1257 RankExpr::Logarithm(Box::new(self))
1258 }
1259
1260 pub fn abs(self) -> Self {
1287 RankExpr::Absolute(Box::new(self))
1288 }
1289
1290 pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314 let other = other.into();
1315
1316 match self {
1317 RankExpr::Maximum(mut exprs) => match other {
1318 RankExpr::Maximum(other_exprs) => {
1319 exprs.extend(other_exprs);
1320 RankExpr::Maximum(exprs)
1321 }
1322 _ => {
1323 exprs.push(other);
1324 RankExpr::Maximum(exprs)
1325 }
1326 },
1327 _ => match other {
1328 RankExpr::Maximum(mut exprs) => {
1329 exprs.insert(0, self);
1330 RankExpr::Maximum(exprs)
1331 }
1332 _ => RankExpr::Maximum(vec![self, other]),
1333 },
1334 }
1335 }
1336
1337 pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361 let other = other.into();
1362
1363 match self {
1364 RankExpr::Minimum(mut exprs) => match other {
1365 RankExpr::Minimum(other_exprs) => {
1366 exprs.extend(other_exprs);
1367 RankExpr::Minimum(exprs)
1368 }
1369 _ => {
1370 exprs.push(other);
1371 RankExpr::Minimum(exprs)
1372 }
1373 },
1374 _ => match other {
1375 RankExpr::Minimum(mut exprs) => {
1376 exprs.insert(0, self);
1377 RankExpr::Minimum(exprs)
1378 }
1379 _ => RankExpr::Minimum(vec![self, other]),
1380 },
1381 }
1382 }
1383}
1384
1385impl Add for RankExpr {
1386 type Output = RankExpr;
1387
1388 fn add(self, rhs: Self) -> Self::Output {
1389 match self {
1390 RankExpr::Summation(mut exprs) => match rhs {
1391 RankExpr::Summation(rhs_exprs) => {
1392 exprs.extend(rhs_exprs);
1393 RankExpr::Summation(exprs)
1394 }
1395 _ => {
1396 exprs.push(rhs);
1397 RankExpr::Summation(exprs)
1398 }
1399 },
1400 _ => match rhs {
1401 RankExpr::Summation(mut exprs) => {
1402 exprs.insert(0, self);
1403 RankExpr::Summation(exprs)
1404 }
1405 _ => RankExpr::Summation(vec![self, rhs]),
1406 },
1407 }
1408 }
1409}
1410
1411impl Add<f32> for RankExpr {
1412 type Output = RankExpr;
1413
1414 fn add(self, rhs: f32) -> Self::Output {
1415 self + RankExpr::Value(rhs)
1416 }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420 type Output = RankExpr;
1421
1422 fn add(self, rhs: RankExpr) -> Self::Output {
1423 RankExpr::Value(self) + rhs
1424 }
1425}
1426
1427impl Sub for RankExpr {
1428 type Output = RankExpr;
1429
1430 fn sub(self, rhs: Self) -> Self::Output {
1431 RankExpr::Subtraction {
1432 left: Box::new(self),
1433 right: Box::new(rhs),
1434 }
1435 }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439 type Output = RankExpr;
1440
1441 fn sub(self, rhs: f32) -> Self::Output {
1442 self - RankExpr::Value(rhs)
1443 }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447 type Output = RankExpr;
1448
1449 fn sub(self, rhs: RankExpr) -> Self::Output {
1450 RankExpr::Value(self) - rhs
1451 }
1452}
1453
1454impl Mul for RankExpr {
1455 type Output = RankExpr;
1456
1457 fn mul(self, rhs: Self) -> Self::Output {
1458 match self {
1459 RankExpr::Multiplication(mut exprs) => match rhs {
1460 RankExpr::Multiplication(rhs_exprs) => {
1461 exprs.extend(rhs_exprs);
1462 RankExpr::Multiplication(exprs)
1463 }
1464 _ => {
1465 exprs.push(rhs);
1466 RankExpr::Multiplication(exprs)
1467 }
1468 },
1469 _ => match rhs {
1470 RankExpr::Multiplication(mut exprs) => {
1471 exprs.insert(0, self);
1472 RankExpr::Multiplication(exprs)
1473 }
1474 _ => RankExpr::Multiplication(vec![self, rhs]),
1475 },
1476 }
1477 }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481 type Output = RankExpr;
1482
1483 fn mul(self, rhs: f32) -> Self::Output {
1484 self * RankExpr::Value(rhs)
1485 }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489 type Output = RankExpr;
1490
1491 fn mul(self, rhs: RankExpr) -> Self::Output {
1492 RankExpr::Value(self) * rhs
1493 }
1494}
1495
1496impl Div for RankExpr {
1497 type Output = RankExpr;
1498
1499 fn div(self, rhs: Self) -> Self::Output {
1500 RankExpr::Division {
1501 left: Box::new(self),
1502 right: Box::new(rhs),
1503 }
1504 }
1505}
1506
1507impl Div<f32> for RankExpr {
1508 type Output = RankExpr;
1509
1510 fn div(self, rhs: f32) -> Self::Output {
1511 self / RankExpr::Value(rhs)
1512 }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516 type Output = RankExpr;
1517
1518 fn div(self, rhs: RankExpr) -> Self::Output {
1519 RankExpr::Value(self) / rhs
1520 }
1521}
1522
1523impl Neg for RankExpr {
1524 type Output = RankExpr;
1525
1526 fn neg(self) -> Self::Output {
1527 RankExpr::Value(-1.0) * self
1528 }
1529}
1530
1531impl From<f32> for RankExpr {
1532 fn from(v: f32) -> Self {
1533 RankExpr::Value(v)
1534 }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538 type Error = QueryConversionError;
1539
1540 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541 match proto_expr.rank {
1542 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544 }
1545 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548 Ok(RankExpr::Division {
1549 left: Box::new(RankExpr::try_from(*left)?),
1550 right: Box::new(RankExpr::try_from(*right)?),
1551 })
1552 }
1553 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555 ),
1556 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557 let query = knn
1558 .query
1559 .ok_or(QueryConversionError::field("query"))?
1560 .try_into()?;
1561 Ok(RankExpr::Knn {
1562 query,
1563 key: Key::from(knn.key),
1564 limit: knn.limit,
1565 default: knn.default,
1566 return_rank: knn.return_rank,
1567 })
1568 }
1569 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571 }
1572 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573 let exprs = max
1574 .exprs
1575 .into_iter()
1576 .map(RankExpr::try_from)
1577 .collect::<Result<Vec<_>, _>>()?;
1578 Ok(RankExpr::Maximum(exprs))
1579 }
1580 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581 let exprs = min
1582 .exprs
1583 .into_iter()
1584 .map(RankExpr::try_from)
1585 .collect::<Result<Vec<_>, _>>()?;
1586 Ok(RankExpr::Minimum(exprs))
1587 }
1588 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589 let exprs = mul
1590 .exprs
1591 .into_iter()
1592 .map(RankExpr::try_from)
1593 .collect::<Result<Vec<_>, _>>()?;
1594 Ok(RankExpr::Multiplication(exprs))
1595 }
1596 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599 Ok(RankExpr::Subtraction {
1600 left: Box::new(RankExpr::try_from(*left)?),
1601 right: Box::new(RankExpr::try_from(*right)?),
1602 })
1603 }
1604 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605 let exprs = sum
1606 .exprs
1607 .into_iter()
1608 .map(RankExpr::try_from)
1609 .collect::<Result<Vec<_>, _>>()?;
1610 Ok(RankExpr::Summation(exprs))
1611 }
1612 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613 None => Err(QueryConversionError::field("rank")),
1614 }
1615 }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619 type Error = QueryConversionError;
1620
1621 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622 let proto_rank = match rank_expr {
1623 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624 chroma_proto::RankExpr::try_from(*expr)?,
1625 )),
1626 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627 Box::new(chroma_proto::rank_expr::RankPair {
1628 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630 }),
1631 ),
1632 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634 ),
1635 RankExpr::Knn {
1636 query,
1637 key,
1638 limit,
1639 default,
1640 return_rank,
1641 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642 query: Some(query.try_into()?),
1643 key: key.to_string(),
1644 limit,
1645 default,
1646 return_rank,
1647 }),
1648 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649 chroma_proto::RankExpr::try_from(*expr)?,
1650 )),
1651 RankExpr::Maximum(exprs) => {
1652 let proto_exprs = exprs
1653 .into_iter()
1654 .map(chroma_proto::RankExpr::try_from)
1655 .collect::<Result<Vec<_>, _>>()?;
1656 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657 exprs: proto_exprs,
1658 })
1659 }
1660 RankExpr::Minimum(exprs) => {
1661 let proto_exprs = exprs
1662 .into_iter()
1663 .map(chroma_proto::RankExpr::try_from)
1664 .collect::<Result<Vec<_>, _>>()?;
1665 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666 exprs: proto_exprs,
1667 })
1668 }
1669 RankExpr::Multiplication(exprs) => {
1670 let proto_exprs = exprs
1671 .into_iter()
1672 .map(chroma_proto::RankExpr::try_from)
1673 .collect::<Result<Vec<_>, _>>()?;
1674 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675 exprs: proto_exprs,
1676 })
1677 }
1678 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679 Box::new(chroma_proto::rank_expr::RankPair {
1680 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682 }),
1683 ),
1684 RankExpr::Summation(exprs) => {
1685 let proto_exprs = exprs
1686 .into_iter()
1687 .map(chroma_proto::RankExpr::try_from)
1688 .collect::<Result<Vec<_>, _>>()?;
1689 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690 exprs: proto_exprs,
1691 })
1692 }
1693 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694 };
1695
1696 Ok(chroma_proto::RankExpr {
1697 rank: Some(proto_rank),
1698 })
1699 }
1700}
1701
1702#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769 Document,
1771 Embedding,
1772 Metadata,
1773 Score,
1774 MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779 where
1780 S: serde::Serializer,
1781 {
1782 match self {
1783 Key::Document => serializer.serialize_str("#document"),
1784 Key::Embedding => serializer.serialize_str("#embedding"),
1785 Key::Metadata => serializer.serialize_str("#metadata"),
1786 Key::Score => serializer.serialize_str("#score"),
1787 Key::MetadataField(field) => serializer.serialize_str(field),
1788 }
1789 }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794 where
1795 D: Deserializer<'de>,
1796 {
1797 let s = String::deserialize(deserializer)?;
1798 Ok(Key::from(s))
1799 }
1800}
1801
1802impl fmt::Display for Key {
1803 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804 match self {
1805 Key::Document => write!(f, "#document"),
1806 Key::Embedding => write!(f, "#embedding"),
1807 Key::Metadata => write!(f, "#metadata"),
1808 Key::Score => write!(f, "#score"),
1809 Key::MetadataField(field) => write!(f, "{}", field),
1810 }
1811 }
1812}
1813
1814impl From<&str> for Key {
1815 fn from(s: &str) -> Self {
1816 match s {
1817 "#document" => Key::Document,
1818 "#embedding" => Key::Embedding,
1819 "#metadata" => Key::Metadata,
1820 "#score" => Key::Score,
1821 field => Key::MetadataField(field.to_string()),
1823 }
1824 }
1825}
1826
1827impl From<String> for Key {
1828 fn from(s: String) -> Self {
1829 Key::from(s.as_str())
1830 }
1831}
1832
1833impl Key {
1834 pub fn field(name: impl Into<String>) -> Self {
1846 Key::MetadataField(name.into())
1847 }
1848
1849 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866 Where::Metadata(MetadataExpression {
1867 key: self.to_string(),
1868 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869 })
1870 }
1871
1872 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883 Where::Metadata(MetadataExpression {
1884 key: self.to_string(),
1885 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886 })
1887 }
1888
1889 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900 Where::Metadata(MetadataExpression {
1901 key: self.to_string(),
1902 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903 })
1904 }
1905
1906 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917 Where::Metadata(MetadataExpression {
1918 key: self.to_string(),
1919 comparison: MetadataComparison::Primitive(
1920 PrimitiveOperator::GreaterThanOrEqual,
1921 value.into(),
1922 ),
1923 })
1924 }
1925
1926 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937 Where::Metadata(MetadataExpression {
1938 key: self.to_string(),
1939 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940 })
1941 }
1942
1943 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954 Where::Metadata(MetadataExpression {
1955 key: self.to_string(),
1956 comparison: MetadataComparison::Primitive(
1957 PrimitiveOperator::LessThanOrEqual,
1958 value.into(),
1959 ),
1960 })
1961 }
1962
1963 pub fn is_in<I, T>(self, values: I) -> Where
1983 where
1984 I: IntoIterator<Item = T>,
1985 Vec<T>: Into<MetadataSetValue>,
1986 {
1987 let vec: Vec<T> = values.into_iter().collect();
1988 Where::Metadata(MetadataExpression {
1989 key: self.to_string(),
1990 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991 })
1992 }
1993
1994 pub fn not_in<I, T>(self, values: I) -> Where
2010 where
2011 I: IntoIterator<Item = T>,
2012 Vec<T>: Into<MetadataSetValue>,
2013 {
2014 let vec: Vec<T> = values.into_iter().collect();
2015 Where::Metadata(MetadataExpression {
2016 key: self.to_string(),
2017 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018 })
2019 }
2020
2021 pub fn contains<S: Into<String>>(self, text: S) -> Where {
2035 Where::Document(DocumentExpression {
2036 operator: DocumentOperator::Contains,
2037 pattern: text.into(),
2038 })
2039 }
2040
2041 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2054 Where::Document(DocumentExpression {
2055 operator: DocumentOperator::NotContains,
2056 pattern: text.into(),
2057 })
2058 }
2059
2060 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2077 Where::Document(DocumentExpression {
2078 operator: DocumentOperator::Regex,
2079 pattern: pattern.into(),
2080 })
2081 }
2082
2083 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2099 Where::Document(DocumentExpression {
2100 operator: DocumentOperator::NotRegex,
2101 pattern: pattern.into(),
2102 })
2103 }
2104}
2105
2106#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2160pub struct Select {
2161 #[serde(default)]
2162 pub keys: HashSet<Key>,
2163}
2164
2165impl TryFrom<chroma_proto::SelectOperator> for Select {
2166 type Error = QueryConversionError;
2167
2168 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2169 let keys = value
2170 .keys
2171 .into_iter()
2172 .map(|key| {
2173 serde_json::from_value(serde_json::Value::String(key))
2175 .map_err(|_| QueryConversionError::field("keys"))
2176 })
2177 .collect::<Result<HashSet<_>, _>>()?;
2178
2179 Ok(Self { keys })
2180 }
2181}
2182
2183impl TryFrom<Select> for chroma_proto::SelectOperator {
2184 type Error = QueryConversionError;
2185
2186 fn try_from(value: Select) -> Result<Self, Self::Error> {
2187 let keys = value
2188 .keys
2189 .into_iter()
2190 .map(|key| {
2191 serde_json::to_value(&key)
2193 .ok()
2194 .and_then(|v| v.as_str().map(String::from))
2195 .ok_or(QueryConversionError::field("keys"))
2196 })
2197 .collect::<Result<Vec<_>, _>>()?;
2198
2199 Ok(Self { keys })
2200 }
2201}
2202
2203#[derive(Clone, Debug, Deserialize, Serialize)]
2242pub enum Aggregate {
2243 #[serde(rename = "$min_k")]
2245 MinK {
2246 keys: Vec<Key>,
2248 k: u32,
2250 },
2251 #[serde(rename = "$max_k")]
2253 MaxK {
2254 keys: Vec<Key>,
2256 k: u32,
2258 },
2259}
2260
2261#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2303pub struct GroupBy {
2304 #[serde(default)]
2306 pub keys: Vec<Key>,
2307 #[serde(default)]
2309 pub aggregate: Option<Aggregate>,
2310}
2311
2312impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2313 type Error = QueryConversionError;
2314
2315 fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2316 match value
2317 .aggregate
2318 .ok_or(QueryConversionError::field("aggregate"))?
2319 {
2320 chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2321 let keys = min_k.keys.into_iter().map(Key::from).collect();
2322 Ok(Aggregate::MinK { keys, k: min_k.k })
2323 }
2324 chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2325 let keys = max_k.keys.into_iter().map(Key::from).collect();
2326 Ok(Aggregate::MaxK { keys, k: max_k.k })
2327 }
2328 }
2329 }
2330}
2331
2332impl From<Aggregate> for chroma_proto::Aggregate {
2333 fn from(value: Aggregate) -> Self {
2334 let aggregate = match value {
2335 Aggregate::MinK { keys, k } => {
2336 chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2337 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2338 k,
2339 })
2340 }
2341 Aggregate::MaxK { keys, k } => {
2342 chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2343 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2344 k,
2345 })
2346 }
2347 };
2348
2349 chroma_proto::Aggregate {
2350 aggregate: Some(aggregate),
2351 }
2352 }
2353}
2354
2355impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2356 type Error = QueryConversionError;
2357
2358 fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2359 let keys = value.keys.into_iter().map(Key::from).collect();
2360 let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2361
2362 Ok(Self { keys, aggregate })
2363 }
2364}
2365
2366impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2367 type Error = QueryConversionError;
2368
2369 fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2370 let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2371 let aggregate = value.aggregate.map(Into::into);
2372
2373 Ok(Self { keys, aggregate })
2374 }
2375}
2376
2377#[derive(Clone, Debug, Deserialize, Serialize)]
2414#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2415pub struct SearchRecord {
2416 pub id: String,
2417 pub document: Option<String>,
2418 pub embedding: Option<Vec<f32>>,
2419 pub metadata: Option<Metadata>,
2420 pub score: Option<f32>,
2421}
2422
2423impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2424 type Error = QueryConversionError;
2425
2426 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2427 Ok(Self {
2428 id: value.id,
2429 document: value.document,
2430 embedding: value
2431 .embedding
2432 .map(|vec| vec.try_into().map(|(v, _)| v))
2433 .transpose()?,
2434 metadata: value.metadata.map(TryInto::try_into).transpose()?,
2435 score: value.score,
2436 })
2437 }
2438}
2439
2440impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2441 type Error = QueryConversionError;
2442
2443 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2444 Ok(Self {
2445 id: value.id,
2446 document: value.document,
2447 embedding: value
2448 .embedding
2449 .map(|embedding| {
2450 let embedding_dimension = embedding.len();
2451 chroma_proto::Vector::try_from((
2452 embedding,
2453 ScalarEncoding::FLOAT32,
2454 embedding_dimension,
2455 ))
2456 })
2457 .transpose()?,
2458 metadata: value.metadata.map(Into::into),
2459 score: value.score,
2460 })
2461 }
2462}
2463
2464#[derive(Clone, Debug, Default)]
2486pub struct SearchPayloadResult {
2487 pub records: Vec<SearchRecord>,
2488}
2489
2490impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2491 type Error = QueryConversionError;
2492
2493 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2494 Ok(Self {
2495 records: value
2496 .records
2497 .into_iter()
2498 .map(TryInto::try_into)
2499 .collect::<Result<_, _>>()?,
2500 })
2501 }
2502}
2503
2504impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2505 type Error = QueryConversionError;
2506
2507 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2508 Ok(Self {
2509 records: value
2510 .records
2511 .into_iter()
2512 .map(TryInto::try_into)
2513 .collect::<Result<Vec<_>, _>>()?,
2514 })
2515 }
2516}
2517
2518#[derive(Clone, Debug)]
2561pub struct SearchResult {
2562 pub results: Vec<SearchPayloadResult>,
2563 pub pulled_log_bytes: u64,
2564}
2565
2566impl SearchResult {
2567 pub fn size_bytes(&self) -> u64 {
2568 self.results
2569 .iter()
2570 .flat_map(|result| {
2571 result.records.iter().map(|record| {
2572 (record.id.len()
2573 + record
2574 .document
2575 .as_ref()
2576 .map(|doc| doc.len())
2577 .unwrap_or_default()
2578 + record
2579 .embedding
2580 .as_ref()
2581 .map(|emb| size_of_val(&emb[..]))
2582 .unwrap_or_default()
2583 + record
2584 .metadata
2585 .as_ref()
2586 .map(logical_size_of_metadata)
2587 .unwrap_or_default()
2588 + record.score.as_ref().map(size_of_val).unwrap_or_default())
2589 as u64
2590 })
2591 })
2592 .sum()
2593 }
2594}
2595
2596impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2597 type Error = QueryConversionError;
2598
2599 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2600 Ok(Self {
2601 results: value
2602 .results
2603 .into_iter()
2604 .map(TryInto::try_into)
2605 .collect::<Result<_, _>>()?,
2606 pulled_log_bytes: value.pulled_log_bytes,
2607 })
2608 }
2609}
2610
2611impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2612 type Error = QueryConversionError;
2613
2614 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2615 Ok(Self {
2616 results: value
2617 .results
2618 .into_iter()
2619 .map(TryInto::try_into)
2620 .collect::<Result<Vec<_>, _>>()?,
2621 pulled_log_bytes: value.pulled_log_bytes,
2622 })
2623 }
2624}
2625
2626pub fn rrf(
2771 ranks: Vec<RankExpr>,
2772 k: Option<u32>,
2773 weights: Option<Vec<f32>>,
2774 normalize: bool,
2775) -> Result<RankExpr, QueryConversionError> {
2776 let k = k.unwrap_or(60);
2777
2778 if ranks.is_empty() {
2779 return Err(QueryConversionError::validation(
2780 "RRF requires at least one rank expression",
2781 ));
2782 }
2783
2784 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2785
2786 if weights.len() != ranks.len() {
2787 return Err(QueryConversionError::validation(format!(
2788 "RRF weights length ({}) must match ranks length ({})",
2789 weights.len(),
2790 ranks.len()
2791 )));
2792 }
2793
2794 let weights = if normalize {
2795 let sum: f32 = weights.iter().sum();
2796 if sum == 0.0 {
2797 return Err(QueryConversionError::validation(
2798 "RRF weights sum to zero, cannot normalize",
2799 ));
2800 }
2801 weights.into_iter().map(|w| w / sum).collect()
2802 } else {
2803 weights
2804 };
2805
2806 let terms: Vec<RankExpr> = weights
2807 .into_iter()
2808 .zip(ranks)
2809 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2810 .collect();
2811
2812 let sum = terms
2815 .into_iter()
2816 .reduce(|a, b| a + b)
2817 .unwrap_or(RankExpr::Value(0.0));
2818 Ok(-sum)
2819}
2820
2821#[cfg(test)]
2822mod tests {
2823 use super::*;
2824
2825 #[test]
2826 fn test_key_from_string() {
2827 assert_eq!(Key::from("#document"), Key::Document);
2829 assert_eq!(Key::from("#embedding"), Key::Embedding);
2830 assert_eq!(Key::from("#metadata"), Key::Metadata);
2831 assert_eq!(Key::from("#score"), Key::Score);
2832
2833 assert_eq!(
2835 Key::from("custom_field"),
2836 Key::MetadataField("custom_field".to_string())
2837 );
2838 assert_eq!(
2839 Key::from("author"),
2840 Key::MetadataField("author".to_string())
2841 );
2842
2843 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2845 assert_eq!(
2846 Key::from("year".to_string()),
2847 Key::MetadataField("year".to_string())
2848 );
2849 }
2850
2851 #[test]
2852 fn test_query_vector_dense_proto_conversion() {
2853 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2854 let query_vector = QueryVector::Dense(dense_vec.clone());
2855
2856 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2858
2859 let converted: QueryVector = proto.try_into().unwrap();
2861
2862 assert_eq!(converted, query_vector);
2863 if let QueryVector::Dense(v) = converted {
2864 assert_eq!(v, dense_vec);
2865 } else {
2866 panic!("Expected dense vector");
2867 }
2868 }
2869
2870 #[test]
2871 fn test_query_vector_sparse_proto_conversion() {
2872 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2873 let query_vector = QueryVector::Sparse(sparse.clone());
2874
2875 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2877
2878 let converted: QueryVector = proto.try_into().unwrap();
2880
2881 assert_eq!(converted, query_vector);
2882 if let QueryVector::Sparse(s) = converted {
2883 assert_eq!(s, sparse);
2884 } else {
2885 panic!("Expected sparse vector");
2886 }
2887 }
2888
2889 #[test]
2890 fn test_filter_json_deserialization() {
2891 let simple_where = r#"{"author": "John Doe"}"#;
2895 let filter: Filter = serde_json::from_str(simple_where).unwrap();
2896 assert_eq!(filter.query_ids, None);
2897 assert!(filter.where_clause.is_some());
2898
2899 let id_filter_json = serde_json::json!({
2901 "#id": {
2902 "$in": ["doc1", "doc2", "doc3"]
2903 }
2904 });
2905 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2906 assert_eq!(filter.query_ids, None);
2907 assert!(filter.where_clause.is_some());
2908
2909 let complex_json = serde_json::json!({
2911 "$and": [
2912 {
2913 "#id": {
2914 "$in": ["doc1", "doc2", "doc3"]
2915 }
2916 },
2917 {
2918 "$or": [
2919 {
2920 "author": {
2921 "$eq": "John Doe"
2922 }
2923 },
2924 {
2925 "author": {
2926 "$eq": "Jane Smith"
2927 }
2928 }
2929 ]
2930 },
2931 {
2932 "year": {
2933 "$gte": 2020
2934 }
2935 },
2936 {
2937 "tags": {
2938 "$contains": "machine-learning"
2939 }
2940 }
2941 ]
2942 });
2943
2944 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2945 assert_eq!(filter.query_ids, None);
2946 assert!(filter.where_clause.is_some());
2947
2948 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2950 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2951 assert_eq!(composite.children.len(), 4);
2952
2953 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2955 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2956 assert_eq!(or_composite.children.len(), 2);
2957 } else {
2958 panic!("Expected OR composite in second child");
2959 }
2960 } else {
2961 panic!("Expected AND composite where clause");
2962 }
2963
2964 let mixed_operators_json = serde_json::json!({
2966 "$and": [
2967 {
2968 "status": {
2969 "$ne": "deleted"
2970 }
2971 },
2972 {
2973 "score": {
2974 "$gt": 0.5
2975 }
2976 },
2977 {
2978 "score": {
2979 "$lt": 0.9
2980 }
2981 },
2982 {
2983 "priority": {
2984 "$lte": 10
2985 }
2986 }
2987 ]
2988 });
2989
2990 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2991 assert_eq!(filter.query_ids, None);
2992 assert!(filter.where_clause.is_some());
2993
2994 let deeply_nested_json = serde_json::json!({
2996 "$or": [
2997 {
2998 "$and": [
2999 {
3000 "#id": {
3001 "$in": ["id1", "id2"]
3002 }
3003 },
3004 {
3005 "$or": [
3006 {
3007 "category": "tech"
3008 },
3009 {
3010 "category": "science"
3011 }
3012 ]
3013 }
3014 ]
3015 },
3016 {
3017 "$and": [
3018 {
3019 "author": "Admin"
3020 },
3021 {
3022 "published": true
3023 }
3024 ]
3025 }
3026 ]
3027 });
3028
3029 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3030 assert_eq!(filter.query_ids, None);
3031 assert!(filter.where_clause.is_some());
3032
3033 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3035 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3036 assert_eq!(composite.children.len(), 2);
3037
3038 for child in &composite.children {
3040 if let crate::metadata::Where::Composite(and_composite) = child {
3041 assert_eq!(
3042 and_composite.operator,
3043 crate::metadata::BooleanOperator::And
3044 );
3045 } else {
3046 panic!("Expected AND composite in OR children");
3047 }
3048 }
3049 } else {
3050 panic!("Expected OR composite at top level");
3051 }
3052
3053 let single_id_json = serde_json::json!({
3055 "#id": {
3056 "$eq": "single-doc-id"
3057 }
3058 });
3059
3060 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3061 assert_eq!(filter.query_ids, None);
3062 assert!(filter.where_clause.is_some());
3063
3064 let empty_json = serde_json::json!({});
3066 let filter: Filter = serde_json::from_value(empty_json).unwrap();
3067 assert_eq!(filter.query_ids, None);
3068 assert_eq!(filter.where_clause, None);
3070
3071 let advanced_json = serde_json::json!({
3073 "$and": [
3074 {
3075 "#id": {
3076 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3077 }
3078 },
3079 {
3080 "tags": {
3081 "$not_contains": "deprecated"
3082 }
3083 },
3084 {
3085 "$or": [
3086 {
3087 "$and": [
3088 {
3089 "confidence": {
3090 "$gte": 0.8
3091 }
3092 },
3093 {
3094 "verified": true
3095 }
3096 ]
3097 },
3098 {
3099 "$and": [
3100 {
3101 "confidence": {
3102 "$gte": 0.6
3103 }
3104 },
3105 {
3106 "confidence": {
3107 "$lt": 0.8
3108 }
3109 },
3110 {
3111 "reviews": {
3112 "$gte": 5
3113 }
3114 }
3115 ]
3116 }
3117 ]
3118 }
3119 ]
3120 });
3121
3122 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3123 assert_eq!(filter.query_ids, None);
3124 assert!(filter.where_clause.is_some());
3125
3126 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3128 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3129 assert_eq!(composite.children.len(), 3);
3130 } else {
3131 panic!("Expected AND composite at top level");
3132 }
3133 }
3134
3135 #[test]
3136 fn test_limit_json_serialization() {
3137 let limit = Limit {
3138 offset: 10,
3139 limit: Some(20),
3140 };
3141
3142 let json = serde_json::to_string(&limit).unwrap();
3143 let deserialized: Limit = serde_json::from_str(&json).unwrap();
3144
3145 assert_eq!(deserialized.offset, limit.offset);
3146 assert_eq!(deserialized.limit, limit.limit);
3147 }
3148
3149 #[test]
3150 fn test_query_vector_json_serialization() {
3151 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3153 let json = serde_json::to_string(&dense).unwrap();
3154 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3155 assert_eq!(deserialized, dense);
3156
3157 let sparse =
3159 QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3160 let json = serde_json::to_string(&sparse).unwrap();
3161 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3162 assert_eq!(deserialized, sparse);
3163 }
3164
3165 #[test]
3166 fn test_select_key_json_serialization() {
3167 use std::collections::HashSet;
3168
3169 let doc_key = Key::Document;
3171 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3172
3173 let embed_key = Key::Embedding;
3174 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3175
3176 let meta_key = Key::Metadata;
3177 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3178
3179 let score_key = Key::Score;
3180 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3181
3182 let custom_key = Key::MetadataField("custom_key".to_string());
3184 assert_eq!(
3185 serde_json::to_string(&custom_key).unwrap(),
3186 "\"custom_key\""
3187 );
3188
3189 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3191 assert!(matches!(deserialized, Key::Document));
3192
3193 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3194 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3195
3196 let mut keys = HashSet::new();
3198 keys.insert(Key::Document);
3199 keys.insert(Key::Embedding);
3200 keys.insert(Key::MetadataField("author".to_string()));
3201
3202 let select = Select { keys };
3203 let json = serde_json::to_string(&select).unwrap();
3204 let deserialized: Select = serde_json::from_str(&json).unwrap();
3205
3206 assert_eq!(deserialized.keys.len(), 3);
3207 assert!(deserialized.keys.contains(&Key::Document));
3208 assert!(deserialized.keys.contains(&Key::Embedding));
3209 assert!(deserialized
3210 .keys
3211 .contains(&Key::MetadataField("author".to_string())));
3212 }
3213
3214 #[test]
3215 fn test_merge_basic_integers() {
3216 use std::cmp::Reverse;
3217
3218 let merge = Merge { k: 5 };
3219
3220 let input = vec![
3222 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3223 vec![Reverse(2), Reverse(5), Reverse(8)],
3224 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3225 ];
3226
3227 let result = merge.merge(input);
3228
3229 assert_eq!(result.len(), 5);
3231 assert_eq!(
3232 result,
3233 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3234 );
3235 }
3236
3237 #[test]
3238 fn test_merge_u32_descending() {
3239 let merge = Merge { k: 6 };
3240
3241 let input = vec![
3243 vec![100u32, 75, 50, 25],
3244 vec![90, 60, 30],
3245 vec![95, 85, 70, 40, 10],
3246 ];
3247
3248 let result = merge.merge(input);
3249
3250 assert_eq!(result.len(), 6);
3252 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3253 }
3254
3255 #[test]
3256 fn test_merge_i32_descending() {
3257 let merge = Merge { k: 5 };
3258
3259 let input = vec![
3261 vec![50i32, 10, -10, -50],
3262 vec![30, 0, -30],
3263 vec![40, 20, -20, -40],
3264 ];
3265
3266 let result = merge.merge(input);
3267
3268 assert_eq!(result.len(), 5);
3270 assert_eq!(result, vec![50, 40, 30, 20, 10]);
3271 }
3272
3273 #[test]
3274 fn test_merge_with_duplicates() {
3275 let merge = Merge { k: 10 };
3276
3277 let input = vec![
3279 vec![100u32, 80, 80, 60, 40],
3280 vec![90, 80, 50, 30],
3281 vec![100, 70, 60, 20],
3282 ];
3283
3284 let result = merge.merge(input);
3285
3286 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3288 }
3289
3290 #[test]
3291 fn test_merge_empty_vectors() {
3292 let merge = Merge { k: 5 };
3293
3294 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3296 let result = merge.merge(input);
3297 assert_eq!(result.len(), 0);
3298
3299 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3301 let result = merge.merge(input);
3302 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3303
3304 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3306 let result = merge.merge(input);
3307 assert_eq!(result, vec![100, 50, 25]);
3308 }
3309
3310 #[test]
3311 fn test_merge_k_boundary_conditions() {
3312 let merge = Merge { k: 0 };
3314 let input = vec![vec![100u32, 50], vec![75, 25]];
3315 let result = merge.merge(input);
3316 assert_eq!(result.len(), 0);
3317
3318 let merge = Merge { k: 1 };
3320 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3321 let result = merge.merge(input);
3322 assert_eq!(result, vec![1000]);
3323
3324 let merge = Merge { k: 100 };
3326 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3327 let result = merge.merge(input);
3328 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3329 }
3330
3331 #[test]
3332 fn test_merge_with_strings() {
3333 let merge = Merge { k: 4 };
3334
3335 let input = vec![
3337 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3338 vec!["elephant".to_string(), "banana".to_string()],
3339 vec!["fish".to_string(), "cat".to_string()],
3340 ];
3341
3342 let result = merge.merge(input);
3343
3344 assert_eq!(
3346 result,
3347 vec![
3348 "zebra".to_string(),
3349 "fish".to_string(),
3350 "elephant".to_string(),
3351 "dog".to_string()
3352 ]
3353 );
3354 }
3355
3356 #[test]
3357 fn test_merge_with_custom_struct() {
3358 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3359 struct Score {
3360 value: i32,
3361 id: String,
3362 }
3363
3364 let merge = Merge { k: 3 };
3365
3366 let input = vec![
3368 vec![
3369 Score {
3370 value: 100,
3371 id: "a".to_string(),
3372 },
3373 Score {
3374 value: 80,
3375 id: "b".to_string(),
3376 },
3377 Score {
3378 value: 60,
3379 id: "c".to_string(),
3380 },
3381 ],
3382 vec![
3383 Score {
3384 value: 90,
3385 id: "d".to_string(),
3386 },
3387 Score {
3388 value: 70,
3389 id: "e".to_string(),
3390 },
3391 ],
3392 vec![
3393 Score {
3394 value: 95,
3395 id: "f".to_string(),
3396 },
3397 Score {
3398 value: 85,
3399 id: "g".to_string(),
3400 },
3401 ],
3402 ];
3403
3404 let result = merge.merge(input);
3405
3406 assert_eq!(result.len(), 3);
3407 assert_eq!(
3408 result[0],
3409 Score {
3410 value: 100,
3411 id: "a".to_string()
3412 }
3413 );
3414 assert_eq!(
3415 result[1],
3416 Score {
3417 value: 95,
3418 id: "f".to_string()
3419 }
3420 );
3421 assert_eq!(
3422 result[2],
3423 Score {
3424 value: 90,
3425 id: "d".to_string()
3426 }
3427 );
3428 }
3429
3430 #[test]
3431 fn test_merge_preserves_order() {
3432 use std::cmp::Reverse;
3433
3434 let merge = Merge { k: 10 };
3435
3436 let input = vec![
3439 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3440 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3441 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3442 ];
3443
3444 let result = merge.merge(input);
3445
3446 for i in 1..result.len() {
3449 assert!(
3450 result[i - 1] >= result[i],
3451 "Output should be in descending Reverse order"
3452 );
3453 assert!(
3454 result[i - 1].0 <= result[i].0,
3455 "Inner values should be in ascending order"
3456 );
3457 }
3458
3459 assert_eq!(
3461 result,
3462 vec![
3463 Reverse(1),
3464 Reverse(2),
3465 Reverse(3),
3466 Reverse(4),
3467 Reverse(5),
3468 Reverse(6),
3469 Reverse(7),
3470 Reverse(8),
3471 Reverse(9),
3472 Reverse(10)
3473 ]
3474 );
3475 }
3476
3477 #[test]
3478 fn test_merge_single_vector() {
3479 let merge = Merge { k: 3 };
3480
3481 let input = vec![vec![1000u64, 800, 600, 400, 200]];
3483
3484 let result = merge.merge(input);
3485
3486 assert_eq!(result, vec![1000, 800, 600]);
3487 }
3488
3489 #[test]
3490 fn test_merge_all_same_values() {
3491 let merge = Merge { k: 5 };
3492
3493 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3495
3496 let result = merge.merge(input);
3497
3498 assert_eq!(result, vec![42]);
3500 }
3501
3502 #[test]
3503 fn test_merge_mixed_types_sizes() {
3504 let merge = Merge { k: 4 };
3506 let input = vec![
3507 vec![1000usize, 500, 100],
3508 vec![800, 300],
3509 vec![900, 600, 200],
3510 ];
3511 let result = merge.merge(input);
3512 assert_eq!(result, vec![1000, 900, 800, 600]);
3513
3514 let merge = Merge { k: 5 };
3516 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3517 let result = merge.merge(input);
3518 assert_eq!(result, vec![15, 10, 5, 0, -5]);
3519 }
3520
3521 #[test]
3522 fn test_aggregate_json_serialization() {
3523 let min_k = Aggregate::MinK {
3525 keys: vec![Key::Score, Key::field("date")],
3526 k: 3,
3527 };
3528 let json = serde_json::to_value(&min_k).unwrap();
3529 assert!(json.get("$min_k").is_some());
3530 assert_eq!(json["$min_k"]["k"], 3);
3531
3532 let min_k_json = serde_json::json!({
3534 "$min_k": {
3535 "keys": ["#score", "date"],
3536 "k": 5
3537 }
3538 });
3539 let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3540 match deserialized {
3541 Aggregate::MinK { keys, k } => {
3542 assert_eq!(k, 5);
3543 assert_eq!(keys.len(), 2);
3544 assert_eq!(keys[0], Key::Score);
3545 assert_eq!(keys[1], Key::field("date"));
3546 }
3547 _ => panic!("Expected MinK"),
3548 }
3549
3550 let max_k = Aggregate::MaxK {
3552 keys: vec![Key::field("timestamp")],
3553 k: 10,
3554 };
3555 let json = serde_json::to_value(&max_k).unwrap();
3556 assert!(json.get("$max_k").is_some());
3557 assert_eq!(json["$max_k"]["k"], 10);
3558
3559 let max_k_json = serde_json::json!({
3561 "$max_k": {
3562 "keys": ["timestamp"],
3563 "k": 2
3564 }
3565 });
3566 let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3567 match deserialized {
3568 Aggregate::MaxK { keys, k } => {
3569 assert_eq!(k, 2);
3570 assert_eq!(keys.len(), 1);
3571 assert_eq!(keys[0], Key::field("timestamp"));
3572 }
3573 _ => panic!("Expected MaxK"),
3574 }
3575 }
3576
3577 #[test]
3578 fn test_group_by_json_serialization() {
3579 let group_by = GroupBy {
3581 keys: vec![Key::field("category"), Key::field("author")],
3582 aggregate: Some(Aggregate::MinK {
3583 keys: vec![Key::Score],
3584 k: 3,
3585 }),
3586 };
3587
3588 let json = serde_json::to_value(&group_by).unwrap();
3589 assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3590 assert!(json["aggregate"]["$min_k"].is_object());
3591
3592 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3594 assert_eq!(deserialized.keys.len(), 2);
3595 assert_eq!(deserialized.keys[0], Key::field("category"));
3596 assert_eq!(deserialized.keys[1], Key::field("author"));
3597 assert!(deserialized.aggregate.is_some());
3598
3599 let empty_group_by = GroupBy::default();
3601 let json = serde_json::to_value(&empty_group_by).unwrap();
3602 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3603 assert!(deserialized.keys.is_empty());
3604 assert!(deserialized.aggregate.is_none());
3605
3606 let json = serde_json::json!({
3608 "keys": ["category"],
3609 "aggregate": {
3610 "$max_k": {
3611 "keys": ["#score", "priority"],
3612 "k": 5
3613 }
3614 }
3615 });
3616 let group_by: GroupBy = serde_json::from_value(json).unwrap();
3617 assert_eq!(group_by.keys.len(), 1);
3618 assert_eq!(group_by.keys[0], Key::field("category"));
3619 match group_by.aggregate {
3620 Some(Aggregate::MaxK { keys, k }) => {
3621 assert_eq!(k, 5);
3622 assert_eq!(keys.len(), 2);
3623 assert_eq!(keys[0], Key::Score);
3624 }
3625 _ => panic!("Expected MaxK aggregate"),
3626 }
3627 }
3628}