chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::Hash,
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15    MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16    Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// The `Filter` operator filters the collection with specified criteria
122///
123/// # Parameters
124/// - `query_ids`: The user provided ids, which specifies the domain of the filter if provided
125/// - `where_clause`: The predicate on individual record
126#[derive(Clone, Debug, Default)]
127pub struct Filter {
128    pub query_ids: Option<Vec<String>>,
129    pub where_clause: Option<Where>,
130}
131
132impl Serialize for Filter {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: Serializer,
136    {
137        // For the search API, serialize directly as the where clause (or empty object if None)
138        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
139
140        match (&self.query_ids, &self.where_clause) {
141            (None, None) => {
142                // No filter at all - serialize empty object
143                let map = serializer.serialize_map(Some(0))?;
144                map.end()
145            }
146            (None, Some(where_clause)) => {
147                // Only where clause - serialize it directly
148                where_clause.serialize(serializer)
149            }
150            (Some(ids), None) => {
151                // Only query_ids - create Where clause: Key::ID.is_in(ids)
152                let id_where = Where::Metadata(MetadataExpression {
153                    key: "#id".to_string(),
154                    comparison: MetadataComparison::Set(
155                        SetOperator::In,
156                        MetadataSetValue::Str(ids.clone()),
157                    ),
158                });
159                id_where.serialize(serializer)
160            }
161            (Some(ids), Some(where_clause)) => {
162                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
163                let id_where = Where::Metadata(MetadataExpression {
164                    key: "#id".to_string(),
165                    comparison: MetadataComparison::Set(
166                        SetOperator::In,
167                        MetadataSetValue::Str(ids.clone()),
168                    ),
169                });
170                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
171                combined.serialize(serializer)
172            }
173        }
174    }
175}
176
177impl<'de> Deserialize<'de> for Filter {
178    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
179    where
180        D: Deserializer<'de>,
181    {
182        // For the new search API, the entire JSON is the where clause
183        let where_json = Value::deserialize(deserializer)?;
184        let where_clause =
185            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
186                None
187            } else {
188                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
189            };
190
191        Ok(Filter {
192            query_ids: None, // Always None for new search API
193            where_clause,
194        })
195    }
196}
197
198impl TryFrom<chroma_proto::FilterOperator> for Filter {
199    type Error = QueryConversionError;
200
201    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
202        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
203        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
204        let where_clause = match (where_metadata, where_document) {
205            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
206            (Some(w), None) | (None, Some(w)) => Some(w),
207            _ => None,
208        };
209
210        Ok(Self {
211            query_ids: value.ids.map(|uids| uids.ids),
212            where_clause,
213        })
214    }
215}
216
217impl TryFrom<Filter> for chroma_proto::FilterOperator {
218    type Error = QueryConversionError;
219
220    fn try_from(value: Filter) -> Result<Self, Self::Error> {
221        Ok(Self {
222            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
223            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
224            where_document: None,
225        })
226    }
227}
228
229/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
230///
231/// # Parameters
232/// - `embedding`: The target embedding to search around
233/// - `fetch`: The number of records to fetch around the target
234#[derive(Clone, Debug)]
235pub struct Knn {
236    pub embedding: Vec<f32>,
237    pub fetch: u32,
238}
239
240impl From<KnnBatch> for Vec<Knn> {
241    fn from(value: KnnBatch) -> Self {
242        value
243            .embeddings
244            .into_iter()
245            .map(|embedding| Knn {
246                embedding,
247                fetch: value.fetch,
248            })
249            .collect()
250    }
251}
252
253/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
254///
255/// # Parameters
256/// - `embedding`: The target embedding to search around
257/// - `fetch`: The number of records to fetch around the target
258#[derive(Clone, Debug)]
259pub struct KnnBatch {
260    pub embeddings: Vec<Vec<f32>>,
261    pub fetch: u32,
262}
263
264impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
265    type Error = QueryConversionError;
266
267    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
268        Ok(Self {
269            embeddings: value
270                .embeddings
271                .into_iter()
272                .map(|vec| vec.try_into().map(|(v, _)| v))
273                .collect::<Result<_, _>>()?,
274            fetch: value.fetch,
275        })
276    }
277}
278
279impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
280    type Error = QueryConversionError;
281
282    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
283        Ok(Self {
284            embeddings: value
285                .embeddings
286                .into_iter()
287                .map(|embedding| {
288                    let dim = embedding.len();
289                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
290                })
291                .collect::<Result<_, _>>()?,
292            fetch: value.fetch,
293        })
294    }
295}
296
297/// The `Limit` operator selects a range or records sorted by their offset ids
298///
299/// # Parameters
300/// - `skip`: The number of records to skip in the beginning
301/// - `fetch`: The number of records to fetch after `skip`
302#[derive(Clone, Debug, Default, Deserialize, Serialize)]
303pub struct Limit {
304    #[serde(default)]
305    pub offset: u32,
306    #[serde(default)]
307    pub limit: Option<u32>,
308}
309
310impl From<chroma_proto::LimitOperator> for Limit {
311    fn from(value: chroma_proto::LimitOperator) -> Self {
312        Self {
313            offset: value.offset,
314            limit: value.limit,
315        }
316    }
317}
318
319impl From<Limit> for chroma_proto::LimitOperator {
320    fn from(value: Limit) -> Self {
321        Self {
322            offset: value.offset,
323            limit: value.limit,
324        }
325    }
326}
327
328/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
329#[derive(Clone, Debug)]
330pub struct RecordMeasure {
331    pub offset_id: u32,
332    pub measure: f32,
333}
334
335impl PartialEq for RecordMeasure {
336    fn eq(&self, other: &Self) -> bool {
337        self.offset_id.eq(&other.offset_id)
338    }
339}
340
341impl Eq for RecordMeasure {}
342
343impl Ord for RecordMeasure {
344    fn cmp(&self, other: &Self) -> Ordering {
345        self.measure.total_cmp(&other.measure)
346    }
347}
348
349impl PartialOrd for RecordMeasure {
350    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
351        Some(self.cmp(other))
352    }
353}
354
355#[derive(Debug, Default)]
356pub struct KnnOutput {
357    pub distances: Vec<RecordMeasure>,
358}
359
360/// The `Merge` operator selects the top records from the batch vectors of records
361/// which are all sorted in descending order. If the same record occurs multiple times
362/// only one copy will remain in the final result.
363///
364/// # Parameters
365/// - `k`: The total number of records to take after merge
366///
367/// # Usage
368/// It can be used to merge the query results from different operators
369#[derive(Clone, Debug)]
370pub struct Merge {
371    pub k: u32,
372}
373
374impl Merge {
375    pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
376        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
377
378        let mut max_heap = batch_iters
379            .iter_mut()
380            .enumerate()
381            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
382            .collect::<BinaryHeap<_>>();
383
384        let mut fusion = Vec::with_capacity(self.k as usize);
385        while let Some((m, idx)) = max_heap.pop() {
386            if self.k <= fusion.len() as u32 {
387                break;
388            }
389            if let Some(next_m) = batch_iters[idx].next() {
390                max_heap.push((next_m, idx));
391            }
392            if fusion.last().is_some_and(|tail| tail == &m) {
393                continue;
394            }
395            fusion.push(m);
396        }
397        fusion
398    }
399}
400
401/// The `Projection` operator retrieves record content by offset ids
402///
403/// # Parameters
404/// - `document`: Whether to retrieve document
405/// - `embedding`: Whether to retrieve embedding
406/// - `metadata`: Whether to retrieve metadata
407#[derive(Clone, Debug, Default)]
408pub struct Projection {
409    pub document: bool,
410    pub embedding: bool,
411    pub metadata: bool,
412}
413
414impl From<chroma_proto::ProjectionOperator> for Projection {
415    fn from(value: chroma_proto::ProjectionOperator) -> Self {
416        Self {
417            document: value.document,
418            embedding: value.embedding,
419            metadata: value.metadata,
420        }
421    }
422}
423
424impl From<Projection> for chroma_proto::ProjectionOperator {
425    fn from(value: Projection) -> Self {
426        Self {
427            document: value.document,
428            embedding: value.embedding,
429            metadata: value.metadata,
430        }
431    }
432}
433
434#[derive(Clone, Debug, PartialEq)]
435pub struct ProjectionRecord {
436    pub id: String,
437    pub document: Option<String>,
438    pub embedding: Option<Vec<f32>>,
439    pub metadata: Option<Metadata>,
440}
441
442impl ProjectionRecord {
443    pub fn size_bytes(&self) -> u64 {
444        (self.id.len()
445            + self
446                .document
447                .as_ref()
448                .map(|doc| doc.len())
449                .unwrap_or_default()
450            + self
451                .embedding
452                .as_ref()
453                .map(|emb| size_of_val(&emb[..]))
454                .unwrap_or_default()
455            + self
456                .metadata
457                .as_ref()
458                .map(logical_size_of_metadata)
459                .unwrap_or_default()) as u64
460    }
461}
462
463impl Eq for ProjectionRecord {}
464
465impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
466    type Error = QueryConversionError;
467
468    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
469        Ok(Self {
470            id: value.id,
471            document: value.document,
472            embedding: value
473                .embedding
474                .map(|vec| vec.try_into().map(|(v, _)| v))
475                .transpose()?,
476            metadata: value.metadata.map(TryInto::try_into).transpose()?,
477        })
478    }
479}
480
481impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
482    type Error = QueryConversionError;
483
484    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
485        Ok(Self {
486            id: value.id,
487            document: value.document,
488            embedding: value
489                .embedding
490                .map(|embedding| {
491                    let embedding_dimension = embedding.len();
492                    chroma_proto::Vector::try_from((
493                        embedding,
494                        ScalarEncoding::FLOAT32,
495                        embedding_dimension,
496                    ))
497                })
498                .transpose()?,
499            metadata: value.metadata.map(|metadata| metadata.into()),
500        })
501    }
502}
503
504#[derive(Clone, Debug, Eq, PartialEq)]
505pub struct ProjectionOutput {
506    pub records: Vec<ProjectionRecord>,
507}
508
509#[derive(Clone, Debug, Eq, PartialEq)]
510pub struct GetResult {
511    pub pulled_log_bytes: u64,
512    pub result: ProjectionOutput,
513}
514
515impl GetResult {
516    pub fn size_bytes(&self) -> u64 {
517        self.result
518            .records
519            .iter()
520            .map(ProjectionRecord::size_bytes)
521            .sum()
522    }
523}
524
525impl TryFrom<chroma_proto::GetResult> for GetResult {
526    type Error = QueryConversionError;
527
528    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
529        Ok(Self {
530            pulled_log_bytes: value.pulled_log_bytes,
531            result: ProjectionOutput {
532                records: value
533                    .records
534                    .into_iter()
535                    .map(TryInto::try_into)
536                    .collect::<Result<_, _>>()?,
537            },
538        })
539    }
540}
541
542impl TryFrom<GetResult> for chroma_proto::GetResult {
543    type Error = QueryConversionError;
544
545    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
546        Ok(Self {
547            pulled_log_bytes: value.pulled_log_bytes,
548            records: value
549                .result
550                .records
551                .into_iter()
552                .map(TryInto::try_into)
553                .collect::<Result<_, _>>()?,
554        })
555    }
556}
557
558/// The `KnnProjection` operator retrieves record content by offset ids
559/// It is based on `ProjectionOperator`, and it attaches the distance
560/// of the records to the target embedding to the record content
561///
562/// # Parameters
563/// - `projection`: The parameters of the `ProjectionOperator`
564/// - `distance`: Whether to attach distance information
565#[derive(Clone, Debug)]
566pub struct KnnProjection {
567    pub projection: Projection,
568    pub distance: bool,
569}
570
571impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
572    type Error = QueryConversionError;
573
574    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
575        Ok(Self {
576            projection: value
577                .projection
578                .ok_or(QueryConversionError::field("projection"))?
579                .into(),
580            distance: value.distance,
581        })
582    }
583}
584
585impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
586    fn from(value: KnnProjection) -> Self {
587        Self {
588            projection: Some(value.projection.into()),
589            distance: value.distance,
590        }
591    }
592}
593
594#[derive(Clone, Debug)]
595pub struct KnnProjectionRecord {
596    pub record: ProjectionRecord,
597    pub distance: Option<f32>,
598}
599
600impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
601    type Error = QueryConversionError;
602
603    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
604        Ok(Self {
605            record: value
606                .record
607                .ok_or(QueryConversionError::field("record"))?
608                .try_into()?,
609            distance: value.distance,
610        })
611    }
612}
613
614impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
615    type Error = QueryConversionError;
616
617    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
618        Ok(Self {
619            record: Some(value.record.try_into()?),
620            distance: value.distance,
621        })
622    }
623}
624
625#[derive(Clone, Debug, Default)]
626pub struct KnnProjectionOutput {
627    pub records: Vec<KnnProjectionRecord>,
628}
629
630impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
631    type Error = QueryConversionError;
632
633    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
634        Ok(Self {
635            records: value
636                .records
637                .into_iter()
638                .map(TryInto::try_into)
639                .collect::<Result<_, _>>()?,
640        })
641    }
642}
643
644impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
645    type Error = QueryConversionError;
646
647    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
648        Ok(Self {
649            records: value
650                .records
651                .into_iter()
652                .map(TryInto::try_into)
653                .collect::<Result<_, _>>()?,
654        })
655    }
656}
657
658#[derive(Clone, Debug, Default)]
659pub struct KnnBatchResult {
660    pub pulled_log_bytes: u64,
661    pub results: Vec<KnnProjectionOutput>,
662}
663
664impl KnnBatchResult {
665    pub fn size_bytes(&self) -> u64 {
666        self.results
667            .iter()
668            .flat_map(|res| {
669                res.records
670                    .iter()
671                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
672            })
673            .sum()
674    }
675}
676
677impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
678    type Error = QueryConversionError;
679
680    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
681        Ok(Self {
682            pulled_log_bytes: value.pulled_log_bytes,
683            results: value
684                .results
685                .into_iter()
686                .map(TryInto::try_into)
687                .collect::<Result<_, _>>()?,
688        })
689    }
690}
691
692impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
693    type Error = QueryConversionError;
694
695    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
696        Ok(Self {
697            pulled_log_bytes: value.pulled_log_bytes,
698            results: value
699                .results
700                .into_iter()
701                .map(TryInto::try_into)
702                .collect::<Result<_, _>>()?,
703        })
704    }
705}
706
707#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
708#[serde(untagged)]
709pub enum QueryVector {
710    Dense(Vec<f32>),
711    Sparse(SparseVector),
712}
713
714impl TryFrom<chroma_proto::QueryVector> for QueryVector {
715    type Error = QueryConversionError;
716
717    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
718        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
719        match vector {
720            chroma_proto::query_vector::Vector::Dense(dense) => {
721                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
722            }
723            chroma_proto::query_vector::Vector::Sparse(sparse) => {
724                Ok(QueryVector::Sparse(sparse.into()))
725            }
726        }
727    }
728}
729
730impl TryFrom<QueryVector> for chroma_proto::QueryVector {
731    type Error = QueryConversionError;
732
733    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
734        match value {
735            QueryVector::Dense(vec) => {
736                let dim = vec.len();
737                Ok(chroma_proto::QueryVector {
738                    vector: Some(chroma_proto::query_vector::Vector::Dense(
739                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
740                    )),
741                })
742            }
743            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
744                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
745            }),
746        }
747    }
748}
749
750impl From<Vec<f32>> for QueryVector {
751    fn from(vec: Vec<f32>) -> Self {
752        QueryVector::Dense(vec)
753    }
754}
755
756impl From<SparseVector> for QueryVector {
757    fn from(sparse: SparseVector) -> Self {
758        QueryVector::Sparse(sparse)
759    }
760}
761
762#[derive(Clone, Debug, PartialEq)]
763pub struct KnnQuery {
764    pub query: QueryVector,
765    pub key: Key,
766    pub limit: u32,
767}
768
769#[derive(Clone, Debug, Default, Deserialize, Serialize)]
770#[serde(transparent)]
771pub struct Rank {
772    pub expr: Option<RankExpr>,
773}
774
775impl Rank {
776    pub fn knn_queries(&self) -> Vec<KnnQuery> {
777        self.expr
778            .as_ref()
779            .map(RankExpr::knn_queries)
780            .unwrap_or_default()
781    }
782}
783
784impl TryFrom<chroma_proto::RankOperator> for Rank {
785    type Error = QueryConversionError;
786
787    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
788        Ok(Rank {
789            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
790        })
791    }
792}
793
794impl TryFrom<Rank> for chroma_proto::RankOperator {
795    type Error = QueryConversionError;
796
797    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
798        Ok(chroma_proto::RankOperator {
799            expr: rank.expr.map(TryInto::try_into).transpose()?,
800        })
801    }
802}
803
804#[derive(Clone, Debug, Deserialize, Serialize)]
805pub enum RankExpr {
806    #[serde(rename = "$abs")]
807    Absolute(Box<RankExpr>),
808    #[serde(rename = "$div")]
809    Division {
810        left: Box<RankExpr>,
811        right: Box<RankExpr>,
812    },
813    #[serde(rename = "$exp")]
814    Exponentiation(Box<RankExpr>),
815    #[serde(rename = "$knn")]
816    Knn {
817        query: QueryVector,
818        #[serde(default = "RankExpr::default_knn_key")]
819        key: Key,
820        #[serde(default = "RankExpr::default_knn_limit")]
821        limit: u32,
822        #[serde(default)]
823        default: Option<f32>,
824        #[serde(default)]
825        return_rank: bool,
826    },
827    #[serde(rename = "$log")]
828    Logarithm(Box<RankExpr>),
829    #[serde(rename = "$max")]
830    Maximum(Vec<RankExpr>),
831    #[serde(rename = "$min")]
832    Minimum(Vec<RankExpr>),
833    #[serde(rename = "$mul")]
834    Multiplication(Vec<RankExpr>),
835    #[serde(rename = "$sub")]
836    Subtraction {
837        left: Box<RankExpr>,
838        right: Box<RankExpr>,
839    },
840    #[serde(rename = "$sum")]
841    Summation(Vec<RankExpr>),
842    #[serde(rename = "$val")]
843    Value(f32),
844}
845
846impl RankExpr {
847    pub fn default_knn_key() -> Key {
848        Key::Embedding
849    }
850
851    pub fn default_knn_limit() -> u32 {
852        16
853    }
854
855    pub fn knn_queries(&self) -> Vec<KnnQuery> {
856        match self {
857            RankExpr::Absolute(expr)
858            | RankExpr::Exponentiation(expr)
859            | RankExpr::Logarithm(expr) => expr.knn_queries(),
860            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
861                .knn_queries()
862                .into_iter()
863                .chain(right.knn_queries())
864                .collect(),
865            RankExpr::Maximum(exprs)
866            | RankExpr::Minimum(exprs)
867            | RankExpr::Multiplication(exprs)
868            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
869            RankExpr::Value(_) => Vec::new(),
870            RankExpr::Knn {
871                query,
872                key,
873                limit,
874                default: _,
875                return_rank: _,
876            } => vec![KnnQuery {
877                query: query.clone(),
878                key: key.clone(),
879                limit: *limit,
880            }],
881        }
882    }
883
884    /// Exponential: rank.exp()
885    pub fn exp(self) -> Self {
886        RankExpr::Exponentiation(Box::new(self))
887    }
888
889    /// Natural logarithm: rank.log()
890    pub fn log(self) -> Self {
891        RankExpr::Logarithm(Box::new(self))
892    }
893
894    /// Absolute value: rank.abs()
895    pub fn abs(self) -> Self {
896        RankExpr::Absolute(Box::new(self))
897    }
898
899    /// Maximum: rank.max(other)
900    pub fn max(self, other: impl Into<RankExpr>) -> Self {
901        let other = other.into();
902
903        match self {
904            RankExpr::Maximum(mut exprs) => match other {
905                RankExpr::Maximum(other_exprs) => {
906                    exprs.extend(other_exprs);
907                    RankExpr::Maximum(exprs)
908                }
909                _ => {
910                    exprs.push(other);
911                    RankExpr::Maximum(exprs)
912                }
913            },
914            _ => match other {
915                RankExpr::Maximum(mut exprs) => {
916                    exprs.insert(0, self);
917                    RankExpr::Maximum(exprs)
918                }
919                _ => RankExpr::Maximum(vec![self, other]),
920            },
921        }
922    }
923
924    /// Minimum: rank.min(other)
925    pub fn min(self, other: impl Into<RankExpr>) -> Self {
926        let other = other.into();
927
928        match self {
929            RankExpr::Minimum(mut exprs) => match other {
930                RankExpr::Minimum(other_exprs) => {
931                    exprs.extend(other_exprs);
932                    RankExpr::Minimum(exprs)
933                }
934                _ => {
935                    exprs.push(other);
936                    RankExpr::Minimum(exprs)
937                }
938            },
939            _ => match other {
940                RankExpr::Minimum(mut exprs) => {
941                    exprs.insert(0, self);
942                    RankExpr::Minimum(exprs)
943                }
944                _ => RankExpr::Minimum(vec![self, other]),
945            },
946        }
947    }
948}
949
950impl Add for RankExpr {
951    type Output = RankExpr;
952
953    fn add(self, rhs: Self) -> Self::Output {
954        match self {
955            RankExpr::Summation(mut exprs) => match rhs {
956                RankExpr::Summation(rhs_exprs) => {
957                    exprs.extend(rhs_exprs);
958                    RankExpr::Summation(exprs)
959                }
960                _ => {
961                    exprs.push(rhs);
962                    RankExpr::Summation(exprs)
963                }
964            },
965            _ => match rhs {
966                RankExpr::Summation(mut exprs) => {
967                    exprs.insert(0, self);
968                    RankExpr::Summation(exprs)
969                }
970                _ => RankExpr::Summation(vec![self, rhs]),
971            },
972        }
973    }
974}
975
976impl Add<f32> for RankExpr {
977    type Output = RankExpr;
978
979    fn add(self, rhs: f32) -> Self::Output {
980        self + RankExpr::Value(rhs)
981    }
982}
983
984impl Add<RankExpr> for f32 {
985    type Output = RankExpr;
986
987    fn add(self, rhs: RankExpr) -> Self::Output {
988        RankExpr::Value(self) + rhs
989    }
990}
991
992impl Sub for RankExpr {
993    type Output = RankExpr;
994
995    fn sub(self, rhs: Self) -> Self::Output {
996        RankExpr::Subtraction {
997            left: Box::new(self),
998            right: Box::new(rhs),
999        }
1000    }
1001}
1002
1003impl Sub<f32> for RankExpr {
1004    type Output = RankExpr;
1005
1006    fn sub(self, rhs: f32) -> Self::Output {
1007        self - RankExpr::Value(rhs)
1008    }
1009}
1010
1011impl Sub<RankExpr> for f32 {
1012    type Output = RankExpr;
1013
1014    fn sub(self, rhs: RankExpr) -> Self::Output {
1015        RankExpr::Value(self) - rhs
1016    }
1017}
1018
1019impl Mul for RankExpr {
1020    type Output = RankExpr;
1021
1022    fn mul(self, rhs: Self) -> Self::Output {
1023        match self {
1024            RankExpr::Multiplication(mut exprs) => match rhs {
1025                RankExpr::Multiplication(rhs_exprs) => {
1026                    exprs.extend(rhs_exprs);
1027                    RankExpr::Multiplication(exprs)
1028                }
1029                _ => {
1030                    exprs.push(rhs);
1031                    RankExpr::Multiplication(exprs)
1032                }
1033            },
1034            _ => match rhs {
1035                RankExpr::Multiplication(mut exprs) => {
1036                    exprs.insert(0, self);
1037                    RankExpr::Multiplication(exprs)
1038                }
1039                _ => RankExpr::Multiplication(vec![self, rhs]),
1040            },
1041        }
1042    }
1043}
1044
1045impl Mul<f32> for RankExpr {
1046    type Output = RankExpr;
1047
1048    fn mul(self, rhs: f32) -> Self::Output {
1049        self * RankExpr::Value(rhs)
1050    }
1051}
1052
1053impl Mul<RankExpr> for f32 {
1054    type Output = RankExpr;
1055
1056    fn mul(self, rhs: RankExpr) -> Self::Output {
1057        RankExpr::Value(self) * rhs
1058    }
1059}
1060
1061impl Div for RankExpr {
1062    type Output = RankExpr;
1063
1064    fn div(self, rhs: Self) -> Self::Output {
1065        RankExpr::Division {
1066            left: Box::new(self),
1067            right: Box::new(rhs),
1068        }
1069    }
1070}
1071
1072impl Div<f32> for RankExpr {
1073    type Output = RankExpr;
1074
1075    fn div(self, rhs: f32) -> Self::Output {
1076        self / RankExpr::Value(rhs)
1077    }
1078}
1079
1080impl Div<RankExpr> for f32 {
1081    type Output = RankExpr;
1082
1083    fn div(self, rhs: RankExpr) -> Self::Output {
1084        RankExpr::Value(self) / rhs
1085    }
1086}
1087
1088impl Neg for RankExpr {
1089    type Output = RankExpr;
1090
1091    fn neg(self) -> Self::Output {
1092        RankExpr::Value(-1.0) * self
1093    }
1094}
1095
1096impl From<f32> for RankExpr {
1097    fn from(v: f32) -> Self {
1098        RankExpr::Value(v)
1099    }
1100}
1101
1102impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1103    type Error = QueryConversionError;
1104
1105    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1106        match proto_expr.rank {
1107            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1108                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1109            }
1110            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1111                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1112                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1113                Ok(RankExpr::Division {
1114                    left: Box::new(RankExpr::try_from(*left)?),
1115                    right: Box::new(RankExpr::try_from(*right)?),
1116                })
1117            }
1118            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1119                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1120            ),
1121            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1122                let query = knn
1123                    .query
1124                    .ok_or(QueryConversionError::field("query"))?
1125                    .try_into()?;
1126                Ok(RankExpr::Knn {
1127                    query,
1128                    key: Key::from(knn.key),
1129                    limit: knn.limit,
1130                    default: knn.default,
1131                    return_rank: knn.return_rank,
1132                })
1133            }
1134            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1135                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1136            }
1137            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1138                let exprs = max
1139                    .exprs
1140                    .into_iter()
1141                    .map(RankExpr::try_from)
1142                    .collect::<Result<Vec<_>, _>>()?;
1143                Ok(RankExpr::Maximum(exprs))
1144            }
1145            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1146                let exprs = min
1147                    .exprs
1148                    .into_iter()
1149                    .map(RankExpr::try_from)
1150                    .collect::<Result<Vec<_>, _>>()?;
1151                Ok(RankExpr::Minimum(exprs))
1152            }
1153            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1154                let exprs = mul
1155                    .exprs
1156                    .into_iter()
1157                    .map(RankExpr::try_from)
1158                    .collect::<Result<Vec<_>, _>>()?;
1159                Ok(RankExpr::Multiplication(exprs))
1160            }
1161            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1162                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1163                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1164                Ok(RankExpr::Subtraction {
1165                    left: Box::new(RankExpr::try_from(*left)?),
1166                    right: Box::new(RankExpr::try_from(*right)?),
1167                })
1168            }
1169            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1170                let exprs = sum
1171                    .exprs
1172                    .into_iter()
1173                    .map(RankExpr::try_from)
1174                    .collect::<Result<Vec<_>, _>>()?;
1175                Ok(RankExpr::Summation(exprs))
1176            }
1177            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1178            None => Err(QueryConversionError::field("rank")),
1179        }
1180    }
1181}
1182
1183impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1184    type Error = QueryConversionError;
1185
1186    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1187        let proto_rank = match rank_expr {
1188            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1189                chroma_proto::RankExpr::try_from(*expr)?,
1190            )),
1191            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1192                Box::new(chroma_proto::rank_expr::RankPair {
1193                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1194                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1195                }),
1196            ),
1197            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1198                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1199            ),
1200            RankExpr::Knn {
1201                query,
1202                key,
1203                limit,
1204                default,
1205                return_rank,
1206            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1207                query: Some(query.try_into()?),
1208                key: key.to_string(),
1209                limit,
1210                default,
1211                return_rank,
1212            }),
1213            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1214                chroma_proto::RankExpr::try_from(*expr)?,
1215            )),
1216            RankExpr::Maximum(exprs) => {
1217                let proto_exprs = exprs
1218                    .into_iter()
1219                    .map(chroma_proto::RankExpr::try_from)
1220                    .collect::<Result<Vec<_>, _>>()?;
1221                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1222                    exprs: proto_exprs,
1223                })
1224            }
1225            RankExpr::Minimum(exprs) => {
1226                let proto_exprs = exprs
1227                    .into_iter()
1228                    .map(chroma_proto::RankExpr::try_from)
1229                    .collect::<Result<Vec<_>, _>>()?;
1230                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1231                    exprs: proto_exprs,
1232                })
1233            }
1234            RankExpr::Multiplication(exprs) => {
1235                let proto_exprs = exprs
1236                    .into_iter()
1237                    .map(chroma_proto::RankExpr::try_from)
1238                    .collect::<Result<Vec<_>, _>>()?;
1239                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1240                    exprs: proto_exprs,
1241                })
1242            }
1243            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1244                Box::new(chroma_proto::rank_expr::RankPair {
1245                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1246                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1247                }),
1248            ),
1249            RankExpr::Summation(exprs) => {
1250                let proto_exprs = exprs
1251                    .into_iter()
1252                    .map(chroma_proto::RankExpr::try_from)
1253                    .collect::<Result<Vec<_>, _>>()?;
1254                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1255                    exprs: proto_exprs,
1256                })
1257            }
1258            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1259        };
1260
1261        Ok(chroma_proto::RankExpr {
1262            rank: Some(proto_rank),
1263        })
1264    }
1265}
1266
1267#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1268#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1269pub enum Key {
1270    // Predefined keys
1271    Document,
1272    Embedding,
1273    Metadata,
1274    Score,
1275    MetadataField(String),
1276}
1277
1278impl Serialize for Key {
1279    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1280    where
1281        S: serde::Serializer,
1282    {
1283        match self {
1284            Key::Document => serializer.serialize_str("#document"),
1285            Key::Embedding => serializer.serialize_str("#embedding"),
1286            Key::Metadata => serializer.serialize_str("#metadata"),
1287            Key::Score => serializer.serialize_str("#score"),
1288            Key::MetadataField(field) => serializer.serialize_str(field),
1289        }
1290    }
1291}
1292
1293impl<'de> Deserialize<'de> for Key {
1294    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1295    where
1296        D: Deserializer<'de>,
1297    {
1298        let s = String::deserialize(deserializer)?;
1299        Ok(Key::from(s))
1300    }
1301}
1302
1303impl fmt::Display for Key {
1304    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1305        match self {
1306            Key::Document => write!(f, "#document"),
1307            Key::Embedding => write!(f, "#embedding"),
1308            Key::Metadata => write!(f, "#metadata"),
1309            Key::Score => write!(f, "#score"),
1310            Key::MetadataField(field) => write!(f, "{}", field),
1311        }
1312    }
1313}
1314
1315impl From<&str> for Key {
1316    fn from(s: &str) -> Self {
1317        match s {
1318            "#document" => Key::Document,
1319            "#embedding" => Key::Embedding,
1320            "#metadata" => Key::Metadata,
1321            "#score" => Key::Score,
1322            // Any other string is treated as a metadata field key
1323            field => Key::MetadataField(field.to_string()),
1324        }
1325    }
1326}
1327
1328impl From<String> for Key {
1329    fn from(s: String) -> Self {
1330        Key::from(s.as_str())
1331    }
1332}
1333
1334impl Key {
1335    /// Create a Key for a metadata field
1336    pub fn field(name: impl Into<String>) -> Self {
1337        Key::MetadataField(name.into())
1338    }
1339
1340    /// Equality: Key::field("status").eq("active")
1341    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1342        Where::Metadata(MetadataExpression {
1343            key: self.to_string(),
1344            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1345        })
1346    }
1347
1348    /// Not equal: Key::field("status").ne("deleted")
1349    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1350        Where::Metadata(MetadataExpression {
1351            key: self.to_string(),
1352            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1353        })
1354    }
1355
1356    /// Greater than: Key::field("score").gt(0.5)
1357    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1358        Where::Metadata(MetadataExpression {
1359            key: self.to_string(),
1360            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1361        })
1362    }
1363
1364    /// Greater than or equal: Key::field("score").gte(0.5)
1365    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1366        Where::Metadata(MetadataExpression {
1367            key: self.to_string(),
1368            comparison: MetadataComparison::Primitive(
1369                PrimitiveOperator::GreaterThanOrEqual,
1370                value.into(),
1371            ),
1372        })
1373    }
1374
1375    /// Less than: Key::field("score").lt(0.9)
1376    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1377        Where::Metadata(MetadataExpression {
1378            key: self.to_string(),
1379            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1380        })
1381    }
1382
1383    /// Less than or equal: Key::field("score").lte(0.9)
1384    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1385        Where::Metadata(MetadataExpression {
1386            key: self.to_string(),
1387            comparison: MetadataComparison::Primitive(
1388                PrimitiveOperator::LessThanOrEqual,
1389                value.into(),
1390            ),
1391        })
1392    }
1393
1394    /// In set: Key::field("year").is_in(vec![2023, 2024, 2025])
1395    /// Also accepts arrays, slices, and any iterator
1396    pub fn is_in<I, T>(self, values: I) -> Where
1397    where
1398        I: IntoIterator<Item = T>,
1399        Vec<T>: Into<MetadataSetValue>,
1400    {
1401        let vec: Vec<T> = values.into_iter().collect();
1402        Where::Metadata(MetadataExpression {
1403            key: self.to_string(),
1404            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1405        })
1406    }
1407
1408    /// Not in set: Key::field("status").not_in(vec!["deleted", "archived"])
1409    /// Also accepts arrays, slices, and any iterator
1410    pub fn not_in<I, T>(self, values: I) -> Where
1411    where
1412        I: IntoIterator<Item = T>,
1413        Vec<T>: Into<MetadataSetValue>,
1414    {
1415        let vec: Vec<T> = values.into_iter().collect();
1416        Where::Metadata(MetadataExpression {
1417            key: self.to_string(),
1418            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
1419        })
1420    }
1421
1422    /// Contains text: Key::Document.contains("search term")
1423    pub fn contains<S: Into<String>>(self, text: S) -> Where {
1424        Where::Document(DocumentExpression {
1425            operator: DocumentOperator::Contains,
1426            pattern: text.into(),
1427        })
1428    }
1429
1430    /// Does not contain text: Key::Document.not_contains("exclude term")
1431    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
1432        Where::Document(DocumentExpression {
1433            operator: DocumentOperator::NotContains,
1434            pattern: text.into(),
1435        })
1436    }
1437
1438    /// Regex match: Key::field("email").regex(r"^.*@example\.com$")
1439    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
1440        Where::Document(DocumentExpression {
1441            operator: DocumentOperator::Regex,
1442            pattern: pattern.into(),
1443        })
1444    }
1445
1446    /// Negative regex match: Key::field("email").not_regex(r"^.*@spam\.com$")
1447    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
1448        Where::Document(DocumentExpression {
1449            operator: DocumentOperator::NotRegex,
1450            pattern: pattern.into(),
1451        })
1452    }
1453}
1454
1455#[derive(Clone, Debug, Default, Deserialize, Serialize)]
1456pub struct Select {
1457    #[serde(default)]
1458    pub keys: HashSet<Key>,
1459}
1460
1461impl TryFrom<chroma_proto::SelectOperator> for Select {
1462    type Error = QueryConversionError;
1463
1464    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
1465        let keys = value
1466            .keys
1467            .into_iter()
1468            .map(|key| {
1469                // Try to deserialize each string as a Key
1470                serde_json::from_value(serde_json::Value::String(key))
1471                    .map_err(|_| QueryConversionError::field("keys"))
1472            })
1473            .collect::<Result<HashSet<_>, _>>()?;
1474
1475        Ok(Self { keys })
1476    }
1477}
1478
1479impl TryFrom<Select> for chroma_proto::SelectOperator {
1480    type Error = QueryConversionError;
1481
1482    fn try_from(value: Select) -> Result<Self, Self::Error> {
1483        let keys = value
1484            .keys
1485            .into_iter()
1486            .map(|key| {
1487                // Serialize each Key back to string
1488                serde_json::to_value(&key)
1489                    .ok()
1490                    .and_then(|v| v.as_str().map(String::from))
1491                    .ok_or(QueryConversionError::field("keys"))
1492            })
1493            .collect::<Result<Vec<_>, _>>()?;
1494
1495        Ok(Self { keys })
1496    }
1497}
1498
1499#[derive(Clone, Debug, Deserialize, Serialize)]
1500#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1501pub struct SearchRecord {
1502    pub id: String,
1503    pub document: Option<String>,
1504    pub embedding: Option<Vec<f32>>,
1505    pub metadata: Option<Metadata>,
1506    pub score: Option<f32>,
1507}
1508
1509impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
1510    type Error = QueryConversionError;
1511
1512    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
1513        Ok(Self {
1514            id: value.id,
1515            document: value.document,
1516            embedding: value
1517                .embedding
1518                .map(|vec| vec.try_into().map(|(v, _)| v))
1519                .transpose()?,
1520            metadata: value.metadata.map(TryInto::try_into).transpose()?,
1521            score: value.score,
1522        })
1523    }
1524}
1525
1526impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
1527    type Error = QueryConversionError;
1528
1529    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
1530        Ok(Self {
1531            id: value.id,
1532            document: value.document,
1533            embedding: value
1534                .embedding
1535                .map(|embedding| {
1536                    let embedding_dimension = embedding.len();
1537                    chroma_proto::Vector::try_from((
1538                        embedding,
1539                        ScalarEncoding::FLOAT32,
1540                        embedding_dimension,
1541                    ))
1542                })
1543                .transpose()?,
1544            metadata: value.metadata.map(Into::into),
1545            score: value.score,
1546        })
1547    }
1548}
1549
1550#[derive(Clone, Debug, Default)]
1551pub struct SearchPayloadResult {
1552    pub records: Vec<SearchRecord>,
1553}
1554
1555impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
1556    type Error = QueryConversionError;
1557
1558    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
1559        Ok(Self {
1560            records: value
1561                .records
1562                .into_iter()
1563                .map(TryInto::try_into)
1564                .collect::<Result<_, _>>()?,
1565        })
1566    }
1567}
1568
1569impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
1570    type Error = QueryConversionError;
1571
1572    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
1573        Ok(Self {
1574            records: value
1575                .records
1576                .into_iter()
1577                .map(TryInto::try_into)
1578                .collect::<Result<Vec<_>, _>>()?,
1579        })
1580    }
1581}
1582
1583#[derive(Clone, Debug)]
1584pub struct SearchResult {
1585    pub results: Vec<SearchPayloadResult>,
1586    pub pulled_log_bytes: u64,
1587}
1588
1589impl SearchResult {
1590    pub fn size_bytes(&self) -> u64 {
1591        self.results
1592            .iter()
1593            .flat_map(|result| {
1594                result.records.iter().map(|record| {
1595                    (record.id.len()
1596                        + record
1597                            .document
1598                            .as_ref()
1599                            .map(|doc| doc.len())
1600                            .unwrap_or_default()
1601                        + record
1602                            .embedding
1603                            .as_ref()
1604                            .map(|emb| size_of_val(&emb[..]))
1605                            .unwrap_or_default()
1606                        + record
1607                            .metadata
1608                            .as_ref()
1609                            .map(logical_size_of_metadata)
1610                            .unwrap_or_default()
1611                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
1612                        as u64
1613                })
1614            })
1615            .sum()
1616    }
1617}
1618
1619impl TryFrom<chroma_proto::SearchResult> for SearchResult {
1620    type Error = QueryConversionError;
1621
1622    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
1623        Ok(Self {
1624            results: value
1625                .results
1626                .into_iter()
1627                .map(TryInto::try_into)
1628                .collect::<Result<_, _>>()?,
1629            pulled_log_bytes: value.pulled_log_bytes,
1630        })
1631    }
1632}
1633
1634impl TryFrom<SearchResult> for chroma_proto::SearchResult {
1635    type Error = QueryConversionError;
1636
1637    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
1638        Ok(Self {
1639            results: value
1640                .results
1641                .into_iter()
1642                .map(TryInto::try_into)
1643                .collect::<Result<Vec<_>, _>>()?,
1644            pulled_log_bytes: value.pulled_log_bytes,
1645        })
1646    }
1647}
1648
1649/// Reciprocal Rank Fusion: combines multiple rank expressions
1650/// Formula: -sum(weight_i / (k + rank_i))
1651pub fn rrf(
1652    ranks: Vec<RankExpr>,
1653    k: Option<u32>,
1654    weights: Option<Vec<f32>>,
1655    normalize: bool,
1656) -> Result<RankExpr, QueryConversionError> {
1657    let k = k.unwrap_or(60);
1658
1659    if ranks.is_empty() {
1660        return Err(QueryConversionError::validation(
1661            "RRF requires at least one rank expression",
1662        ));
1663    }
1664
1665    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
1666
1667    if weights.len() != ranks.len() {
1668        return Err(QueryConversionError::validation(format!(
1669            "RRF weights length ({}) must match ranks length ({})",
1670            weights.len(),
1671            ranks.len()
1672        )));
1673    }
1674
1675    let weights = if normalize {
1676        let sum: f32 = weights.iter().sum();
1677        if sum == 0.0 {
1678            return Err(QueryConversionError::validation(
1679                "RRF weights sum to zero, cannot normalize",
1680            ));
1681        }
1682        weights.into_iter().map(|w| w / sum).collect()
1683    } else {
1684        weights
1685    };
1686
1687    let terms: Vec<RankExpr> = weights
1688        .into_iter()
1689        .zip(ranks)
1690        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
1691        .collect();
1692
1693    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
1694    // Using unwrap_or_else as defensive programming to avoid panic.
1695    let sum = terms
1696        .into_iter()
1697        .reduce(|a, b| a + b)
1698        .unwrap_or(RankExpr::Value(0.0));
1699    Ok(-sum)
1700}
1701
1702#[cfg(test)]
1703mod tests {
1704    use super::*;
1705
1706    #[test]
1707    fn test_key_from_string() {
1708        // Test predefined keys
1709        assert_eq!(Key::from("#document"), Key::Document);
1710        assert_eq!(Key::from("#embedding"), Key::Embedding);
1711        assert_eq!(Key::from("#metadata"), Key::Metadata);
1712        assert_eq!(Key::from("#score"), Key::Score);
1713
1714        // Test metadata field keys
1715        assert_eq!(
1716            Key::from("custom_field"),
1717            Key::MetadataField("custom_field".to_string())
1718        );
1719        assert_eq!(
1720            Key::from("author"),
1721            Key::MetadataField("author".to_string())
1722        );
1723
1724        // Test String variant
1725        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
1726        assert_eq!(
1727            Key::from("year".to_string()),
1728            Key::MetadataField("year".to_string())
1729        );
1730    }
1731
1732    #[test]
1733    fn test_query_vector_dense_proto_conversion() {
1734        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
1735        let query_vector = QueryVector::Dense(dense_vec.clone());
1736
1737        // Convert to proto
1738        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
1739
1740        // Convert back
1741        let converted: QueryVector = proto.try_into().unwrap();
1742
1743        assert_eq!(converted, query_vector);
1744        if let QueryVector::Dense(v) = converted {
1745            assert_eq!(v, dense_vec);
1746        } else {
1747            panic!("Expected dense vector");
1748        }
1749    }
1750
1751    #[test]
1752    fn test_query_vector_sparse_proto_conversion() {
1753        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
1754        let query_vector = QueryVector::Sparse(sparse.clone());
1755
1756        // Convert to proto
1757        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
1758
1759        // Convert back
1760        let converted: QueryVector = proto.try_into().unwrap();
1761
1762        assert_eq!(converted, query_vector);
1763        if let QueryVector::Sparse(s) = converted {
1764            assert_eq!(s, sparse);
1765        } else {
1766            panic!("Expected sparse vector");
1767        }
1768    }
1769
1770    #[test]
1771    fn test_filter_json_deserialization() {
1772        // For the new search API, deserialization treats the entire JSON as a where clause
1773
1774        // Test 1: Simple direct metadata comparison
1775        let simple_where = r#"{"author": "John Doe"}"#;
1776        let filter: Filter = serde_json::from_str(simple_where).unwrap();
1777        assert_eq!(filter.query_ids, None);
1778        assert!(filter.where_clause.is_some());
1779
1780        // Test 2: ID filter using #id with $in operator
1781        let id_filter_json = serde_json::json!({
1782            "#id": {
1783                "$in": ["doc1", "doc2", "doc3"]
1784            }
1785        });
1786        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
1787        assert_eq!(filter.query_ids, None);
1788        assert!(filter.where_clause.is_some());
1789
1790        // Test 3: Complex nested expression with AND, OR, and various operators
1791        let complex_json = serde_json::json!({
1792            "$and": [
1793                {
1794                    "#id": {
1795                        "$in": ["doc1", "doc2", "doc3"]
1796                    }
1797                },
1798                {
1799                    "$or": [
1800                        {
1801                            "author": {
1802                                "$eq": "John Doe"
1803                            }
1804                        },
1805                        {
1806                            "author": {
1807                                "$eq": "Jane Smith"
1808                            }
1809                        }
1810                    ]
1811                },
1812                {
1813                    "year": {
1814                        "$gte": 2020
1815                    }
1816                },
1817                {
1818                    "tags": {
1819                        "$contains": "machine-learning"
1820                    }
1821                }
1822            ]
1823        });
1824
1825        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
1826        assert_eq!(filter.query_ids, None);
1827        assert!(filter.where_clause.is_some());
1828
1829        // Verify the structure
1830        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
1831            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
1832            assert_eq!(composite.children.len(), 4);
1833
1834            // Check that the second child is an OR
1835            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
1836                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
1837                assert_eq!(or_composite.children.len(), 2);
1838            } else {
1839                panic!("Expected OR composite in second child");
1840            }
1841        } else {
1842            panic!("Expected AND composite where clause");
1843        }
1844
1845        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
1846        let mixed_operators_json = serde_json::json!({
1847            "$and": [
1848                {
1849                    "status": {
1850                        "$ne": "deleted"
1851                    }
1852                },
1853                {
1854                    "score": {
1855                        "$gt": 0.5
1856                    }
1857                },
1858                {
1859                    "score": {
1860                        "$lt": 0.9
1861                    }
1862                },
1863                {
1864                    "priority": {
1865                        "$lte": 10
1866                    }
1867                }
1868            ]
1869        });
1870
1871        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
1872        assert_eq!(filter.query_ids, None);
1873        assert!(filter.where_clause.is_some());
1874
1875        // Test 5: Deeply nested expression
1876        let deeply_nested_json = serde_json::json!({
1877            "$or": [
1878                {
1879                    "$and": [
1880                        {
1881                            "#id": {
1882                                "$in": ["id1", "id2"]
1883                            }
1884                        },
1885                        {
1886                            "$or": [
1887                                {
1888                                    "category": "tech"
1889                                },
1890                                {
1891                                    "category": "science"
1892                                }
1893                            ]
1894                        }
1895                    ]
1896                },
1897                {
1898                    "$and": [
1899                        {
1900                            "author": "Admin"
1901                        },
1902                        {
1903                            "published": true
1904                        }
1905                    ]
1906                }
1907            ]
1908        });
1909
1910        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
1911        assert_eq!(filter.query_ids, None);
1912        assert!(filter.where_clause.is_some());
1913
1914        // Verify it's an OR at the top level
1915        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
1916            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
1917            assert_eq!(composite.children.len(), 2);
1918
1919            // Both children should be AND composites
1920            for child in &composite.children {
1921                if let crate::metadata::Where::Composite(and_composite) = child {
1922                    assert_eq!(
1923                        and_composite.operator,
1924                        crate::metadata::BooleanOperator::And
1925                    );
1926                } else {
1927                    panic!("Expected AND composite in OR children");
1928                }
1929            }
1930        } else {
1931            panic!("Expected OR composite at top level");
1932        }
1933
1934        // Test 6: Single ID filter (edge case)
1935        let single_id_json = serde_json::json!({
1936            "#id": {
1937                "$eq": "single-doc-id"
1938            }
1939        });
1940
1941        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
1942        assert_eq!(filter.query_ids, None);
1943        assert!(filter.where_clause.is_some());
1944
1945        // Test 7: Empty object should create empty filter
1946        let empty_json = serde_json::json!({});
1947        let filter: Filter = serde_json::from_value(empty_json).unwrap();
1948        assert_eq!(filter.query_ids, None);
1949        // Empty object results in None where_clause
1950        assert_eq!(filter.where_clause, None);
1951
1952        // Test 8: Combining #id filter with $not_contains and numeric comparisons
1953        let advanced_json = serde_json::json!({
1954            "$and": [
1955                {
1956                    "#id": {
1957                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
1958                    }
1959                },
1960                {
1961                    "tags": {
1962                        "$not_contains": "deprecated"
1963                    }
1964                },
1965                {
1966                    "$or": [
1967                        {
1968                            "$and": [
1969                                {
1970                                    "confidence": {
1971                                        "$gte": 0.8
1972                                    }
1973                                },
1974                                {
1975                                    "verified": true
1976                                }
1977                            ]
1978                        },
1979                        {
1980                            "$and": [
1981                                {
1982                                    "confidence": {
1983                                        "$gte": 0.6
1984                                    }
1985                                },
1986                                {
1987                                    "confidence": {
1988                                        "$lt": 0.8
1989                                    }
1990                                },
1991                                {
1992                                    "reviews": {
1993                                        "$gte": 5
1994                                    }
1995                                }
1996                            ]
1997                        }
1998                    ]
1999                }
2000            ]
2001        });
2002
2003        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2004        assert_eq!(filter.query_ids, None);
2005        assert!(filter.where_clause.is_some());
2006
2007        // Verify top-level structure
2008        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2009            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2010            assert_eq!(composite.children.len(), 3);
2011        } else {
2012            panic!("Expected AND composite at top level");
2013        }
2014    }
2015
2016    #[test]
2017    fn test_limit_json_serialization() {
2018        let limit = Limit {
2019            offset: 10,
2020            limit: Some(20),
2021        };
2022
2023        let json = serde_json::to_string(&limit).unwrap();
2024        let deserialized: Limit = serde_json::from_str(&json).unwrap();
2025
2026        assert_eq!(deserialized.offset, limit.offset);
2027        assert_eq!(deserialized.limit, limit.limit);
2028    }
2029
2030    #[test]
2031    fn test_query_vector_json_serialization() {
2032        // Test dense vector
2033        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2034        let json = serde_json::to_string(&dense).unwrap();
2035        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2036        assert_eq!(deserialized, dense);
2037
2038        // Test sparse vector
2039        let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]));
2040        let json = serde_json::to_string(&sparse).unwrap();
2041        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2042        assert_eq!(deserialized, sparse);
2043    }
2044
2045    #[test]
2046    fn test_select_key_json_serialization() {
2047        use std::collections::HashSet;
2048
2049        // Test predefined keys
2050        let doc_key = Key::Document;
2051        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2052
2053        let embed_key = Key::Embedding;
2054        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
2055
2056        let meta_key = Key::Metadata;
2057        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
2058
2059        let score_key = Key::Score;
2060        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
2061
2062        // Test metadata key
2063        let custom_key = Key::MetadataField("custom_key".to_string());
2064        assert_eq!(
2065            serde_json::to_string(&custom_key).unwrap(),
2066            "\"custom_key\""
2067        );
2068
2069        // Test deserialization
2070        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
2071        assert!(matches!(deserialized, Key::Document));
2072
2073        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
2074        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
2075
2076        // Test Select struct with multiple keys
2077        let mut keys = HashSet::new();
2078        keys.insert(Key::Document);
2079        keys.insert(Key::Embedding);
2080        keys.insert(Key::MetadataField("author".to_string()));
2081
2082        let select = Select { keys };
2083        let json = serde_json::to_string(&select).unwrap();
2084        let deserialized: Select = serde_json::from_str(&json).unwrap();
2085
2086        assert_eq!(deserialized.keys.len(), 3);
2087        assert!(deserialized.keys.contains(&Key::Document));
2088        assert!(deserialized.keys.contains(&Key::Embedding));
2089        assert!(deserialized
2090            .keys
2091            .contains(&Key::MetadataField("author".to_string())));
2092    }
2093
2094    #[test]
2095    fn test_merge_basic_integers() {
2096        use std::cmp::Reverse;
2097
2098        let merge = Merge { k: 5 };
2099
2100        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
2101        let input = vec![
2102            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
2103            vec![Reverse(2), Reverse(5), Reverse(8)],
2104            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
2105        ];
2106
2107        let result = merge.merge(input);
2108
2109        // Should get top-5 smallest values (largest Reverse values)
2110        assert_eq!(result.len(), 5);
2111        assert_eq!(
2112            result,
2113            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
2114        );
2115    }
2116
2117    #[test]
2118    fn test_merge_u32_descending() {
2119        let merge = Merge { k: 6 };
2120
2121        // Regular u32 in descending order (largest first)
2122        let input = vec![
2123            vec![100u32, 75, 50, 25],
2124            vec![90, 60, 30],
2125            vec![95, 85, 70, 40, 10],
2126        ];
2127
2128        let result = merge.merge(input);
2129
2130        // Should get top-6 largest u32 values
2131        assert_eq!(result.len(), 6);
2132        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
2133    }
2134
2135    #[test]
2136    fn test_merge_i32_descending() {
2137        let merge = Merge { k: 5 };
2138
2139        // i32 values in descending order (including negatives)
2140        let input = vec![
2141            vec![50i32, 10, -10, -50],
2142            vec![30, 0, -30],
2143            vec![40, 20, -20, -40],
2144        ];
2145
2146        let result = merge.merge(input);
2147
2148        // Should get top-5 largest i32 values
2149        assert_eq!(result.len(), 5);
2150        assert_eq!(result, vec![50, 40, 30, 20, 10]);
2151    }
2152
2153    #[test]
2154    fn test_merge_with_duplicates() {
2155        let merge = Merge { k: 10 };
2156
2157        // Input with duplicates using regular u32 in descending order
2158        let input = vec![
2159            vec![100u32, 80, 80, 60, 40],
2160            vec![90, 80, 50, 30],
2161            vec![100, 70, 60, 20],
2162        ];
2163
2164        let result = merge.merge(input);
2165
2166        // Duplicates should be removed
2167        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
2168    }
2169
2170    #[test]
2171    fn test_merge_empty_vectors() {
2172        let merge = Merge { k: 5 };
2173
2174        // All empty with u32
2175        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
2176        let result = merge.merge(input);
2177        assert_eq!(result.len(), 0);
2178
2179        // Some empty, some with data (u64)
2180        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
2181        let result = merge.merge(input);
2182        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
2183
2184        // Single non-empty vector (i32)
2185        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
2186        let result = merge.merge(input);
2187        assert_eq!(result, vec![100, 50, 25]);
2188    }
2189
2190    #[test]
2191    fn test_merge_k_boundary_conditions() {
2192        // k = 0 with u32
2193        let merge = Merge { k: 0 };
2194        let input = vec![vec![100u32, 50], vec![75, 25]];
2195        let result = merge.merge(input);
2196        assert_eq!(result.len(), 0);
2197
2198        // k = 1 with i64
2199        let merge = Merge { k: 1 };
2200        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
2201        let result = merge.merge(input);
2202        assert_eq!(result, vec![1000]);
2203
2204        // k larger than total unique elements with u128
2205        let merge = Merge { k: 100 };
2206        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
2207        let result = merge.merge(input);
2208        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
2209    }
2210
2211    #[test]
2212    fn test_merge_with_strings() {
2213        let merge = Merge { k: 4 };
2214
2215        // Strings must be sorted in descending order (largest first) for the max heap merge
2216        let input = vec![
2217            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
2218            vec!["elephant".to_string(), "banana".to_string()],
2219            vec!["fish".to_string(), "cat".to_string()],
2220        ];
2221
2222        let result = merge.merge(input);
2223
2224        // Should get top-4 lexicographically largest strings
2225        assert_eq!(
2226            result,
2227            vec![
2228                "zebra".to_string(),
2229                "fish".to_string(),
2230                "elephant".to_string(),
2231                "dog".to_string()
2232            ]
2233        );
2234    }
2235
2236    #[test]
2237    fn test_merge_with_custom_struct() {
2238        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
2239        struct Score {
2240            value: i32,
2241            id: String,
2242        }
2243
2244        let merge = Merge { k: 3 };
2245
2246        // Custom structs sorted by value (descending), then by id
2247        let input = vec![
2248            vec![
2249                Score {
2250                    value: 100,
2251                    id: "a".to_string(),
2252                },
2253                Score {
2254                    value: 80,
2255                    id: "b".to_string(),
2256                },
2257                Score {
2258                    value: 60,
2259                    id: "c".to_string(),
2260                },
2261            ],
2262            vec![
2263                Score {
2264                    value: 90,
2265                    id: "d".to_string(),
2266                },
2267                Score {
2268                    value: 70,
2269                    id: "e".to_string(),
2270                },
2271            ],
2272            vec![
2273                Score {
2274                    value: 95,
2275                    id: "f".to_string(),
2276                },
2277                Score {
2278                    value: 85,
2279                    id: "g".to_string(),
2280                },
2281            ],
2282        ];
2283
2284        let result = merge.merge(input);
2285
2286        assert_eq!(result.len(), 3);
2287        assert_eq!(
2288            result[0],
2289            Score {
2290                value: 100,
2291                id: "a".to_string()
2292            }
2293        );
2294        assert_eq!(
2295            result[1],
2296            Score {
2297                value: 95,
2298                id: "f".to_string()
2299            }
2300        );
2301        assert_eq!(
2302            result[2],
2303            Score {
2304                value: 90,
2305                id: "d".to_string()
2306            }
2307        );
2308    }
2309
2310    #[test]
2311    fn test_merge_preserves_order() {
2312        use std::cmp::Reverse;
2313
2314        let merge = Merge { k: 10 };
2315
2316        // For Reverse, smaller inner values are "larger" in ordering
2317        // So vectors should be sorted with smallest inner values first
2318        let input = vec![
2319            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
2320            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
2321            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
2322        ];
2323
2324        let result = merge.merge(input);
2325
2326        // Verify output maintains order - should be sorted by Reverse ordering
2327        // which means ascending inner values
2328        for i in 1..result.len() {
2329            assert!(
2330                result[i - 1] >= result[i],
2331                "Output should be in descending Reverse order"
2332            );
2333            assert!(
2334                result[i - 1].0 <= result[i].0,
2335                "Inner values should be in ascending order"
2336            );
2337        }
2338
2339        // Check we got the right elements
2340        assert_eq!(
2341            result,
2342            vec![
2343                Reverse(1),
2344                Reverse(2),
2345                Reverse(3),
2346                Reverse(4),
2347                Reverse(5),
2348                Reverse(6),
2349                Reverse(7),
2350                Reverse(8),
2351                Reverse(9),
2352                Reverse(10)
2353            ]
2354        );
2355    }
2356
2357    #[test]
2358    fn test_merge_single_vector() {
2359        let merge = Merge { k: 3 };
2360
2361        // Single vector input with u64
2362        let input = vec![vec![1000u64, 800, 600, 400, 200]];
2363
2364        let result = merge.merge(input);
2365
2366        assert_eq!(result, vec![1000, 800, 600]);
2367    }
2368
2369    #[test]
2370    fn test_merge_all_same_values() {
2371        let merge = Merge { k: 5 };
2372
2373        // All vectors contain the same value (using i16)
2374        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
2375
2376        let result = merge.merge(input);
2377
2378        // Should deduplicate to single value
2379        assert_eq!(result, vec![42]);
2380    }
2381
2382    #[test]
2383    fn test_merge_mixed_types_sizes() {
2384        // Test with usize (common in real usage)
2385        let merge = Merge { k: 4 };
2386        let input = vec![
2387            vec![1000usize, 500, 100],
2388            vec![800, 300],
2389            vec![900, 600, 200],
2390        ];
2391        let result = merge.merge(input);
2392        assert_eq!(result, vec![1000, 900, 800, 600]);
2393
2394        // Test with negative integers (i32)
2395        let merge = Merge { k: 5 };
2396        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
2397        let result = merge.merge(input);
2398        assert_eq!(result, vec![15, 10, 5, 0, -5]);
2399    }
2400}