Skip to main content

chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::Hash,
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    ContainsOperator, DocumentExpression, DocumentOperator, Metadata, MetadataComparison,
15    MetadataExpression, MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding,
16    SetOperator, SparseVector, Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// Filter the search results.
122///
123/// Combines document ID filtering with metadata and document content predicates.
124/// For the Search API, use `where_clause` with Key expressions.
125///
126/// # Fields
127///
128/// * `query_ids` - Optional list of document IDs to filter (legacy, prefer Where expressions)
129/// * `where_clause` - Predicate on document metadata, content, or IDs
130///
131/// # Examples
132///
133/// ## Simple metadata filter
134///
135/// ```
136/// use chroma_types::operator::{Filter, Key};
137///
138/// let filter = Filter {
139///     query_ids: None,
140///     where_clause: Some(Key::field("status").eq("published")),
141/// };
142/// ```
143///
144/// ## Combined filters
145///
146/// ```
147/// use chroma_types::operator::{Filter, Key};
148///
149/// let filter = Filter {
150///     query_ids: None,
151///     where_clause: Some(
152///         Key::field("status").eq("published")
153///             & Key::field("year").gte(2020)
154///             & Key::field("category").is_in(vec!["tech", "science"])
155///     ),
156/// };
157/// ```
158///
159/// ## Document content filter
160///
161/// ```
162/// use chroma_types::operator::{Filter, Key};
163///
164/// let filter = Filter {
165///     query_ids: None,
166///     where_clause: Some(Key::Document.contains("machine learning")),
167/// };
168/// ```
169#[derive(Clone, Debug, Default)]
170pub struct Filter {
171    pub query_ids: Option<Vec<String>>,
172    pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        // For the search API, serialize directly as the where clause (or empty object if None)
181        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
182
183        match (&self.query_ids, &self.where_clause) {
184            (None, None) => {
185                // No filter at all - serialize empty object
186                let map = serializer.serialize_map(Some(0))?;
187                map.end()
188            }
189            (None, Some(where_clause)) => {
190                // Only where clause - serialize it directly
191                where_clause.serialize(serializer)
192            }
193            (Some(ids), None) => {
194                // Only query_ids - create Where clause: Key::ID.is_in(ids)
195                let id_where = Where::Metadata(MetadataExpression {
196                    key: "#id".to_string(),
197                    comparison: MetadataComparison::Set(
198                        SetOperator::In,
199                        MetadataSetValue::Str(ids.clone()),
200                    ),
201                });
202                id_where.serialize(serializer)
203            }
204            (Some(ids), Some(where_clause)) => {
205                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
206                let id_where = Where::Metadata(MetadataExpression {
207                    key: "#id".to_string(),
208                    comparison: MetadataComparison::Set(
209                        SetOperator::In,
210                        MetadataSetValue::Str(ids.clone()),
211                    ),
212                });
213                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214                combined.serialize(serializer)
215            }
216        }
217    }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        // For the new search API, the entire JSON is the where clause
226        let where_json = Value::deserialize(deserializer)?;
227        let where_clause =
228            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229                None
230            } else {
231                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232            };
233
234        Ok(Filter {
235            query_ids: None, // Always None for new search API
236            where_clause,
237        })
238    }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242    type Error = QueryConversionError;
243
244    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247        let where_clause = match (where_metadata, where_document) {
248            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249            (Some(w), None) | (None, Some(w)) => Some(w),
250            _ => None,
251        };
252
253        Ok(Self {
254            query_ids: value.ids.map(|uids| uids.ids),
255            where_clause,
256        })
257    }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261    type Error = QueryConversionError;
262
263    fn try_from(value: Filter) -> Result<Self, Self::Error> {
264        Ok(Self {
265            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267            where_document: None,
268        })
269    }
270}
271
272/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
273///
274/// # Parameters
275/// - `embedding`: The target embedding to search around
276/// - `fetch`: The number of records to fetch around the target
277#[derive(Clone, Debug)]
278pub struct Knn {
279    pub embedding: Vec<f32>,
280    pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284    fn from(value: KnnBatch) -> Self {
285        value
286            .embeddings
287            .into_iter()
288            .map(|embedding| Knn {
289                embedding,
290                fetch: value.fetch,
291            })
292            .collect()
293    }
294}
295
296/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
297///
298/// # Parameters
299/// - `embedding`: The target embedding to search around
300/// - `fetch`: The number of records to fetch around the target
301#[derive(Clone, Debug)]
302pub struct KnnBatch {
303    pub embeddings: Vec<Vec<f32>>,
304    pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308    type Error = QueryConversionError;
309
310    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311        Ok(Self {
312            embeddings: value
313                .embeddings
314                .into_iter()
315                .map(|vec| vec.try_into().map(|(v, _)| v))
316                .collect::<Result<_, _>>()?,
317            fetch: value.fetch,
318        })
319    }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323    type Error = QueryConversionError;
324
325    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326        Ok(Self {
327            embeddings: value
328                .embeddings
329                .into_iter()
330                .map(|embedding| {
331                    let dim = embedding.len();
332                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333                })
334                .collect::<Result<_, _>>()?,
335            fetch: value.fetch,
336        })
337    }
338}
339
340/// Pagination control for search results.
341///
342/// Controls how many results to return and how many to skip for pagination.
343///
344/// # Fields
345///
346/// * `offset` - Number of results to skip (default: 0)
347/// * `limit` - Maximum results to return (None = no limit)
348///
349/// # Examples
350///
351/// ```
352/// use chroma_types::operator::Limit;
353///
354/// // First page: results 0-9
355/// let limit = Limit {
356///     offset: 0,
357///     limit: Some(10),
358/// };
359///
360/// // Second page: results 10-19
361/// let limit = Limit {
362///     offset: 10,
363///     limit: Some(10),
364/// };
365///
366/// // No limit: all results
367/// let limit = Limit {
368///     offset: 0,
369///     limit: None,
370/// };
371/// ```
372#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374    #[serde(default)]
375    pub offset: u32,
376    #[serde(default)]
377    pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381    fn from(value: chroma_proto::LimitOperator) -> Self {
382        Self {
383            offset: value.offset,
384            limit: value.limit,
385        }
386    }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390    fn from(value: Limit) -> Self {
391        Self {
392            offset: value.offset,
393            limit: value.limit,
394        }
395    }
396}
397
398/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
399#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401    pub offset_id: u32,
402    pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406    fn eq(&self, other: &Self) -> bool {
407        self.offset_id.eq(&other.offset_id)
408    }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414    fn cmp(&self, other: &Self) -> Ordering {
415        self.measure
416            .total_cmp(&other.measure)
417            .then_with(|| self.offset_id.cmp(&other.offset_id))
418    }
419}
420
421impl PartialOrd for RecordMeasure {
422    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423        Some(self.cmp(other))
424    }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429    pub distances: Vec<RecordMeasure>,
430}
431
432/// The `Merge` operator selects the top records from the batch vectors of records
433/// which are all sorted in descending order. If the same record occurs multiple times
434/// only one copy will remain in the final result.
435///
436/// # Parameters
437/// - `k`: The total number of records to take after merge
438///
439/// # Usage
440/// It can be used to merge the query results from different operators
441#[derive(Clone, Debug)]
442pub struct Merge {
443    pub k: u32,
444}
445
446impl Merge {
447    pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450        let mut max_heap = batch_iters
451            .iter_mut()
452            .enumerate()
453            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454            .collect::<BinaryHeap<_>>();
455
456        let mut fusion = Vec::with_capacity(self.k as usize);
457        while let Some((m, idx)) = max_heap.pop() {
458            if self.k <= fusion.len() as u32 {
459                break;
460            }
461            if let Some(next_m) = batch_iters[idx].next() {
462                max_heap.push((next_m, idx));
463            }
464            if fusion.last().is_some_and(|tail| tail == &m) {
465                continue;
466            }
467            fusion.push(m);
468        }
469        fusion
470    }
471}
472
473/// The `Projection` operator retrieves record content by offset ids
474///
475/// # Parameters
476/// - `document`: Whether to retrieve document
477/// - `embedding`: Whether to retrieve embedding
478/// - `metadata`: Whether to retrieve metadata
479#[derive(Clone, Debug, Default)]
480pub struct Projection {
481    pub document: bool,
482    pub embedding: bool,
483    pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487    fn from(value: chroma_proto::ProjectionOperator) -> Self {
488        Self {
489            document: value.document,
490            embedding: value.embedding,
491            metadata: value.metadata,
492        }
493    }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497    fn from(value: Projection) -> Self {
498        Self {
499            document: value.document,
500            embedding: value.embedding,
501            metadata: value.metadata,
502        }
503    }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508    pub id: String,
509    pub document: Option<String>,
510    pub embedding: Option<Vec<f32>>,
511    pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515    pub fn size_bytes(&self) -> u64 {
516        (self.id.len()
517            + self
518                .document
519                .as_ref()
520                .map(|doc| doc.len())
521                .unwrap_or_default()
522            + self
523                .embedding
524                .as_ref()
525                .map(|emb| size_of_val(&emb[..]))
526                .unwrap_or_default()
527            + self
528                .metadata
529                .as_ref()
530                .map(logical_size_of_metadata)
531                .unwrap_or_default()) as u64
532    }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538    type Error = QueryConversionError;
539
540    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541        Ok(Self {
542            id: value.id,
543            document: value.document,
544            embedding: value
545                .embedding
546                .map(|vec| vec.try_into().map(|(v, _)| v))
547                .transpose()?,
548            metadata: value.metadata.map(TryInto::try_into).transpose()?,
549        })
550    }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554    type Error = QueryConversionError;
555
556    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557        Ok(Self {
558            id: value.id,
559            document: value.document,
560            embedding: value
561                .embedding
562                .map(|embedding| {
563                    let embedding_dimension = embedding.len();
564                    chroma_proto::Vector::try_from((
565                        embedding,
566                        ScalarEncoding::FLOAT32,
567                        embedding_dimension,
568                    ))
569                })
570                .transpose()?,
571            metadata: value.metadata.map(|metadata| metadata.into()),
572        })
573    }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578    pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583    pub pulled_log_bytes: u64,
584    pub result: ProjectionOutput,
585}
586
587impl GetResult {
588    pub fn size_bytes(&self) -> u64 {
589        self.result
590            .records
591            .iter()
592            .map(ProjectionRecord::size_bytes)
593            .sum()
594    }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598    type Error = QueryConversionError;
599
600    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601        Ok(Self {
602            pulled_log_bytes: value.pulled_log_bytes,
603            result: ProjectionOutput {
604                records: value
605                    .records
606                    .into_iter()
607                    .map(TryInto::try_into)
608                    .collect::<Result<_, _>>()?,
609            },
610        })
611    }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615    type Error = QueryConversionError;
616
617    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618        Ok(Self {
619            pulled_log_bytes: value.pulled_log_bytes,
620            records: value
621                .result
622                .records
623                .into_iter()
624                .map(TryInto::try_into)
625                .collect::<Result<_, _>>()?,
626        })
627    }
628}
629
630/// The `KnnProjection` operator retrieves record content by offset ids
631/// It is based on `ProjectionOperator`, and it attaches the distance
632/// of the records to the target embedding to the record content
633///
634/// # Parameters
635/// - `projection`: The parameters of the `ProjectionOperator`
636/// - `distance`: Whether to attach distance information
637#[derive(Clone, Debug)]
638pub struct KnnProjection {
639    pub projection: Projection,
640    pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644    type Error = QueryConversionError;
645
646    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647        Ok(Self {
648            projection: value
649                .projection
650                .ok_or(QueryConversionError::field("projection"))?
651                .into(),
652            distance: value.distance,
653        })
654    }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658    fn from(value: KnnProjection) -> Self {
659        Self {
660            projection: Some(value.projection.into()),
661            distance: value.distance,
662        }
663    }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668    pub record: ProjectionRecord,
669    pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673    type Error = QueryConversionError;
674
675    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676        Ok(Self {
677            record: value
678                .record
679                .ok_or(QueryConversionError::field("record"))?
680                .try_into()?,
681            distance: value.distance,
682        })
683    }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687    type Error = QueryConversionError;
688
689    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690        Ok(Self {
691            record: Some(value.record.try_into()?),
692            distance: value.distance,
693        })
694    }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699    pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703    type Error = QueryConversionError;
704
705    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706        Ok(Self {
707            records: value
708                .records
709                .into_iter()
710                .map(TryInto::try_into)
711                .collect::<Result<_, _>>()?,
712        })
713    }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717    type Error = QueryConversionError;
718
719    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720        Ok(Self {
721            records: value
722                .records
723                .into_iter()
724                .map(TryInto::try_into)
725                .collect::<Result<_, _>>()?,
726        })
727    }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732    pub pulled_log_bytes: u64,
733    pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737    pub fn size_bytes(&self) -> u64 {
738        self.results
739            .iter()
740            .flat_map(|res| {
741                res.records
742                    .iter()
743                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744            })
745            .sum()
746    }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750    type Error = QueryConversionError;
751
752    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753        Ok(Self {
754            pulled_log_bytes: value.pulled_log_bytes,
755            results: value
756                .results
757                .into_iter()
758                .map(TryInto::try_into)
759                .collect::<Result<_, _>>()?,
760        })
761    }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765    type Error = QueryConversionError;
766
767    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768        Ok(Self {
769            pulled_log_bytes: value.pulled_log_bytes,
770            results: value
771                .results
772                .into_iter()
773                .map(TryInto::try_into)
774                .collect::<Result<_, _>>()?,
775        })
776    }
777}
778
779/// A query vector for KNN search.
780///
781/// Supports both dense and sparse vector formats.
782///
783/// # Variants
784///
785/// ## Dense
786///
787/// Standard dense embeddings as a vector of floats.
788///
789/// ```
790/// use chroma_types::operator::QueryVector;
791///
792/// let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3, 0.4]);
793/// ```
794///
795/// ## Sparse
796///
797/// Sparse vectors with explicit indices and values.
798///
799/// ```
800/// use chroma_types::operator::QueryVector;
801/// use chroma_types::SparseVector;
802///
803/// let sparse = QueryVector::Sparse(SparseVector::new(
804///     vec![0, 5, 10, 50],      // indices
805///     vec![0.5, 0.3, 0.8, 0.2], // values
806/// ).unwrap());
807/// ```
808///
809/// # Examples
810///
811/// ## Dense vector in KNN
812///
813/// ```
814/// use chroma_types::operator::{RankExpr, QueryVector, Key};
815///
816/// let rank = RankExpr::Knn {
817///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
818///     key: Key::Embedding,
819///     limit: 100,
820///     default: None,
821///     return_rank: false,
822/// };
823/// ```
824///
825/// ## Sparse vector in KNN
826///
827/// ```
828/// use chroma_types::operator::{RankExpr, QueryVector, Key};
829/// use chroma_types::SparseVector;
830///
831/// let rank = RankExpr::Knn {
832///     query: QueryVector::Sparse(SparseVector::new(
833///         vec![1, 5, 10],
834///         vec![0.5, 0.3, 0.8],
835///     ).unwrap()),
836///     key: Key::field("sparse_embedding"),
837///     limit: 100,
838///     default: None,
839///     return_rank: false,
840/// };
841/// ```
842#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845    Dense(Vec<f32>),
846    Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850    type Error = QueryConversionError;
851
852    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854        match vector {
855            chroma_proto::query_vector::Vector::Dense(dense) => {
856                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857            }
858            chroma_proto::query_vector::Vector::Sparse(sparse) => {
859                Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860                    QueryConversionError::validation("sparse vector length mismatch")
861                })?))
862            }
863        }
864    }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868    type Error = QueryConversionError;
869
870    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871        match value {
872            QueryVector::Dense(vec) => {
873                let dim = vec.len();
874                Ok(chroma_proto::QueryVector {
875                    vector: Some(chroma_proto::query_vector::Vector::Dense(
876                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877                    )),
878                })
879            }
880            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882            }),
883        }
884    }
885}
886
887impl From<Vec<f32>> for QueryVector {
888    fn from(vec: Vec<f32>) -> Self {
889        QueryVector::Dense(vec)
890    }
891}
892
893impl From<SparseVector> for QueryVector {
894    fn from(sparse: SparseVector) -> Self {
895        QueryVector::Sparse(sparse)
896    }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901    pub query: QueryVector,
902    pub key: Key,
903    pub limit: u32,
904}
905
906/// Wrapper for ranking expressions in search queries.
907///
908/// Contains an optional ranking expression. When None, results are returned in
909/// natural storage order without scoring.
910///
911/// # Fields
912///
913/// * `expr` - The ranking expression (None = no ranking)
914///
915/// # Examples
916///
917/// ```
918/// use chroma_types::operator::{Rank, RankExpr, QueryVector, Key};
919///
920/// // With ranking
921/// let rank = Rank {
922///     expr: Some(RankExpr::Knn {
923///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
924///         key: Key::Embedding,
925///         limit: 100,
926///         default: None,
927///         return_rank: false,
928///     }),
929/// };
930///
931/// // No ranking (natural order)
932/// let rank = Rank {
933///     expr: None,
934/// };
935/// ```
936#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939    pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943    pub fn knn_queries(&self) -> Vec<KnnQuery> {
944        self.expr
945            .as_ref()
946            .map(RankExpr::knn_queries)
947            .unwrap_or_default()
948    }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952    type Error = QueryConversionError;
953
954    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955        Ok(Rank {
956            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957        })
958    }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962    type Error = QueryConversionError;
963
964    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965        Ok(chroma_proto::RankOperator {
966            expr: rank.expr.map(TryInto::try_into).transpose()?,
967        })
968    }
969}
970
971/// A ranking expression for scoring and ordering search results.
972///
973/// Ranking expressions determine which documents appear in results and their order.
974/// Lower scores indicate better matches (distance-based scoring).
975///
976/// # Variants
977///
978/// ## Knn - K-Nearest Neighbor Search
979///
980/// The primary ranking method for vector similarity search.
981///
982/// ```
983/// use chroma_types::operator::{RankExpr, QueryVector, Key};
984///
985/// let rank = RankExpr::Knn {
986///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
987///     key: Key::Embedding,
988///     limit: 100,        // Consider top 100 candidates
989///     default: None,     // No default score for missing documents
990///     return_rank: false, // Return distances, not rank positions
991/// };
992/// ```
993///
994/// ## Value - Constant
995///
996/// Represents a constant score.
997///
998/// ```
999/// use chroma_types::operator::RankExpr;
1000///
1001/// let rank = RankExpr::Value(0.5);
1002/// ```
1003///
1004/// ## Arithmetic Operations
1005///
1006/// Combine ranking expressions using standard operators (+, -, *, /).
1007///
1008/// ```
1009/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1010///
1011/// let knn1 = RankExpr::Knn {
1012///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1013///     key: Key::Embedding,
1014///     limit: 100,
1015///     default: None,
1016///     return_rank: false,
1017/// };
1018///
1019/// let knn2 = RankExpr::Knn {
1020///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1021///     key: Key::field("other_embedding"),
1022///     limit: 100,
1023///     default: None,
1024///     return_rank: false,
1025/// };
1026///
1027/// // Weighted combination: 70% knn1 + 30% knn2
1028/// let combined = knn1 * 0.7 + knn2 * 0.3;
1029///
1030/// // Normalized
1031/// let normalized = combined / 2.0;
1032/// ```
1033///
1034/// ## Mathematical Functions
1035///
1036/// Apply mathematical transformations to scores.
1037///
1038/// ```
1039/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1040///
1041/// let knn = RankExpr::Knn {
1042///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1043///     key: Key::Embedding,
1044///     limit: 100,
1045///     default: None,
1046///     return_rank: false,
1047/// };
1048///
1049/// // Exponential - amplifies differences
1050/// let amplified = knn.clone().exp();
1051///
1052/// // Logarithm - compresses range (add constant to avoid log(0))
1053/// let compressed = (knn.clone() + 1.0).log();
1054///
1055/// // Absolute value
1056/// let absolute = knn.clone().abs();
1057///
1058/// // Min/Max - clamping
1059/// let clamped = knn.min(1.0).max(0.0);
1060/// ```
1061///
1062/// # Examples
1063///
1064/// ## Basic vector search
1065///
1066/// ```
1067/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1068///
1069/// let rank = RankExpr::Knn {
1070///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1071///     key: Key::Embedding,
1072///     limit: 100,
1073///     default: None,
1074///     return_rank: false,
1075/// };
1076/// ```
1077///
1078/// ## Hybrid search with weighted combination
1079///
1080/// ```
1081/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1082///
1083/// let dense = RankExpr::Knn {
1084///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1085///     key: Key::Embedding,
1086///     limit: 200,
1087///     default: None,
1088///     return_rank: false,
1089/// };
1090///
1091/// let sparse = RankExpr::Knn {
1092///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]), // Use sparse in practice
1093///     key: Key::field("sparse_embedding"),
1094///     limit: 200,
1095///     default: None,
1096///     return_rank: false,
1097/// };
1098///
1099/// // 70% semantic + 30% keyword
1100/// let hybrid = dense * 0.7 + sparse * 0.3;
1101/// ```
1102///
1103/// ## Reciprocal Rank Fusion (RRF)
1104///
1105/// Use the `rrf()` function for combining rankings with different score scales.
1106///
1107/// ```
1108/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
1109///
1110/// let dense = RankExpr::Knn {
1111///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1112///     key: Key::Embedding,
1113///     limit: 200,
1114///     default: None,
1115///     return_rank: true, // RRF requires rank positions
1116/// };
1117///
1118/// let sparse = RankExpr::Knn {
1119///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1120///     key: Key::field("sparse_embedding"),
1121///     limit: 200,
1122///     default: None,
1123///     return_rank: true, // RRF requires rank positions
1124/// };
1125///
1126/// let rrf_rank = rrf(
1127///     vec![dense, sparse],
1128///     Some(60),           // k parameter (smoothing)
1129///     Some(vec![0.7, 0.3]), // weights
1130///     false,              // normalize weights
1131/// ).unwrap();
1132/// ```
1133#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135    #[serde(rename = "$abs")]
1136    Absolute(Box<RankExpr>),
1137    #[serde(rename = "$div")]
1138    Division {
1139        left: Box<RankExpr>,
1140        right: Box<RankExpr>,
1141    },
1142    #[serde(rename = "$exp")]
1143    Exponentiation(Box<RankExpr>),
1144    #[serde(rename = "$knn")]
1145    Knn {
1146        query: QueryVector,
1147        #[serde(default = "RankExpr::default_knn_key")]
1148        key: Key,
1149        #[serde(default = "RankExpr::default_knn_limit")]
1150        limit: u32,
1151        #[serde(default)]
1152        default: Option<f32>,
1153        #[serde(default)]
1154        return_rank: bool,
1155    },
1156    #[serde(rename = "$log")]
1157    Logarithm(Box<RankExpr>),
1158    #[serde(rename = "$max")]
1159    Maximum(Vec<RankExpr>),
1160    #[serde(rename = "$min")]
1161    Minimum(Vec<RankExpr>),
1162    #[serde(rename = "$mul")]
1163    Multiplication(Vec<RankExpr>),
1164    #[serde(rename = "$sub")]
1165    Subtraction {
1166        left: Box<RankExpr>,
1167        right: Box<RankExpr>,
1168    },
1169    #[serde(rename = "$sum")]
1170    Summation(Vec<RankExpr>),
1171    #[serde(rename = "$val")]
1172    Value(f32),
1173}
1174
1175impl RankExpr {
1176    pub fn default_knn_key() -> Key {
1177        Key::Embedding
1178    }
1179
1180    pub fn default_knn_limit() -> u32 {
1181        16
1182    }
1183
1184    pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185        match self {
1186            RankExpr::Absolute(expr)
1187            | RankExpr::Exponentiation(expr)
1188            | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190                .knn_queries()
1191                .into_iter()
1192                .chain(right.knn_queries())
1193                .collect(),
1194            RankExpr::Maximum(exprs)
1195            | RankExpr::Minimum(exprs)
1196            | RankExpr::Multiplication(exprs)
1197            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198            RankExpr::Value(_) => Vec::new(),
1199            RankExpr::Knn {
1200                query,
1201                key,
1202                limit,
1203                default: _,
1204                return_rank: _,
1205            } => vec![KnnQuery {
1206                query: query.clone(),
1207                key: key.clone(),
1208                limit: *limit,
1209            }],
1210        }
1211    }
1212
1213    /// Applies exponential transformation: e^rank.
1214    ///
1215    /// Amplifies differences between scores.
1216    ///
1217    /// # Examples
1218    ///
1219    /// ```
1220    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1221    ///
1222    /// let knn = RankExpr::Knn {
1223    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1224    ///     key: Key::Embedding,
1225    ///     limit: 100,
1226    ///     default: None,
1227    ///     return_rank: false,
1228    /// };
1229    ///
1230    /// let amplified = knn.exp();
1231    /// ```
1232    pub fn exp(self) -> Self {
1233        RankExpr::Exponentiation(Box::new(self))
1234    }
1235
1236    /// Applies natural logarithm transformation: ln(rank).
1237    ///
1238    /// Compresses the score range. Add a constant to avoid log(0).
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1244    ///
1245    /// let knn = RankExpr::Knn {
1246    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1247    ///     key: Key::Embedding,
1248    ///     limit: 100,
1249    ///     default: None,
1250    ///     return_rank: false,
1251    /// };
1252    ///
1253    /// // Add constant to avoid log(0)
1254    /// let compressed = (knn + 1.0).log();
1255    /// ```
1256    pub fn log(self) -> Self {
1257        RankExpr::Logarithm(Box::new(self))
1258    }
1259
1260    /// Takes absolute value of the ranking expression.
1261    ///
1262    /// # Examples
1263    ///
1264    /// ```
1265    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1266    ///
1267    /// let knn1 = RankExpr::Knn {
1268    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1269    ///     key: Key::Embedding,
1270    ///     limit: 100,
1271    ///     default: None,
1272    ///     return_rank: false,
1273    /// };
1274    ///
1275    /// let knn2 = RankExpr::Knn {
1276    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1277    ///     key: Key::field("other"),
1278    ///     limit: 100,
1279    ///     default: None,
1280    ///     return_rank: false,
1281    /// };
1282    ///
1283    /// // Absolute difference
1284    /// let diff = (knn1 - knn2).abs();
1285    /// ```
1286    pub fn abs(self) -> Self {
1287        RankExpr::Absolute(Box::new(self))
1288    }
1289
1290    /// Returns maximum of this expression and another.
1291    ///
1292    /// Can be chained to clamp scores to a maximum value.
1293    ///
1294    /// # Examples
1295    ///
1296    /// ```
1297    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1298    ///
1299    /// let knn = RankExpr::Knn {
1300    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1301    ///     key: Key::Embedding,
1302    ///     limit: 100,
1303    ///     default: None,
1304    ///     return_rank: false,
1305    /// };
1306    ///
1307    /// // Clamp to maximum of 1.0
1308    /// let clamped = knn.clone().max(1.0);
1309    ///
1310    /// // Clamp to range [0.0, 1.0]
1311    /// let range_clamped = knn.min(0.0).max(1.0);
1312    /// ```
1313    pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314        let other = other.into();
1315
1316        match self {
1317            RankExpr::Maximum(mut exprs) => match other {
1318                RankExpr::Maximum(other_exprs) => {
1319                    exprs.extend(other_exprs);
1320                    RankExpr::Maximum(exprs)
1321                }
1322                _ => {
1323                    exprs.push(other);
1324                    RankExpr::Maximum(exprs)
1325                }
1326            },
1327            _ => match other {
1328                RankExpr::Maximum(mut exprs) => {
1329                    exprs.insert(0, self);
1330                    RankExpr::Maximum(exprs)
1331                }
1332                _ => RankExpr::Maximum(vec![self, other]),
1333            },
1334        }
1335    }
1336
1337    /// Returns minimum of this expression and another.
1338    ///
1339    /// Can be chained to clamp scores to a minimum value.
1340    ///
1341    /// # Examples
1342    ///
1343    /// ```
1344    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1345    ///
1346    /// let knn = RankExpr::Knn {
1347    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1348    ///     key: Key::Embedding,
1349    ///     limit: 100,
1350    ///     default: None,
1351    ///     return_rank: false,
1352    /// };
1353    ///
1354    /// // Clamp to minimum of 0.0 (ensure non-negative)
1355    /// let clamped = knn.clone().min(0.0);
1356    ///
1357    /// // Clamp to range [0.0, 1.0]
1358    /// let range_clamped = knn.min(0.0).max(1.0);
1359    /// ```
1360    pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361        let other = other.into();
1362
1363        match self {
1364            RankExpr::Minimum(mut exprs) => match other {
1365                RankExpr::Minimum(other_exprs) => {
1366                    exprs.extend(other_exprs);
1367                    RankExpr::Minimum(exprs)
1368                }
1369                _ => {
1370                    exprs.push(other);
1371                    RankExpr::Minimum(exprs)
1372                }
1373            },
1374            _ => match other {
1375                RankExpr::Minimum(mut exprs) => {
1376                    exprs.insert(0, self);
1377                    RankExpr::Minimum(exprs)
1378                }
1379                _ => RankExpr::Minimum(vec![self, other]),
1380            },
1381        }
1382    }
1383}
1384
1385impl Add for RankExpr {
1386    type Output = RankExpr;
1387
1388    fn add(self, rhs: Self) -> Self::Output {
1389        match self {
1390            RankExpr::Summation(mut exprs) => match rhs {
1391                RankExpr::Summation(rhs_exprs) => {
1392                    exprs.extend(rhs_exprs);
1393                    RankExpr::Summation(exprs)
1394                }
1395                _ => {
1396                    exprs.push(rhs);
1397                    RankExpr::Summation(exprs)
1398                }
1399            },
1400            _ => match rhs {
1401                RankExpr::Summation(mut exprs) => {
1402                    exprs.insert(0, self);
1403                    RankExpr::Summation(exprs)
1404                }
1405                _ => RankExpr::Summation(vec![self, rhs]),
1406            },
1407        }
1408    }
1409}
1410
1411impl Add<f32> for RankExpr {
1412    type Output = RankExpr;
1413
1414    fn add(self, rhs: f32) -> Self::Output {
1415        self + RankExpr::Value(rhs)
1416    }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420    type Output = RankExpr;
1421
1422    fn add(self, rhs: RankExpr) -> Self::Output {
1423        RankExpr::Value(self) + rhs
1424    }
1425}
1426
1427impl Sub for RankExpr {
1428    type Output = RankExpr;
1429
1430    fn sub(self, rhs: Self) -> Self::Output {
1431        RankExpr::Subtraction {
1432            left: Box::new(self),
1433            right: Box::new(rhs),
1434        }
1435    }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439    type Output = RankExpr;
1440
1441    fn sub(self, rhs: f32) -> Self::Output {
1442        self - RankExpr::Value(rhs)
1443    }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447    type Output = RankExpr;
1448
1449    fn sub(self, rhs: RankExpr) -> Self::Output {
1450        RankExpr::Value(self) - rhs
1451    }
1452}
1453
1454impl Mul for RankExpr {
1455    type Output = RankExpr;
1456
1457    fn mul(self, rhs: Self) -> Self::Output {
1458        match self {
1459            RankExpr::Multiplication(mut exprs) => match rhs {
1460                RankExpr::Multiplication(rhs_exprs) => {
1461                    exprs.extend(rhs_exprs);
1462                    RankExpr::Multiplication(exprs)
1463                }
1464                _ => {
1465                    exprs.push(rhs);
1466                    RankExpr::Multiplication(exprs)
1467                }
1468            },
1469            _ => match rhs {
1470                RankExpr::Multiplication(mut exprs) => {
1471                    exprs.insert(0, self);
1472                    RankExpr::Multiplication(exprs)
1473                }
1474                _ => RankExpr::Multiplication(vec![self, rhs]),
1475            },
1476        }
1477    }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481    type Output = RankExpr;
1482
1483    fn mul(self, rhs: f32) -> Self::Output {
1484        self * RankExpr::Value(rhs)
1485    }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489    type Output = RankExpr;
1490
1491    fn mul(self, rhs: RankExpr) -> Self::Output {
1492        RankExpr::Value(self) * rhs
1493    }
1494}
1495
1496impl Div for RankExpr {
1497    type Output = RankExpr;
1498
1499    fn div(self, rhs: Self) -> Self::Output {
1500        RankExpr::Division {
1501            left: Box::new(self),
1502            right: Box::new(rhs),
1503        }
1504    }
1505}
1506
1507impl Div<f32> for RankExpr {
1508    type Output = RankExpr;
1509
1510    fn div(self, rhs: f32) -> Self::Output {
1511        self / RankExpr::Value(rhs)
1512    }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516    type Output = RankExpr;
1517
1518    fn div(self, rhs: RankExpr) -> Self::Output {
1519        RankExpr::Value(self) / rhs
1520    }
1521}
1522
1523impl Neg for RankExpr {
1524    type Output = RankExpr;
1525
1526    fn neg(self) -> Self::Output {
1527        RankExpr::Value(-1.0) * self
1528    }
1529}
1530
1531impl From<f32> for RankExpr {
1532    fn from(v: f32) -> Self {
1533        RankExpr::Value(v)
1534    }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538    type Error = QueryConversionError;
1539
1540    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541        match proto_expr.rank {
1542            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544            }
1545            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548                Ok(RankExpr::Division {
1549                    left: Box::new(RankExpr::try_from(*left)?),
1550                    right: Box::new(RankExpr::try_from(*right)?),
1551                })
1552            }
1553            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555            ),
1556            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557                let query = knn
1558                    .query
1559                    .ok_or(QueryConversionError::field("query"))?
1560                    .try_into()?;
1561                Ok(RankExpr::Knn {
1562                    query,
1563                    key: Key::from(knn.key),
1564                    limit: knn.limit,
1565                    default: knn.default,
1566                    return_rank: knn.return_rank,
1567                })
1568            }
1569            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571            }
1572            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573                let exprs = max
1574                    .exprs
1575                    .into_iter()
1576                    .map(RankExpr::try_from)
1577                    .collect::<Result<Vec<_>, _>>()?;
1578                Ok(RankExpr::Maximum(exprs))
1579            }
1580            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581                let exprs = min
1582                    .exprs
1583                    .into_iter()
1584                    .map(RankExpr::try_from)
1585                    .collect::<Result<Vec<_>, _>>()?;
1586                Ok(RankExpr::Minimum(exprs))
1587            }
1588            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589                let exprs = mul
1590                    .exprs
1591                    .into_iter()
1592                    .map(RankExpr::try_from)
1593                    .collect::<Result<Vec<_>, _>>()?;
1594                Ok(RankExpr::Multiplication(exprs))
1595            }
1596            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599                Ok(RankExpr::Subtraction {
1600                    left: Box::new(RankExpr::try_from(*left)?),
1601                    right: Box::new(RankExpr::try_from(*right)?),
1602                })
1603            }
1604            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605                let exprs = sum
1606                    .exprs
1607                    .into_iter()
1608                    .map(RankExpr::try_from)
1609                    .collect::<Result<Vec<_>, _>>()?;
1610                Ok(RankExpr::Summation(exprs))
1611            }
1612            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613            None => Err(QueryConversionError::field("rank")),
1614        }
1615    }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619    type Error = QueryConversionError;
1620
1621    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622        let proto_rank = match rank_expr {
1623            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624                chroma_proto::RankExpr::try_from(*expr)?,
1625            )),
1626            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627                Box::new(chroma_proto::rank_expr::RankPair {
1628                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630                }),
1631            ),
1632            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634            ),
1635            RankExpr::Knn {
1636                query,
1637                key,
1638                limit,
1639                default,
1640                return_rank,
1641            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642                query: Some(query.try_into()?),
1643                key: key.to_string(),
1644                limit,
1645                default,
1646                return_rank,
1647            }),
1648            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649                chroma_proto::RankExpr::try_from(*expr)?,
1650            )),
1651            RankExpr::Maximum(exprs) => {
1652                let proto_exprs = exprs
1653                    .into_iter()
1654                    .map(chroma_proto::RankExpr::try_from)
1655                    .collect::<Result<Vec<_>, _>>()?;
1656                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657                    exprs: proto_exprs,
1658                })
1659            }
1660            RankExpr::Minimum(exprs) => {
1661                let proto_exprs = exprs
1662                    .into_iter()
1663                    .map(chroma_proto::RankExpr::try_from)
1664                    .collect::<Result<Vec<_>, _>>()?;
1665                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666                    exprs: proto_exprs,
1667                })
1668            }
1669            RankExpr::Multiplication(exprs) => {
1670                let proto_exprs = exprs
1671                    .into_iter()
1672                    .map(chroma_proto::RankExpr::try_from)
1673                    .collect::<Result<Vec<_>, _>>()?;
1674                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675                    exprs: proto_exprs,
1676                })
1677            }
1678            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679                Box::new(chroma_proto::rank_expr::RankPair {
1680                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682                }),
1683            ),
1684            RankExpr::Summation(exprs) => {
1685                let proto_exprs = exprs
1686                    .into_iter()
1687                    .map(chroma_proto::RankExpr::try_from)
1688                    .collect::<Result<Vec<_>, _>>()?;
1689                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690                    exprs: proto_exprs,
1691                })
1692            }
1693            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694        };
1695
1696        Ok(chroma_proto::RankExpr {
1697            rank: Some(proto_rank),
1698        })
1699    }
1700}
1701
1702/// Represents a field key in search queries.
1703///
1704/// Used for both selecting fields to return and building filter expressions.
1705/// Predefined keys access special fields, while custom keys access metadata.
1706///
1707/// # Predefined Keys
1708///
1709/// - `Key::Document` - Document text content (`#document`)
1710/// - `Key::Embedding` - Vector embeddings (`#embedding`)
1711/// - `Key::Metadata` - All metadata fields (`#metadata`)
1712/// - `Key::Score` - Search scores (`#score`)
1713///
1714/// # Custom Keys
1715///
1716/// Use `Key::field()` or `Key::from()` to reference metadata fields:
1717///
1718/// ```
1719/// use chroma_types::operator::Key;
1720///
1721/// let key = Key::field("author");
1722/// let key = Key::from("title");
1723/// ```
1724///
1725/// # Examples
1726///
1727/// ## Building filters
1728///
1729/// ```
1730/// use chroma_types::operator::Key;
1731///
1732/// // Equality
1733/// let filter = Key::field("status").eq("published");
1734///
1735/// // Comparisons
1736/// let filter = Key::field("year").gte(2020);
1737/// let filter = Key::field("score").lt(0.9);
1738///
1739/// // Set operations
1740/// let filter = Key::field("category").is_in(vec!["tech", "science"]);
1741/// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
1742///
1743/// // Document content
1744/// let filter = Key::Document.contains("machine learning");
1745/// let filter = Key::Document.regex(r"\bAPI\b");
1746///
1747/// // Combining filters
1748/// let filter = Key::field("status").eq("published")
1749///     & Key::field("year").gte(2020);
1750/// ```
1751///
1752/// ## Selecting fields
1753///
1754/// ```
1755/// use chroma_types::plan::SearchPayload;
1756/// use chroma_types::operator::Key;
1757///
1758/// let search = SearchPayload::default()
1759///     .select([
1760///         Key::Document,
1761///         Key::Score,
1762///         Key::field("title"),
1763///         Key::field("author"),
1764///     ]);
1765/// ```
1766#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769    // Predefined keys
1770    Document,
1771    Embedding,
1772    Metadata,
1773    Score,
1774    MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779    where
1780        S: serde::Serializer,
1781    {
1782        match self {
1783            Key::Document => serializer.serialize_str("#document"),
1784            Key::Embedding => serializer.serialize_str("#embedding"),
1785            Key::Metadata => serializer.serialize_str("#metadata"),
1786            Key::Score => serializer.serialize_str("#score"),
1787            Key::MetadataField(field) => serializer.serialize_str(field),
1788        }
1789    }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794    where
1795        D: Deserializer<'de>,
1796    {
1797        let s = String::deserialize(deserializer)?;
1798        Ok(Key::from(s))
1799    }
1800}
1801
1802impl fmt::Display for Key {
1803    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804        match self {
1805            Key::Document => write!(f, "#document"),
1806            Key::Embedding => write!(f, "#embedding"),
1807            Key::Metadata => write!(f, "#metadata"),
1808            Key::Score => write!(f, "#score"),
1809            Key::MetadataField(field) => write!(f, "{}", field),
1810        }
1811    }
1812}
1813
1814impl From<&str> for Key {
1815    fn from(s: &str) -> Self {
1816        match s {
1817            "#document" => Key::Document,
1818            "#embedding" => Key::Embedding,
1819            "#metadata" => Key::Metadata,
1820            "#score" => Key::Score,
1821            // Any other string is treated as a metadata field key
1822            field => Key::MetadataField(field.to_string()),
1823        }
1824    }
1825}
1826
1827impl From<String> for Key {
1828    fn from(s: String) -> Self {
1829        Key::from(s.as_str())
1830    }
1831}
1832
1833impl Key {
1834    /// Creates a Key for a custom metadata field.
1835    ///
1836    /// # Examples
1837    ///
1838    /// ```
1839    /// use chroma_types::operator::Key;
1840    ///
1841    /// let status = Key::field("status");
1842    /// let year = Key::field("year");
1843    /// let author = Key::field("author");
1844    /// ```
1845    pub fn field(name: impl Into<String>) -> Self {
1846        Key::MetadataField(name.into())
1847    }
1848
1849    /// Creates an equality filter: `field == value`.
1850    ///
1851    /// # Examples
1852    ///
1853    /// ```
1854    /// use chroma_types::operator::Key;
1855    ///
1856    /// // String equality
1857    /// let filter = Key::field("status").eq("published");
1858    ///
1859    /// // Numeric equality
1860    /// let filter = Key::field("count").eq(42);
1861    ///
1862    /// // Boolean equality
1863    /// let filter = Key::field("featured").eq(true);
1864    /// ```
1865    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866        Where::Metadata(MetadataExpression {
1867            key: self.to_string(),
1868            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869        })
1870    }
1871
1872    /// Creates an inequality filter: `field != value`.
1873    ///
1874    /// # Examples
1875    ///
1876    /// ```
1877    /// use chroma_types::operator::Key;
1878    ///
1879    /// let filter = Key::field("status").ne("deleted");
1880    /// let filter = Key::field("count").ne(0);
1881    /// ```
1882    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883        Where::Metadata(MetadataExpression {
1884            key: self.to_string(),
1885            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886        })
1887    }
1888
1889    /// Creates a greater-than filter: `field > value` (numeric only).
1890    ///
1891    /// # Examples
1892    ///
1893    /// ```
1894    /// use chroma_types::operator::Key;
1895    ///
1896    /// let filter = Key::field("score").gt(0.5);
1897    /// let filter = Key::field("year").gt(2020);
1898    /// ```
1899    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900        Where::Metadata(MetadataExpression {
1901            key: self.to_string(),
1902            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903        })
1904    }
1905
1906    /// Creates a greater-than-or-equal filter: `field >= value` (numeric only).
1907    ///
1908    /// # Examples
1909    ///
1910    /// ```
1911    /// use chroma_types::operator::Key;
1912    ///
1913    /// let filter = Key::field("score").gte(0.5);
1914    /// let filter = Key::field("year").gte(2020);
1915    /// ```
1916    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917        Where::Metadata(MetadataExpression {
1918            key: self.to_string(),
1919            comparison: MetadataComparison::Primitive(
1920                PrimitiveOperator::GreaterThanOrEqual,
1921                value.into(),
1922            ),
1923        })
1924    }
1925
1926    /// Creates a less-than filter: `field < value` (numeric only).
1927    ///
1928    /// # Examples
1929    ///
1930    /// ```
1931    /// use chroma_types::operator::Key;
1932    ///
1933    /// let filter = Key::field("score").lt(0.9);
1934    /// let filter = Key::field("year").lt(2025);
1935    /// ```
1936    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937        Where::Metadata(MetadataExpression {
1938            key: self.to_string(),
1939            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940        })
1941    }
1942
1943    /// Creates a less-than-or-equal filter: `field <= value` (numeric only).
1944    ///
1945    /// # Examples
1946    ///
1947    /// ```
1948    /// use chroma_types::operator::Key;
1949    ///
1950    /// let filter = Key::field("score").lte(0.9);
1951    /// let filter = Key::field("year").lte(2024);
1952    /// ```
1953    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954        Where::Metadata(MetadataExpression {
1955            key: self.to_string(),
1956            comparison: MetadataComparison::Primitive(
1957                PrimitiveOperator::LessThanOrEqual,
1958                value.into(),
1959            ),
1960        })
1961    }
1962
1963    /// Creates a set membership filter: `field IN values`.
1964    ///
1965    /// Accepts any iterator (Vec, array, slice, etc.).
1966    ///
1967    /// # Examples
1968    ///
1969    /// ```
1970    /// use chroma_types::operator::Key;
1971    ///
1972    /// // With Vec
1973    /// let filter = Key::field("year").is_in(vec![2023, 2024, 2025]);
1974    ///
1975    /// // With array
1976    /// let filter = Key::field("category").is_in(["tech", "science", "math"]);
1977    ///
1978    /// // With owned strings
1979    /// let categories = vec!["tech".to_string(), "science".to_string()];
1980    /// let filter = Key::field("category").is_in(categories);
1981    /// ```
1982    pub fn is_in<I, T>(self, values: I) -> Where
1983    where
1984        I: IntoIterator<Item = T>,
1985        Vec<T>: Into<MetadataSetValue>,
1986    {
1987        let vec: Vec<T> = values.into_iter().collect();
1988        Where::Metadata(MetadataExpression {
1989            key: self.to_string(),
1990            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991        })
1992    }
1993
1994    /// Creates a set exclusion filter: `field NOT IN values`.
1995    ///
1996    /// Accepts any iterator (Vec, array, slice, etc.).
1997    ///
1998    /// # Examples
1999    ///
2000    /// ```
2001    /// use chroma_types::operator::Key;
2002    ///
2003    /// // Exclude deleted and archived
2004    /// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
2005    ///
2006    /// // Exclude specific years
2007    /// let filter = Key::field("year").not_in(vec![2019, 2020]);
2008    /// ```
2009    pub fn not_in<I, T>(self, values: I) -> Where
2010    where
2011        I: IntoIterator<Item = T>,
2012        Vec<T>: Into<MetadataSetValue>,
2013    {
2014        let vec: Vec<T> = values.into_iter().collect();
2015        Where::Metadata(MetadataExpression {
2016            key: self.to_string(),
2017            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018        })
2019    }
2020
2021    /// Creates a document substring filter (case-sensitive).
2022    ///
2023    /// Only valid on `Key::Document`. Pattern must have at least 3 literal
2024    /// characters for accurate results.
2025    ///
2026    /// For metadata array contains, use [`contains_value`](Key::contains_value).
2027    ///
2028    /// # Examples
2029    ///
2030    /// ```
2031    /// use chroma_types::operator::Key;
2032    ///
2033    /// let filter = Key::Document.contains("machine learning");
2034    /// let filter = Key::Document.contains("API");
2035    /// ```
2036    pub fn contains<S: Into<String>>(self, text: S) -> Where {
2037        Where::Document(DocumentExpression {
2038            operator: DocumentOperator::Contains,
2039            pattern: text.into(),
2040        })
2041    }
2042
2043    /// Creates a negative document substring filter (case-sensitive).
2044    ///
2045    /// Only valid on `Key::Document`.
2046    ///
2047    /// For metadata array not-contains, use
2048    /// [`not_contains_value`](Key::not_contains_value).
2049    ///
2050    /// # Examples
2051    ///
2052    /// ```
2053    /// use chroma_types::operator::Key;
2054    ///
2055    /// let filter = Key::Document.not_contains("deprecated");
2056    /// let filter = Key::Document.not_contains("beta");
2057    /// ```
2058    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2059        Where::Document(DocumentExpression {
2060            operator: DocumentOperator::NotContains,
2061            pattern: text.into(),
2062        })
2063    }
2064
2065    /// Checks whether a metadata array field contains the given scalar value.
2066    ///
2067    /// # Examples
2068    ///
2069    /// ```
2070    /// use chroma_types::operator::Key;
2071    ///
2072    /// let filter = Key::field("tags").contains_value("action");
2073    /// let filter = Key::field("scores").contains_value(42);
2074    /// let filter = Key::field("ratings").contains_value(4.5);
2075    /// let filter = Key::field("flags").contains_value(true);
2076    /// ```
2077    pub fn contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2078        Where::Metadata(MetadataExpression {
2079            key: self.to_string(),
2080            comparison: MetadataComparison::ArrayContains(ContainsOperator::Contains, value.into()),
2081        })
2082    }
2083
2084    /// Checks that a metadata array field does **not** contain the given scalar
2085    /// value.
2086    ///
2087    /// # Examples
2088    ///
2089    /// ```
2090    /// use chroma_types::operator::Key;
2091    ///
2092    /// let filter = Key::field("tags").not_contains_value("draft");
2093    /// let filter = Key::field("scores").not_contains_value(0);
2094    /// ```
2095    pub fn not_contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2096        Where::Metadata(MetadataExpression {
2097            key: self.to_string(),
2098            comparison: MetadataComparison::ArrayContains(
2099                ContainsOperator::NotContains,
2100                value.into(),
2101            ),
2102        })
2103    }
2104
2105    /// Creates a regex filter (case-sensitive, document content only).
2106    ///
2107    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2108    /// 3 literal characters for accurate results.
2109    ///
2110    /// # Examples
2111    ///
2112    /// ```
2113    /// use chroma_types::operator::Key;
2114    ///
2115    /// // Match whole word "API"
2116    /// let filter = Key::Document.regex(r"\bAPI\b");
2117    ///
2118    /// // Match version pattern
2119    /// let filter = Key::Document.regex(r"v\d+\.\d+\.\d+");
2120    /// ```
2121    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2122        Where::Document(DocumentExpression {
2123            operator: DocumentOperator::Regex,
2124            pattern: pattern.into(),
2125        })
2126    }
2127
2128    /// Creates a negative regex filter (case-sensitive, document content only).
2129    ///
2130    /// Note: Currently only works with `Key::Document`.
2131    ///
2132    /// # Examples
2133    ///
2134    /// ```
2135    /// use chroma_types::operator::Key;
2136    ///
2137    /// // Exclude beta versions
2138    /// let filter = Key::Document.not_regex(r"beta");
2139    ///
2140    /// // Exclude test documents
2141    /// let filter = Key::Document.not_regex(r"\btest\b");
2142    /// ```
2143    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2144        Where::Document(DocumentExpression {
2145            operator: DocumentOperator::NotRegex,
2146            pattern: pattern.into(),
2147        })
2148    }
2149}
2150
2151/// Field selection for search results.
2152///
2153/// Specifies which fields to include in the results. IDs are always included.
2154///
2155/// # Fields
2156///
2157/// * `keys` - Set of keys to include in results
2158///
2159/// # Available Keys
2160///
2161/// * `Key::Document` - Document text content
2162/// * `Key::Embedding` - Vector embeddings
2163/// * `Key::Metadata` - All metadata fields
2164/// * `Key::Score` - Search scores
2165/// * `Key::field("name")` - Specific metadata field
2166///
2167/// # Performance
2168///
2169/// Selecting fewer fields improves performance by reducing data transfer:
2170/// - Minimal: IDs only (default, fastest)
2171/// - Moderate: Scores + specific metadata fields
2172/// - Heavy: Documents + embeddings (larger payloads)
2173///
2174/// # Examples
2175///
2176/// ```
2177/// use chroma_types::operator::{Select, Key};
2178/// use std::collections::HashSet;
2179///
2180/// // Select predefined fields
2181/// let select = Select {
2182///     keys: [Key::Document, Key::Score].into_iter().collect(),
2183/// };
2184///
2185/// // Select specific metadata fields
2186/// let select = Select {
2187///     keys: [
2188///         Key::field("title"),
2189///         Key::field("author"),
2190///         Key::Score,
2191///     ].into_iter().collect(),
2192/// };
2193///
2194/// // Select everything
2195/// let select = Select {
2196///     keys: [
2197///         Key::Document,
2198///         Key::Embedding,
2199///         Key::Metadata,
2200///         Key::Score,
2201///     ].into_iter().collect(),
2202/// };
2203/// ```
2204#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2205pub struct Select {
2206    #[serde(default)]
2207    pub keys: HashSet<Key>,
2208}
2209
2210impl TryFrom<chroma_proto::SelectOperator> for Select {
2211    type Error = QueryConversionError;
2212
2213    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2214        let keys = value
2215            .keys
2216            .into_iter()
2217            .map(|key| {
2218                // Try to deserialize each string as a Key
2219                serde_json::from_value(serde_json::Value::String(key))
2220                    .map_err(|_| QueryConversionError::field("keys"))
2221            })
2222            .collect::<Result<HashSet<_>, _>>()?;
2223
2224        Ok(Self { keys })
2225    }
2226}
2227
2228impl TryFrom<Select> for chroma_proto::SelectOperator {
2229    type Error = QueryConversionError;
2230
2231    fn try_from(value: Select) -> Result<Self, Self::Error> {
2232        let keys = value
2233            .keys
2234            .into_iter()
2235            .map(|key| {
2236                // Serialize each Key back to string
2237                serde_json::to_value(&key)
2238                    .ok()
2239                    .and_then(|v| v.as_str().map(String::from))
2240                    .ok_or(QueryConversionError::field("keys"))
2241            })
2242            .collect::<Result<Vec<_>, _>>()?;
2243
2244        Ok(Self { keys })
2245    }
2246}
2247
2248/// Aggregation function applied within each group.
2249///
2250/// Determines which records to keep from each group and their ordering.
2251///
2252/// # Variants
2253///
2254/// * `MinK` - Returns k records with minimum values (ascending order).
2255///   Use with `Key::Score` to get best matches (lower score = better in Chroma).
2256/// * `MaxK` - Returns k records with maximum values (descending order).
2257///
2258/// # Multi-level Ordering
2259///
2260/// The `keys` field supports multi-level ordering. Records are sorted by
2261/// the first key, then by the second key for ties, and so on.
2262///
2263/// # Examples
2264///
2265/// ```
2266/// use chroma_types::operator::{Aggregate, Key};
2267///
2268/// // Best 3 by score per group
2269/// let agg = Aggregate::MinK {
2270///     keys: vec![Key::Score],
2271///     k: 3,
2272/// };
2273///
2274/// // Best 3 by score, then by date for ties
2275/// let agg = Aggregate::MinK {
2276///     keys: vec![Key::Score, Key::field("date")],
2277///     k: 3,
2278/// };
2279///
2280/// // Top 5 by recency (highest date first)
2281/// let agg = Aggregate::MaxK {
2282///     keys: vec![Key::field("date")],
2283///     k: 5,
2284/// };
2285/// ```
2286#[derive(Clone, Debug, Deserialize, Serialize)]
2287pub enum Aggregate {
2288    /// Returns k records with minimum values (ascending order)
2289    #[serde(rename = "$min_k")]
2290    MinK {
2291        /// Keys for multi-level ordering
2292        keys: Vec<Key>,
2293        /// Number of records to return per group
2294        k: u32,
2295    },
2296    /// Returns k records with maximum values (descending order)
2297    #[serde(rename = "$max_k")]
2298    MaxK {
2299        /// Keys for multi-level ordering
2300        keys: Vec<Key>,
2301        /// Number of records to return per group
2302        k: u32,
2303    },
2304}
2305
2306/// Groups results by metadata keys and aggregates within each group.
2307///
2308/// Results are grouped by the specified metadata keys (like SQL GROUP BY),
2309/// then aggregated within each group using MinK or MaxK ordering.
2310/// The final output is flattened and sorted by score.
2311///
2312/// # Fields
2313///
2314/// * `keys` - Metadata keys to group by (composite grouping)
2315/// * `aggregate` - Aggregation function to apply within each group
2316///
2317/// # Behavior
2318///
2319/// * Missing metadata keys are treated as Null (forming their own group)
2320/// * Empty groups are omitted from results
2321/// * Final output is flattened (group structure not preserved)
2322/// * Results are sorted by score after aggregation
2323///
2324/// # Examples
2325///
2326/// ```
2327/// use chroma_types::operator::{GroupBy, Aggregate, Key};
2328///
2329/// // Top 3 documents per category
2330/// let group_by = GroupBy {
2331///     keys: vec![Key::field("category")],
2332///     aggregate: Some(Aggregate::MinK {
2333///         keys: vec![Key::Score],
2334///         k: 3,
2335///     }),
2336/// };
2337///
2338/// // Top 2 per (category, author) combination
2339/// let group_by = GroupBy {
2340///     keys: vec![Key::field("category"), Key::field("author")],
2341///     aggregate: Some(Aggregate::MinK {
2342///         keys: vec![Key::Score, Key::field("date")],
2343///         k: 2,
2344///     }),
2345/// };
2346/// ```
2347#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2348pub struct GroupBy {
2349    /// Metadata keys to group by
2350    #[serde(default)]
2351    pub keys: Vec<Key>,
2352    /// Aggregation to apply within each group (required when keys is non-empty)
2353    #[serde(default)]
2354    pub aggregate: Option<Aggregate>,
2355}
2356
2357impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2358    type Error = QueryConversionError;
2359
2360    fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2361        match value
2362            .aggregate
2363            .ok_or(QueryConversionError::field("aggregate"))?
2364        {
2365            chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2366                let keys = min_k.keys.into_iter().map(Key::from).collect();
2367                Ok(Aggregate::MinK { keys, k: min_k.k })
2368            }
2369            chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2370                let keys = max_k.keys.into_iter().map(Key::from).collect();
2371                Ok(Aggregate::MaxK { keys, k: max_k.k })
2372            }
2373        }
2374    }
2375}
2376
2377impl From<Aggregate> for chroma_proto::Aggregate {
2378    fn from(value: Aggregate) -> Self {
2379        let aggregate = match value {
2380            Aggregate::MinK { keys, k } => {
2381                chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2382                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2383                    k,
2384                })
2385            }
2386            Aggregate::MaxK { keys, k } => {
2387                chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2388                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2389                    k,
2390                })
2391            }
2392        };
2393
2394        chroma_proto::Aggregate {
2395            aggregate: Some(aggregate),
2396        }
2397    }
2398}
2399
2400impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2401    type Error = QueryConversionError;
2402
2403    fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2404        let keys = value.keys.into_iter().map(Key::from).collect();
2405        let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2406
2407        Ok(Self { keys, aggregate })
2408    }
2409}
2410
2411impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2412    type Error = QueryConversionError;
2413
2414    fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2415        let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2416        let aggregate = value.aggregate.map(Into::into);
2417
2418        Ok(Self { keys, aggregate })
2419    }
2420}
2421
2422/// A single search result record.
2423///
2424/// Contains the document ID and optionally document content, embeddings, metadata,
2425/// and search score based on what was selected in the search query.
2426///
2427/// # Fields
2428///
2429/// * `id` - Document ID (always present)
2430/// * `document` - Document text content (if selected)
2431/// * `embedding` - Vector embedding (if selected)
2432/// * `metadata` - Document metadata (if selected)
2433/// * `score` - Search score (present when ranking is used, lower = better match)
2434///
2435/// # Examples
2436///
2437/// ```
2438/// use chroma_types::operator::SearchRecord;
2439///
2440/// fn process_results(records: Vec<SearchRecord>) {
2441///     for record in records {
2442///         println!("ID: {}", record.id);
2443///
2444///         if let Some(score) = record.score {
2445///             println!("  Score: {:.3}", score);
2446///         }
2447///
2448///         if let Some(doc) = record.document {
2449///             println!("  Document: {}", doc);
2450///         }
2451///
2452///         if let Some(meta) = record.metadata {
2453///             println!("  Metadata: {:?}", meta);
2454///         }
2455///     }
2456/// }
2457/// ```
2458#[derive(Clone, Debug, Deserialize, Serialize)]
2459#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2460pub struct SearchRecord {
2461    pub id: String,
2462    pub document: Option<String>,
2463    pub embedding: Option<Vec<f32>>,
2464    pub metadata: Option<Metadata>,
2465    pub score: Option<f32>,
2466}
2467
2468impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2469    type Error = QueryConversionError;
2470
2471    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2472        Ok(Self {
2473            id: value.id,
2474            document: value.document,
2475            embedding: value
2476                .embedding
2477                .map(|vec| vec.try_into().map(|(v, _)| v))
2478                .transpose()?,
2479            metadata: value.metadata.map(TryInto::try_into).transpose()?,
2480            score: value.score,
2481        })
2482    }
2483}
2484
2485impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2486    type Error = QueryConversionError;
2487
2488    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2489        Ok(Self {
2490            id: value.id,
2491            document: value.document,
2492            embedding: value
2493                .embedding
2494                .map(|embedding| {
2495                    let embedding_dimension = embedding.len();
2496                    chroma_proto::Vector::try_from((
2497                        embedding,
2498                        ScalarEncoding::FLOAT32,
2499                        embedding_dimension,
2500                    ))
2501                })
2502                .transpose()?,
2503            metadata: value.metadata.map(Into::into),
2504            score: value.score,
2505        })
2506    }
2507}
2508
2509/// Results for a single search payload.
2510///
2511/// Contains all matching records for one search query.
2512///
2513/// # Fields
2514///
2515/// * `records` - Vector of search records, ordered by score (ascending)
2516///
2517/// # Examples
2518///
2519/// ```
2520/// use chroma_types::operator::{SearchPayloadResult, SearchRecord};
2521///
2522/// fn process_search_result(result: SearchPayloadResult) {
2523///     println!("Found {} results", result.records.len());
2524///
2525///     for (i, record) in result.records.iter().enumerate() {
2526///         println!("{}. {} (score: {:?})", i + 1, record.id, record.score);
2527///     }
2528/// }
2529/// ```
2530#[derive(Clone, Debug, Default)]
2531pub struct SearchPayloadResult {
2532    pub records: Vec<SearchRecord>,
2533}
2534
2535impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2536    type Error = QueryConversionError;
2537
2538    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2539        Ok(Self {
2540            records: value
2541                .records
2542                .into_iter()
2543                .map(TryInto::try_into)
2544                .collect::<Result<_, _>>()?,
2545        })
2546    }
2547}
2548
2549impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2550    type Error = QueryConversionError;
2551
2552    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2553        Ok(Self {
2554            records: value
2555                .records
2556                .into_iter()
2557                .map(TryInto::try_into)
2558                .collect::<Result<Vec<_>, _>>()?,
2559        })
2560    }
2561}
2562
2563/// Results from a batch search operation.
2564///
2565/// Contains results for each search payload in the batch, maintaining the same order
2566/// as the input searches.
2567///
2568/// # Fields
2569///
2570/// * `results` - Results for each search payload (indexed by search position)
2571/// * `pulled_log_bytes` - Total bytes pulled from log (for internal metrics)
2572///
2573/// # Examples
2574///
2575/// ## Single search
2576///
2577/// ```
2578/// use chroma_types::operator::SearchResult;
2579///
2580/// fn process_single_search(result: SearchResult) {
2581///     // Single search, so results[0] contains our records
2582///     let records = &result.results[0].records;
2583///
2584///     for record in records {
2585///         println!("{}: score={:?}", record.id, record.score);
2586///     }
2587/// }
2588/// ```
2589///
2590/// ## Batch search
2591///
2592/// ```
2593/// use chroma_types::operator::SearchResult;
2594///
2595/// fn process_batch_search(result: SearchResult) {
2596///     // Multiple searches in batch
2597///     for (i, search_result) in result.results.iter().enumerate() {
2598///         println!("\nSearch {}:", i + 1);
2599///         for record in &search_result.records {
2600///             println!("  {}: score={:?}", record.id, record.score);
2601///         }
2602///     }
2603/// }
2604/// ```
2605#[derive(Clone, Debug)]
2606pub struct SearchResult {
2607    pub results: Vec<SearchPayloadResult>,
2608    pub pulled_log_bytes: u64,
2609}
2610
2611impl SearchResult {
2612    pub fn size_bytes(&self) -> u64 {
2613        self.results
2614            .iter()
2615            .flat_map(|result| {
2616                result.records.iter().map(|record| {
2617                    (record.id.len()
2618                        + record
2619                            .document
2620                            .as_ref()
2621                            .map(|doc| doc.len())
2622                            .unwrap_or_default()
2623                        + record
2624                            .embedding
2625                            .as_ref()
2626                            .map(|emb| size_of_val(&emb[..]))
2627                            .unwrap_or_default()
2628                        + record
2629                            .metadata
2630                            .as_ref()
2631                            .map(logical_size_of_metadata)
2632                            .unwrap_or_default()
2633                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
2634                        as u64
2635                })
2636            })
2637            .sum()
2638    }
2639}
2640
2641impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2642    type Error = QueryConversionError;
2643
2644    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2645        Ok(Self {
2646            results: value
2647                .results
2648                .into_iter()
2649                .map(TryInto::try_into)
2650                .collect::<Result<_, _>>()?,
2651            pulled_log_bytes: value.pulled_log_bytes,
2652        })
2653    }
2654}
2655
2656impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2657    type Error = QueryConversionError;
2658
2659    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2660        Ok(Self {
2661            results: value
2662                .results
2663                .into_iter()
2664                .map(TryInto::try_into)
2665                .collect::<Result<Vec<_>, _>>()?,
2666            pulled_log_bytes: value.pulled_log_bytes,
2667        })
2668    }
2669}
2670
2671/// Reciprocal Rank Fusion (RRF) - combines multiple ranking strategies.
2672///
2673/// RRF is ideal for hybrid search where you want to merge results from different
2674/// ranking methods (e.g., dense and sparse embeddings) with different score scales.
2675/// It uses rank positions instead of raw scores, making it scale-agnostic.
2676///
2677/// # Formula
2678///
2679/// ```text
2680/// score = -Σ(weight_i / (k + rank_i))
2681/// ```
2682///
2683/// Where:
2684/// - `weight_i` = weight for ranking i (default: 1.0)
2685/// - `rank_i` = rank position from ranking i (0, 1, 2...)
2686/// - `k` = smoothing parameter (default: 60)
2687///
2688/// Score is negative because Chroma uses ascending order (lower = better).
2689///
2690/// # Arguments
2691///
2692/// * `ranks` - List of ranking expressions (must have `return_rank=true`)
2693/// * `k` - Smoothing parameter (None = 60). Higher values reduce emphasis on top ranks.
2694/// * `weights` - Weight for each ranking (None = all 1.0)
2695/// * `normalize` - If true, normalize weights to sum to 1.0
2696///
2697/// # Returns
2698///
2699/// A combined RankExpr or an error if:
2700/// - `ranks` is empty
2701/// - `weights` length doesn't match `ranks` length
2702/// - `weights` sum to zero when normalizing
2703///
2704/// # Examples
2705///
2706/// ## Basic RRF with default parameters
2707///
2708/// ```
2709/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2710///
2711/// let dense = RankExpr::Knn {
2712///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2713///     key: Key::Embedding,
2714///     limit: 200,
2715///     default: None,
2716///     return_rank: true, // Required for RRF
2717/// };
2718///
2719/// let sparse = RankExpr::Knn {
2720///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2721///     key: Key::field("sparse_embedding"),
2722///     limit: 200,
2723///     default: None,
2724///     return_rank: true, // Required for RRF
2725/// };
2726///
2727/// // Equal weights, k=60 (defaults)
2728/// let combined = rrf(vec![dense, sparse], None, None, false).unwrap();
2729/// ```
2730///
2731/// ## RRF with custom weights
2732///
2733/// ```
2734/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2735///
2736/// # let dense = RankExpr::Knn {
2737/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2738/// #     key: Key::Embedding,
2739/// #     limit: 200,
2740/// #     default: None,
2741/// #     return_rank: true,
2742/// # };
2743/// # let sparse = RankExpr::Knn {
2744/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2745/// #     key: Key::field("sparse_embedding"),
2746/// #     limit: 200,
2747/// #     default: None,
2748/// #     return_rank: true,
2749/// # };
2750/// // 70% dense, 30% sparse
2751/// let combined = rrf(
2752///     vec![dense, sparse],
2753///     Some(60),
2754///     Some(vec![0.7, 0.3]),
2755///     false,
2756/// ).unwrap();
2757/// ```
2758///
2759/// ## RRF with normalized weights
2760///
2761/// ```
2762/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2763///
2764/// # let dense = RankExpr::Knn {
2765/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2766/// #     key: Key::Embedding,
2767/// #     limit: 200,
2768/// #     default: None,
2769/// #     return_rank: true,
2770/// # };
2771/// # let sparse = RankExpr::Knn {
2772/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2773/// #     key: Key::field("sparse_embedding"),
2774/// #     limit: 200,
2775/// #     default: None,
2776/// #     return_rank: true,
2777/// # };
2778/// // Weights [75, 25] normalized to [0.75, 0.25]
2779/// let combined = rrf(
2780///     vec![dense, sparse],
2781///     Some(60),
2782///     Some(vec![75.0, 25.0]),
2783///     true, // normalize
2784/// ).unwrap();
2785/// ```
2786///
2787/// ## Adjusting the k parameter
2788///
2789/// ```
2790/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2791///
2792/// # let dense = RankExpr::Knn {
2793/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2794/// #     key: Key::Embedding,
2795/// #     limit: 200,
2796/// #     default: None,
2797/// #     return_rank: true,
2798/// # };
2799/// # let sparse = RankExpr::Knn {
2800/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2801/// #     key: Key::field("sparse_embedding"),
2802/// #     limit: 200,
2803/// #     default: None,
2804/// #     return_rank: true,
2805/// # };
2806/// // Small k (10) = heavy emphasis on top ranks
2807/// let top_heavy = rrf(vec![dense.clone(), sparse.clone()], Some(10), None, false).unwrap();
2808///
2809/// // Default k (60) = balanced
2810/// let balanced = rrf(vec![dense.clone(), sparse.clone()], Some(60), None, false).unwrap();
2811///
2812/// // Large k (200) = more uniform weighting
2813/// let uniform = rrf(vec![dense, sparse], Some(200), None, false).unwrap();
2814/// ```
2815pub fn rrf(
2816    ranks: Vec<RankExpr>,
2817    k: Option<u32>,
2818    weights: Option<Vec<f32>>,
2819    normalize: bool,
2820) -> Result<RankExpr, QueryConversionError> {
2821    let k = k.unwrap_or(60);
2822
2823    if ranks.is_empty() {
2824        return Err(QueryConversionError::validation(
2825            "RRF requires at least one rank expression",
2826        ));
2827    }
2828
2829    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2830
2831    if weights.len() != ranks.len() {
2832        return Err(QueryConversionError::validation(format!(
2833            "RRF weights length ({}) must match ranks length ({})",
2834            weights.len(),
2835            ranks.len()
2836        )));
2837    }
2838
2839    let weights = if normalize {
2840        let sum: f32 = weights.iter().sum();
2841        if sum == 0.0 {
2842            return Err(QueryConversionError::validation(
2843                "RRF weights sum to zero, cannot normalize",
2844            ));
2845        }
2846        weights.into_iter().map(|w| w / sum).collect()
2847    } else {
2848        weights
2849    };
2850
2851    let terms: Vec<RankExpr> = weights
2852        .into_iter()
2853        .zip(ranks)
2854        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2855        .collect();
2856
2857    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
2858    // Using unwrap_or_else as defensive programming to avoid panic.
2859    let sum = terms
2860        .into_iter()
2861        .reduce(|a, b| a + b)
2862        .unwrap_or(RankExpr::Value(0.0));
2863    Ok(-sum)
2864}
2865
2866#[cfg(test)]
2867mod tests {
2868    use super::*;
2869
2870    #[test]
2871    fn test_key_from_string() {
2872        // Test predefined keys
2873        assert_eq!(Key::from("#document"), Key::Document);
2874        assert_eq!(Key::from("#embedding"), Key::Embedding);
2875        assert_eq!(Key::from("#metadata"), Key::Metadata);
2876        assert_eq!(Key::from("#score"), Key::Score);
2877
2878        // Test metadata field keys
2879        assert_eq!(
2880            Key::from("custom_field"),
2881            Key::MetadataField("custom_field".to_string())
2882        );
2883        assert_eq!(
2884            Key::from("author"),
2885            Key::MetadataField("author".to_string())
2886        );
2887
2888        // Test String variant
2889        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2890        assert_eq!(
2891            Key::from("year".to_string()),
2892            Key::MetadataField("year".to_string())
2893        );
2894    }
2895
2896    #[test]
2897    fn test_query_vector_dense_proto_conversion() {
2898        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2899        let query_vector = QueryVector::Dense(dense_vec.clone());
2900
2901        // Convert to proto
2902        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2903
2904        // Convert back
2905        let converted: QueryVector = proto.try_into().unwrap();
2906
2907        assert_eq!(converted, query_vector);
2908        if let QueryVector::Dense(v) = converted {
2909            assert_eq!(v, dense_vec);
2910        } else {
2911            panic!("Expected dense vector");
2912        }
2913    }
2914
2915    #[test]
2916    fn test_query_vector_sparse_proto_conversion() {
2917        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2918        let query_vector = QueryVector::Sparse(sparse.clone());
2919
2920        // Convert to proto
2921        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2922
2923        // Convert back
2924        let converted: QueryVector = proto.try_into().unwrap();
2925
2926        assert_eq!(converted, query_vector);
2927        if let QueryVector::Sparse(s) = converted {
2928            assert_eq!(s, sparse);
2929        } else {
2930            panic!("Expected sparse vector");
2931        }
2932    }
2933
2934    #[test]
2935    fn test_filter_json_deserialization() {
2936        // For the new search API, deserialization treats the entire JSON as a where clause
2937
2938        // Test 1: Simple direct metadata comparison
2939        let simple_where = r#"{"author": "John Doe"}"#;
2940        let filter: Filter = serde_json::from_str(simple_where).unwrap();
2941        assert_eq!(filter.query_ids, None);
2942        assert!(filter.where_clause.is_some());
2943
2944        // Test 2: ID filter using #id with $in operator
2945        let id_filter_json = serde_json::json!({
2946            "#id": {
2947                "$in": ["doc1", "doc2", "doc3"]
2948            }
2949        });
2950        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2951        assert_eq!(filter.query_ids, None);
2952        assert!(filter.where_clause.is_some());
2953
2954        // Test 3: Complex nested expression with AND, OR, and various operators
2955        let complex_json = serde_json::json!({
2956            "$and": [
2957                {
2958                    "#id": {
2959                        "$in": ["doc1", "doc2", "doc3"]
2960                    }
2961                },
2962                {
2963                    "$or": [
2964                        {
2965                            "author": {
2966                                "$eq": "John Doe"
2967                            }
2968                        },
2969                        {
2970                            "author": {
2971                                "$eq": "Jane Smith"
2972                            }
2973                        }
2974                    ]
2975                },
2976                {
2977                    "year": {
2978                        "$gte": 2020
2979                    }
2980                },
2981                {
2982                    "tags": {
2983                        "$contains": "machine-learning"
2984                    }
2985                }
2986            ]
2987        });
2988
2989        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2990        assert_eq!(filter.query_ids, None);
2991        assert!(filter.where_clause.is_some());
2992
2993        // Verify the structure
2994        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2995            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2996            assert_eq!(composite.children.len(), 4);
2997
2998            // Check that the second child is an OR
2999            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
3000                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
3001                assert_eq!(or_composite.children.len(), 2);
3002            } else {
3003                panic!("Expected OR composite in second child");
3004            }
3005        } else {
3006            panic!("Expected AND composite where clause");
3007        }
3008
3009        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
3010        let mixed_operators_json = serde_json::json!({
3011            "$and": [
3012                {
3013                    "status": {
3014                        "$ne": "deleted"
3015                    }
3016                },
3017                {
3018                    "score": {
3019                        "$gt": 0.5
3020                    }
3021                },
3022                {
3023                    "score": {
3024                        "$lt": 0.9
3025                    }
3026                },
3027                {
3028                    "priority": {
3029                        "$lte": 10
3030                    }
3031                }
3032            ]
3033        });
3034
3035        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
3036        assert_eq!(filter.query_ids, None);
3037        assert!(filter.where_clause.is_some());
3038
3039        // Test 5: Deeply nested expression
3040        let deeply_nested_json = serde_json::json!({
3041            "$or": [
3042                {
3043                    "$and": [
3044                        {
3045                            "#id": {
3046                                "$in": ["id1", "id2"]
3047                            }
3048                        },
3049                        {
3050                            "$or": [
3051                                {
3052                                    "category": "tech"
3053                                },
3054                                {
3055                                    "category": "science"
3056                                }
3057                            ]
3058                        }
3059                    ]
3060                },
3061                {
3062                    "$and": [
3063                        {
3064                            "author": "Admin"
3065                        },
3066                        {
3067                            "published": true
3068                        }
3069                    ]
3070                }
3071            ]
3072        });
3073
3074        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3075        assert_eq!(filter.query_ids, None);
3076        assert!(filter.where_clause.is_some());
3077
3078        // Verify it's an OR at the top level
3079        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3080            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3081            assert_eq!(composite.children.len(), 2);
3082
3083            // Both children should be AND composites
3084            for child in &composite.children {
3085                if let crate::metadata::Where::Composite(and_composite) = child {
3086                    assert_eq!(
3087                        and_composite.operator,
3088                        crate::metadata::BooleanOperator::And
3089                    );
3090                } else {
3091                    panic!("Expected AND composite in OR children");
3092                }
3093            }
3094        } else {
3095            panic!("Expected OR composite at top level");
3096        }
3097
3098        // Test 6: Single ID filter (edge case)
3099        let single_id_json = serde_json::json!({
3100            "#id": {
3101                "$eq": "single-doc-id"
3102            }
3103        });
3104
3105        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3106        assert_eq!(filter.query_ids, None);
3107        assert!(filter.where_clause.is_some());
3108
3109        // Test 7: Empty object should create empty filter
3110        let empty_json = serde_json::json!({});
3111        let filter: Filter = serde_json::from_value(empty_json).unwrap();
3112        assert_eq!(filter.query_ids, None);
3113        // Empty object results in None where_clause
3114        assert_eq!(filter.where_clause, None);
3115
3116        // Test 8: Combining #id filter with $not_contains and numeric comparisons
3117        let advanced_json = serde_json::json!({
3118            "$and": [
3119                {
3120                    "#id": {
3121                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3122                    }
3123                },
3124                {
3125                    "tags": {
3126                        "$not_contains": "deprecated"
3127                    }
3128                },
3129                {
3130                    "$or": [
3131                        {
3132                            "$and": [
3133                                {
3134                                    "confidence": {
3135                                        "$gte": 0.8
3136                                    }
3137                                },
3138                                {
3139                                    "verified": true
3140                                }
3141                            ]
3142                        },
3143                        {
3144                            "$and": [
3145                                {
3146                                    "confidence": {
3147                                        "$gte": 0.6
3148                                    }
3149                                },
3150                                {
3151                                    "confidence": {
3152                                        "$lt": 0.8
3153                                    }
3154                                },
3155                                {
3156                                    "reviews": {
3157                                        "$gte": 5
3158                                    }
3159                                }
3160                            ]
3161                        }
3162                    ]
3163                }
3164            ]
3165        });
3166
3167        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3168        assert_eq!(filter.query_ids, None);
3169        assert!(filter.where_clause.is_some());
3170
3171        // Verify top-level structure
3172        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3173            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3174            assert_eq!(composite.children.len(), 3);
3175        } else {
3176            panic!("Expected AND composite at top level");
3177        }
3178    }
3179
3180    #[test]
3181    fn test_limit_json_serialization() {
3182        let limit = Limit {
3183            offset: 10,
3184            limit: Some(20),
3185        };
3186
3187        let json = serde_json::to_string(&limit).unwrap();
3188        let deserialized: Limit = serde_json::from_str(&json).unwrap();
3189
3190        assert_eq!(deserialized.offset, limit.offset);
3191        assert_eq!(deserialized.limit, limit.limit);
3192    }
3193
3194    #[test]
3195    fn test_query_vector_json_serialization() {
3196        // Test dense vector
3197        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3198        let json = serde_json::to_string(&dense).unwrap();
3199        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3200        assert_eq!(deserialized, dense);
3201
3202        // Test sparse vector
3203        let sparse =
3204            QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3205        let json = serde_json::to_string(&sparse).unwrap();
3206        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3207        assert_eq!(deserialized, sparse);
3208    }
3209
3210    #[test]
3211    fn test_select_key_json_serialization() {
3212        use std::collections::HashSet;
3213
3214        // Test predefined keys
3215        let doc_key = Key::Document;
3216        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3217
3218        let embed_key = Key::Embedding;
3219        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3220
3221        let meta_key = Key::Metadata;
3222        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3223
3224        let score_key = Key::Score;
3225        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3226
3227        // Test metadata key
3228        let custom_key = Key::MetadataField("custom_key".to_string());
3229        assert_eq!(
3230            serde_json::to_string(&custom_key).unwrap(),
3231            "\"custom_key\""
3232        );
3233
3234        // Test deserialization
3235        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3236        assert!(matches!(deserialized, Key::Document));
3237
3238        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3239        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3240
3241        // Test Select struct with multiple keys
3242        let mut keys = HashSet::new();
3243        keys.insert(Key::Document);
3244        keys.insert(Key::Embedding);
3245        keys.insert(Key::MetadataField("author".to_string()));
3246
3247        let select = Select { keys };
3248        let json = serde_json::to_string(&select).unwrap();
3249        let deserialized: Select = serde_json::from_str(&json).unwrap();
3250
3251        assert_eq!(deserialized.keys.len(), 3);
3252        assert!(deserialized.keys.contains(&Key::Document));
3253        assert!(deserialized.keys.contains(&Key::Embedding));
3254        assert!(deserialized
3255            .keys
3256            .contains(&Key::MetadataField("author".to_string())));
3257    }
3258
3259    #[test]
3260    fn test_merge_basic_integers() {
3261        use std::cmp::Reverse;
3262
3263        let merge = Merge { k: 5 };
3264
3265        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
3266        let input = vec![
3267            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3268            vec![Reverse(2), Reverse(5), Reverse(8)],
3269            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3270        ];
3271
3272        let result = merge.merge(input);
3273
3274        // Should get top-5 smallest values (largest Reverse values)
3275        assert_eq!(result.len(), 5);
3276        assert_eq!(
3277            result,
3278            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3279        );
3280    }
3281
3282    #[test]
3283    fn test_merge_u32_descending() {
3284        let merge = Merge { k: 6 };
3285
3286        // Regular u32 in descending order (largest first)
3287        let input = vec![
3288            vec![100u32, 75, 50, 25],
3289            vec![90, 60, 30],
3290            vec![95, 85, 70, 40, 10],
3291        ];
3292
3293        let result = merge.merge(input);
3294
3295        // Should get top-6 largest u32 values
3296        assert_eq!(result.len(), 6);
3297        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3298    }
3299
3300    #[test]
3301    fn test_merge_i32_descending() {
3302        let merge = Merge { k: 5 };
3303
3304        // i32 values in descending order (including negatives)
3305        let input = vec![
3306            vec![50i32, 10, -10, -50],
3307            vec![30, 0, -30],
3308            vec![40, 20, -20, -40],
3309        ];
3310
3311        let result = merge.merge(input);
3312
3313        // Should get top-5 largest i32 values
3314        assert_eq!(result.len(), 5);
3315        assert_eq!(result, vec![50, 40, 30, 20, 10]);
3316    }
3317
3318    #[test]
3319    fn test_merge_with_duplicates() {
3320        let merge = Merge { k: 10 };
3321
3322        // Input with duplicates using regular u32 in descending order
3323        let input = vec![
3324            vec![100u32, 80, 80, 60, 40],
3325            vec![90, 80, 50, 30],
3326            vec![100, 70, 60, 20],
3327        ];
3328
3329        let result = merge.merge(input);
3330
3331        // Duplicates should be removed
3332        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3333    }
3334
3335    #[test]
3336    fn test_merge_empty_vectors() {
3337        let merge = Merge { k: 5 };
3338
3339        // All empty with u32
3340        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3341        let result = merge.merge(input);
3342        assert_eq!(result.len(), 0);
3343
3344        // Some empty, some with data (u64)
3345        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3346        let result = merge.merge(input);
3347        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3348
3349        // Single non-empty vector (i32)
3350        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3351        let result = merge.merge(input);
3352        assert_eq!(result, vec![100, 50, 25]);
3353    }
3354
3355    #[test]
3356    fn test_merge_k_boundary_conditions() {
3357        // k = 0 with u32
3358        let merge = Merge { k: 0 };
3359        let input = vec![vec![100u32, 50], vec![75, 25]];
3360        let result = merge.merge(input);
3361        assert_eq!(result.len(), 0);
3362
3363        // k = 1 with i64
3364        let merge = Merge { k: 1 };
3365        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3366        let result = merge.merge(input);
3367        assert_eq!(result, vec![1000]);
3368
3369        // k larger than total unique elements with u128
3370        let merge = Merge { k: 100 };
3371        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3372        let result = merge.merge(input);
3373        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3374    }
3375
3376    #[test]
3377    fn test_merge_with_strings() {
3378        let merge = Merge { k: 4 };
3379
3380        // Strings must be sorted in descending order (largest first) for the max heap merge
3381        let input = vec![
3382            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3383            vec!["elephant".to_string(), "banana".to_string()],
3384            vec!["fish".to_string(), "cat".to_string()],
3385        ];
3386
3387        let result = merge.merge(input);
3388
3389        // Should get top-4 lexicographically largest strings
3390        assert_eq!(
3391            result,
3392            vec![
3393                "zebra".to_string(),
3394                "fish".to_string(),
3395                "elephant".to_string(),
3396                "dog".to_string()
3397            ]
3398        );
3399    }
3400
3401    #[test]
3402    fn test_merge_with_custom_struct() {
3403        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3404        struct Score {
3405            value: i32,
3406            id: String,
3407        }
3408
3409        let merge = Merge { k: 3 };
3410
3411        // Custom structs sorted by value (descending), then by id
3412        let input = vec![
3413            vec![
3414                Score {
3415                    value: 100,
3416                    id: "a".to_string(),
3417                },
3418                Score {
3419                    value: 80,
3420                    id: "b".to_string(),
3421                },
3422                Score {
3423                    value: 60,
3424                    id: "c".to_string(),
3425                },
3426            ],
3427            vec![
3428                Score {
3429                    value: 90,
3430                    id: "d".to_string(),
3431                },
3432                Score {
3433                    value: 70,
3434                    id: "e".to_string(),
3435                },
3436            ],
3437            vec![
3438                Score {
3439                    value: 95,
3440                    id: "f".to_string(),
3441                },
3442                Score {
3443                    value: 85,
3444                    id: "g".to_string(),
3445                },
3446            ],
3447        ];
3448
3449        let result = merge.merge(input);
3450
3451        assert_eq!(result.len(), 3);
3452        assert_eq!(
3453            result[0],
3454            Score {
3455                value: 100,
3456                id: "a".to_string()
3457            }
3458        );
3459        assert_eq!(
3460            result[1],
3461            Score {
3462                value: 95,
3463                id: "f".to_string()
3464            }
3465        );
3466        assert_eq!(
3467            result[2],
3468            Score {
3469                value: 90,
3470                id: "d".to_string()
3471            }
3472        );
3473    }
3474
3475    #[test]
3476    fn test_merge_preserves_order() {
3477        use std::cmp::Reverse;
3478
3479        let merge = Merge { k: 10 };
3480
3481        // For Reverse, smaller inner values are "larger" in ordering
3482        // So vectors should be sorted with smallest inner values first
3483        let input = vec![
3484            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3485            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3486            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3487        ];
3488
3489        let result = merge.merge(input);
3490
3491        // Verify output maintains order - should be sorted by Reverse ordering
3492        // which means ascending inner values
3493        for i in 1..result.len() {
3494            assert!(
3495                result[i - 1] >= result[i],
3496                "Output should be in descending Reverse order"
3497            );
3498            assert!(
3499                result[i - 1].0 <= result[i].0,
3500                "Inner values should be in ascending order"
3501            );
3502        }
3503
3504        // Check we got the right elements
3505        assert_eq!(
3506            result,
3507            vec![
3508                Reverse(1),
3509                Reverse(2),
3510                Reverse(3),
3511                Reverse(4),
3512                Reverse(5),
3513                Reverse(6),
3514                Reverse(7),
3515                Reverse(8),
3516                Reverse(9),
3517                Reverse(10)
3518            ]
3519        );
3520    }
3521
3522    #[test]
3523    fn test_merge_single_vector() {
3524        let merge = Merge { k: 3 };
3525
3526        // Single vector input with u64
3527        let input = vec![vec![1000u64, 800, 600, 400, 200]];
3528
3529        let result = merge.merge(input);
3530
3531        assert_eq!(result, vec![1000, 800, 600]);
3532    }
3533
3534    #[test]
3535    fn test_merge_all_same_values() {
3536        let merge = Merge { k: 5 };
3537
3538        // All vectors contain the same value (using i16)
3539        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3540
3541        let result = merge.merge(input);
3542
3543        // Should deduplicate to single value
3544        assert_eq!(result, vec![42]);
3545    }
3546
3547    #[test]
3548    fn test_merge_mixed_types_sizes() {
3549        // Test with usize (common in real usage)
3550        let merge = Merge { k: 4 };
3551        let input = vec![
3552            vec![1000usize, 500, 100],
3553            vec![800, 300],
3554            vec![900, 600, 200],
3555        ];
3556        let result = merge.merge(input);
3557        assert_eq!(result, vec![1000, 900, 800, 600]);
3558
3559        // Test with negative integers (i32)
3560        let merge = Merge { k: 5 };
3561        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3562        let result = merge.merge(input);
3563        assert_eq!(result, vec![15, 10, 5, 0, -5]);
3564    }
3565
3566    #[test]
3567    fn test_aggregate_json_serialization() {
3568        // Test MinK serialization
3569        let min_k = Aggregate::MinK {
3570            keys: vec![Key::Score, Key::field("date")],
3571            k: 3,
3572        };
3573        let json = serde_json::to_value(&min_k).unwrap();
3574        assert!(json.get("$min_k").is_some());
3575        assert_eq!(json["$min_k"]["k"], 3);
3576
3577        // Test MinK deserialization
3578        let min_k_json = serde_json::json!({
3579            "$min_k": {
3580                "keys": ["#score", "date"],
3581                "k": 5
3582            }
3583        });
3584        let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3585        match deserialized {
3586            Aggregate::MinK { keys, k } => {
3587                assert_eq!(k, 5);
3588                assert_eq!(keys.len(), 2);
3589                assert_eq!(keys[0], Key::Score);
3590                assert_eq!(keys[1], Key::field("date"));
3591            }
3592            _ => panic!("Expected MinK"),
3593        }
3594
3595        // Test MaxK serialization
3596        let max_k = Aggregate::MaxK {
3597            keys: vec![Key::field("timestamp")],
3598            k: 10,
3599        };
3600        let json = serde_json::to_value(&max_k).unwrap();
3601        assert!(json.get("$max_k").is_some());
3602        assert_eq!(json["$max_k"]["k"], 10);
3603
3604        // Test MaxK deserialization
3605        let max_k_json = serde_json::json!({
3606            "$max_k": {
3607                "keys": ["timestamp"],
3608                "k": 2
3609            }
3610        });
3611        let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3612        match deserialized {
3613            Aggregate::MaxK { keys, k } => {
3614                assert_eq!(k, 2);
3615                assert_eq!(keys.len(), 1);
3616                assert_eq!(keys[0], Key::field("timestamp"));
3617            }
3618            _ => panic!("Expected MaxK"),
3619        }
3620    }
3621
3622    #[test]
3623    fn test_group_by_json_serialization() {
3624        // Test GroupBy with MinK
3625        let group_by = GroupBy {
3626            keys: vec![Key::field("category"), Key::field("author")],
3627            aggregate: Some(Aggregate::MinK {
3628                keys: vec![Key::Score],
3629                k: 3,
3630            }),
3631        };
3632
3633        let json = serde_json::to_value(&group_by).unwrap();
3634        assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3635        assert!(json["aggregate"]["$min_k"].is_object());
3636
3637        // Test roundtrip
3638        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3639        assert_eq!(deserialized.keys.len(), 2);
3640        assert_eq!(deserialized.keys[0], Key::field("category"));
3641        assert_eq!(deserialized.keys[1], Key::field("author"));
3642        assert!(deserialized.aggregate.is_some());
3643
3644        // Test empty GroupBy
3645        let empty_group_by = GroupBy::default();
3646        let json = serde_json::to_value(&empty_group_by).unwrap();
3647        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3648        assert!(deserialized.keys.is_empty());
3649        assert!(deserialized.aggregate.is_none());
3650
3651        // Test deserialization from JSON
3652        let json = serde_json::json!({
3653            "keys": ["category"],
3654            "aggregate": {
3655                "$max_k": {
3656                    "keys": ["#score", "priority"],
3657                    "k": 5
3658                }
3659            }
3660        });
3661        let group_by: GroupBy = serde_json::from_value(json).unwrap();
3662        assert_eq!(group_by.keys.len(), 1);
3663        assert_eq!(group_by.keys[0], Key::field("category"));
3664        match group_by.aggregate {
3665            Some(Aggregate::MaxK { keys, k }) => {
3666                assert_eq!(k, 5);
3667                assert_eq!(keys.len(), 2);
3668                assert_eq!(keys[0], Key::Score);
3669            }
3670            _ => panic!("Expected MaxK aggregate"),
3671        }
3672    }
3673}