1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::Hash,
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15 MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16 Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
170pub struct Filter {
171 pub query_ids: Option<Vec<String>>,
172 pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 match (&self.query_ids, &self.where_clause) {
184 (None, None) => {
185 let map = serializer.serialize_map(Some(0))?;
187 map.end()
188 }
189 (None, Some(where_clause)) => {
190 where_clause.serialize(serializer)
192 }
193 (Some(ids), None) => {
194 let id_where = Where::Metadata(MetadataExpression {
196 key: "#id".to_string(),
197 comparison: MetadataComparison::Set(
198 SetOperator::In,
199 MetadataSetValue::Str(ids.clone()),
200 ),
201 });
202 id_where.serialize(serializer)
203 }
204 (Some(ids), Some(where_clause)) => {
205 let id_where = Where::Metadata(MetadataExpression {
207 key: "#id".to_string(),
208 comparison: MetadataComparison::Set(
209 SetOperator::In,
210 MetadataSetValue::Str(ids.clone()),
211 ),
212 });
213 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214 combined.serialize(serializer)
215 }
216 }
217 }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let where_json = Value::deserialize(deserializer)?;
227 let where_clause =
228 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229 None
230 } else {
231 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232 };
233
234 Ok(Filter {
235 query_ids: None, where_clause,
237 })
238 }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242 type Error = QueryConversionError;
243
244 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247 let where_clause = match (where_metadata, where_document) {
248 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249 (Some(w), None) | (None, Some(w)) => Some(w),
250 _ => None,
251 };
252
253 Ok(Self {
254 query_ids: value.ids.map(|uids| uids.ids),
255 where_clause,
256 })
257 }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261 type Error = QueryConversionError;
262
263 fn try_from(value: Filter) -> Result<Self, Self::Error> {
264 Ok(Self {
265 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267 where_document: None,
268 })
269 }
270}
271
272#[derive(Clone, Debug)]
278pub struct Knn {
279 pub embedding: Vec<f32>,
280 pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284 fn from(value: KnnBatch) -> Self {
285 value
286 .embeddings
287 .into_iter()
288 .map(|embedding| Knn {
289 embedding,
290 fetch: value.fetch,
291 })
292 .collect()
293 }
294}
295
296#[derive(Clone, Debug)]
302pub struct KnnBatch {
303 pub embeddings: Vec<Vec<f32>>,
304 pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308 type Error = QueryConversionError;
309
310 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311 Ok(Self {
312 embeddings: value
313 .embeddings
314 .into_iter()
315 .map(|vec| vec.try_into().map(|(v, _)| v))
316 .collect::<Result<_, _>>()?,
317 fetch: value.fetch,
318 })
319 }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323 type Error = QueryConversionError;
324
325 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326 Ok(Self {
327 embeddings: value
328 .embeddings
329 .into_iter()
330 .map(|embedding| {
331 let dim = embedding.len();
332 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333 })
334 .collect::<Result<_, _>>()?,
335 fetch: value.fetch,
336 })
337 }
338}
339
340#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374 #[serde(default)]
375 pub offset: u32,
376 #[serde(default)]
377 pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381 fn from(value: chroma_proto::LimitOperator) -> Self {
382 Self {
383 offset: value.offset,
384 limit: value.limit,
385 }
386 }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390 fn from(value: Limit) -> Self {
391 Self {
392 offset: value.offset,
393 limit: value.limit,
394 }
395 }
396}
397
398#[derive(Clone, Debug)]
400pub struct RecordMeasure {
401 pub offset_id: u32,
402 pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406 fn eq(&self, other: &Self) -> bool {
407 self.offset_id.eq(&other.offset_id)
408 }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414 fn cmp(&self, other: &Self) -> Ordering {
415 self.measure
416 .total_cmp(&other.measure)
417 .then_with(|| self.offset_id.cmp(&other.offset_id))
418 }
419}
420
421impl PartialOrd for RecordMeasure {
422 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423 Some(self.cmp(other))
424 }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429 pub distances: Vec<RecordMeasure>,
430}
431
432#[derive(Clone, Debug)]
442pub struct Merge {
443 pub k: u32,
444}
445
446impl Merge {
447 pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450 let mut max_heap = batch_iters
451 .iter_mut()
452 .enumerate()
453 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454 .collect::<BinaryHeap<_>>();
455
456 let mut fusion = Vec::with_capacity(self.k as usize);
457 while let Some((m, idx)) = max_heap.pop() {
458 if self.k <= fusion.len() as u32 {
459 break;
460 }
461 if let Some(next_m) = batch_iters[idx].next() {
462 max_heap.push((next_m, idx));
463 }
464 if fusion.last().is_some_and(|tail| tail == &m) {
465 continue;
466 }
467 fusion.push(m);
468 }
469 fusion
470 }
471}
472
473#[derive(Clone, Debug, Default)]
480pub struct Projection {
481 pub document: bool,
482 pub embedding: bool,
483 pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487 fn from(value: chroma_proto::ProjectionOperator) -> Self {
488 Self {
489 document: value.document,
490 embedding: value.embedding,
491 metadata: value.metadata,
492 }
493 }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497 fn from(value: Projection) -> Self {
498 Self {
499 document: value.document,
500 embedding: value.embedding,
501 metadata: value.metadata,
502 }
503 }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508 pub id: String,
509 pub document: Option<String>,
510 pub embedding: Option<Vec<f32>>,
511 pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515 pub fn size_bytes(&self) -> u64 {
516 (self.id.len()
517 + self
518 .document
519 .as_ref()
520 .map(|doc| doc.len())
521 .unwrap_or_default()
522 + self
523 .embedding
524 .as_ref()
525 .map(|emb| size_of_val(&emb[..]))
526 .unwrap_or_default()
527 + self
528 .metadata
529 .as_ref()
530 .map(logical_size_of_metadata)
531 .unwrap_or_default()) as u64
532 }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538 type Error = QueryConversionError;
539
540 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541 Ok(Self {
542 id: value.id,
543 document: value.document,
544 embedding: value
545 .embedding
546 .map(|vec| vec.try_into().map(|(v, _)| v))
547 .transpose()?,
548 metadata: value.metadata.map(TryInto::try_into).transpose()?,
549 })
550 }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554 type Error = QueryConversionError;
555
556 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557 Ok(Self {
558 id: value.id,
559 document: value.document,
560 embedding: value
561 .embedding
562 .map(|embedding| {
563 let embedding_dimension = embedding.len();
564 chroma_proto::Vector::try_from((
565 embedding,
566 ScalarEncoding::FLOAT32,
567 embedding_dimension,
568 ))
569 })
570 .transpose()?,
571 metadata: value.metadata.map(|metadata| metadata.into()),
572 })
573 }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578 pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583 pub pulled_log_bytes: u64,
584 pub result: ProjectionOutput,
585}
586
587impl GetResult {
588 pub fn size_bytes(&self) -> u64 {
589 self.result
590 .records
591 .iter()
592 .map(ProjectionRecord::size_bytes)
593 .sum()
594 }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598 type Error = QueryConversionError;
599
600 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601 Ok(Self {
602 pulled_log_bytes: value.pulled_log_bytes,
603 result: ProjectionOutput {
604 records: value
605 .records
606 .into_iter()
607 .map(TryInto::try_into)
608 .collect::<Result<_, _>>()?,
609 },
610 })
611 }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615 type Error = QueryConversionError;
616
617 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618 Ok(Self {
619 pulled_log_bytes: value.pulled_log_bytes,
620 records: value
621 .result
622 .records
623 .into_iter()
624 .map(TryInto::try_into)
625 .collect::<Result<_, _>>()?,
626 })
627 }
628}
629
630#[derive(Clone, Debug)]
638pub struct KnnProjection {
639 pub projection: Projection,
640 pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644 type Error = QueryConversionError;
645
646 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647 Ok(Self {
648 projection: value
649 .projection
650 .ok_or(QueryConversionError::field("projection"))?
651 .into(),
652 distance: value.distance,
653 })
654 }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658 fn from(value: KnnProjection) -> Self {
659 Self {
660 projection: Some(value.projection.into()),
661 distance: value.distance,
662 }
663 }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668 pub record: ProjectionRecord,
669 pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673 type Error = QueryConversionError;
674
675 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676 Ok(Self {
677 record: value
678 .record
679 .ok_or(QueryConversionError::field("record"))?
680 .try_into()?,
681 distance: value.distance,
682 })
683 }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687 type Error = QueryConversionError;
688
689 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690 Ok(Self {
691 record: Some(value.record.try_into()?),
692 distance: value.distance,
693 })
694 }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699 pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703 type Error = QueryConversionError;
704
705 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706 Ok(Self {
707 records: value
708 .records
709 .into_iter()
710 .map(TryInto::try_into)
711 .collect::<Result<_, _>>()?,
712 })
713 }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717 type Error = QueryConversionError;
718
719 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720 Ok(Self {
721 records: value
722 .records
723 .into_iter()
724 .map(TryInto::try_into)
725 .collect::<Result<_, _>>()?,
726 })
727 }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732 pub pulled_log_bytes: u64,
733 pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737 pub fn size_bytes(&self) -> u64 {
738 self.results
739 .iter()
740 .flat_map(|res| {
741 res.records
742 .iter()
743 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744 })
745 .sum()
746 }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750 type Error = QueryConversionError;
751
752 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753 Ok(Self {
754 pulled_log_bytes: value.pulled_log_bytes,
755 results: value
756 .results
757 .into_iter()
758 .map(TryInto::try_into)
759 .collect::<Result<_, _>>()?,
760 })
761 }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765 type Error = QueryConversionError;
766
767 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768 Ok(Self {
769 pulled_log_bytes: value.pulled_log_bytes,
770 results: value
771 .results
772 .into_iter()
773 .map(TryInto::try_into)
774 .collect::<Result<_, _>>()?,
775 })
776 }
777}
778
779#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845 Dense(Vec<f32>),
846 Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850 type Error = QueryConversionError;
851
852 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854 match vector {
855 chroma_proto::query_vector::Vector::Dense(dense) => {
856 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857 }
858 chroma_proto::query_vector::Vector::Sparse(sparse) => {
859 Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860 QueryConversionError::validation("sparse vector length mismatch")
861 })?))
862 }
863 }
864 }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868 type Error = QueryConversionError;
869
870 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871 match value {
872 QueryVector::Dense(vec) => {
873 let dim = vec.len();
874 Ok(chroma_proto::QueryVector {
875 vector: Some(chroma_proto::query_vector::Vector::Dense(
876 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877 )),
878 })
879 }
880 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882 }),
883 }
884 }
885}
886
887impl From<Vec<f32>> for QueryVector {
888 fn from(vec: Vec<f32>) -> Self {
889 QueryVector::Dense(vec)
890 }
891}
892
893impl From<SparseVector> for QueryVector {
894 fn from(sparse: SparseVector) -> Self {
895 QueryVector::Sparse(sparse)
896 }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901 pub query: QueryVector,
902 pub key: Key,
903 pub limit: u32,
904}
905
906#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939 pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943 pub fn knn_queries(&self) -> Vec<KnnQuery> {
944 self.expr
945 .as_ref()
946 .map(RankExpr::knn_queries)
947 .unwrap_or_default()
948 }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952 type Error = QueryConversionError;
953
954 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955 Ok(Rank {
956 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957 })
958 }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962 type Error = QueryConversionError;
963
964 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965 Ok(chroma_proto::RankOperator {
966 expr: rank.expr.map(TryInto::try_into).transpose()?,
967 })
968 }
969}
970
971#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135 #[serde(rename = "$abs")]
1136 Absolute(Box<RankExpr>),
1137 #[serde(rename = "$div")]
1138 Division {
1139 left: Box<RankExpr>,
1140 right: Box<RankExpr>,
1141 },
1142 #[serde(rename = "$exp")]
1143 Exponentiation(Box<RankExpr>),
1144 #[serde(rename = "$knn")]
1145 Knn {
1146 query: QueryVector,
1147 #[serde(default = "RankExpr::default_knn_key")]
1148 key: Key,
1149 #[serde(default = "RankExpr::default_knn_limit")]
1150 limit: u32,
1151 #[serde(default)]
1152 default: Option<f32>,
1153 #[serde(default)]
1154 return_rank: bool,
1155 },
1156 #[serde(rename = "$log")]
1157 Logarithm(Box<RankExpr>),
1158 #[serde(rename = "$max")]
1159 Maximum(Vec<RankExpr>),
1160 #[serde(rename = "$min")]
1161 Minimum(Vec<RankExpr>),
1162 #[serde(rename = "$mul")]
1163 Multiplication(Vec<RankExpr>),
1164 #[serde(rename = "$sub")]
1165 Subtraction {
1166 left: Box<RankExpr>,
1167 right: Box<RankExpr>,
1168 },
1169 #[serde(rename = "$sum")]
1170 Summation(Vec<RankExpr>),
1171 #[serde(rename = "$val")]
1172 Value(f32),
1173}
1174
1175impl RankExpr {
1176 pub fn default_knn_key() -> Key {
1177 Key::Embedding
1178 }
1179
1180 pub fn default_knn_limit() -> u32 {
1181 16
1182 }
1183
1184 pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185 match self {
1186 RankExpr::Absolute(expr)
1187 | RankExpr::Exponentiation(expr)
1188 | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190 .knn_queries()
1191 .into_iter()
1192 .chain(right.knn_queries())
1193 .collect(),
1194 RankExpr::Maximum(exprs)
1195 | RankExpr::Minimum(exprs)
1196 | RankExpr::Multiplication(exprs)
1197 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198 RankExpr::Value(_) => Vec::new(),
1199 RankExpr::Knn {
1200 query,
1201 key,
1202 limit,
1203 default: _,
1204 return_rank: _,
1205 } => vec![KnnQuery {
1206 query: query.clone(),
1207 key: key.clone(),
1208 limit: *limit,
1209 }],
1210 }
1211 }
1212
1213 pub fn exp(self) -> Self {
1233 RankExpr::Exponentiation(Box::new(self))
1234 }
1235
1236 pub fn log(self) -> Self {
1257 RankExpr::Logarithm(Box::new(self))
1258 }
1259
1260 pub fn abs(self) -> Self {
1287 RankExpr::Absolute(Box::new(self))
1288 }
1289
1290 pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314 let other = other.into();
1315
1316 match self {
1317 RankExpr::Maximum(mut exprs) => match other {
1318 RankExpr::Maximum(other_exprs) => {
1319 exprs.extend(other_exprs);
1320 RankExpr::Maximum(exprs)
1321 }
1322 _ => {
1323 exprs.push(other);
1324 RankExpr::Maximum(exprs)
1325 }
1326 },
1327 _ => match other {
1328 RankExpr::Maximum(mut exprs) => {
1329 exprs.insert(0, self);
1330 RankExpr::Maximum(exprs)
1331 }
1332 _ => RankExpr::Maximum(vec![self, other]),
1333 },
1334 }
1335 }
1336
1337 pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361 let other = other.into();
1362
1363 match self {
1364 RankExpr::Minimum(mut exprs) => match other {
1365 RankExpr::Minimum(other_exprs) => {
1366 exprs.extend(other_exprs);
1367 RankExpr::Minimum(exprs)
1368 }
1369 _ => {
1370 exprs.push(other);
1371 RankExpr::Minimum(exprs)
1372 }
1373 },
1374 _ => match other {
1375 RankExpr::Minimum(mut exprs) => {
1376 exprs.insert(0, self);
1377 RankExpr::Minimum(exprs)
1378 }
1379 _ => RankExpr::Minimum(vec![self, other]),
1380 },
1381 }
1382 }
1383}
1384
1385impl Add for RankExpr {
1386 type Output = RankExpr;
1387
1388 fn add(self, rhs: Self) -> Self::Output {
1389 match self {
1390 RankExpr::Summation(mut exprs) => match rhs {
1391 RankExpr::Summation(rhs_exprs) => {
1392 exprs.extend(rhs_exprs);
1393 RankExpr::Summation(exprs)
1394 }
1395 _ => {
1396 exprs.push(rhs);
1397 RankExpr::Summation(exprs)
1398 }
1399 },
1400 _ => match rhs {
1401 RankExpr::Summation(mut exprs) => {
1402 exprs.insert(0, self);
1403 RankExpr::Summation(exprs)
1404 }
1405 _ => RankExpr::Summation(vec![self, rhs]),
1406 },
1407 }
1408 }
1409}
1410
1411impl Add<f32> for RankExpr {
1412 type Output = RankExpr;
1413
1414 fn add(self, rhs: f32) -> Self::Output {
1415 self + RankExpr::Value(rhs)
1416 }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420 type Output = RankExpr;
1421
1422 fn add(self, rhs: RankExpr) -> Self::Output {
1423 RankExpr::Value(self) + rhs
1424 }
1425}
1426
1427impl Sub for RankExpr {
1428 type Output = RankExpr;
1429
1430 fn sub(self, rhs: Self) -> Self::Output {
1431 RankExpr::Subtraction {
1432 left: Box::new(self),
1433 right: Box::new(rhs),
1434 }
1435 }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439 type Output = RankExpr;
1440
1441 fn sub(self, rhs: f32) -> Self::Output {
1442 self - RankExpr::Value(rhs)
1443 }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447 type Output = RankExpr;
1448
1449 fn sub(self, rhs: RankExpr) -> Self::Output {
1450 RankExpr::Value(self) - rhs
1451 }
1452}
1453
1454impl Mul for RankExpr {
1455 type Output = RankExpr;
1456
1457 fn mul(self, rhs: Self) -> Self::Output {
1458 match self {
1459 RankExpr::Multiplication(mut exprs) => match rhs {
1460 RankExpr::Multiplication(rhs_exprs) => {
1461 exprs.extend(rhs_exprs);
1462 RankExpr::Multiplication(exprs)
1463 }
1464 _ => {
1465 exprs.push(rhs);
1466 RankExpr::Multiplication(exprs)
1467 }
1468 },
1469 _ => match rhs {
1470 RankExpr::Multiplication(mut exprs) => {
1471 exprs.insert(0, self);
1472 RankExpr::Multiplication(exprs)
1473 }
1474 _ => RankExpr::Multiplication(vec![self, rhs]),
1475 },
1476 }
1477 }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481 type Output = RankExpr;
1482
1483 fn mul(self, rhs: f32) -> Self::Output {
1484 self * RankExpr::Value(rhs)
1485 }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489 type Output = RankExpr;
1490
1491 fn mul(self, rhs: RankExpr) -> Self::Output {
1492 RankExpr::Value(self) * rhs
1493 }
1494}
1495
1496impl Div for RankExpr {
1497 type Output = RankExpr;
1498
1499 fn div(self, rhs: Self) -> Self::Output {
1500 RankExpr::Division {
1501 left: Box::new(self),
1502 right: Box::new(rhs),
1503 }
1504 }
1505}
1506
1507impl Div<f32> for RankExpr {
1508 type Output = RankExpr;
1509
1510 fn div(self, rhs: f32) -> Self::Output {
1511 self / RankExpr::Value(rhs)
1512 }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516 type Output = RankExpr;
1517
1518 fn div(self, rhs: RankExpr) -> Self::Output {
1519 RankExpr::Value(self) / rhs
1520 }
1521}
1522
1523impl Neg for RankExpr {
1524 type Output = RankExpr;
1525
1526 fn neg(self) -> Self::Output {
1527 RankExpr::Value(-1.0) * self
1528 }
1529}
1530
1531impl From<f32> for RankExpr {
1532 fn from(v: f32) -> Self {
1533 RankExpr::Value(v)
1534 }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538 type Error = QueryConversionError;
1539
1540 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541 match proto_expr.rank {
1542 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544 }
1545 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548 Ok(RankExpr::Division {
1549 left: Box::new(RankExpr::try_from(*left)?),
1550 right: Box::new(RankExpr::try_from(*right)?),
1551 })
1552 }
1553 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555 ),
1556 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557 let query = knn
1558 .query
1559 .ok_or(QueryConversionError::field("query"))?
1560 .try_into()?;
1561 Ok(RankExpr::Knn {
1562 query,
1563 key: Key::from(knn.key),
1564 limit: knn.limit,
1565 default: knn.default,
1566 return_rank: knn.return_rank,
1567 })
1568 }
1569 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571 }
1572 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573 let exprs = max
1574 .exprs
1575 .into_iter()
1576 .map(RankExpr::try_from)
1577 .collect::<Result<Vec<_>, _>>()?;
1578 Ok(RankExpr::Maximum(exprs))
1579 }
1580 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581 let exprs = min
1582 .exprs
1583 .into_iter()
1584 .map(RankExpr::try_from)
1585 .collect::<Result<Vec<_>, _>>()?;
1586 Ok(RankExpr::Minimum(exprs))
1587 }
1588 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589 let exprs = mul
1590 .exprs
1591 .into_iter()
1592 .map(RankExpr::try_from)
1593 .collect::<Result<Vec<_>, _>>()?;
1594 Ok(RankExpr::Multiplication(exprs))
1595 }
1596 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599 Ok(RankExpr::Subtraction {
1600 left: Box::new(RankExpr::try_from(*left)?),
1601 right: Box::new(RankExpr::try_from(*right)?),
1602 })
1603 }
1604 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605 let exprs = sum
1606 .exprs
1607 .into_iter()
1608 .map(RankExpr::try_from)
1609 .collect::<Result<Vec<_>, _>>()?;
1610 Ok(RankExpr::Summation(exprs))
1611 }
1612 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613 None => Err(QueryConversionError::field("rank")),
1614 }
1615 }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619 type Error = QueryConversionError;
1620
1621 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622 let proto_rank = match rank_expr {
1623 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624 chroma_proto::RankExpr::try_from(*expr)?,
1625 )),
1626 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627 Box::new(chroma_proto::rank_expr::RankPair {
1628 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630 }),
1631 ),
1632 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634 ),
1635 RankExpr::Knn {
1636 query,
1637 key,
1638 limit,
1639 default,
1640 return_rank,
1641 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642 query: Some(query.try_into()?),
1643 key: key.to_string(),
1644 limit,
1645 default,
1646 return_rank,
1647 }),
1648 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649 chroma_proto::RankExpr::try_from(*expr)?,
1650 )),
1651 RankExpr::Maximum(exprs) => {
1652 let proto_exprs = exprs
1653 .into_iter()
1654 .map(chroma_proto::RankExpr::try_from)
1655 .collect::<Result<Vec<_>, _>>()?;
1656 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657 exprs: proto_exprs,
1658 })
1659 }
1660 RankExpr::Minimum(exprs) => {
1661 let proto_exprs = exprs
1662 .into_iter()
1663 .map(chroma_proto::RankExpr::try_from)
1664 .collect::<Result<Vec<_>, _>>()?;
1665 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666 exprs: proto_exprs,
1667 })
1668 }
1669 RankExpr::Multiplication(exprs) => {
1670 let proto_exprs = exprs
1671 .into_iter()
1672 .map(chroma_proto::RankExpr::try_from)
1673 .collect::<Result<Vec<_>, _>>()?;
1674 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675 exprs: proto_exprs,
1676 })
1677 }
1678 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679 Box::new(chroma_proto::rank_expr::RankPair {
1680 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682 }),
1683 ),
1684 RankExpr::Summation(exprs) => {
1685 let proto_exprs = exprs
1686 .into_iter()
1687 .map(chroma_proto::RankExpr::try_from)
1688 .collect::<Result<Vec<_>, _>>()?;
1689 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690 exprs: proto_exprs,
1691 })
1692 }
1693 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694 };
1695
1696 Ok(chroma_proto::RankExpr {
1697 rank: Some(proto_rank),
1698 })
1699 }
1700}
1701
1702#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769 Document,
1771 Embedding,
1772 Metadata,
1773 Score,
1774 MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779 where
1780 S: serde::Serializer,
1781 {
1782 match self {
1783 Key::Document => serializer.serialize_str("#document"),
1784 Key::Embedding => serializer.serialize_str("#embedding"),
1785 Key::Metadata => serializer.serialize_str("#metadata"),
1786 Key::Score => serializer.serialize_str("#score"),
1787 Key::MetadataField(field) => serializer.serialize_str(field),
1788 }
1789 }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794 where
1795 D: Deserializer<'de>,
1796 {
1797 let s = String::deserialize(deserializer)?;
1798 Ok(Key::from(s))
1799 }
1800}
1801
1802impl fmt::Display for Key {
1803 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804 match self {
1805 Key::Document => write!(f, "#document"),
1806 Key::Embedding => write!(f, "#embedding"),
1807 Key::Metadata => write!(f, "#metadata"),
1808 Key::Score => write!(f, "#score"),
1809 Key::MetadataField(field) => write!(f, "{}", field),
1810 }
1811 }
1812}
1813
1814impl From<&str> for Key {
1815 fn from(s: &str) -> Self {
1816 match s {
1817 "#document" => Key::Document,
1818 "#embedding" => Key::Embedding,
1819 "#metadata" => Key::Metadata,
1820 "#score" => Key::Score,
1821 field => Key::MetadataField(field.to_string()),
1823 }
1824 }
1825}
1826
1827impl From<String> for Key {
1828 fn from(s: String) -> Self {
1829 Key::from(s.as_str())
1830 }
1831}
1832
1833impl Key {
1834 pub fn field(name: impl Into<String>) -> Self {
1846 Key::MetadataField(name.into())
1847 }
1848
1849 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866 Where::Metadata(MetadataExpression {
1867 key: self.to_string(),
1868 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869 })
1870 }
1871
1872 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883 Where::Metadata(MetadataExpression {
1884 key: self.to_string(),
1885 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886 })
1887 }
1888
1889 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900 Where::Metadata(MetadataExpression {
1901 key: self.to_string(),
1902 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903 })
1904 }
1905
1906 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917 Where::Metadata(MetadataExpression {
1918 key: self.to_string(),
1919 comparison: MetadataComparison::Primitive(
1920 PrimitiveOperator::GreaterThanOrEqual,
1921 value.into(),
1922 ),
1923 })
1924 }
1925
1926 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937 Where::Metadata(MetadataExpression {
1938 key: self.to_string(),
1939 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940 })
1941 }
1942
1943 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954 Where::Metadata(MetadataExpression {
1955 key: self.to_string(),
1956 comparison: MetadataComparison::Primitive(
1957 PrimitiveOperator::LessThanOrEqual,
1958 value.into(),
1959 ),
1960 })
1961 }
1962
1963 pub fn is_in<I, T>(self, values: I) -> Where
1983 where
1984 I: IntoIterator<Item = T>,
1985 Vec<T>: Into<MetadataSetValue>,
1986 {
1987 let vec: Vec<T> = values.into_iter().collect();
1988 Where::Metadata(MetadataExpression {
1989 key: self.to_string(),
1990 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991 })
1992 }
1993
1994 pub fn not_in<I, T>(self, values: I) -> Where
2010 where
2011 I: IntoIterator<Item = T>,
2012 Vec<T>: Into<MetadataSetValue>,
2013 {
2014 let vec: Vec<T> = values.into_iter().collect();
2015 Where::Metadata(MetadataExpression {
2016 key: self.to_string(),
2017 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018 })
2019 }
2020
2021 pub fn contains<S: Into<String>>(self, text: S) -> Where {
2035 Where::Document(DocumentExpression {
2036 operator: DocumentOperator::Contains,
2037 pattern: text.into(),
2038 })
2039 }
2040
2041 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2054 Where::Document(DocumentExpression {
2055 operator: DocumentOperator::NotContains,
2056 pattern: text.into(),
2057 })
2058 }
2059
2060 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2077 Where::Document(DocumentExpression {
2078 operator: DocumentOperator::Regex,
2079 pattern: pattern.into(),
2080 })
2081 }
2082
2083 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2099 Where::Document(DocumentExpression {
2100 operator: DocumentOperator::NotRegex,
2101 pattern: pattern.into(),
2102 })
2103 }
2104}
2105
2106#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2160pub struct Select {
2161 #[serde(default)]
2162 pub keys: HashSet<Key>,
2163}
2164
2165impl TryFrom<chroma_proto::SelectOperator> for Select {
2166 type Error = QueryConversionError;
2167
2168 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2169 let keys = value
2170 .keys
2171 .into_iter()
2172 .map(|key| {
2173 serde_json::from_value(serde_json::Value::String(key))
2175 .map_err(|_| QueryConversionError::field("keys"))
2176 })
2177 .collect::<Result<HashSet<_>, _>>()?;
2178
2179 Ok(Self { keys })
2180 }
2181}
2182
2183impl TryFrom<Select> for chroma_proto::SelectOperator {
2184 type Error = QueryConversionError;
2185
2186 fn try_from(value: Select) -> Result<Self, Self::Error> {
2187 let keys = value
2188 .keys
2189 .into_iter()
2190 .map(|key| {
2191 serde_json::to_value(&key)
2193 .ok()
2194 .and_then(|v| v.as_str().map(String::from))
2195 .ok_or(QueryConversionError::field("keys"))
2196 })
2197 .collect::<Result<Vec<_>, _>>()?;
2198
2199 Ok(Self { keys })
2200 }
2201}
2202
2203#[derive(Clone, Debug, Deserialize, Serialize)]
2240#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2241pub struct SearchRecord {
2242 pub id: String,
2243 pub document: Option<String>,
2244 pub embedding: Option<Vec<f32>>,
2245 pub metadata: Option<Metadata>,
2246 pub score: Option<f32>,
2247}
2248
2249impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2250 type Error = QueryConversionError;
2251
2252 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2253 Ok(Self {
2254 id: value.id,
2255 document: value.document,
2256 embedding: value
2257 .embedding
2258 .map(|vec| vec.try_into().map(|(v, _)| v))
2259 .transpose()?,
2260 metadata: value.metadata.map(TryInto::try_into).transpose()?,
2261 score: value.score,
2262 })
2263 }
2264}
2265
2266impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2267 type Error = QueryConversionError;
2268
2269 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2270 Ok(Self {
2271 id: value.id,
2272 document: value.document,
2273 embedding: value
2274 .embedding
2275 .map(|embedding| {
2276 let embedding_dimension = embedding.len();
2277 chroma_proto::Vector::try_from((
2278 embedding,
2279 ScalarEncoding::FLOAT32,
2280 embedding_dimension,
2281 ))
2282 })
2283 .transpose()?,
2284 metadata: value.metadata.map(Into::into),
2285 score: value.score,
2286 })
2287 }
2288}
2289
2290#[derive(Clone, Debug, Default)]
2312pub struct SearchPayloadResult {
2313 pub records: Vec<SearchRecord>,
2314}
2315
2316impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2317 type Error = QueryConversionError;
2318
2319 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2320 Ok(Self {
2321 records: value
2322 .records
2323 .into_iter()
2324 .map(TryInto::try_into)
2325 .collect::<Result<_, _>>()?,
2326 })
2327 }
2328}
2329
2330impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2331 type Error = QueryConversionError;
2332
2333 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2334 Ok(Self {
2335 records: value
2336 .records
2337 .into_iter()
2338 .map(TryInto::try_into)
2339 .collect::<Result<Vec<_>, _>>()?,
2340 })
2341 }
2342}
2343
2344#[derive(Clone, Debug)]
2387pub struct SearchResult {
2388 pub results: Vec<SearchPayloadResult>,
2389 pub pulled_log_bytes: u64,
2390}
2391
2392impl SearchResult {
2393 pub fn size_bytes(&self) -> u64 {
2394 self.results
2395 .iter()
2396 .flat_map(|result| {
2397 result.records.iter().map(|record| {
2398 (record.id.len()
2399 + record
2400 .document
2401 .as_ref()
2402 .map(|doc| doc.len())
2403 .unwrap_or_default()
2404 + record
2405 .embedding
2406 .as_ref()
2407 .map(|emb| size_of_val(&emb[..]))
2408 .unwrap_or_default()
2409 + record
2410 .metadata
2411 .as_ref()
2412 .map(logical_size_of_metadata)
2413 .unwrap_or_default()
2414 + record.score.as_ref().map(size_of_val).unwrap_or_default())
2415 as u64
2416 })
2417 })
2418 .sum()
2419 }
2420}
2421
2422impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2423 type Error = QueryConversionError;
2424
2425 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2426 Ok(Self {
2427 results: value
2428 .results
2429 .into_iter()
2430 .map(TryInto::try_into)
2431 .collect::<Result<_, _>>()?,
2432 pulled_log_bytes: value.pulled_log_bytes,
2433 })
2434 }
2435}
2436
2437impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2438 type Error = QueryConversionError;
2439
2440 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2441 Ok(Self {
2442 results: value
2443 .results
2444 .into_iter()
2445 .map(TryInto::try_into)
2446 .collect::<Result<Vec<_>, _>>()?,
2447 pulled_log_bytes: value.pulled_log_bytes,
2448 })
2449 }
2450}
2451
2452pub fn rrf(
2597 ranks: Vec<RankExpr>,
2598 k: Option<u32>,
2599 weights: Option<Vec<f32>>,
2600 normalize: bool,
2601) -> Result<RankExpr, QueryConversionError> {
2602 let k = k.unwrap_or(60);
2603
2604 if ranks.is_empty() {
2605 return Err(QueryConversionError::validation(
2606 "RRF requires at least one rank expression",
2607 ));
2608 }
2609
2610 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2611
2612 if weights.len() != ranks.len() {
2613 return Err(QueryConversionError::validation(format!(
2614 "RRF weights length ({}) must match ranks length ({})",
2615 weights.len(),
2616 ranks.len()
2617 )));
2618 }
2619
2620 let weights = if normalize {
2621 let sum: f32 = weights.iter().sum();
2622 if sum == 0.0 {
2623 return Err(QueryConversionError::validation(
2624 "RRF weights sum to zero, cannot normalize",
2625 ));
2626 }
2627 weights.into_iter().map(|w| w / sum).collect()
2628 } else {
2629 weights
2630 };
2631
2632 let terms: Vec<RankExpr> = weights
2633 .into_iter()
2634 .zip(ranks)
2635 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2636 .collect();
2637
2638 let sum = terms
2641 .into_iter()
2642 .reduce(|a, b| a + b)
2643 .unwrap_or(RankExpr::Value(0.0));
2644 Ok(-sum)
2645}
2646
2647#[cfg(test)]
2648mod tests {
2649 use super::*;
2650
2651 #[test]
2652 fn test_key_from_string() {
2653 assert_eq!(Key::from("#document"), Key::Document);
2655 assert_eq!(Key::from("#embedding"), Key::Embedding);
2656 assert_eq!(Key::from("#metadata"), Key::Metadata);
2657 assert_eq!(Key::from("#score"), Key::Score);
2658
2659 assert_eq!(
2661 Key::from("custom_field"),
2662 Key::MetadataField("custom_field".to_string())
2663 );
2664 assert_eq!(
2665 Key::from("author"),
2666 Key::MetadataField("author".to_string())
2667 );
2668
2669 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2671 assert_eq!(
2672 Key::from("year".to_string()),
2673 Key::MetadataField("year".to_string())
2674 );
2675 }
2676
2677 #[test]
2678 fn test_query_vector_dense_proto_conversion() {
2679 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2680 let query_vector = QueryVector::Dense(dense_vec.clone());
2681
2682 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2684
2685 let converted: QueryVector = proto.try_into().unwrap();
2687
2688 assert_eq!(converted, query_vector);
2689 if let QueryVector::Dense(v) = converted {
2690 assert_eq!(v, dense_vec);
2691 } else {
2692 panic!("Expected dense vector");
2693 }
2694 }
2695
2696 #[test]
2697 fn test_query_vector_sparse_proto_conversion() {
2698 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2699 let query_vector = QueryVector::Sparse(sparse.clone());
2700
2701 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2703
2704 let converted: QueryVector = proto.try_into().unwrap();
2706
2707 assert_eq!(converted, query_vector);
2708 if let QueryVector::Sparse(s) = converted {
2709 assert_eq!(s, sparse);
2710 } else {
2711 panic!("Expected sparse vector");
2712 }
2713 }
2714
2715 #[test]
2716 fn test_filter_json_deserialization() {
2717 let simple_where = r#"{"author": "John Doe"}"#;
2721 let filter: Filter = serde_json::from_str(simple_where).unwrap();
2722 assert_eq!(filter.query_ids, None);
2723 assert!(filter.where_clause.is_some());
2724
2725 let id_filter_json = serde_json::json!({
2727 "#id": {
2728 "$in": ["doc1", "doc2", "doc3"]
2729 }
2730 });
2731 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2732 assert_eq!(filter.query_ids, None);
2733 assert!(filter.where_clause.is_some());
2734
2735 let complex_json = serde_json::json!({
2737 "$and": [
2738 {
2739 "#id": {
2740 "$in": ["doc1", "doc2", "doc3"]
2741 }
2742 },
2743 {
2744 "$or": [
2745 {
2746 "author": {
2747 "$eq": "John Doe"
2748 }
2749 },
2750 {
2751 "author": {
2752 "$eq": "Jane Smith"
2753 }
2754 }
2755 ]
2756 },
2757 {
2758 "year": {
2759 "$gte": 2020
2760 }
2761 },
2762 {
2763 "tags": {
2764 "$contains": "machine-learning"
2765 }
2766 }
2767 ]
2768 });
2769
2770 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2771 assert_eq!(filter.query_ids, None);
2772 assert!(filter.where_clause.is_some());
2773
2774 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2776 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2777 assert_eq!(composite.children.len(), 4);
2778
2779 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2781 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2782 assert_eq!(or_composite.children.len(), 2);
2783 } else {
2784 panic!("Expected OR composite in second child");
2785 }
2786 } else {
2787 panic!("Expected AND composite where clause");
2788 }
2789
2790 let mixed_operators_json = serde_json::json!({
2792 "$and": [
2793 {
2794 "status": {
2795 "$ne": "deleted"
2796 }
2797 },
2798 {
2799 "score": {
2800 "$gt": 0.5
2801 }
2802 },
2803 {
2804 "score": {
2805 "$lt": 0.9
2806 }
2807 },
2808 {
2809 "priority": {
2810 "$lte": 10
2811 }
2812 }
2813 ]
2814 });
2815
2816 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2817 assert_eq!(filter.query_ids, None);
2818 assert!(filter.where_clause.is_some());
2819
2820 let deeply_nested_json = serde_json::json!({
2822 "$or": [
2823 {
2824 "$and": [
2825 {
2826 "#id": {
2827 "$in": ["id1", "id2"]
2828 }
2829 },
2830 {
2831 "$or": [
2832 {
2833 "category": "tech"
2834 },
2835 {
2836 "category": "science"
2837 }
2838 ]
2839 }
2840 ]
2841 },
2842 {
2843 "$and": [
2844 {
2845 "author": "Admin"
2846 },
2847 {
2848 "published": true
2849 }
2850 ]
2851 }
2852 ]
2853 });
2854
2855 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
2856 assert_eq!(filter.query_ids, None);
2857 assert!(filter.where_clause.is_some());
2858
2859 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2861 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
2862 assert_eq!(composite.children.len(), 2);
2863
2864 for child in &composite.children {
2866 if let crate::metadata::Where::Composite(and_composite) = child {
2867 assert_eq!(
2868 and_composite.operator,
2869 crate::metadata::BooleanOperator::And
2870 );
2871 } else {
2872 panic!("Expected AND composite in OR children");
2873 }
2874 }
2875 } else {
2876 panic!("Expected OR composite at top level");
2877 }
2878
2879 let single_id_json = serde_json::json!({
2881 "#id": {
2882 "$eq": "single-doc-id"
2883 }
2884 });
2885
2886 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
2887 assert_eq!(filter.query_ids, None);
2888 assert!(filter.where_clause.is_some());
2889
2890 let empty_json = serde_json::json!({});
2892 let filter: Filter = serde_json::from_value(empty_json).unwrap();
2893 assert_eq!(filter.query_ids, None);
2894 assert_eq!(filter.where_clause, None);
2896
2897 let advanced_json = serde_json::json!({
2899 "$and": [
2900 {
2901 "#id": {
2902 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
2903 }
2904 },
2905 {
2906 "tags": {
2907 "$not_contains": "deprecated"
2908 }
2909 },
2910 {
2911 "$or": [
2912 {
2913 "$and": [
2914 {
2915 "confidence": {
2916 "$gte": 0.8
2917 }
2918 },
2919 {
2920 "verified": true
2921 }
2922 ]
2923 },
2924 {
2925 "$and": [
2926 {
2927 "confidence": {
2928 "$gte": 0.6
2929 }
2930 },
2931 {
2932 "confidence": {
2933 "$lt": 0.8
2934 }
2935 },
2936 {
2937 "reviews": {
2938 "$gte": 5
2939 }
2940 }
2941 ]
2942 }
2943 ]
2944 }
2945 ]
2946 });
2947
2948 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2949 assert_eq!(filter.query_ids, None);
2950 assert!(filter.where_clause.is_some());
2951
2952 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2954 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2955 assert_eq!(composite.children.len(), 3);
2956 } else {
2957 panic!("Expected AND composite at top level");
2958 }
2959 }
2960
2961 #[test]
2962 fn test_limit_json_serialization() {
2963 let limit = Limit {
2964 offset: 10,
2965 limit: Some(20),
2966 };
2967
2968 let json = serde_json::to_string(&limit).unwrap();
2969 let deserialized: Limit = serde_json::from_str(&json).unwrap();
2970
2971 assert_eq!(deserialized.offset, limit.offset);
2972 assert_eq!(deserialized.limit, limit.limit);
2973 }
2974
2975 #[test]
2976 fn test_query_vector_json_serialization() {
2977 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2979 let json = serde_json::to_string(&dense).unwrap();
2980 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2981 assert_eq!(deserialized, dense);
2982
2983 let sparse =
2985 QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
2986 let json = serde_json::to_string(&sparse).unwrap();
2987 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2988 assert_eq!(deserialized, sparse);
2989 }
2990
2991 #[test]
2992 fn test_select_key_json_serialization() {
2993 use std::collections::HashSet;
2994
2995 let doc_key = Key::Document;
2997 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2998
2999 let embed_key = Key::Embedding;
3000 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3001
3002 let meta_key = Key::Metadata;
3003 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3004
3005 let score_key = Key::Score;
3006 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3007
3008 let custom_key = Key::MetadataField("custom_key".to_string());
3010 assert_eq!(
3011 serde_json::to_string(&custom_key).unwrap(),
3012 "\"custom_key\""
3013 );
3014
3015 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3017 assert!(matches!(deserialized, Key::Document));
3018
3019 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3020 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3021
3022 let mut keys = HashSet::new();
3024 keys.insert(Key::Document);
3025 keys.insert(Key::Embedding);
3026 keys.insert(Key::MetadataField("author".to_string()));
3027
3028 let select = Select { keys };
3029 let json = serde_json::to_string(&select).unwrap();
3030 let deserialized: Select = serde_json::from_str(&json).unwrap();
3031
3032 assert_eq!(deserialized.keys.len(), 3);
3033 assert!(deserialized.keys.contains(&Key::Document));
3034 assert!(deserialized.keys.contains(&Key::Embedding));
3035 assert!(deserialized
3036 .keys
3037 .contains(&Key::MetadataField("author".to_string())));
3038 }
3039
3040 #[test]
3041 fn test_merge_basic_integers() {
3042 use std::cmp::Reverse;
3043
3044 let merge = Merge { k: 5 };
3045
3046 let input = vec![
3048 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3049 vec![Reverse(2), Reverse(5), Reverse(8)],
3050 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3051 ];
3052
3053 let result = merge.merge(input);
3054
3055 assert_eq!(result.len(), 5);
3057 assert_eq!(
3058 result,
3059 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3060 );
3061 }
3062
3063 #[test]
3064 fn test_merge_u32_descending() {
3065 let merge = Merge { k: 6 };
3066
3067 let input = vec![
3069 vec![100u32, 75, 50, 25],
3070 vec![90, 60, 30],
3071 vec![95, 85, 70, 40, 10],
3072 ];
3073
3074 let result = merge.merge(input);
3075
3076 assert_eq!(result.len(), 6);
3078 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3079 }
3080
3081 #[test]
3082 fn test_merge_i32_descending() {
3083 let merge = Merge { k: 5 };
3084
3085 let input = vec![
3087 vec![50i32, 10, -10, -50],
3088 vec![30, 0, -30],
3089 vec![40, 20, -20, -40],
3090 ];
3091
3092 let result = merge.merge(input);
3093
3094 assert_eq!(result.len(), 5);
3096 assert_eq!(result, vec![50, 40, 30, 20, 10]);
3097 }
3098
3099 #[test]
3100 fn test_merge_with_duplicates() {
3101 let merge = Merge { k: 10 };
3102
3103 let input = vec![
3105 vec![100u32, 80, 80, 60, 40],
3106 vec![90, 80, 50, 30],
3107 vec![100, 70, 60, 20],
3108 ];
3109
3110 let result = merge.merge(input);
3111
3112 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3114 }
3115
3116 #[test]
3117 fn test_merge_empty_vectors() {
3118 let merge = Merge { k: 5 };
3119
3120 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3122 let result = merge.merge(input);
3123 assert_eq!(result.len(), 0);
3124
3125 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3127 let result = merge.merge(input);
3128 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3129
3130 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3132 let result = merge.merge(input);
3133 assert_eq!(result, vec![100, 50, 25]);
3134 }
3135
3136 #[test]
3137 fn test_merge_k_boundary_conditions() {
3138 let merge = Merge { k: 0 };
3140 let input = vec![vec![100u32, 50], vec![75, 25]];
3141 let result = merge.merge(input);
3142 assert_eq!(result.len(), 0);
3143
3144 let merge = Merge { k: 1 };
3146 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3147 let result = merge.merge(input);
3148 assert_eq!(result, vec![1000]);
3149
3150 let merge = Merge { k: 100 };
3152 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3153 let result = merge.merge(input);
3154 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3155 }
3156
3157 #[test]
3158 fn test_merge_with_strings() {
3159 let merge = Merge { k: 4 };
3160
3161 let input = vec![
3163 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3164 vec!["elephant".to_string(), "banana".to_string()],
3165 vec!["fish".to_string(), "cat".to_string()],
3166 ];
3167
3168 let result = merge.merge(input);
3169
3170 assert_eq!(
3172 result,
3173 vec![
3174 "zebra".to_string(),
3175 "fish".to_string(),
3176 "elephant".to_string(),
3177 "dog".to_string()
3178 ]
3179 );
3180 }
3181
3182 #[test]
3183 fn test_merge_with_custom_struct() {
3184 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3185 struct Score {
3186 value: i32,
3187 id: String,
3188 }
3189
3190 let merge = Merge { k: 3 };
3191
3192 let input = vec![
3194 vec![
3195 Score {
3196 value: 100,
3197 id: "a".to_string(),
3198 },
3199 Score {
3200 value: 80,
3201 id: "b".to_string(),
3202 },
3203 Score {
3204 value: 60,
3205 id: "c".to_string(),
3206 },
3207 ],
3208 vec![
3209 Score {
3210 value: 90,
3211 id: "d".to_string(),
3212 },
3213 Score {
3214 value: 70,
3215 id: "e".to_string(),
3216 },
3217 ],
3218 vec![
3219 Score {
3220 value: 95,
3221 id: "f".to_string(),
3222 },
3223 Score {
3224 value: 85,
3225 id: "g".to_string(),
3226 },
3227 ],
3228 ];
3229
3230 let result = merge.merge(input);
3231
3232 assert_eq!(result.len(), 3);
3233 assert_eq!(
3234 result[0],
3235 Score {
3236 value: 100,
3237 id: "a".to_string()
3238 }
3239 );
3240 assert_eq!(
3241 result[1],
3242 Score {
3243 value: 95,
3244 id: "f".to_string()
3245 }
3246 );
3247 assert_eq!(
3248 result[2],
3249 Score {
3250 value: 90,
3251 id: "d".to_string()
3252 }
3253 );
3254 }
3255
3256 #[test]
3257 fn test_merge_preserves_order() {
3258 use std::cmp::Reverse;
3259
3260 let merge = Merge { k: 10 };
3261
3262 let input = vec![
3265 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3266 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3267 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3268 ];
3269
3270 let result = merge.merge(input);
3271
3272 for i in 1..result.len() {
3275 assert!(
3276 result[i - 1] >= result[i],
3277 "Output should be in descending Reverse order"
3278 );
3279 assert!(
3280 result[i - 1].0 <= result[i].0,
3281 "Inner values should be in ascending order"
3282 );
3283 }
3284
3285 assert_eq!(
3287 result,
3288 vec![
3289 Reverse(1),
3290 Reverse(2),
3291 Reverse(3),
3292 Reverse(4),
3293 Reverse(5),
3294 Reverse(6),
3295 Reverse(7),
3296 Reverse(8),
3297 Reverse(9),
3298 Reverse(10)
3299 ]
3300 );
3301 }
3302
3303 #[test]
3304 fn test_merge_single_vector() {
3305 let merge = Merge { k: 3 };
3306
3307 let input = vec![vec![1000u64, 800, 600, 400, 200]];
3309
3310 let result = merge.merge(input);
3311
3312 assert_eq!(result, vec![1000, 800, 600]);
3313 }
3314
3315 #[test]
3316 fn test_merge_all_same_values() {
3317 let merge = Merge { k: 5 };
3318
3319 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3321
3322 let result = merge.merge(input);
3323
3324 assert_eq!(result, vec![42]);
3326 }
3327
3328 #[test]
3329 fn test_merge_mixed_types_sizes() {
3330 let merge = Merge { k: 4 };
3332 let input = vec![
3333 vec![1000usize, 500, 100],
3334 vec![800, 300],
3335 vec![900, 600, 200],
3336 ];
3337 let result = merge.merge(input);
3338 assert_eq!(result, vec![1000, 900, 800, 600]);
3339
3340 let merge = Merge { k: 5 };
3342 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3343 let result = merge.merge(input);
3344 assert_eq!(result, vec![15, 10, 5, 0, -5]);
3345 }
3346}