1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashSet},
6 fmt,
7 hash::Hash,
8 ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13 chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14 DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15 MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16 Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23#[derive(Clone, Debug)]
28pub struct Scan {
29 pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36 Ok(Self {
37 collection_and_segments: CollectionAndSegments {
38 collection: value
39 .collection
40 .ok_or(QueryConversionError::field("collection"))?
41 .try_into()?,
42 metadata_segment: value
43 .metadata
44 .ok_or(QueryConversionError::field("metadata segment"))?
45 .try_into()?,
46 record_segment: value
47 .record
48 .ok_or(QueryConversionError::field("record segment"))?
49 .try_into()?,
50 vector_segment: value
51 .knn
52 .ok_or(QueryConversionError::field("vector segment"))?
53 .try_into()?,
54 },
55 })
56 }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61 #[error("Could not convert collection to proto")]
62 CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66 type Error = ScanToProtoError;
67
68 fn try_from(value: Scan) -> Result<Self, Self::Error> {
69 Ok(Self {
70 collection: Some(value.collection_and_segments.collection.try_into()?),
71 knn: Some(value.collection_and_segments.vector_segment.into()),
72 metadata: Some(value.collection_and_segments.metadata_segment.into()),
73 record: Some(value.collection_and_segments.record_segment.into()),
74 })
75 }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80 pub count: u32,
81 pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85 pub fn size_bytes(&self) -> u64 {
86 size_of_val(&self.count) as u64
87 }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91 fn from(value: chroma_proto::CountResult) -> Self {
92 Self {
93 count: value.count,
94 pulled_log_bytes: value.pulled_log_bytes,
95 }
96 }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100 fn from(value: CountResult) -> Self {
101 Self {
102 count: value.count,
103 pulled_log_bytes: value.pulled_log_bytes,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
115pub struct FetchLog {
116 pub collection_uuid: CollectionUuid,
117 pub maximum_fetch_count: Option<u32>,
118 pub start_log_offset_id: u32,
119}
120
121#[derive(Clone, Debug, Default)]
170pub struct Filter {
171 pub query_ids: Option<Vec<String>>,
172 pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 match (&self.query_ids, &self.where_clause) {
184 (None, None) => {
185 let map = serializer.serialize_map(Some(0))?;
187 map.end()
188 }
189 (None, Some(where_clause)) => {
190 where_clause.serialize(serializer)
192 }
193 (Some(ids), None) => {
194 let id_where = Where::Metadata(MetadataExpression {
196 key: "#id".to_string(),
197 comparison: MetadataComparison::Set(
198 SetOperator::In,
199 MetadataSetValue::Str(ids.clone()),
200 ),
201 });
202 id_where.serialize(serializer)
203 }
204 (Some(ids), Some(where_clause)) => {
205 let id_where = Where::Metadata(MetadataExpression {
207 key: "#id".to_string(),
208 comparison: MetadataComparison::Set(
209 SetOperator::In,
210 MetadataSetValue::Str(ids.clone()),
211 ),
212 });
213 let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214 combined.serialize(serializer)
215 }
216 }
217 }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let where_json = Value::deserialize(deserializer)?;
227 let where_clause =
228 if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229 None
230 } else {
231 Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232 };
233
234 Ok(Filter {
235 query_ids: None, where_clause,
237 })
238 }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242 type Error = QueryConversionError;
243
244 fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245 let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246 let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247 let where_clause = match (where_metadata, where_document) {
248 (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249 (Some(w), None) | (None, Some(w)) => Some(w),
250 _ => None,
251 };
252
253 Ok(Self {
254 query_ids: value.ids.map(|uids| uids.ids),
255 where_clause,
256 })
257 }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261 type Error = QueryConversionError;
262
263 fn try_from(value: Filter) -> Result<Self, Self::Error> {
264 Ok(Self {
265 ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266 r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267 where_document: None,
268 })
269 }
270}
271
272#[derive(Clone, Debug)]
278pub struct Knn {
279 pub embedding: Vec<f32>,
280 pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284 fn from(value: KnnBatch) -> Self {
285 value
286 .embeddings
287 .into_iter()
288 .map(|embedding| Knn {
289 embedding,
290 fetch: value.fetch,
291 })
292 .collect()
293 }
294}
295
296#[derive(Clone, Debug)]
302pub struct KnnBatch {
303 pub embeddings: Vec<Vec<f32>>,
304 pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308 type Error = QueryConversionError;
309
310 fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311 Ok(Self {
312 embeddings: value
313 .embeddings
314 .into_iter()
315 .map(|vec| vec.try_into().map(|(v, _)| v))
316 .collect::<Result<_, _>>()?,
317 fetch: value.fetch,
318 })
319 }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323 type Error = QueryConversionError;
324
325 fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326 Ok(Self {
327 embeddings: value
328 .embeddings
329 .into_iter()
330 .map(|embedding| {
331 let dim = embedding.len();
332 chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333 })
334 .collect::<Result<_, _>>()?,
335 fetch: value.fetch,
336 })
337 }
338}
339
340#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374 #[serde(default)]
375 pub offset: u32,
376 #[serde(default)]
377 pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381 fn from(value: chroma_proto::LimitOperator) -> Self {
382 Self {
383 offset: value.offset,
384 limit: value.limit,
385 }
386 }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390 fn from(value: Limit) -> Self {
391 Self {
392 offset: value.offset,
393 limit: value.limit,
394 }
395 }
396}
397
398#[derive(Clone, Debug)]
400pub struct RecordMeasure {
401 pub offset_id: u32,
402 pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406 fn eq(&self, other: &Self) -> bool {
407 self.offset_id.eq(&other.offset_id)
408 }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414 fn cmp(&self, other: &Self) -> Ordering {
415 self.measure.total_cmp(&other.measure)
416 }
417}
418
419impl PartialOrd for RecordMeasure {
420 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
421 Some(self.cmp(other))
422 }
423}
424
425#[derive(Debug, Default)]
426pub struct KnnOutput {
427 pub distances: Vec<RecordMeasure>,
428}
429
430#[derive(Clone, Debug)]
440pub struct Merge {
441 pub k: u32,
442}
443
444impl Merge {
445 pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
446 let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
447
448 let mut max_heap = batch_iters
449 .iter_mut()
450 .enumerate()
451 .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
452 .collect::<BinaryHeap<_>>();
453
454 let mut fusion = Vec::with_capacity(self.k as usize);
455 while let Some((m, idx)) = max_heap.pop() {
456 if self.k <= fusion.len() as u32 {
457 break;
458 }
459 if let Some(next_m) = batch_iters[idx].next() {
460 max_heap.push((next_m, idx));
461 }
462 if fusion.last().is_some_and(|tail| tail == &m) {
463 continue;
464 }
465 fusion.push(m);
466 }
467 fusion
468 }
469}
470
471#[derive(Clone, Debug, Default)]
478pub struct Projection {
479 pub document: bool,
480 pub embedding: bool,
481 pub metadata: bool,
482}
483
484impl From<chroma_proto::ProjectionOperator> for Projection {
485 fn from(value: chroma_proto::ProjectionOperator) -> Self {
486 Self {
487 document: value.document,
488 embedding: value.embedding,
489 metadata: value.metadata,
490 }
491 }
492}
493
494impl From<Projection> for chroma_proto::ProjectionOperator {
495 fn from(value: Projection) -> Self {
496 Self {
497 document: value.document,
498 embedding: value.embedding,
499 metadata: value.metadata,
500 }
501 }
502}
503
504#[derive(Clone, Debug, PartialEq)]
505pub struct ProjectionRecord {
506 pub id: String,
507 pub document: Option<String>,
508 pub embedding: Option<Vec<f32>>,
509 pub metadata: Option<Metadata>,
510}
511
512impl ProjectionRecord {
513 pub fn size_bytes(&self) -> u64 {
514 (self.id.len()
515 + self
516 .document
517 .as_ref()
518 .map(|doc| doc.len())
519 .unwrap_or_default()
520 + self
521 .embedding
522 .as_ref()
523 .map(|emb| size_of_val(&emb[..]))
524 .unwrap_or_default()
525 + self
526 .metadata
527 .as_ref()
528 .map(logical_size_of_metadata)
529 .unwrap_or_default()) as u64
530 }
531}
532
533impl Eq for ProjectionRecord {}
534
535impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
536 type Error = QueryConversionError;
537
538 fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
539 Ok(Self {
540 id: value.id,
541 document: value.document,
542 embedding: value
543 .embedding
544 .map(|vec| vec.try_into().map(|(v, _)| v))
545 .transpose()?,
546 metadata: value.metadata.map(TryInto::try_into).transpose()?,
547 })
548 }
549}
550
551impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
552 type Error = QueryConversionError;
553
554 fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
555 Ok(Self {
556 id: value.id,
557 document: value.document,
558 embedding: value
559 .embedding
560 .map(|embedding| {
561 let embedding_dimension = embedding.len();
562 chroma_proto::Vector::try_from((
563 embedding,
564 ScalarEncoding::FLOAT32,
565 embedding_dimension,
566 ))
567 })
568 .transpose()?,
569 metadata: value.metadata.map(|metadata| metadata.into()),
570 })
571 }
572}
573
574#[derive(Clone, Debug, Eq, PartialEq)]
575pub struct ProjectionOutput {
576 pub records: Vec<ProjectionRecord>,
577}
578
579#[derive(Clone, Debug, Eq, PartialEq)]
580pub struct GetResult {
581 pub pulled_log_bytes: u64,
582 pub result: ProjectionOutput,
583}
584
585impl GetResult {
586 pub fn size_bytes(&self) -> u64 {
587 self.result
588 .records
589 .iter()
590 .map(ProjectionRecord::size_bytes)
591 .sum()
592 }
593}
594
595impl TryFrom<chroma_proto::GetResult> for GetResult {
596 type Error = QueryConversionError;
597
598 fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
599 Ok(Self {
600 pulled_log_bytes: value.pulled_log_bytes,
601 result: ProjectionOutput {
602 records: value
603 .records
604 .into_iter()
605 .map(TryInto::try_into)
606 .collect::<Result<_, _>>()?,
607 },
608 })
609 }
610}
611
612impl TryFrom<GetResult> for chroma_proto::GetResult {
613 type Error = QueryConversionError;
614
615 fn try_from(value: GetResult) -> Result<Self, Self::Error> {
616 Ok(Self {
617 pulled_log_bytes: value.pulled_log_bytes,
618 records: value
619 .result
620 .records
621 .into_iter()
622 .map(TryInto::try_into)
623 .collect::<Result<_, _>>()?,
624 })
625 }
626}
627
628#[derive(Clone, Debug)]
636pub struct KnnProjection {
637 pub projection: Projection,
638 pub distance: bool,
639}
640
641impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
642 type Error = QueryConversionError;
643
644 fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
645 Ok(Self {
646 projection: value
647 .projection
648 .ok_or(QueryConversionError::field("projection"))?
649 .into(),
650 distance: value.distance,
651 })
652 }
653}
654
655impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
656 fn from(value: KnnProjection) -> Self {
657 Self {
658 projection: Some(value.projection.into()),
659 distance: value.distance,
660 }
661 }
662}
663
664#[derive(Clone, Debug)]
665pub struct KnnProjectionRecord {
666 pub record: ProjectionRecord,
667 pub distance: Option<f32>,
668}
669
670impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
671 type Error = QueryConversionError;
672
673 fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
674 Ok(Self {
675 record: value
676 .record
677 .ok_or(QueryConversionError::field("record"))?
678 .try_into()?,
679 distance: value.distance,
680 })
681 }
682}
683
684impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
685 type Error = QueryConversionError;
686
687 fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
688 Ok(Self {
689 record: Some(value.record.try_into()?),
690 distance: value.distance,
691 })
692 }
693}
694
695#[derive(Clone, Debug, Default)]
696pub struct KnnProjectionOutput {
697 pub records: Vec<KnnProjectionRecord>,
698}
699
700impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
701 type Error = QueryConversionError;
702
703 fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
704 Ok(Self {
705 records: value
706 .records
707 .into_iter()
708 .map(TryInto::try_into)
709 .collect::<Result<_, _>>()?,
710 })
711 }
712}
713
714impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
715 type Error = QueryConversionError;
716
717 fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
718 Ok(Self {
719 records: value
720 .records
721 .into_iter()
722 .map(TryInto::try_into)
723 .collect::<Result<_, _>>()?,
724 })
725 }
726}
727
728#[derive(Clone, Debug, Default)]
729pub struct KnnBatchResult {
730 pub pulled_log_bytes: u64,
731 pub results: Vec<KnnProjectionOutput>,
732}
733
734impl KnnBatchResult {
735 pub fn size_bytes(&self) -> u64 {
736 self.results
737 .iter()
738 .flat_map(|res| {
739 res.records
740 .iter()
741 .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
742 })
743 .sum()
744 }
745}
746
747impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
748 type Error = QueryConversionError;
749
750 fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
751 Ok(Self {
752 pulled_log_bytes: value.pulled_log_bytes,
753 results: value
754 .results
755 .into_iter()
756 .map(TryInto::try_into)
757 .collect::<Result<_, _>>()?,
758 })
759 }
760}
761
762impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
763 type Error = QueryConversionError;
764
765 fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
766 Ok(Self {
767 pulled_log_bytes: value.pulled_log_bytes,
768 results: value
769 .results
770 .into_iter()
771 .map(TryInto::try_into)
772 .collect::<Result<_, _>>()?,
773 })
774 }
775}
776
777#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
841#[serde(untagged)]
842pub enum QueryVector {
843 Dense(Vec<f32>),
844 Sparse(SparseVector),
845}
846
847impl TryFrom<chroma_proto::QueryVector> for QueryVector {
848 type Error = QueryConversionError;
849
850 fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
851 let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
852 match vector {
853 chroma_proto::query_vector::Vector::Dense(dense) => {
854 Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
855 }
856 chroma_proto::query_vector::Vector::Sparse(sparse) => {
857 Ok(QueryVector::Sparse(sparse.into()))
858 }
859 }
860 }
861}
862
863impl TryFrom<QueryVector> for chroma_proto::QueryVector {
864 type Error = QueryConversionError;
865
866 fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
867 match value {
868 QueryVector::Dense(vec) => {
869 let dim = vec.len();
870 Ok(chroma_proto::QueryVector {
871 vector: Some(chroma_proto::query_vector::Vector::Dense(
872 chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
873 )),
874 })
875 }
876 QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
877 vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
878 }),
879 }
880 }
881}
882
883impl From<Vec<f32>> for QueryVector {
884 fn from(vec: Vec<f32>) -> Self {
885 QueryVector::Dense(vec)
886 }
887}
888
889impl From<SparseVector> for QueryVector {
890 fn from(sparse: SparseVector) -> Self {
891 QueryVector::Sparse(sparse)
892 }
893}
894
895#[derive(Clone, Debug, PartialEq)]
896pub struct KnnQuery {
897 pub query: QueryVector,
898 pub key: Key,
899 pub limit: u32,
900}
901
902#[derive(Clone, Debug, Default, Deserialize, Serialize)]
933#[serde(transparent)]
934pub struct Rank {
935 pub expr: Option<RankExpr>,
936}
937
938impl Rank {
939 pub fn knn_queries(&self) -> Vec<KnnQuery> {
940 self.expr
941 .as_ref()
942 .map(RankExpr::knn_queries)
943 .unwrap_or_default()
944 }
945}
946
947impl TryFrom<chroma_proto::RankOperator> for Rank {
948 type Error = QueryConversionError;
949
950 fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
951 Ok(Rank {
952 expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
953 })
954 }
955}
956
957impl TryFrom<Rank> for chroma_proto::RankOperator {
958 type Error = QueryConversionError;
959
960 fn try_from(rank: Rank) -> Result<Self, Self::Error> {
961 Ok(chroma_proto::RankOperator {
962 expr: rank.expr.map(TryInto::try_into).transpose()?,
963 })
964 }
965}
966
967#[derive(Clone, Debug, Deserialize, Serialize)]
1130pub enum RankExpr {
1131 #[serde(rename = "$abs")]
1132 Absolute(Box<RankExpr>),
1133 #[serde(rename = "$div")]
1134 Division {
1135 left: Box<RankExpr>,
1136 right: Box<RankExpr>,
1137 },
1138 #[serde(rename = "$exp")]
1139 Exponentiation(Box<RankExpr>),
1140 #[serde(rename = "$knn")]
1141 Knn {
1142 query: QueryVector,
1143 #[serde(default = "RankExpr::default_knn_key")]
1144 key: Key,
1145 #[serde(default = "RankExpr::default_knn_limit")]
1146 limit: u32,
1147 #[serde(default)]
1148 default: Option<f32>,
1149 #[serde(default)]
1150 return_rank: bool,
1151 },
1152 #[serde(rename = "$log")]
1153 Logarithm(Box<RankExpr>),
1154 #[serde(rename = "$max")]
1155 Maximum(Vec<RankExpr>),
1156 #[serde(rename = "$min")]
1157 Minimum(Vec<RankExpr>),
1158 #[serde(rename = "$mul")]
1159 Multiplication(Vec<RankExpr>),
1160 #[serde(rename = "$sub")]
1161 Subtraction {
1162 left: Box<RankExpr>,
1163 right: Box<RankExpr>,
1164 },
1165 #[serde(rename = "$sum")]
1166 Summation(Vec<RankExpr>),
1167 #[serde(rename = "$val")]
1168 Value(f32),
1169}
1170
1171impl RankExpr {
1172 pub fn default_knn_key() -> Key {
1173 Key::Embedding
1174 }
1175
1176 pub fn default_knn_limit() -> u32 {
1177 16
1178 }
1179
1180 pub fn knn_queries(&self) -> Vec<KnnQuery> {
1181 match self {
1182 RankExpr::Absolute(expr)
1183 | RankExpr::Exponentiation(expr)
1184 | RankExpr::Logarithm(expr) => expr.knn_queries(),
1185 RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1186 .knn_queries()
1187 .into_iter()
1188 .chain(right.knn_queries())
1189 .collect(),
1190 RankExpr::Maximum(exprs)
1191 | RankExpr::Minimum(exprs)
1192 | RankExpr::Multiplication(exprs)
1193 | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1194 RankExpr::Value(_) => Vec::new(),
1195 RankExpr::Knn {
1196 query,
1197 key,
1198 limit,
1199 default: _,
1200 return_rank: _,
1201 } => vec![KnnQuery {
1202 query: query.clone(),
1203 key: key.clone(),
1204 limit: *limit,
1205 }],
1206 }
1207 }
1208
1209 pub fn exp(self) -> Self {
1229 RankExpr::Exponentiation(Box::new(self))
1230 }
1231
1232 pub fn log(self) -> Self {
1253 RankExpr::Logarithm(Box::new(self))
1254 }
1255
1256 pub fn abs(self) -> Self {
1283 RankExpr::Absolute(Box::new(self))
1284 }
1285
1286 pub fn max(self, other: impl Into<RankExpr>) -> Self {
1310 let other = other.into();
1311
1312 match self {
1313 RankExpr::Maximum(mut exprs) => match other {
1314 RankExpr::Maximum(other_exprs) => {
1315 exprs.extend(other_exprs);
1316 RankExpr::Maximum(exprs)
1317 }
1318 _ => {
1319 exprs.push(other);
1320 RankExpr::Maximum(exprs)
1321 }
1322 },
1323 _ => match other {
1324 RankExpr::Maximum(mut exprs) => {
1325 exprs.insert(0, self);
1326 RankExpr::Maximum(exprs)
1327 }
1328 _ => RankExpr::Maximum(vec![self, other]),
1329 },
1330 }
1331 }
1332
1333 pub fn min(self, other: impl Into<RankExpr>) -> Self {
1357 let other = other.into();
1358
1359 match self {
1360 RankExpr::Minimum(mut exprs) => match other {
1361 RankExpr::Minimum(other_exprs) => {
1362 exprs.extend(other_exprs);
1363 RankExpr::Minimum(exprs)
1364 }
1365 _ => {
1366 exprs.push(other);
1367 RankExpr::Minimum(exprs)
1368 }
1369 },
1370 _ => match other {
1371 RankExpr::Minimum(mut exprs) => {
1372 exprs.insert(0, self);
1373 RankExpr::Minimum(exprs)
1374 }
1375 _ => RankExpr::Minimum(vec![self, other]),
1376 },
1377 }
1378 }
1379}
1380
1381impl Add for RankExpr {
1382 type Output = RankExpr;
1383
1384 fn add(self, rhs: Self) -> Self::Output {
1385 match self {
1386 RankExpr::Summation(mut exprs) => match rhs {
1387 RankExpr::Summation(rhs_exprs) => {
1388 exprs.extend(rhs_exprs);
1389 RankExpr::Summation(exprs)
1390 }
1391 _ => {
1392 exprs.push(rhs);
1393 RankExpr::Summation(exprs)
1394 }
1395 },
1396 _ => match rhs {
1397 RankExpr::Summation(mut exprs) => {
1398 exprs.insert(0, self);
1399 RankExpr::Summation(exprs)
1400 }
1401 _ => RankExpr::Summation(vec![self, rhs]),
1402 },
1403 }
1404 }
1405}
1406
1407impl Add<f32> for RankExpr {
1408 type Output = RankExpr;
1409
1410 fn add(self, rhs: f32) -> Self::Output {
1411 self + RankExpr::Value(rhs)
1412 }
1413}
1414
1415impl Add<RankExpr> for f32 {
1416 type Output = RankExpr;
1417
1418 fn add(self, rhs: RankExpr) -> Self::Output {
1419 RankExpr::Value(self) + rhs
1420 }
1421}
1422
1423impl Sub for RankExpr {
1424 type Output = RankExpr;
1425
1426 fn sub(self, rhs: Self) -> Self::Output {
1427 RankExpr::Subtraction {
1428 left: Box::new(self),
1429 right: Box::new(rhs),
1430 }
1431 }
1432}
1433
1434impl Sub<f32> for RankExpr {
1435 type Output = RankExpr;
1436
1437 fn sub(self, rhs: f32) -> Self::Output {
1438 self - RankExpr::Value(rhs)
1439 }
1440}
1441
1442impl Sub<RankExpr> for f32 {
1443 type Output = RankExpr;
1444
1445 fn sub(self, rhs: RankExpr) -> Self::Output {
1446 RankExpr::Value(self) - rhs
1447 }
1448}
1449
1450impl Mul for RankExpr {
1451 type Output = RankExpr;
1452
1453 fn mul(self, rhs: Self) -> Self::Output {
1454 match self {
1455 RankExpr::Multiplication(mut exprs) => match rhs {
1456 RankExpr::Multiplication(rhs_exprs) => {
1457 exprs.extend(rhs_exprs);
1458 RankExpr::Multiplication(exprs)
1459 }
1460 _ => {
1461 exprs.push(rhs);
1462 RankExpr::Multiplication(exprs)
1463 }
1464 },
1465 _ => match rhs {
1466 RankExpr::Multiplication(mut exprs) => {
1467 exprs.insert(0, self);
1468 RankExpr::Multiplication(exprs)
1469 }
1470 _ => RankExpr::Multiplication(vec![self, rhs]),
1471 },
1472 }
1473 }
1474}
1475
1476impl Mul<f32> for RankExpr {
1477 type Output = RankExpr;
1478
1479 fn mul(self, rhs: f32) -> Self::Output {
1480 self * RankExpr::Value(rhs)
1481 }
1482}
1483
1484impl Mul<RankExpr> for f32 {
1485 type Output = RankExpr;
1486
1487 fn mul(self, rhs: RankExpr) -> Self::Output {
1488 RankExpr::Value(self) * rhs
1489 }
1490}
1491
1492impl Div for RankExpr {
1493 type Output = RankExpr;
1494
1495 fn div(self, rhs: Self) -> Self::Output {
1496 RankExpr::Division {
1497 left: Box::new(self),
1498 right: Box::new(rhs),
1499 }
1500 }
1501}
1502
1503impl Div<f32> for RankExpr {
1504 type Output = RankExpr;
1505
1506 fn div(self, rhs: f32) -> Self::Output {
1507 self / RankExpr::Value(rhs)
1508 }
1509}
1510
1511impl Div<RankExpr> for f32 {
1512 type Output = RankExpr;
1513
1514 fn div(self, rhs: RankExpr) -> Self::Output {
1515 RankExpr::Value(self) / rhs
1516 }
1517}
1518
1519impl Neg for RankExpr {
1520 type Output = RankExpr;
1521
1522 fn neg(self) -> Self::Output {
1523 RankExpr::Value(-1.0) * self
1524 }
1525}
1526
1527impl From<f32> for RankExpr {
1528 fn from(v: f32) -> Self {
1529 RankExpr::Value(v)
1530 }
1531}
1532
1533impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1534 type Error = QueryConversionError;
1535
1536 fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1537 match proto_expr.rank {
1538 Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1539 Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1540 }
1541 Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1542 let left = div.left.ok_or(QueryConversionError::field("left"))?;
1543 let right = div.right.ok_or(QueryConversionError::field("right"))?;
1544 Ok(RankExpr::Division {
1545 left: Box::new(RankExpr::try_from(*left)?),
1546 right: Box::new(RankExpr::try_from(*right)?),
1547 })
1548 }
1549 Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1550 RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1551 ),
1552 Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1553 let query = knn
1554 .query
1555 .ok_or(QueryConversionError::field("query"))?
1556 .try_into()?;
1557 Ok(RankExpr::Knn {
1558 query,
1559 key: Key::from(knn.key),
1560 limit: knn.limit,
1561 default: knn.default,
1562 return_rank: knn.return_rank,
1563 })
1564 }
1565 Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1566 Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1567 }
1568 Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1569 let exprs = max
1570 .exprs
1571 .into_iter()
1572 .map(RankExpr::try_from)
1573 .collect::<Result<Vec<_>, _>>()?;
1574 Ok(RankExpr::Maximum(exprs))
1575 }
1576 Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1577 let exprs = min
1578 .exprs
1579 .into_iter()
1580 .map(RankExpr::try_from)
1581 .collect::<Result<Vec<_>, _>>()?;
1582 Ok(RankExpr::Minimum(exprs))
1583 }
1584 Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1585 let exprs = mul
1586 .exprs
1587 .into_iter()
1588 .map(RankExpr::try_from)
1589 .collect::<Result<Vec<_>, _>>()?;
1590 Ok(RankExpr::Multiplication(exprs))
1591 }
1592 Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1593 let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1594 let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1595 Ok(RankExpr::Subtraction {
1596 left: Box::new(RankExpr::try_from(*left)?),
1597 right: Box::new(RankExpr::try_from(*right)?),
1598 })
1599 }
1600 Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1601 let exprs = sum
1602 .exprs
1603 .into_iter()
1604 .map(RankExpr::try_from)
1605 .collect::<Result<Vec<_>, _>>()?;
1606 Ok(RankExpr::Summation(exprs))
1607 }
1608 Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1609 None => Err(QueryConversionError::field("rank")),
1610 }
1611 }
1612}
1613
1614impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1615 type Error = QueryConversionError;
1616
1617 fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1618 let proto_rank = match rank_expr {
1619 RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1620 chroma_proto::RankExpr::try_from(*expr)?,
1621 )),
1622 RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1623 Box::new(chroma_proto::rank_expr::RankPair {
1624 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1625 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1626 }),
1627 ),
1628 RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1629 Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1630 ),
1631 RankExpr::Knn {
1632 query,
1633 key,
1634 limit,
1635 default,
1636 return_rank,
1637 } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1638 query: Some(query.try_into()?),
1639 key: key.to_string(),
1640 limit,
1641 default,
1642 return_rank,
1643 }),
1644 RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1645 chroma_proto::RankExpr::try_from(*expr)?,
1646 )),
1647 RankExpr::Maximum(exprs) => {
1648 let proto_exprs = exprs
1649 .into_iter()
1650 .map(chroma_proto::RankExpr::try_from)
1651 .collect::<Result<Vec<_>, _>>()?;
1652 chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1653 exprs: proto_exprs,
1654 })
1655 }
1656 RankExpr::Minimum(exprs) => {
1657 let proto_exprs = exprs
1658 .into_iter()
1659 .map(chroma_proto::RankExpr::try_from)
1660 .collect::<Result<Vec<_>, _>>()?;
1661 chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1662 exprs: proto_exprs,
1663 })
1664 }
1665 RankExpr::Multiplication(exprs) => {
1666 let proto_exprs = exprs
1667 .into_iter()
1668 .map(chroma_proto::RankExpr::try_from)
1669 .collect::<Result<Vec<_>, _>>()?;
1670 chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1671 exprs: proto_exprs,
1672 })
1673 }
1674 RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1675 Box::new(chroma_proto::rank_expr::RankPair {
1676 left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1677 right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1678 }),
1679 ),
1680 RankExpr::Summation(exprs) => {
1681 let proto_exprs = exprs
1682 .into_iter()
1683 .map(chroma_proto::RankExpr::try_from)
1684 .collect::<Result<Vec<_>, _>>()?;
1685 chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1686 exprs: proto_exprs,
1687 })
1688 }
1689 RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1690 };
1691
1692 Ok(chroma_proto::RankExpr {
1693 rank: Some(proto_rank),
1694 })
1695 }
1696}
1697
1698#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1763#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1764pub enum Key {
1765 Document,
1767 Embedding,
1768 Metadata,
1769 Score,
1770 MetadataField(String),
1771}
1772
1773impl Serialize for Key {
1774 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1775 where
1776 S: serde::Serializer,
1777 {
1778 match self {
1779 Key::Document => serializer.serialize_str("#document"),
1780 Key::Embedding => serializer.serialize_str("#embedding"),
1781 Key::Metadata => serializer.serialize_str("#metadata"),
1782 Key::Score => serializer.serialize_str("#score"),
1783 Key::MetadataField(field) => serializer.serialize_str(field),
1784 }
1785 }
1786}
1787
1788impl<'de> Deserialize<'de> for Key {
1789 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1790 where
1791 D: Deserializer<'de>,
1792 {
1793 let s = String::deserialize(deserializer)?;
1794 Ok(Key::from(s))
1795 }
1796}
1797
1798impl fmt::Display for Key {
1799 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1800 match self {
1801 Key::Document => write!(f, "#document"),
1802 Key::Embedding => write!(f, "#embedding"),
1803 Key::Metadata => write!(f, "#metadata"),
1804 Key::Score => write!(f, "#score"),
1805 Key::MetadataField(field) => write!(f, "{}", field),
1806 }
1807 }
1808}
1809
1810impl From<&str> for Key {
1811 fn from(s: &str) -> Self {
1812 match s {
1813 "#document" => Key::Document,
1814 "#embedding" => Key::Embedding,
1815 "#metadata" => Key::Metadata,
1816 "#score" => Key::Score,
1817 field => Key::MetadataField(field.to_string()),
1819 }
1820 }
1821}
1822
1823impl From<String> for Key {
1824 fn from(s: String) -> Self {
1825 Key::from(s.as_str())
1826 }
1827}
1828
1829impl Key {
1830 pub fn field(name: impl Into<String>) -> Self {
1842 Key::MetadataField(name.into())
1843 }
1844
1845 pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1862 Where::Metadata(MetadataExpression {
1863 key: self.to_string(),
1864 comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1865 })
1866 }
1867
1868 pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1879 Where::Metadata(MetadataExpression {
1880 key: self.to_string(),
1881 comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1882 })
1883 }
1884
1885 pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1896 Where::Metadata(MetadataExpression {
1897 key: self.to_string(),
1898 comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1899 })
1900 }
1901
1902 pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1913 Where::Metadata(MetadataExpression {
1914 key: self.to_string(),
1915 comparison: MetadataComparison::Primitive(
1916 PrimitiveOperator::GreaterThanOrEqual,
1917 value.into(),
1918 ),
1919 })
1920 }
1921
1922 pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1933 Where::Metadata(MetadataExpression {
1934 key: self.to_string(),
1935 comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1936 })
1937 }
1938
1939 pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1950 Where::Metadata(MetadataExpression {
1951 key: self.to_string(),
1952 comparison: MetadataComparison::Primitive(
1953 PrimitiveOperator::LessThanOrEqual,
1954 value.into(),
1955 ),
1956 })
1957 }
1958
1959 pub fn is_in<I, T>(self, values: I) -> Where
1979 where
1980 I: IntoIterator<Item = T>,
1981 Vec<T>: Into<MetadataSetValue>,
1982 {
1983 let vec: Vec<T> = values.into_iter().collect();
1984 Where::Metadata(MetadataExpression {
1985 key: self.to_string(),
1986 comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1987 })
1988 }
1989
1990 pub fn not_in<I, T>(self, values: I) -> Where
2006 where
2007 I: IntoIterator<Item = T>,
2008 Vec<T>: Into<MetadataSetValue>,
2009 {
2010 let vec: Vec<T> = values.into_iter().collect();
2011 Where::Metadata(MetadataExpression {
2012 key: self.to_string(),
2013 comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2014 })
2015 }
2016
2017 pub fn contains<S: Into<String>>(self, text: S) -> Where {
2031 Where::Document(DocumentExpression {
2032 operator: DocumentOperator::Contains,
2033 pattern: text.into(),
2034 })
2035 }
2036
2037 pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2050 Where::Document(DocumentExpression {
2051 operator: DocumentOperator::NotContains,
2052 pattern: text.into(),
2053 })
2054 }
2055
2056 pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2073 Where::Document(DocumentExpression {
2074 operator: DocumentOperator::Regex,
2075 pattern: pattern.into(),
2076 })
2077 }
2078
2079 pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2095 Where::Document(DocumentExpression {
2096 operator: DocumentOperator::NotRegex,
2097 pattern: pattern.into(),
2098 })
2099 }
2100}
2101
2102#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2156pub struct Select {
2157 #[serde(default)]
2158 pub keys: HashSet<Key>,
2159}
2160
2161impl TryFrom<chroma_proto::SelectOperator> for Select {
2162 type Error = QueryConversionError;
2163
2164 fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2165 let keys = value
2166 .keys
2167 .into_iter()
2168 .map(|key| {
2169 serde_json::from_value(serde_json::Value::String(key))
2171 .map_err(|_| QueryConversionError::field("keys"))
2172 })
2173 .collect::<Result<HashSet<_>, _>>()?;
2174
2175 Ok(Self { keys })
2176 }
2177}
2178
2179impl TryFrom<Select> for chroma_proto::SelectOperator {
2180 type Error = QueryConversionError;
2181
2182 fn try_from(value: Select) -> Result<Self, Self::Error> {
2183 let keys = value
2184 .keys
2185 .into_iter()
2186 .map(|key| {
2187 serde_json::to_value(&key)
2189 .ok()
2190 .and_then(|v| v.as_str().map(String::from))
2191 .ok_or(QueryConversionError::field("keys"))
2192 })
2193 .collect::<Result<Vec<_>, _>>()?;
2194
2195 Ok(Self { keys })
2196 }
2197}
2198
2199#[derive(Clone, Debug, Deserialize, Serialize)]
2236#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2237pub struct SearchRecord {
2238 pub id: String,
2239 pub document: Option<String>,
2240 pub embedding: Option<Vec<f32>>,
2241 pub metadata: Option<Metadata>,
2242 pub score: Option<f32>,
2243}
2244
2245impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2246 type Error = QueryConversionError;
2247
2248 fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2249 Ok(Self {
2250 id: value.id,
2251 document: value.document,
2252 embedding: value
2253 .embedding
2254 .map(|vec| vec.try_into().map(|(v, _)| v))
2255 .transpose()?,
2256 metadata: value.metadata.map(TryInto::try_into).transpose()?,
2257 score: value.score,
2258 })
2259 }
2260}
2261
2262impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2263 type Error = QueryConversionError;
2264
2265 fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2266 Ok(Self {
2267 id: value.id,
2268 document: value.document,
2269 embedding: value
2270 .embedding
2271 .map(|embedding| {
2272 let embedding_dimension = embedding.len();
2273 chroma_proto::Vector::try_from((
2274 embedding,
2275 ScalarEncoding::FLOAT32,
2276 embedding_dimension,
2277 ))
2278 })
2279 .transpose()?,
2280 metadata: value.metadata.map(Into::into),
2281 score: value.score,
2282 })
2283 }
2284}
2285
2286#[derive(Clone, Debug, Default)]
2308pub struct SearchPayloadResult {
2309 pub records: Vec<SearchRecord>,
2310}
2311
2312impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2313 type Error = QueryConversionError;
2314
2315 fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2316 Ok(Self {
2317 records: value
2318 .records
2319 .into_iter()
2320 .map(TryInto::try_into)
2321 .collect::<Result<_, _>>()?,
2322 })
2323 }
2324}
2325
2326impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2327 type Error = QueryConversionError;
2328
2329 fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2330 Ok(Self {
2331 records: value
2332 .records
2333 .into_iter()
2334 .map(TryInto::try_into)
2335 .collect::<Result<Vec<_>, _>>()?,
2336 })
2337 }
2338}
2339
2340#[derive(Clone, Debug)]
2383pub struct SearchResult {
2384 pub results: Vec<SearchPayloadResult>,
2385 pub pulled_log_bytes: u64,
2386}
2387
2388impl SearchResult {
2389 pub fn size_bytes(&self) -> u64 {
2390 self.results
2391 .iter()
2392 .flat_map(|result| {
2393 result.records.iter().map(|record| {
2394 (record.id.len()
2395 + record
2396 .document
2397 .as_ref()
2398 .map(|doc| doc.len())
2399 .unwrap_or_default()
2400 + record
2401 .embedding
2402 .as_ref()
2403 .map(|emb| size_of_val(&emb[..]))
2404 .unwrap_or_default()
2405 + record
2406 .metadata
2407 .as_ref()
2408 .map(logical_size_of_metadata)
2409 .unwrap_or_default()
2410 + record.score.as_ref().map(size_of_val).unwrap_or_default())
2411 as u64
2412 })
2413 })
2414 .sum()
2415 }
2416}
2417
2418impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2419 type Error = QueryConversionError;
2420
2421 fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2422 Ok(Self {
2423 results: value
2424 .results
2425 .into_iter()
2426 .map(TryInto::try_into)
2427 .collect::<Result<_, _>>()?,
2428 pulled_log_bytes: value.pulled_log_bytes,
2429 })
2430 }
2431}
2432
2433impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2434 type Error = QueryConversionError;
2435
2436 fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2437 Ok(Self {
2438 results: value
2439 .results
2440 .into_iter()
2441 .map(TryInto::try_into)
2442 .collect::<Result<Vec<_>, _>>()?,
2443 pulled_log_bytes: value.pulled_log_bytes,
2444 })
2445 }
2446}
2447
2448pub fn rrf(
2593 ranks: Vec<RankExpr>,
2594 k: Option<u32>,
2595 weights: Option<Vec<f32>>,
2596 normalize: bool,
2597) -> Result<RankExpr, QueryConversionError> {
2598 let k = k.unwrap_or(60);
2599
2600 if ranks.is_empty() {
2601 return Err(QueryConversionError::validation(
2602 "RRF requires at least one rank expression",
2603 ));
2604 }
2605
2606 let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2607
2608 if weights.len() != ranks.len() {
2609 return Err(QueryConversionError::validation(format!(
2610 "RRF weights length ({}) must match ranks length ({})",
2611 weights.len(),
2612 ranks.len()
2613 )));
2614 }
2615
2616 let weights = if normalize {
2617 let sum: f32 = weights.iter().sum();
2618 if sum == 0.0 {
2619 return Err(QueryConversionError::validation(
2620 "RRF weights sum to zero, cannot normalize",
2621 ));
2622 }
2623 weights.into_iter().map(|w| w / sum).collect()
2624 } else {
2625 weights
2626 };
2627
2628 let terms: Vec<RankExpr> = weights
2629 .into_iter()
2630 .zip(ranks)
2631 .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2632 .collect();
2633
2634 let sum = terms
2637 .into_iter()
2638 .reduce(|a, b| a + b)
2639 .unwrap_or(RankExpr::Value(0.0));
2640 Ok(-sum)
2641}
2642
2643#[cfg(test)]
2644mod tests {
2645 use super::*;
2646
2647 #[test]
2648 fn test_key_from_string() {
2649 assert_eq!(Key::from("#document"), Key::Document);
2651 assert_eq!(Key::from("#embedding"), Key::Embedding);
2652 assert_eq!(Key::from("#metadata"), Key::Metadata);
2653 assert_eq!(Key::from("#score"), Key::Score);
2654
2655 assert_eq!(
2657 Key::from("custom_field"),
2658 Key::MetadataField("custom_field".to_string())
2659 );
2660 assert_eq!(
2661 Key::from("author"),
2662 Key::MetadataField("author".to_string())
2663 );
2664
2665 assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2667 assert_eq!(
2668 Key::from("year".to_string()),
2669 Key::MetadataField("year".to_string())
2670 );
2671 }
2672
2673 #[test]
2674 fn test_query_vector_dense_proto_conversion() {
2675 let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2676 let query_vector = QueryVector::Dense(dense_vec.clone());
2677
2678 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2680
2681 let converted: QueryVector = proto.try_into().unwrap();
2683
2684 assert_eq!(converted, query_vector);
2685 if let QueryVector::Dense(v) = converted {
2686 assert_eq!(v, dense_vec);
2687 } else {
2688 panic!("Expected dense vector");
2689 }
2690 }
2691
2692 #[test]
2693 fn test_query_vector_sparse_proto_conversion() {
2694 let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
2695 let query_vector = QueryVector::Sparse(sparse.clone());
2696
2697 let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2699
2700 let converted: QueryVector = proto.try_into().unwrap();
2702
2703 assert_eq!(converted, query_vector);
2704 if let QueryVector::Sparse(s) = converted {
2705 assert_eq!(s, sparse);
2706 } else {
2707 panic!("Expected sparse vector");
2708 }
2709 }
2710
2711 #[test]
2712 fn test_filter_json_deserialization() {
2713 let simple_where = r#"{"author": "John Doe"}"#;
2717 let filter: Filter = serde_json::from_str(simple_where).unwrap();
2718 assert_eq!(filter.query_ids, None);
2719 assert!(filter.where_clause.is_some());
2720
2721 let id_filter_json = serde_json::json!({
2723 "#id": {
2724 "$in": ["doc1", "doc2", "doc3"]
2725 }
2726 });
2727 let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2728 assert_eq!(filter.query_ids, None);
2729 assert!(filter.where_clause.is_some());
2730
2731 let complex_json = serde_json::json!({
2733 "$and": [
2734 {
2735 "#id": {
2736 "$in": ["doc1", "doc2", "doc3"]
2737 }
2738 },
2739 {
2740 "$or": [
2741 {
2742 "author": {
2743 "$eq": "John Doe"
2744 }
2745 },
2746 {
2747 "author": {
2748 "$eq": "Jane Smith"
2749 }
2750 }
2751 ]
2752 },
2753 {
2754 "year": {
2755 "$gte": 2020
2756 }
2757 },
2758 {
2759 "tags": {
2760 "$contains": "machine-learning"
2761 }
2762 }
2763 ]
2764 });
2765
2766 let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2767 assert_eq!(filter.query_ids, None);
2768 assert!(filter.where_clause.is_some());
2769
2770 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2772 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2773 assert_eq!(composite.children.len(), 4);
2774
2775 if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2777 assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2778 assert_eq!(or_composite.children.len(), 2);
2779 } else {
2780 panic!("Expected OR composite in second child");
2781 }
2782 } else {
2783 panic!("Expected AND composite where clause");
2784 }
2785
2786 let mixed_operators_json = serde_json::json!({
2788 "$and": [
2789 {
2790 "status": {
2791 "$ne": "deleted"
2792 }
2793 },
2794 {
2795 "score": {
2796 "$gt": 0.5
2797 }
2798 },
2799 {
2800 "score": {
2801 "$lt": 0.9
2802 }
2803 },
2804 {
2805 "priority": {
2806 "$lte": 10
2807 }
2808 }
2809 ]
2810 });
2811
2812 let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2813 assert_eq!(filter.query_ids, None);
2814 assert!(filter.where_clause.is_some());
2815
2816 let deeply_nested_json = serde_json::json!({
2818 "$or": [
2819 {
2820 "$and": [
2821 {
2822 "#id": {
2823 "$in": ["id1", "id2"]
2824 }
2825 },
2826 {
2827 "$or": [
2828 {
2829 "category": "tech"
2830 },
2831 {
2832 "category": "science"
2833 }
2834 ]
2835 }
2836 ]
2837 },
2838 {
2839 "$and": [
2840 {
2841 "author": "Admin"
2842 },
2843 {
2844 "published": true
2845 }
2846 ]
2847 }
2848 ]
2849 });
2850
2851 let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
2852 assert_eq!(filter.query_ids, None);
2853 assert!(filter.where_clause.is_some());
2854
2855 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2857 assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
2858 assert_eq!(composite.children.len(), 2);
2859
2860 for child in &composite.children {
2862 if let crate::metadata::Where::Composite(and_composite) = child {
2863 assert_eq!(
2864 and_composite.operator,
2865 crate::metadata::BooleanOperator::And
2866 );
2867 } else {
2868 panic!("Expected AND composite in OR children");
2869 }
2870 }
2871 } else {
2872 panic!("Expected OR composite at top level");
2873 }
2874
2875 let single_id_json = serde_json::json!({
2877 "#id": {
2878 "$eq": "single-doc-id"
2879 }
2880 });
2881
2882 let filter: Filter = serde_json::from_value(single_id_json).unwrap();
2883 assert_eq!(filter.query_ids, None);
2884 assert!(filter.where_clause.is_some());
2885
2886 let empty_json = serde_json::json!({});
2888 let filter: Filter = serde_json::from_value(empty_json).unwrap();
2889 assert_eq!(filter.query_ids, None);
2890 assert_eq!(filter.where_clause, None);
2892
2893 let advanced_json = serde_json::json!({
2895 "$and": [
2896 {
2897 "#id": {
2898 "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
2899 }
2900 },
2901 {
2902 "tags": {
2903 "$not_contains": "deprecated"
2904 }
2905 },
2906 {
2907 "$or": [
2908 {
2909 "$and": [
2910 {
2911 "confidence": {
2912 "$gte": 0.8
2913 }
2914 },
2915 {
2916 "verified": true
2917 }
2918 ]
2919 },
2920 {
2921 "$and": [
2922 {
2923 "confidence": {
2924 "$gte": 0.6
2925 }
2926 },
2927 {
2928 "confidence": {
2929 "$lt": 0.8
2930 }
2931 },
2932 {
2933 "reviews": {
2934 "$gte": 5
2935 }
2936 }
2937 ]
2938 }
2939 ]
2940 }
2941 ]
2942 });
2943
2944 let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2945 assert_eq!(filter.query_ids, None);
2946 assert!(filter.where_clause.is_some());
2947
2948 if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2950 assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2951 assert_eq!(composite.children.len(), 3);
2952 } else {
2953 panic!("Expected AND composite at top level");
2954 }
2955 }
2956
2957 #[test]
2958 fn test_limit_json_serialization() {
2959 let limit = Limit {
2960 offset: 10,
2961 limit: Some(20),
2962 };
2963
2964 let json = serde_json::to_string(&limit).unwrap();
2965 let deserialized: Limit = serde_json::from_str(&json).unwrap();
2966
2967 assert_eq!(deserialized.offset, limit.offset);
2968 assert_eq!(deserialized.limit, limit.limit);
2969 }
2970
2971 #[test]
2972 fn test_query_vector_json_serialization() {
2973 let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2975 let json = serde_json::to_string(&dense).unwrap();
2976 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2977 assert_eq!(deserialized, dense);
2978
2979 let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]));
2981 let json = serde_json::to_string(&sparse).unwrap();
2982 let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2983 assert_eq!(deserialized, sparse);
2984 }
2985
2986 #[test]
2987 fn test_select_key_json_serialization() {
2988 use std::collections::HashSet;
2989
2990 let doc_key = Key::Document;
2992 assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2993
2994 let embed_key = Key::Embedding;
2995 assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
2996
2997 let meta_key = Key::Metadata;
2998 assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
2999
3000 let score_key = Key::Score;
3001 assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3002
3003 let custom_key = Key::MetadataField("custom_key".to_string());
3005 assert_eq!(
3006 serde_json::to_string(&custom_key).unwrap(),
3007 "\"custom_key\""
3008 );
3009
3010 let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3012 assert!(matches!(deserialized, Key::Document));
3013
3014 let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3015 assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3016
3017 let mut keys = HashSet::new();
3019 keys.insert(Key::Document);
3020 keys.insert(Key::Embedding);
3021 keys.insert(Key::MetadataField("author".to_string()));
3022
3023 let select = Select { keys };
3024 let json = serde_json::to_string(&select).unwrap();
3025 let deserialized: Select = serde_json::from_str(&json).unwrap();
3026
3027 assert_eq!(deserialized.keys.len(), 3);
3028 assert!(deserialized.keys.contains(&Key::Document));
3029 assert!(deserialized.keys.contains(&Key::Embedding));
3030 assert!(deserialized
3031 .keys
3032 .contains(&Key::MetadataField("author".to_string())));
3033 }
3034
3035 #[test]
3036 fn test_merge_basic_integers() {
3037 use std::cmp::Reverse;
3038
3039 let merge = Merge { k: 5 };
3040
3041 let input = vec![
3043 vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3044 vec![Reverse(2), Reverse(5), Reverse(8)],
3045 vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3046 ];
3047
3048 let result = merge.merge(input);
3049
3050 assert_eq!(result.len(), 5);
3052 assert_eq!(
3053 result,
3054 vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3055 );
3056 }
3057
3058 #[test]
3059 fn test_merge_u32_descending() {
3060 let merge = Merge { k: 6 };
3061
3062 let input = vec![
3064 vec![100u32, 75, 50, 25],
3065 vec![90, 60, 30],
3066 vec![95, 85, 70, 40, 10],
3067 ];
3068
3069 let result = merge.merge(input);
3070
3071 assert_eq!(result.len(), 6);
3073 assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3074 }
3075
3076 #[test]
3077 fn test_merge_i32_descending() {
3078 let merge = Merge { k: 5 };
3079
3080 let input = vec![
3082 vec![50i32, 10, -10, -50],
3083 vec![30, 0, -30],
3084 vec![40, 20, -20, -40],
3085 ];
3086
3087 let result = merge.merge(input);
3088
3089 assert_eq!(result.len(), 5);
3091 assert_eq!(result, vec![50, 40, 30, 20, 10]);
3092 }
3093
3094 #[test]
3095 fn test_merge_with_duplicates() {
3096 let merge = Merge { k: 10 };
3097
3098 let input = vec![
3100 vec![100u32, 80, 80, 60, 40],
3101 vec![90, 80, 50, 30],
3102 vec![100, 70, 60, 20],
3103 ];
3104
3105 let result = merge.merge(input);
3106
3107 assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3109 }
3110
3111 #[test]
3112 fn test_merge_empty_vectors() {
3113 let merge = Merge { k: 5 };
3114
3115 let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3117 let result = merge.merge(input);
3118 assert_eq!(result.len(), 0);
3119
3120 let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3122 let result = merge.merge(input);
3123 assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3124
3125 let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3127 let result = merge.merge(input);
3128 assert_eq!(result, vec![100, 50, 25]);
3129 }
3130
3131 #[test]
3132 fn test_merge_k_boundary_conditions() {
3133 let merge = Merge { k: 0 };
3135 let input = vec![vec![100u32, 50], vec![75, 25]];
3136 let result = merge.merge(input);
3137 assert_eq!(result.len(), 0);
3138
3139 let merge = Merge { k: 1 };
3141 let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3142 let result = merge.merge(input);
3143 assert_eq!(result, vec![1000]);
3144
3145 let merge = Merge { k: 100 };
3147 let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3148 let result = merge.merge(input);
3149 assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3150 }
3151
3152 #[test]
3153 fn test_merge_with_strings() {
3154 let merge = Merge { k: 4 };
3155
3156 let input = vec![
3158 vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3159 vec!["elephant".to_string(), "banana".to_string()],
3160 vec!["fish".to_string(), "cat".to_string()],
3161 ];
3162
3163 let result = merge.merge(input);
3164
3165 assert_eq!(
3167 result,
3168 vec![
3169 "zebra".to_string(),
3170 "fish".to_string(),
3171 "elephant".to_string(),
3172 "dog".to_string()
3173 ]
3174 );
3175 }
3176
3177 #[test]
3178 fn test_merge_with_custom_struct() {
3179 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3180 struct Score {
3181 value: i32,
3182 id: String,
3183 }
3184
3185 let merge = Merge { k: 3 };
3186
3187 let input = vec![
3189 vec![
3190 Score {
3191 value: 100,
3192 id: "a".to_string(),
3193 },
3194 Score {
3195 value: 80,
3196 id: "b".to_string(),
3197 },
3198 Score {
3199 value: 60,
3200 id: "c".to_string(),
3201 },
3202 ],
3203 vec![
3204 Score {
3205 value: 90,
3206 id: "d".to_string(),
3207 },
3208 Score {
3209 value: 70,
3210 id: "e".to_string(),
3211 },
3212 ],
3213 vec![
3214 Score {
3215 value: 95,
3216 id: "f".to_string(),
3217 },
3218 Score {
3219 value: 85,
3220 id: "g".to_string(),
3221 },
3222 ],
3223 ];
3224
3225 let result = merge.merge(input);
3226
3227 assert_eq!(result.len(), 3);
3228 assert_eq!(
3229 result[0],
3230 Score {
3231 value: 100,
3232 id: "a".to_string()
3233 }
3234 );
3235 assert_eq!(
3236 result[1],
3237 Score {
3238 value: 95,
3239 id: "f".to_string()
3240 }
3241 );
3242 assert_eq!(
3243 result[2],
3244 Score {
3245 value: 90,
3246 id: "d".to_string()
3247 }
3248 );
3249 }
3250
3251 #[test]
3252 fn test_merge_preserves_order() {
3253 use std::cmp::Reverse;
3254
3255 let merge = Merge { k: 10 };
3256
3257 let input = vec![
3260 vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3261 vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3262 vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3263 ];
3264
3265 let result = merge.merge(input);
3266
3267 for i in 1..result.len() {
3270 assert!(
3271 result[i - 1] >= result[i],
3272 "Output should be in descending Reverse order"
3273 );
3274 assert!(
3275 result[i - 1].0 <= result[i].0,
3276 "Inner values should be in ascending order"
3277 );
3278 }
3279
3280 assert_eq!(
3282 result,
3283 vec![
3284 Reverse(1),
3285 Reverse(2),
3286 Reverse(3),
3287 Reverse(4),
3288 Reverse(5),
3289 Reverse(6),
3290 Reverse(7),
3291 Reverse(8),
3292 Reverse(9),
3293 Reverse(10)
3294 ]
3295 );
3296 }
3297
3298 #[test]
3299 fn test_merge_single_vector() {
3300 let merge = Merge { k: 3 };
3301
3302 let input = vec![vec![1000u64, 800, 600, 400, 200]];
3304
3305 let result = merge.merge(input);
3306
3307 assert_eq!(result, vec![1000, 800, 600]);
3308 }
3309
3310 #[test]
3311 fn test_merge_all_same_values() {
3312 let merge = Merge { k: 5 };
3313
3314 let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3316
3317 let result = merge.merge(input);
3318
3319 assert_eq!(result, vec![42]);
3321 }
3322
3323 #[test]
3324 fn test_merge_mixed_types_sizes() {
3325 let merge = Merge { k: 4 };
3327 let input = vec![
3328 vec![1000usize, 500, 100],
3329 vec![800, 300],
3330 vec![900, 600, 200],
3331 ];
3332 let result = merge.merge(input);
3333 assert_eq!(result, vec![1000, 900, 800, 600]);
3334
3335 let merge = Merge { k: 5 };
3337 let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3338 let result = merge.merge(input);
3339 assert_eq!(result, vec![15, 10, 5, 0, -5]);
3340 }
3341}