1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::{Hash, Hasher},
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 ContainsOperator, DocumentExpression, DocumentOperator, Metadata, MetadataComparison,
15 MetadataExpression, MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding,
16 SetOperator, SparseVector, Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
170pub struct Filter {
171 pub query_ids: Option<Vec<String>>,
172 pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 match (&self.query_ids, &self.where_clause) {
184 (None, None) => {
185 let map = serializer.serialize_map(Some(0))?;
187 map.end()
188 }
189 (None, Some(where_clause)) => {
190 where_clause.serialize(serializer)
192 }
193 (Some(ids), None) => {
194 let id_where = Where::Metadata(MetadataExpression {
196 key: "#id".to_string(),
197 comparison: MetadataComparison::Set(
198 SetOperator::In,
199 MetadataSetValue::Str(ids.clone()),
200 ),
201 });
202 id_where.serialize(serializer)
203 }
204 (Some(ids), Some(where_clause)) => {
205 let id_where = Where::Metadata(MetadataExpression {
207 key: "#id".to_string(),
208 comparison: MetadataComparison::Set(
209 SetOperator::In,
210 MetadataSetValue::Str(ids.clone()),
211 ),
212 });
213 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214 combined.serialize(serializer)
215 }
216 }
217 }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let where_json = Value::deserialize(deserializer)?;
227 let where_clause =
228 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229 None
230 } else {
231 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232 };
233
234 Ok(Filter {
235 query_ids: None, where_clause,
237 })
238 }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242 type Error = QueryConversionError;
243
244 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247 let where_clause = match (where_metadata, where_document) {
248 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249 (Some(w), None) | (None, Some(w)) => Some(w),
250 _ => None,
251 };
252
253 Ok(Self {
254 query_ids: value.ids.map(|uids| uids.ids),
255 where_clause,
256 })
257 }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261 type Error = QueryConversionError;
262
263 fn try_from(value: Filter) -> Result<Self, Self::Error> {
264 Ok(Self {
265 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267 where_document: None,
268 })
269 }
270}
271
272#[derive(Clone, Debug)]
278pub struct Knn {
279 pub embedding: Vec<f32>,
280 pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284 fn from(value: KnnBatch) -> Self {
285 value
286 .embeddings
287 .into_iter()
288 .map(|embedding| Knn {
289 embedding,
290 fetch: value.fetch,
291 })
292 .collect()
293 }
294}
295
296#[derive(Clone, Debug)]
302pub struct KnnBatch {
303 pub embeddings: Vec<Vec<f32>>,
304 pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308 type Error = QueryConversionError;
309
310 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311 Ok(Self {
312 embeddings: value
313 .embeddings
314 .into_iter()
315 .map(|vec| vec.try_into().map(|(v, _)| v))
316 .collect::<Result<_, _>>()?,
317 fetch: value.fetch,
318 })
319 }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323 type Error = QueryConversionError;
324
325 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326 Ok(Self {
327 embeddings: value
328 .embeddings
329 .into_iter()
330 .map(|embedding| {
331 let dim = embedding.len();
332 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333 })
334 .collect::<Result<_, _>>()?,
335 fetch: value.fetch,
336 })
337 }
338}
339
340#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374 #[serde(default)]
375 pub offset: u32,
376 #[serde(default)]
377 pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381 fn from(value: chroma_proto::LimitOperator) -> Self {
382 Self {
383 offset: value.offset,
384 limit: value.limit,
385 }
386 }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390 fn from(value: Limit) -> Self {
391 Self {
392 offset: value.offset,
393 limit: value.limit,
394 }
395 }
396}
397
398#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401 pub offset_id: u32,
402 pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406 fn eq(&self, other: &Self) -> bool {
407 self.offset_id.eq(&other.offset_id)
408 }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Hash for RecordMeasure {
414 fn hash<H: Hasher>(&self, state: &mut H) {
415 self.offset_id.hash(state);
416 }
417}
418
419impl Ord for RecordMeasure {
420 fn cmp(&self, other: &Self) -> Ordering {
421 self.measure
422 .total_cmp(&other.measure)
423 .then_with(|| self.offset_id.cmp(&other.offset_id))
424 }
425}
426
427impl PartialOrd for RecordMeasure {
428 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
429 Some(self.cmp(other))
430 }
431}
432
433#[derive(Debug, Default)]
434pub struct KnnOutput {
435 pub distances: Vec<RecordMeasure>,
436}
437
438#[derive(Clone, Debug)]
448pub struct Merge {
449 pub k: u32,
450}
451
452impl Merge {
453 pub fn merge<M: Clone + Eq + Hash + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
454 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
455
456 let mut max_heap = batch_iters
457 .iter_mut()
458 .enumerate()
459 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
460 .collect::<BinaryHeap<_>>();
461
462 let mut seen = HashSet::with_capacity(self.k as usize);
463 let mut fusion = Vec::with_capacity(self.k as usize);
464 while let Some((m, idx)) = max_heap.pop() {
465 if self.k <= fusion.len() as u32 {
466 break;
467 }
468 if let Some(next_m) = batch_iters[idx].next() {
469 max_heap.push((next_m, idx));
470 }
471 if !seen.insert(m.clone()) {
472 continue;
473 }
474 fusion.push(m);
475 }
476 fusion
477 }
478}
479
480#[derive(Clone, Debug, Default)]
487pub struct Projection {
488 pub document: bool,
489 pub embedding: bool,
490 pub metadata: bool,
491}
492
493impl From<chroma_proto::ProjectionOperator> for Projection {
494 fn from(value: chroma_proto::ProjectionOperator) -> Self {
495 Self {
496 document: value.document,
497 embedding: value.embedding,
498 metadata: value.metadata,
499 }
500 }
501}
502
503impl From<Projection> for chroma_proto::ProjectionOperator {
504 fn from(value: Projection) -> Self {
505 Self {
506 document: value.document,
507 embedding: value.embedding,
508 metadata: value.metadata,
509 }
510 }
511}
512
513#[derive(Clone, Debug, PartialEq)]
514pub struct ProjectionRecord {
515 pub id: String,
516 pub document: Option<String>,
517 pub embedding: Option<Vec<f32>>,
518 pub metadata: Option<Metadata>,
519}
520
521impl ProjectionRecord {
522 pub fn size_bytes(&self) -> u64 {
523 (self.id.len()
524 + self
525 .document
526 .as_ref()
527 .map(|doc| doc.len())
528 .unwrap_or_default()
529 + self
530 .embedding
531 .as_ref()
532 .map(|emb| size_of_val(&emb[..]))
533 .unwrap_or_default()
534 + self
535 .metadata
536 .as_ref()
537 .map(logical_size_of_metadata)
538 .unwrap_or_default()) as u64
539 }
540}
541
542impl Eq for ProjectionRecord {}
543
544impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
545 type Error = QueryConversionError;
546
547 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
548 Ok(Self {
549 id: value.id,
550 document: value.document,
551 embedding: value
552 .embedding
553 .map(|vec| vec.try_into().map(|(v, _)| v))
554 .transpose()?,
555 metadata: value.metadata.map(TryInto::try_into).transpose()?,
556 })
557 }
558}
559
560impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
561 type Error = QueryConversionError;
562
563 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
564 Ok(Self {
565 id: value.id,
566 document: value.document,
567 embedding: value
568 .embedding
569 .map(|embedding| {
570 let embedding_dimension = embedding.len();
571 chroma_proto::Vector::try_from((
572 embedding,
573 ScalarEncoding::FLOAT32,
574 embedding_dimension,
575 ))
576 })
577 .transpose()?,
578 metadata: value.metadata.map(|metadata| metadata.into()),
579 })
580 }
581}
582
583#[derive(Clone, Debug, Eq, PartialEq)]
584pub struct ProjectionOutput {
585 pub records: Vec<ProjectionRecord>,
586}
587
588#[derive(Clone, Debug, Eq, PartialEq)]
589pub struct GetResult {
590 pub pulled_log_bytes: u64,
591 pub result: ProjectionOutput,
592}
593
594impl GetResult {
595 pub fn size_bytes(&self) -> u64 {
596 self.result
597 .records
598 .iter()
599 .map(ProjectionRecord::size_bytes)
600 .sum()
601 }
602}
603
604impl TryFrom<chroma_proto::GetResult> for GetResult {
605 type Error = QueryConversionError;
606
607 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
608 Ok(Self {
609 pulled_log_bytes: value.pulled_log_bytes,
610 result: ProjectionOutput {
611 records: value
612 .records
613 .into_iter()
614 .map(TryInto::try_into)
615 .collect::<Result<_, _>>()?,
616 },
617 })
618 }
619}
620
621impl TryFrom<GetResult> for chroma_proto::GetResult {
622 type Error = QueryConversionError;
623
624 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
625 Ok(Self {
626 pulled_log_bytes: value.pulled_log_bytes,
627 records: value
628 .result
629 .records
630 .into_iter()
631 .map(TryInto::try_into)
632 .collect::<Result<_, _>>()?,
633 })
634 }
635}
636
637#[derive(Clone, Debug)]
645pub struct KnnProjection {
646 pub projection: Projection,
647 pub distance: bool,
648}
649
650impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
651 type Error = QueryConversionError;
652
653 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
654 Ok(Self {
655 projection: value
656 .projection
657 .ok_or(QueryConversionError::field("projection"))?
658 .into(),
659 distance: value.distance,
660 })
661 }
662}
663
664impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
665 fn from(value: KnnProjection) -> Self {
666 Self {
667 projection: Some(value.projection.into()),
668 distance: value.distance,
669 }
670 }
671}
672
673#[derive(Clone, Debug)]
674pub struct KnnProjectionRecord {
675 pub record: ProjectionRecord,
676 pub distance: Option<f32>,
677}
678
679impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
680 type Error = QueryConversionError;
681
682 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
683 Ok(Self {
684 record: value
685 .record
686 .ok_or(QueryConversionError::field("record"))?
687 .try_into()?,
688 distance: value.distance,
689 })
690 }
691}
692
693impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
694 type Error = QueryConversionError;
695
696 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
697 Ok(Self {
698 record: Some(value.record.try_into()?),
699 distance: value.distance,
700 })
701 }
702}
703
704#[derive(Clone, Debug, Default)]
705pub struct KnnProjectionOutput {
706 pub records: Vec<KnnProjectionRecord>,
707}
708
709impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
710 type Error = QueryConversionError;
711
712 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
713 Ok(Self {
714 records: value
715 .records
716 .into_iter()
717 .map(TryInto::try_into)
718 .collect::<Result<_, _>>()?,
719 })
720 }
721}
722
723impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
724 type Error = QueryConversionError;
725
726 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
727 Ok(Self {
728 records: value
729 .records
730 .into_iter()
731 .map(TryInto::try_into)
732 .collect::<Result<_, _>>()?,
733 })
734 }
735}
736
737#[derive(Clone, Debug, Default)]
738pub struct KnnBatchResult {
739 pub pulled_log_bytes: u64,
740 pub results: Vec<KnnProjectionOutput>,
741}
742
743impl KnnBatchResult {
744 pub fn size_bytes(&self) -> u64 {
745 self.results
746 .iter()
747 .flat_map(|res| {
748 res.records
749 .iter()
750 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
751 })
752 .sum()
753 }
754}
755
756impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
757 type Error = QueryConversionError;
758
759 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
760 Ok(Self {
761 pulled_log_bytes: value.pulled_log_bytes,
762 results: value
763 .results
764 .into_iter()
765 .map(TryInto::try_into)
766 .collect::<Result<_, _>>()?,
767 })
768 }
769}
770
771impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
772 type Error = QueryConversionError;
773
774 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
775 Ok(Self {
776 pulled_log_bytes: value.pulled_log_bytes,
777 results: value
778 .results
779 .into_iter()
780 .map(TryInto::try_into)
781 .collect::<Result<_, _>>()?,
782 })
783 }
784}
785
786#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
850#[serde(untagged)]
851pub enum QueryVector {
852 Dense(Vec<f32>),
853 Sparse(SparseVector),
854}
855
856impl TryFrom<chroma_proto::QueryVector> for QueryVector {
857 type Error = QueryConversionError;
858
859 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
860 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
861 match vector {
862 chroma_proto::query_vector::Vector::Dense(dense) => {
863 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
864 }
865 chroma_proto::query_vector::Vector::Sparse(sparse) => {
866 Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
867 QueryConversionError::validation("sparse vector length mismatch")
868 })?))
869 }
870 }
871 }
872}
873
874impl TryFrom<QueryVector> for chroma_proto::QueryVector {
875 type Error = QueryConversionError;
876
877 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
878 match value {
879 QueryVector::Dense(vec) => {
880 let dim = vec.len();
881 Ok(chroma_proto::QueryVector {
882 vector: Some(chroma_proto::query_vector::Vector::Dense(
883 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
884 )),
885 })
886 }
887 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
888 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
889 }),
890 }
891 }
892}
893
894impl From<Vec<f32>> for QueryVector {
895 fn from(vec: Vec<f32>) -> Self {
896 QueryVector::Dense(vec)
897 }
898}
899
900impl From<SparseVector> for QueryVector {
901 fn from(sparse: SparseVector) -> Self {
902 QueryVector::Sparse(sparse)
903 }
904}
905
906#[derive(Clone, Debug, PartialEq)]
907pub struct KnnQuery {
908 pub query: QueryVector,
909 pub key: Key,
910 pub limit: u32,
911}
912
913#[derive(Clone, Debug, Default, Deserialize, Serialize)]
944#[serde(transparent)]
945pub struct Rank {
946 pub expr: Option<RankExpr>,
947}
948
949impl Rank {
950 pub fn knn_queries(&self) -> Vec<KnnQuery> {
951 self.expr
952 .as_ref()
953 .map(RankExpr::knn_queries)
954 .unwrap_or_default()
955 }
956}
957
958impl TryFrom<chroma_proto::RankOperator> for Rank {
959 type Error = QueryConversionError;
960
961 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
962 Ok(Rank {
963 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
964 })
965 }
966}
967
968impl TryFrom<Rank> for chroma_proto::RankOperator {
969 type Error = QueryConversionError;
970
971 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
972 Ok(chroma_proto::RankOperator {
973 expr: rank.expr.map(TryInto::try_into).transpose()?,
974 })
975 }
976}
977
978#[derive(Clone, Debug, Deserialize, Serialize)]
1141pub enum RankExpr {
1142 #[serde(rename = "$abs")]
1143 Absolute(Box<RankExpr>),
1144 #[serde(rename = "$div")]
1145 Division {
1146 left: Box<RankExpr>,
1147 right: Box<RankExpr>,
1148 },
1149 #[serde(rename = "$exp")]
1150 Exponentiation(Box<RankExpr>),
1151 #[serde(rename = "$knn")]
1152 Knn {
1153 query: QueryVector,
1154 #[serde(default = "RankExpr::default_knn_key")]
1155 key: Key,
1156 #[serde(default = "RankExpr::default_knn_limit")]
1157 limit: u32,
1158 #[serde(default)]
1159 default: Option<f32>,
1160 #[serde(default)]
1161 return_rank: bool,
1162 },
1163 #[serde(rename = "$log")]
1164 Logarithm(Box<RankExpr>),
1165 #[serde(rename = "$max")]
1166 Maximum(Vec<RankExpr>),
1167 #[serde(rename = "$min")]
1168 Minimum(Vec<RankExpr>),
1169 #[serde(rename = "$mul")]
1170 Multiplication(Vec<RankExpr>),
1171 #[serde(rename = "$sub")]
1172 Subtraction {
1173 left: Box<RankExpr>,
1174 right: Box<RankExpr>,
1175 },
1176 #[serde(rename = "$sum")]
1177 Summation(Vec<RankExpr>),
1178 #[serde(rename = "$val")]
1179 Value(f32),
1180}
1181
1182impl RankExpr {
1183 pub fn default_knn_key() -> Key {
1184 Key::Embedding
1185 }
1186
1187 pub fn default_knn_limit() -> u32 {
1188 16
1189 }
1190
1191 pub fn knn_queries(&self) -> Vec<KnnQuery> {
1192 match self {
1193 RankExpr::Absolute(expr)
1194 | RankExpr::Exponentiation(expr)
1195 | RankExpr::Logarithm(expr) => expr.knn_queries(),
1196 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1197 .knn_queries()
1198 .into_iter()
1199 .chain(right.knn_queries())
1200 .collect(),
1201 RankExpr::Maximum(exprs)
1202 | RankExpr::Minimum(exprs)
1203 | RankExpr::Multiplication(exprs)
1204 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1205 RankExpr::Value(_) => Vec::new(),
1206 RankExpr::Knn {
1207 query,
1208 key,
1209 limit,
1210 default: _,
1211 return_rank: _,
1212 } => vec![KnnQuery {
1213 query: query.clone(),
1214 key: key.clone(),
1215 limit: *limit,
1216 }],
1217 }
1218 }
1219
1220 pub fn exp(self) -> Self {
1240 RankExpr::Exponentiation(Box::new(self))
1241 }
1242
1243 pub fn log(self) -> Self {
1264 RankExpr::Logarithm(Box::new(self))
1265 }
1266
1267 pub fn abs(self) -> Self {
1294 RankExpr::Absolute(Box::new(self))
1295 }
1296
1297 pub fn max(self, other: impl Into<RankExpr>) -> Self {
1321 let other = other.into();
1322
1323 match self {
1324 RankExpr::Maximum(mut exprs) => match other {
1325 RankExpr::Maximum(other_exprs) => {
1326 exprs.extend(other_exprs);
1327 RankExpr::Maximum(exprs)
1328 }
1329 _ => {
1330 exprs.push(other);
1331 RankExpr::Maximum(exprs)
1332 }
1333 },
1334 _ => match other {
1335 RankExpr::Maximum(mut exprs) => {
1336 exprs.insert(0, self);
1337 RankExpr::Maximum(exprs)
1338 }
1339 _ => RankExpr::Maximum(vec![self, other]),
1340 },
1341 }
1342 }
1343
1344 pub fn min(self, other: impl Into<RankExpr>) -> Self {
1368 let other = other.into();
1369
1370 match self {
1371 RankExpr::Minimum(mut exprs) => match other {
1372 RankExpr::Minimum(other_exprs) => {
1373 exprs.extend(other_exprs);
1374 RankExpr::Minimum(exprs)
1375 }
1376 _ => {
1377 exprs.push(other);
1378 RankExpr::Minimum(exprs)
1379 }
1380 },
1381 _ => match other {
1382 RankExpr::Minimum(mut exprs) => {
1383 exprs.insert(0, self);
1384 RankExpr::Minimum(exprs)
1385 }
1386 _ => RankExpr::Minimum(vec![self, other]),
1387 },
1388 }
1389 }
1390}
1391
1392impl Add for RankExpr {
1393 type Output = RankExpr;
1394
1395 fn add(self, rhs: Self) -> Self::Output {
1396 match self {
1397 RankExpr::Summation(mut exprs) => match rhs {
1398 RankExpr::Summation(rhs_exprs) => {
1399 exprs.extend(rhs_exprs);
1400 RankExpr::Summation(exprs)
1401 }
1402 _ => {
1403 exprs.push(rhs);
1404 RankExpr::Summation(exprs)
1405 }
1406 },
1407 _ => match rhs {
1408 RankExpr::Summation(mut exprs) => {
1409 exprs.insert(0, self);
1410 RankExpr::Summation(exprs)
1411 }
1412 _ => RankExpr::Summation(vec![self, rhs]),
1413 },
1414 }
1415 }
1416}
1417
1418impl Add<f32> for RankExpr {
1419 type Output = RankExpr;
1420
1421 fn add(self, rhs: f32) -> Self::Output {
1422 self + RankExpr::Value(rhs)
1423 }
1424}
1425
1426impl Add<RankExpr> for f32 {
1427 type Output = RankExpr;
1428
1429 fn add(self, rhs: RankExpr) -> Self::Output {
1430 RankExpr::Value(self) + rhs
1431 }
1432}
1433
1434impl Sub for RankExpr {
1435 type Output = RankExpr;
1436
1437 fn sub(self, rhs: Self) -> Self::Output {
1438 RankExpr::Subtraction {
1439 left: Box::new(self),
1440 right: Box::new(rhs),
1441 }
1442 }
1443}
1444
1445impl Sub<f32> for RankExpr {
1446 type Output = RankExpr;
1447
1448 fn sub(self, rhs: f32) -> Self::Output {
1449 self - RankExpr::Value(rhs)
1450 }
1451}
1452
1453impl Sub<RankExpr> for f32 {
1454 type Output = RankExpr;
1455
1456 fn sub(self, rhs: RankExpr) -> Self::Output {
1457 RankExpr::Value(self) - rhs
1458 }
1459}
1460
1461impl Mul for RankExpr {
1462 type Output = RankExpr;
1463
1464 fn mul(self, rhs: Self) -> Self::Output {
1465 match self {
1466 RankExpr::Multiplication(mut exprs) => match rhs {
1467 RankExpr::Multiplication(rhs_exprs) => {
1468 exprs.extend(rhs_exprs);
1469 RankExpr::Multiplication(exprs)
1470 }
1471 _ => {
1472 exprs.push(rhs);
1473 RankExpr::Multiplication(exprs)
1474 }
1475 },
1476 _ => match rhs {
1477 RankExpr::Multiplication(mut exprs) => {
1478 exprs.insert(0, self);
1479 RankExpr::Multiplication(exprs)
1480 }
1481 _ => RankExpr::Multiplication(vec![self, rhs]),
1482 },
1483 }
1484 }
1485}
1486
1487impl Mul<f32> for RankExpr {
1488 type Output = RankExpr;
1489
1490 fn mul(self, rhs: f32) -> Self::Output {
1491 self * RankExpr::Value(rhs)
1492 }
1493}
1494
1495impl Mul<RankExpr> for f32 {
1496 type Output = RankExpr;
1497
1498 fn mul(self, rhs: RankExpr) -> Self::Output {
1499 RankExpr::Value(self) * rhs
1500 }
1501}
1502
1503impl Div for RankExpr {
1504 type Output = RankExpr;
1505
1506 fn div(self, rhs: Self) -> Self::Output {
1507 RankExpr::Division {
1508 left: Box::new(self),
1509 right: Box::new(rhs),
1510 }
1511 }
1512}
1513
1514impl Div<f32> for RankExpr {
1515 type Output = RankExpr;
1516
1517 fn div(self, rhs: f32) -> Self::Output {
1518 self / RankExpr::Value(rhs)
1519 }
1520}
1521
1522impl Div<RankExpr> for f32 {
1523 type Output = RankExpr;
1524
1525 fn div(self, rhs: RankExpr) -> Self::Output {
1526 RankExpr::Value(self) / rhs
1527 }
1528}
1529
1530impl Neg for RankExpr {
1531 type Output = RankExpr;
1532
1533 fn neg(self) -> Self::Output {
1534 RankExpr::Value(-1.0) * self
1535 }
1536}
1537
1538impl From<f32> for RankExpr {
1539 fn from(v: f32) -> Self {
1540 RankExpr::Value(v)
1541 }
1542}
1543
1544impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1545 type Error = QueryConversionError;
1546
1547 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1548 match proto_expr.rank {
1549 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1550 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1551 }
1552 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1553 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1554 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1555 Ok(RankExpr::Division {
1556 left: Box::new(RankExpr::try_from(*left)?),
1557 right: Box::new(RankExpr::try_from(*right)?),
1558 })
1559 }
1560 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1561 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1562 ),
1563 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1564 let query = knn
1565 .query
1566 .ok_or(QueryConversionError::field("query"))?
1567 .try_into()?;
1568 Ok(RankExpr::Knn {
1569 query,
1570 key: Key::from(knn.key),
1571 limit: knn.limit,
1572 default: knn.default,
1573 return_rank: knn.return_rank,
1574 })
1575 }
1576 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1577 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1578 }
1579 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1580 let exprs = max
1581 .exprs
1582 .into_iter()
1583 .map(RankExpr::try_from)
1584 .collect::<Result<Vec<_>, _>>()?;
1585 Ok(RankExpr::Maximum(exprs))
1586 }
1587 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1588 let exprs = min
1589 .exprs
1590 .into_iter()
1591 .map(RankExpr::try_from)
1592 .collect::<Result<Vec<_>, _>>()?;
1593 Ok(RankExpr::Minimum(exprs))
1594 }
1595 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1596 let exprs = mul
1597 .exprs
1598 .into_iter()
1599 .map(RankExpr::try_from)
1600 .collect::<Result<Vec<_>, _>>()?;
1601 Ok(RankExpr::Multiplication(exprs))
1602 }
1603 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1604 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1605 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1606 Ok(RankExpr::Subtraction {
1607 left: Box::new(RankExpr::try_from(*left)?),
1608 right: Box::new(RankExpr::try_from(*right)?),
1609 })
1610 }
1611 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1612 let exprs = sum
1613 .exprs
1614 .into_iter()
1615 .map(RankExpr::try_from)
1616 .collect::<Result<Vec<_>, _>>()?;
1617 Ok(RankExpr::Summation(exprs))
1618 }
1619 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1620 None => Err(QueryConversionError::field("rank")),
1621 }
1622 }
1623}
1624
1625impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1626 type Error = QueryConversionError;
1627
1628 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1629 let proto_rank = match rank_expr {
1630 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1631 chroma_proto::RankExpr::try_from(*expr)?,
1632 )),
1633 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1634 Box::new(chroma_proto::rank_expr::RankPair {
1635 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1636 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1637 }),
1638 ),
1639 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1640 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1641 ),
1642 RankExpr::Knn {
1643 query,
1644 key,
1645 limit,
1646 default,
1647 return_rank,
1648 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1649 query: Some(query.try_into()?),
1650 key: key.to_string(),
1651 limit,
1652 default,
1653 return_rank,
1654 }),
1655 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1656 chroma_proto::RankExpr::try_from(*expr)?,
1657 )),
1658 RankExpr::Maximum(exprs) => {
1659 let proto_exprs = exprs
1660 .into_iter()
1661 .map(chroma_proto::RankExpr::try_from)
1662 .collect::<Result<Vec<_>, _>>()?;
1663 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1664 exprs: proto_exprs,
1665 })
1666 }
1667 RankExpr::Minimum(exprs) => {
1668 let proto_exprs = exprs
1669 .into_iter()
1670 .map(chroma_proto::RankExpr::try_from)
1671 .collect::<Result<Vec<_>, _>>()?;
1672 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1673 exprs: proto_exprs,
1674 })
1675 }
1676 RankExpr::Multiplication(exprs) => {
1677 let proto_exprs = exprs
1678 .into_iter()
1679 .map(chroma_proto::RankExpr::try_from)
1680 .collect::<Result<Vec<_>, _>>()?;
1681 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1682 exprs: proto_exprs,
1683 })
1684 }
1685 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1686 Box::new(chroma_proto::rank_expr::RankPair {
1687 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1688 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1689 }),
1690 ),
1691 RankExpr::Summation(exprs) => {
1692 let proto_exprs = exprs
1693 .into_iter()
1694 .map(chroma_proto::RankExpr::try_from)
1695 .collect::<Result<Vec<_>, _>>()?;
1696 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1697 exprs: proto_exprs,
1698 })
1699 }
1700 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1701 };
1702
1703 Ok(chroma_proto::RankExpr {
1704 rank: Some(proto_rank),
1705 })
1706 }
1707}
1708
1709#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1774#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1775pub enum Key {
1776 Document,
1778 Embedding,
1779 Metadata,
1780 Score,
1781 MetadataField(String),
1782}
1783
1784impl Serialize for Key {
1785 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1786 where
1787 S: serde::Serializer,
1788 {
1789 match self {
1790 Key::Document => serializer.serialize_str("#document"),
1791 Key::Embedding => serializer.serialize_str("#embedding"),
1792 Key::Metadata => serializer.serialize_str("#metadata"),
1793 Key::Score => serializer.serialize_str("#score"),
1794 Key::MetadataField(field) => serializer.serialize_str(field),
1795 }
1796 }
1797}
1798
1799impl<'de> Deserialize<'de> for Key {
1800 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1801 where
1802 D: Deserializer<'de>,
1803 {
1804 let s = String::deserialize(deserializer)?;
1805 Ok(Key::from(s))
1806 }
1807}
1808
1809impl fmt::Display for Key {
1810 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1811 match self {
1812 Key::Document => write!(f, "#document"),
1813 Key::Embedding => write!(f, "#embedding"),
1814 Key::Metadata => write!(f, "#metadata"),
1815 Key::Score => write!(f, "#score"),
1816 Key::MetadataField(field) => write!(f, "{}", field),
1817 }
1818 }
1819}
1820
1821impl From<&str> for Key {
1822 fn from(s: &str) -> Self {
1823 match s {
1824 "#document" => Key::Document,
1825 "#embedding" => Key::Embedding,
1826 "#metadata" => Key::Metadata,
1827 "#score" => Key::Score,
1828 field => Key::MetadataField(field.to_string()),
1830 }
1831 }
1832}
1833
1834impl From<String> for Key {
1835 fn from(s: String) -> Self {
1836 Key::from(s.as_str())
1837 }
1838}
1839
1840impl Key {
1841 pub fn field(name: impl Into<String>) -> Self {
1853 Key::MetadataField(name.into())
1854 }
1855
1856 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1873 Where::Metadata(MetadataExpression {
1874 key: self.to_string(),
1875 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1876 })
1877 }
1878
1879 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1890 Where::Metadata(MetadataExpression {
1891 key: self.to_string(),
1892 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1893 })
1894 }
1895
1896 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1907 Where::Metadata(MetadataExpression {
1908 key: self.to_string(),
1909 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1910 })
1911 }
1912
1913 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1924 Where::Metadata(MetadataExpression {
1925 key: self.to_string(),
1926 comparison: MetadataComparison::Primitive(
1927 PrimitiveOperator::GreaterThanOrEqual,
1928 value.into(),
1929 ),
1930 })
1931 }
1932
1933 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1944 Where::Metadata(MetadataExpression {
1945 key: self.to_string(),
1946 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1947 })
1948 }
1949
1950 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1961 Where::Metadata(MetadataExpression {
1962 key: self.to_string(),
1963 comparison: MetadataComparison::Primitive(
1964 PrimitiveOperator::LessThanOrEqual,
1965 value.into(),
1966 ),
1967 })
1968 }
1969
1970 pub fn is_in<I, T>(self, values: I) -> Where
1990 where
1991 I: IntoIterator<Item = T>,
1992 Vec<T>: Into<MetadataSetValue>,
1993 {
1994 let vec: Vec<T> = values.into_iter().collect();
1995 Where::Metadata(MetadataExpression {
1996 key: self.to_string(),
1997 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1998 })
1999 }
2000
2001 pub fn not_in<I, T>(self, values: I) -> Where
2017 where
2018 I: IntoIterator<Item = T>,
2019 Vec<T>: Into<MetadataSetValue>,
2020 {
2021 let vec: Vec<T> = values.into_iter().collect();
2022 Where::Metadata(MetadataExpression {
2023 key: self.to_string(),
2024 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2025 })
2026 }
2027
2028 pub fn contains<S: Into<String>>(self, text: S) -> Where {
2044 Where::Document(DocumentExpression {
2045 operator: DocumentOperator::Contains,
2046 pattern: text.into(),
2047 })
2048 }
2049
2050 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2066 Where::Document(DocumentExpression {
2067 operator: DocumentOperator::NotContains,
2068 pattern: text.into(),
2069 })
2070 }
2071
2072 pub fn contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2085 Where::Metadata(MetadataExpression {
2086 key: self.to_string(),
2087 comparison: MetadataComparison::ArrayContains(ContainsOperator::Contains, value.into()),
2088 })
2089 }
2090
2091 pub fn not_contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2103 Where::Metadata(MetadataExpression {
2104 key: self.to_string(),
2105 comparison: MetadataComparison::ArrayContains(
2106 ContainsOperator::NotContains,
2107 value.into(),
2108 ),
2109 })
2110 }
2111
2112 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2129 Where::Document(DocumentExpression {
2130 operator: DocumentOperator::Regex,
2131 pattern: pattern.into(),
2132 })
2133 }
2134
2135 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2151 Where::Document(DocumentExpression {
2152 operator: DocumentOperator::NotRegex,
2153 pattern: pattern.into(),
2154 })
2155 }
2156}
2157
2158#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2212pub struct Select {
2213 #[serde(default)]
2214 pub keys: HashSet<Key>,
2215}
2216
2217impl TryFrom<chroma_proto::SelectOperator> for Select {
2218 type Error = QueryConversionError;
2219
2220 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2221 let keys = value
2222 .keys
2223 .into_iter()
2224 .map(|key| {
2225 serde_json::from_value(serde_json::Value::String(key))
2227 .map_err(|_| QueryConversionError::field("keys"))
2228 })
2229 .collect::<Result<HashSet<_>, _>>()?;
2230
2231 Ok(Self { keys })
2232 }
2233}
2234
2235impl TryFrom<Select> for chroma_proto::SelectOperator {
2236 type Error = QueryConversionError;
2237
2238 fn try_from(value: Select) -> Result<Self, Self::Error> {
2239 let keys = value
2240 .keys
2241 .into_iter()
2242 .map(|key| {
2243 serde_json::to_value(&key)
2245 .ok()
2246 .and_then(|v| v.as_str().map(String::from))
2247 .ok_or(QueryConversionError::field("keys"))
2248 })
2249 .collect::<Result<Vec<_>, _>>()?;
2250
2251 Ok(Self { keys })
2252 }
2253}
2254
2255#[derive(Clone, Debug, Deserialize, Serialize)]
2294pub enum Aggregate {
2295 #[serde(rename = "$min_k")]
2297 MinK {
2298 keys: Vec<Key>,
2300 k: u32,
2302 },
2303 #[serde(rename = "$max_k")]
2305 MaxK {
2306 keys: Vec<Key>,
2308 k: u32,
2310 },
2311}
2312
2313#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2355pub struct GroupBy {
2356 #[serde(default)]
2358 pub keys: Vec<Key>,
2359 #[serde(default)]
2361 pub aggregate: Option<Aggregate>,
2362}
2363
2364impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2365 type Error = QueryConversionError;
2366
2367 fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2368 match value
2369 .aggregate
2370 .ok_or(QueryConversionError::field("aggregate"))?
2371 {
2372 chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2373 let keys = min_k.keys.into_iter().map(Key::from).collect();
2374 Ok(Aggregate::MinK { keys, k: min_k.k })
2375 }
2376 chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2377 let keys = max_k.keys.into_iter().map(Key::from).collect();
2378 Ok(Aggregate::MaxK { keys, k: max_k.k })
2379 }
2380 }
2381 }
2382}
2383
2384impl From<Aggregate> for chroma_proto::Aggregate {
2385 fn from(value: Aggregate) -> Self {
2386 let aggregate = match value {
2387 Aggregate::MinK { keys, k } => {
2388 chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2389 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2390 k,
2391 })
2392 }
2393 Aggregate::MaxK { keys, k } => {
2394 chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2395 keys: keys.into_iter().map(|k| k.to_string()).collect(),
2396 k,
2397 })
2398 }
2399 };
2400
2401 chroma_proto::Aggregate {
2402 aggregate: Some(aggregate),
2403 }
2404 }
2405}
2406
2407impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2408 type Error = QueryConversionError;
2409
2410 fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2411 let keys = value.keys.into_iter().map(Key::from).collect();
2412 let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2413
2414 Ok(Self { keys, aggregate })
2415 }
2416}
2417
2418impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2419 type Error = QueryConversionError;
2420
2421 fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2422 let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2423 let aggregate = value.aggregate.map(Into::into);
2424
2425 Ok(Self { keys, aggregate })
2426 }
2427}
2428
2429#[derive(Clone, Debug, Deserialize, Serialize)]
2466#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2467pub struct SearchRecord {
2468 pub id: String,
2469 pub document: Option<String>,
2470 pub embedding: Option<Vec<f32>>,
2471 pub metadata: Option<Metadata>,
2472 pub score: Option<f32>,
2473}
2474
2475impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2476 type Error = QueryConversionError;
2477
2478 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2479 Ok(Self {
2480 id: value.id,
2481 document: value.document,
2482 embedding: value
2483 .embedding
2484 .map(|vec| vec.try_into().map(|(v, _)| v))
2485 .transpose()?,
2486 metadata: value.metadata.map(TryInto::try_into).transpose()?,
2487 score: value.score,
2488 })
2489 }
2490}
2491
2492impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2493 type Error = QueryConversionError;
2494
2495 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2496 Ok(Self {
2497 id: value.id,
2498 document: value.document,
2499 embedding: value
2500 .embedding
2501 .map(|embedding| {
2502 let embedding_dimension = embedding.len();
2503 chroma_proto::Vector::try_from((
2504 embedding,
2505 ScalarEncoding::FLOAT32,
2506 embedding_dimension,
2507 ))
2508 })
2509 .transpose()?,
2510 metadata: value.metadata.map(Into::into),
2511 score: value.score,
2512 })
2513 }
2514}
2515
2516#[derive(Clone, Debug, Default)]
2538pub struct SearchPayloadResult {
2539 pub records: Vec<SearchRecord>,
2540}
2541
2542impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2543 type Error = QueryConversionError;
2544
2545 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2546 Ok(Self {
2547 records: value
2548 .records
2549 .into_iter()
2550 .map(TryInto::try_into)
2551 .collect::<Result<_, _>>()?,
2552 })
2553 }
2554}
2555
2556impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2557 type Error = QueryConversionError;
2558
2559 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2560 Ok(Self {
2561 records: value
2562 .records
2563 .into_iter()
2564 .map(TryInto::try_into)
2565 .collect::<Result<Vec<_>, _>>()?,
2566 })
2567 }
2568}
2569
2570#[derive(Clone, Debug)]
2613pub struct SearchResult {
2614 pub results: Vec<SearchPayloadResult>,
2615 pub pulled_log_bytes: u64,
2616}
2617
2618impl SearchResult {
2619 pub fn size_bytes(&self) -> u64 {
2620 self.results
2621 .iter()
2622 .flat_map(|result| {
2623 result.records.iter().map(|record| {
2624 (record.id.len()
2625 + record
2626 .document
2627 .as_ref()
2628 .map(|doc| doc.len())
2629 .unwrap_or_default()
2630 + record
2631 .embedding
2632 .as_ref()
2633 .map(|emb| size_of_val(&emb[..]))
2634 .unwrap_or_default()
2635 + record
2636 .metadata
2637 .as_ref()
2638 .map(logical_size_of_metadata)
2639 .unwrap_or_default()
2640 + record.score.as_ref().map(size_of_val).unwrap_or_default())
2641 as u64
2642 })
2643 })
2644 .sum()
2645 }
2646}
2647
2648impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2649 type Error = QueryConversionError;
2650
2651 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2652 Ok(Self {
2653 results: value
2654 .results
2655 .into_iter()
2656 .map(TryInto::try_into)
2657 .collect::<Result<_, _>>()?,
2658 pulled_log_bytes: value.pulled_log_bytes,
2659 })
2660 }
2661}
2662
2663impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2664 type Error = QueryConversionError;
2665
2666 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2667 Ok(Self {
2668 results: value
2669 .results
2670 .into_iter()
2671 .map(TryInto::try_into)
2672 .collect::<Result<Vec<_>, _>>()?,
2673 pulled_log_bytes: value.pulled_log_bytes,
2674 })
2675 }
2676}
2677
2678pub fn rrf(
2823 ranks: Vec<RankExpr>,
2824 k: Option<u32>,
2825 weights: Option<Vec<f32>>,
2826 normalize: bool,
2827) -> Result<RankExpr, QueryConversionError> {
2828 let k = k.unwrap_or(60);
2829
2830 if ranks.is_empty() {
2831 return Err(QueryConversionError::validation(
2832 "RRF requires at least one rank expression",
2833 ));
2834 }
2835
2836 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2837
2838 if weights.len() != ranks.len() {
2839 return Err(QueryConversionError::validation(format!(
2840 "RRF weights length ({}) must match ranks length ({})",
2841 weights.len(),
2842 ranks.len()
2843 )));
2844 }
2845
2846 let weights = if normalize {
2847 let sum: f32 = weights.iter().sum();
2848 if sum == 0.0 {
2849 return Err(QueryConversionError::validation(
2850 "RRF weights sum to zero, cannot normalize",
2851 ));
2852 }
2853 weights.into_iter().map(|w| w / sum).collect()
2854 } else {
2855 weights
2856 };
2857
2858 let terms: Vec<RankExpr> = weights
2859 .into_iter()
2860 .zip(ranks)
2861 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2862 .collect();
2863
2864 let sum = terms
2867 .into_iter()
2868 .reduce(|a, b| a + b)
2869 .unwrap_or(RankExpr::Value(0.0));
2870 Ok(-sum)
2871}
2872
2873#[cfg(test)]
2874mod tests {
2875 use super::*;
2876
2877 #[test]
2878 fn test_key_from_string() {
2879 assert_eq!(Key::from("#document"), Key::Document);
2881 assert_eq!(Key::from("#embedding"), Key::Embedding);
2882 assert_eq!(Key::from("#metadata"), Key::Metadata);
2883 assert_eq!(Key::from("#score"), Key::Score);
2884
2885 assert_eq!(
2887 Key::from("custom_field"),
2888 Key::MetadataField("custom_field".to_string())
2889 );
2890 assert_eq!(
2891 Key::from("author"),
2892 Key::MetadataField("author".to_string())
2893 );
2894
2895 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2897 assert_eq!(
2898 Key::from("year".to_string()),
2899 Key::MetadataField("year".to_string())
2900 );
2901 }
2902
2903 #[test]
2904 fn test_query_vector_dense_proto_conversion() {
2905 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2906 let query_vector = QueryVector::Dense(dense_vec.clone());
2907
2908 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2910
2911 let converted: QueryVector = proto.try_into().unwrap();
2913
2914 assert_eq!(converted, query_vector);
2915 if let QueryVector::Dense(v) = converted {
2916 assert_eq!(v, dense_vec);
2917 } else {
2918 panic!("Expected dense vector");
2919 }
2920 }
2921
2922 #[test]
2923 fn test_query_vector_sparse_proto_conversion() {
2924 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2925 let query_vector = QueryVector::Sparse(sparse.clone());
2926
2927 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2929
2930 let converted: QueryVector = proto.try_into().unwrap();
2932
2933 assert_eq!(converted, query_vector);
2934 if let QueryVector::Sparse(s) = converted {
2935 assert_eq!(s, sparse);
2936 } else {
2937 panic!("Expected sparse vector");
2938 }
2939 }
2940
2941 #[test]
2942 fn test_filter_json_deserialization() {
2943 let simple_where = r#"{"author": "John Doe"}"#;
2947 let filter: Filter = serde_json::from_str(simple_where).unwrap();
2948 assert_eq!(filter.query_ids, None);
2949 assert!(filter.where_clause.is_some());
2950
2951 let id_filter_json = serde_json::json!({
2953 "#id": {
2954 "$in": ["doc1", "doc2", "doc3"]
2955 }
2956 });
2957 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2958 assert_eq!(filter.query_ids, None);
2959 assert!(filter.where_clause.is_some());
2960
2961 let complex_json = serde_json::json!({
2963 "$and": [
2964 {
2965 "#id": {
2966 "$in": ["doc1", "doc2", "doc3"]
2967 }
2968 },
2969 {
2970 "$or": [
2971 {
2972 "author": {
2973 "$eq": "John Doe"
2974 }
2975 },
2976 {
2977 "author": {
2978 "$eq": "Jane Smith"
2979 }
2980 }
2981 ]
2982 },
2983 {
2984 "year": {
2985 "$gte": 2020
2986 }
2987 },
2988 {
2989 "tags": {
2990 "$contains": "machine-learning"
2991 }
2992 }
2993 ]
2994 });
2995
2996 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2997 assert_eq!(filter.query_ids, None);
2998 assert!(filter.where_clause.is_some());
2999
3000 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3002 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3003 assert_eq!(composite.children.len(), 4);
3004
3005 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
3007 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
3008 assert_eq!(or_composite.children.len(), 2);
3009 } else {
3010 panic!("Expected OR composite in second child");
3011 }
3012 } else {
3013 panic!("Expected AND composite where clause");
3014 }
3015
3016 let mixed_operators_json = serde_json::json!({
3018 "$and": [
3019 {
3020 "status": {
3021 "$ne": "deleted"
3022 }
3023 },
3024 {
3025 "score": {
3026 "$gt": 0.5
3027 }
3028 },
3029 {
3030 "score": {
3031 "$lt": 0.9
3032 }
3033 },
3034 {
3035 "priority": {
3036 "$lte": 10
3037 }
3038 }
3039 ]
3040 });
3041
3042 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
3043 assert_eq!(filter.query_ids, None);
3044 assert!(filter.where_clause.is_some());
3045
3046 let deeply_nested_json = serde_json::json!({
3048 "$or": [
3049 {
3050 "$and": [
3051 {
3052 "#id": {
3053 "$in": ["id1", "id2"]
3054 }
3055 },
3056 {
3057 "$or": [
3058 {
3059 "category": "tech"
3060 },
3061 {
3062 "category": "science"
3063 }
3064 ]
3065 }
3066 ]
3067 },
3068 {
3069 "$and": [
3070 {
3071 "author": "Admin"
3072 },
3073 {
3074 "published": true
3075 }
3076 ]
3077 }
3078 ]
3079 });
3080
3081 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3082 assert_eq!(filter.query_ids, None);
3083 assert!(filter.where_clause.is_some());
3084
3085 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3087 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3088 assert_eq!(composite.children.len(), 2);
3089
3090 for child in &composite.children {
3092 if let crate::metadata::Where::Composite(and_composite) = child {
3093 assert_eq!(
3094 and_composite.operator,
3095 crate::metadata::BooleanOperator::And
3096 );
3097 } else {
3098 panic!("Expected AND composite in OR children");
3099 }
3100 }
3101 } else {
3102 panic!("Expected OR composite at top level");
3103 }
3104
3105 let single_id_json = serde_json::json!({
3107 "#id": {
3108 "$eq": "single-doc-id"
3109 }
3110 });
3111
3112 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3113 assert_eq!(filter.query_ids, None);
3114 assert!(filter.where_clause.is_some());
3115
3116 let empty_json = serde_json::json!({});
3118 let filter: Filter = serde_json::from_value(empty_json).unwrap();
3119 assert_eq!(filter.query_ids, None);
3120 assert_eq!(filter.where_clause, None);
3122
3123 let advanced_json = serde_json::json!({
3125 "$and": [
3126 {
3127 "#id": {
3128 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3129 }
3130 },
3131 {
3132 "tags": {
3133 "$not_contains": "deprecated"
3134 }
3135 },
3136 {
3137 "$or": [
3138 {
3139 "$and": [
3140 {
3141 "confidence": {
3142 "$gte": 0.8
3143 }
3144 },
3145 {
3146 "verified": true
3147 }
3148 ]
3149 },
3150 {
3151 "$and": [
3152 {
3153 "confidence": {
3154 "$gte": 0.6
3155 }
3156 },
3157 {
3158 "confidence": {
3159 "$lt": 0.8
3160 }
3161 },
3162 {
3163 "reviews": {
3164 "$gte": 5
3165 }
3166 }
3167 ]
3168 }
3169 ]
3170 }
3171 ]
3172 });
3173
3174 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3175 assert_eq!(filter.query_ids, None);
3176 assert!(filter.where_clause.is_some());
3177
3178 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3180 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3181 assert_eq!(composite.children.len(), 3);
3182 } else {
3183 panic!("Expected AND composite at top level");
3184 }
3185 }
3186
3187 #[test]
3188 fn test_limit_json_serialization() {
3189 let limit = Limit {
3190 offset: 10,
3191 limit: Some(20),
3192 };
3193
3194 let json = serde_json::to_string(&limit).unwrap();
3195 let deserialized: Limit = serde_json::from_str(&json).unwrap();
3196
3197 assert_eq!(deserialized.offset, limit.offset);
3198 assert_eq!(deserialized.limit, limit.limit);
3199 }
3200
3201 #[test]
3202 fn test_query_vector_json_serialization() {
3203 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3205 let json = serde_json::to_string(&dense).unwrap();
3206 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3207 assert_eq!(deserialized, dense);
3208
3209 let sparse =
3211 QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3212 let json = serde_json::to_string(&sparse).unwrap();
3213 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3214 assert_eq!(deserialized, sparse);
3215 }
3216
3217 #[test]
3218 fn test_select_key_json_serialization() {
3219 use std::collections::HashSet;
3220
3221 let doc_key = Key::Document;
3223 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3224
3225 let embed_key = Key::Embedding;
3226 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3227
3228 let meta_key = Key::Metadata;
3229 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3230
3231 let score_key = Key::Score;
3232 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3233
3234 let custom_key = Key::MetadataField("custom_key".to_string());
3236 assert_eq!(
3237 serde_json::to_string(&custom_key).unwrap(),
3238 "\"custom_key\""
3239 );
3240
3241 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3243 assert!(matches!(deserialized, Key::Document));
3244
3245 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3246 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3247
3248 let mut keys = HashSet::new();
3250 keys.insert(Key::Document);
3251 keys.insert(Key::Embedding);
3252 keys.insert(Key::MetadataField("author".to_string()));
3253
3254 let select = Select { keys };
3255 let json = serde_json::to_string(&select).unwrap();
3256 let deserialized: Select = serde_json::from_str(&json).unwrap();
3257
3258 assert_eq!(deserialized.keys.len(), 3);
3259 assert!(deserialized.keys.contains(&Key::Document));
3260 assert!(deserialized.keys.contains(&Key::Embedding));
3261 assert!(deserialized
3262 .keys
3263 .contains(&Key::MetadataField("author".to_string())));
3264 }
3265
3266 #[test]
3267 fn test_merge_basic_integers() {
3268 use std::cmp::Reverse;
3269
3270 let merge = Merge { k: 5 };
3271
3272 let input = vec![
3274 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3275 vec![Reverse(2), Reverse(5), Reverse(8)],
3276 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3277 ];
3278
3279 let result = merge.merge(input);
3280
3281 assert_eq!(result.len(), 5);
3283 assert_eq!(
3284 result,
3285 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3286 );
3287 }
3288
3289 #[test]
3290 fn test_merge_u32_descending() {
3291 let merge = Merge { k: 6 };
3292
3293 let input = vec![
3295 vec![100u32, 75, 50, 25],
3296 vec![90, 60, 30],
3297 vec![95, 85, 70, 40, 10],
3298 ];
3299
3300 let result = merge.merge(input);
3301
3302 assert_eq!(result.len(), 6);
3304 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3305 }
3306
3307 #[test]
3308 fn test_merge_i32_descending() {
3309 let merge = Merge { k: 5 };
3310
3311 let input = vec![
3313 vec![50i32, 10, -10, -50],
3314 vec![30, 0, -30],
3315 vec![40, 20, -20, -40],
3316 ];
3317
3318 let result = merge.merge(input);
3319
3320 assert_eq!(result.len(), 5);
3322 assert_eq!(result, vec![50, 40, 30, 20, 10]);
3323 }
3324
3325 #[test]
3326 fn test_merge_with_duplicates() {
3327 let merge = Merge { k: 10 };
3328
3329 let input = vec![
3331 vec![100u32, 80, 80, 60, 40],
3332 vec![90, 80, 50, 30],
3333 vec![100, 70, 60, 20],
3334 ];
3335
3336 let result = merge.merge(input);
3337
3338 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3340 }
3341
3342 #[test]
3343 fn test_merge_empty_vectors() {
3344 let merge = Merge { k: 5 };
3345
3346 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3348 let result = merge.merge(input);
3349 assert_eq!(result.len(), 0);
3350
3351 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3353 let result = merge.merge(input);
3354 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3355
3356 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3358 let result = merge.merge(input);
3359 assert_eq!(result, vec![100, 50, 25]);
3360 }
3361
3362 #[test]
3363 fn test_merge_k_boundary_conditions() {
3364 let merge = Merge { k: 0 };
3366 let input = vec![vec![100u32, 50], vec![75, 25]];
3367 let result = merge.merge(input);
3368 assert_eq!(result.len(), 0);
3369
3370 let merge = Merge { k: 1 };
3372 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3373 let result = merge.merge(input);
3374 assert_eq!(result, vec![1000]);
3375
3376 let merge = Merge { k: 100 };
3378 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3379 let result = merge.merge(input);
3380 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3381 }
3382
3383 #[test]
3384 fn test_merge_with_strings() {
3385 let merge = Merge { k: 4 };
3386
3387 let input = vec![
3389 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3390 vec!["elephant".to_string(), "banana".to_string()],
3391 vec!["fish".to_string(), "cat".to_string()],
3392 ];
3393
3394 let result = merge.merge(input);
3395
3396 assert_eq!(
3398 result,
3399 vec![
3400 "zebra".to_string(),
3401 "fish".to_string(),
3402 "elephant".to_string(),
3403 "dog".to_string()
3404 ]
3405 );
3406 }
3407
3408 #[test]
3409 fn test_merge_with_custom_struct() {
3410 #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
3411 struct Score {
3412 value: i32,
3413 id: String,
3414 }
3415
3416 let merge = Merge { k: 3 };
3417
3418 let input = vec![
3420 vec![
3421 Score {
3422 value: 100,
3423 id: "a".to_string(),
3424 },
3425 Score {
3426 value: 80,
3427 id: "b".to_string(),
3428 },
3429 Score {
3430 value: 60,
3431 id: "c".to_string(),
3432 },
3433 ],
3434 vec![
3435 Score {
3436 value: 90,
3437 id: "d".to_string(),
3438 },
3439 Score {
3440 value: 70,
3441 id: "e".to_string(),
3442 },
3443 ],
3444 vec![
3445 Score {
3446 value: 95,
3447 id: "f".to_string(),
3448 },
3449 Score {
3450 value: 85,
3451 id: "g".to_string(),
3452 },
3453 ],
3454 ];
3455
3456 let result = merge.merge(input);
3457
3458 assert_eq!(result.len(), 3);
3459 assert_eq!(
3460 result[0],
3461 Score {
3462 value: 100,
3463 id: "a".to_string()
3464 }
3465 );
3466 assert_eq!(
3467 result[1],
3468 Score {
3469 value: 95,
3470 id: "f".to_string()
3471 }
3472 );
3473 assert_eq!(
3474 result[2],
3475 Score {
3476 value: 90,
3477 id: "d".to_string()
3478 }
3479 );
3480 }
3481
3482 #[test]
3483 fn test_merge_preserves_order() {
3484 use std::cmp::Reverse;
3485
3486 let merge = Merge { k: 10 };
3487
3488 let input = vec![
3491 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3492 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3493 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3494 ];
3495
3496 let result = merge.merge(input);
3497
3498 for i in 1..result.len() {
3501 assert!(
3502 result[i - 1] >= result[i],
3503 "Output should be in descending Reverse order"
3504 );
3505 assert!(
3506 result[i - 1].0 <= result[i].0,
3507 "Inner values should be in ascending order"
3508 );
3509 }
3510
3511 assert_eq!(
3513 result,
3514 vec![
3515 Reverse(1),
3516 Reverse(2),
3517 Reverse(3),
3518 Reverse(4),
3519 Reverse(5),
3520 Reverse(6),
3521 Reverse(7),
3522 Reverse(8),
3523 Reverse(9),
3524 Reverse(10)
3525 ]
3526 );
3527 }
3528
3529 #[test]
3530 fn test_merge_single_vector() {
3531 let merge = Merge { k: 3 };
3532
3533 let input = vec![vec![1000u64, 800, 600, 400, 200]];
3535
3536 let result = merge.merge(input);
3537
3538 assert_eq!(result, vec![1000, 800, 600]);
3539 }
3540
3541 #[test]
3542 fn test_merge_all_same_values() {
3543 let merge = Merge { k: 5 };
3544
3545 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3547
3548 let result = merge.merge(input);
3549
3550 assert_eq!(result, vec![42]);
3552 }
3553
3554 #[test]
3555 fn test_merge_mixed_types_sizes() {
3556 let merge = Merge { k: 4 };
3558 let input = vec![
3559 vec![1000usize, 500, 100],
3560 vec![800, 300],
3561 vec![900, 600, 200],
3562 ];
3563 let result = merge.merge(input);
3564 assert_eq!(result, vec![1000, 900, 800, 600]);
3565
3566 let merge = Merge { k: 5 };
3568 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3569 let result = merge.merge(input);
3570 assert_eq!(result, vec![15, 10, 5, 0, -5]);
3571 }
3572
3573 #[test]
3574 fn test_merge_dedup_same_id_different_scores() {
3575 use std::cmp::Reverse;
3576
3577 let merge = Merge { k: 5 };
3581
3582 let input: Vec<Vec<Reverse<RecordMeasure>>> = vec![
3586 vec![
3587 Reverse(RecordMeasure {
3588 offset_id: 1,
3589 measure: 0.10,
3590 }),
3591 Reverse(RecordMeasure {
3592 offset_id: 4,
3593 measure: 0.50,
3594 }),
3595 Reverse(RecordMeasure {
3596 offset_id: 5,
3597 measure: 0.70,
3598 }),
3599 ],
3600 vec![
3601 Reverse(RecordMeasure {
3602 offset_id: 2,
3603 measure: 0.20,
3604 }),
3605 Reverse(RecordMeasure {
3606 offset_id: 1,
3607 measure: 0.25,
3608 }), Reverse(RecordMeasure {
3610 offset_id: 3,
3611 measure: 0.60,
3612 }),
3613 ],
3614 vec![
3615 Reverse(RecordMeasure {
3616 offset_id: 3,
3617 measure: 0.15,
3618 }), Reverse(RecordMeasure {
3620 offset_id: 4,
3621 measure: 0.35,
3622 }), Reverse(RecordMeasure {
3624 offset_id: 2,
3625 measure: 0.80,
3626 }), ],
3628 ];
3629
3630 let result: Vec<Reverse<RecordMeasure>> = merge.merge(input);
3640 let ids: Vec<u32> = result.iter().map(|Reverse(r)| r.offset_id).collect();
3641 let measures: Vec<f32> = result.iter().map(|Reverse(r)| r.measure).collect();
3642
3643 assert_eq!(ids, vec![1, 3, 2, 4, 5]);
3644 assert_eq!(measures, vec![0.10, 0.15, 0.20, 0.35, 0.70]);
3645 }
3646
3647 #[test]
3648 fn test_aggregate_json_serialization() {
3649 let min_k = Aggregate::MinK {
3651 keys: vec![Key::Score, Key::field("date")],
3652 k: 3,
3653 };
3654 let json = serde_json::to_value(&min_k).unwrap();
3655 assert!(json.get("$min_k").is_some());
3656 assert_eq!(json["$min_k"]["k"], 3);
3657
3658 let min_k_json = serde_json::json!({
3660 "$min_k": {
3661 "keys": ["#score", "date"],
3662 "k": 5
3663 }
3664 });
3665 let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3666 match deserialized {
3667 Aggregate::MinK { keys, k } => {
3668 assert_eq!(k, 5);
3669 assert_eq!(keys.len(), 2);
3670 assert_eq!(keys[0], Key::Score);
3671 assert_eq!(keys[1], Key::field("date"));
3672 }
3673 _ => panic!("Expected MinK"),
3674 }
3675
3676 let max_k = Aggregate::MaxK {
3678 keys: vec![Key::field("timestamp")],
3679 k: 10,
3680 };
3681 let json = serde_json::to_value(&max_k).unwrap();
3682 assert!(json.get("$max_k").is_some());
3683 assert_eq!(json["$max_k"]["k"], 10);
3684
3685 let max_k_json = serde_json::json!({
3687 "$max_k": {
3688 "keys": ["timestamp"],
3689 "k": 2
3690 }
3691 });
3692 let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3693 match deserialized {
3694 Aggregate::MaxK { keys, k } => {
3695 assert_eq!(k, 2);
3696 assert_eq!(keys.len(), 1);
3697 assert_eq!(keys[0], Key::field("timestamp"));
3698 }
3699 _ => panic!("Expected MaxK"),
3700 }
3701 }
3702
3703 #[test]
3704 fn test_group_by_json_serialization() {
3705 let group_by = GroupBy {
3707 keys: vec![Key::field("category"), Key::field("author")],
3708 aggregate: Some(Aggregate::MinK {
3709 keys: vec![Key::Score],
3710 k: 3,
3711 }),
3712 };
3713
3714 let json = serde_json::to_value(&group_by).unwrap();
3715 assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3716 assert!(json["aggregate"]["$min_k"].is_object());
3717
3718 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3720 assert_eq!(deserialized.keys.len(), 2);
3721 assert_eq!(deserialized.keys[0], Key::field("category"));
3722 assert_eq!(deserialized.keys[1], Key::field("author"));
3723 assert!(deserialized.aggregate.is_some());
3724
3725 let empty_group_by = GroupBy::default();
3727 let json = serde_json::to_value(&empty_group_by).unwrap();
3728 let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3729 assert!(deserialized.keys.is_empty());
3730 assert!(deserialized.aggregate.is_none());
3731
3732 let json = serde_json::json!({
3734 "keys": ["category"],
3735 "aggregate": {
3736 "$max_k": {
3737 "keys": ["#score", "priority"],
3738 "k": 5
3739 }
3740 }
3741 });
3742 let group_by: GroupBy = serde_json::from_value(json).unwrap();
3743 assert_eq!(group_by.keys.len(), 1);
3744 assert_eq!(group_by.keys[0], Key::field("category"));
3745 match group_by.aggregate {
3746 Some(Aggregate::MaxK { keys, k }) => {
3747 assert_eq!(k, 5);
3748 assert_eq!(keys.len(), 2);
3749 assert_eq!(keys[0], Key::Score);
3750 }
3751 _ => panic!("Expected MaxK aggregate"),
3752 }
3753 }
3754}