chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::Hash,
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15    MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16    Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// Filter the search results.
122///
123/// Combines document ID filtering with metadata and document content predicates.
124/// For the Search API, use `where_clause` with Key expressions.
125///
126/// # Fields
127///
128/// * `query_ids` - Optional list of document IDs to filter (legacy, prefer Where expressions)
129/// * `where_clause` - Predicate on document metadata, content, or IDs
130///
131/// # Examples
132///
133/// ## Simple metadata filter
134///
135/// ```
136/// use chroma_types::operator::{Filter, Key};
137///
138/// let filter = Filter {
139///     query_ids: None,
140///     where_clause: Some(Key::field("status").eq("published")),
141/// };
142/// ```
143///
144/// ## Combined filters
145///
146/// ```
147/// use chroma_types::operator::{Filter, Key};
148///
149/// let filter = Filter {
150///     query_ids: None,
151///     where_clause: Some(
152///         Key::field("status").eq("published")
153///             & Key::field("year").gte(2020)
154///             & Key::field("category").is_in(vec!["tech", "science"])
155///     ),
156/// };
157/// ```
158///
159/// ## Document content filter
160///
161/// ```
162/// use chroma_types::operator::{Filter, Key};
163///
164/// let filter = Filter {
165///     query_ids: None,
166///     where_clause: Some(Key::Document.contains("machine learning")),
167/// };
168/// ```
169#[derive(Clone, Debug, Default)]
170pub struct Filter {
171    pub query_ids: Option<Vec<String>>,
172    pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        // For the search API, serialize directly as the where clause (or empty object if None)
181        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
182
183        match (&self.query_ids, &self.where_clause) {
184            (None, None) => {
185                // No filter at all - serialize empty object
186                let map = serializer.serialize_map(Some(0))?;
187                map.end()
188            }
189            (None, Some(where_clause)) => {
190                // Only where clause - serialize it directly
191                where_clause.serialize(serializer)
192            }
193            (Some(ids), None) => {
194                // Only query_ids - create Where clause: Key::ID.is_in(ids)
195                let id_where = Where::Metadata(MetadataExpression {
196                    key: "#id".to_string(),
197                    comparison: MetadataComparison::Set(
198                        SetOperator::In,
199                        MetadataSetValue::Str(ids.clone()),
200                    ),
201                });
202                id_where.serialize(serializer)
203            }
204            (Some(ids), Some(where_clause)) => {
205                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
206                let id_where = Where::Metadata(MetadataExpression {
207                    key: "#id".to_string(),
208                    comparison: MetadataComparison::Set(
209                        SetOperator::In,
210                        MetadataSetValue::Str(ids.clone()),
211                    ),
212                });
213                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214                combined.serialize(serializer)
215            }
216        }
217    }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        // For the new search API, the entire JSON is the where clause
226        let where_json = Value::deserialize(deserializer)?;
227        let where_clause =
228            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229                None
230            } else {
231                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232            };
233
234        Ok(Filter {
235            query_ids: None, // Always None for new search API
236            where_clause,
237        })
238    }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242    type Error = QueryConversionError;
243
244    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247        let where_clause = match (where_metadata, where_document) {
248            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249            (Some(w), None) | (None, Some(w)) => Some(w),
250            _ => None,
251        };
252
253        Ok(Self {
254            query_ids: value.ids.map(|uids| uids.ids),
255            where_clause,
256        })
257    }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261    type Error = QueryConversionError;
262
263    fn try_from(value: Filter) -> Result<Self, Self::Error> {
264        Ok(Self {
265            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267            where_document: None,
268        })
269    }
270}
271
272/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
273///
274/// # Parameters
275/// - `embedding`: The target embedding to search around
276/// - `fetch`: The number of records to fetch around the target
277#[derive(Clone, Debug)]
278pub struct Knn {
279    pub embedding: Vec<f32>,
280    pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284    fn from(value: KnnBatch) -> Self {
285        value
286            .embeddings
287            .into_iter()
288            .map(|embedding| Knn {
289                embedding,
290                fetch: value.fetch,
291            })
292            .collect()
293    }
294}
295
296/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
297///
298/// # Parameters
299/// - `embedding`: The target embedding to search around
300/// - `fetch`: The number of records to fetch around the target
301#[derive(Clone, Debug)]
302pub struct KnnBatch {
303    pub embeddings: Vec<Vec<f32>>,
304    pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308    type Error = QueryConversionError;
309
310    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311        Ok(Self {
312            embeddings: value
313                .embeddings
314                .into_iter()
315                .map(|vec| vec.try_into().map(|(v, _)| v))
316                .collect::<Result<_, _>>()?,
317            fetch: value.fetch,
318        })
319    }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323    type Error = QueryConversionError;
324
325    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326        Ok(Self {
327            embeddings: value
328                .embeddings
329                .into_iter()
330                .map(|embedding| {
331                    let dim = embedding.len();
332                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333                })
334                .collect::<Result<_, _>>()?,
335            fetch: value.fetch,
336        })
337    }
338}
339
340/// Pagination control for search results.
341///
342/// Controls how many results to return and how many to skip for pagination.
343///
344/// # Fields
345///
346/// * `offset` - Number of results to skip (default: 0)
347/// * `limit` - Maximum results to return (None = no limit)
348///
349/// # Examples
350///
351/// ```
352/// use chroma_types::operator::Limit;
353///
354/// // First page: results 0-9
355/// let limit = Limit {
356///     offset: 0,
357///     limit: Some(10),
358/// };
359///
360/// // Second page: results 10-19
361/// let limit = Limit {
362///     offset: 10,
363///     limit: Some(10),
364/// };
365///
366/// // No limit: all results
367/// let limit = Limit {
368///     offset: 0,
369///     limit: None,
370/// };
371/// ```
372#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374    #[serde(default)]
375    pub offset: u32,
376    #[serde(default)]
377    pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381    fn from(value: chroma_proto::LimitOperator) -> Self {
382        Self {
383            offset: value.offset,
384            limit: value.limit,
385        }
386    }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390    fn from(value: Limit) -> Self {
391        Self {
392            offset: value.offset,
393            limit: value.limit,
394        }
395    }
396}
397
398/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
399#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401    pub offset_id: u32,
402    pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406    fn eq(&self, other: &Self) -> bool {
407        self.offset_id.eq(&other.offset_id)
408    }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414    fn cmp(&self, other: &Self) -> Ordering {
415        self.measure
416            .total_cmp(&other.measure)
417            .then_with(|| self.offset_id.cmp(&other.offset_id))
418    }
419}
420
421impl PartialOrd for RecordMeasure {
422    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423        Some(self.cmp(other))
424    }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429    pub distances: Vec<RecordMeasure>,
430}
431
432/// The `Merge` operator selects the top records from the batch vectors of records
433/// which are all sorted in descending order. If the same record occurs multiple times
434/// only one copy will remain in the final result.
435///
436/// # Parameters
437/// - `k`: The total number of records to take after merge
438///
439/// # Usage
440/// It can be used to merge the query results from different operators
441#[derive(Clone, Debug)]
442pub struct Merge {
443    pub k: u32,
444}
445
446impl Merge {
447    pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450        let mut max_heap = batch_iters
451            .iter_mut()
452            .enumerate()
453            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454            .collect::<BinaryHeap<_>>();
455
456        let mut fusion = Vec::with_capacity(self.k as usize);
457        while let Some((m, idx)) = max_heap.pop() {
458            if self.k <= fusion.len() as u32 {
459                break;
460            }
461            if let Some(next_m) = batch_iters[idx].next() {
462                max_heap.push((next_m, idx));
463            }
464            if fusion.last().is_some_and(|tail| tail == &m) {
465                continue;
466            }
467            fusion.push(m);
468        }
469        fusion
470    }
471}
472
473/// The `Projection` operator retrieves record content by offset ids
474///
475/// # Parameters
476/// - `document`: Whether to retrieve document
477/// - `embedding`: Whether to retrieve embedding
478/// - `metadata`: Whether to retrieve metadata
479#[derive(Clone, Debug, Default)]
480pub struct Projection {
481    pub document: bool,
482    pub embedding: bool,
483    pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487    fn from(value: chroma_proto::ProjectionOperator) -> Self {
488        Self {
489            document: value.document,
490            embedding: value.embedding,
491            metadata: value.metadata,
492        }
493    }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497    fn from(value: Projection) -> Self {
498        Self {
499            document: value.document,
500            embedding: value.embedding,
501            metadata: value.metadata,
502        }
503    }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508    pub id: String,
509    pub document: Option<String>,
510    pub embedding: Option<Vec<f32>>,
511    pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515    pub fn size_bytes(&self) -> u64 {
516        (self.id.len()
517            + self
518                .document
519                .as_ref()
520                .map(|doc| doc.len())
521                .unwrap_or_default()
522            + self
523                .embedding
524                .as_ref()
525                .map(|emb| size_of_val(&emb[..]))
526                .unwrap_or_default()
527            + self
528                .metadata
529                .as_ref()
530                .map(logical_size_of_metadata)
531                .unwrap_or_default()) as u64
532    }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538    type Error = QueryConversionError;
539
540    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541        Ok(Self {
542            id: value.id,
543            document: value.document,
544            embedding: value
545                .embedding
546                .map(|vec| vec.try_into().map(|(v, _)| v))
547                .transpose()?,
548            metadata: value.metadata.map(TryInto::try_into).transpose()?,
549        })
550    }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554    type Error = QueryConversionError;
555
556    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557        Ok(Self {
558            id: value.id,
559            document: value.document,
560            embedding: value
561                .embedding
562                .map(|embedding| {
563                    let embedding_dimension = embedding.len();
564                    chroma_proto::Vector::try_from((
565                        embedding,
566                        ScalarEncoding::FLOAT32,
567                        embedding_dimension,
568                    ))
569                })
570                .transpose()?,
571            metadata: value.metadata.map(|metadata| metadata.into()),
572        })
573    }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578    pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583    pub pulled_log_bytes: u64,
584    pub result: ProjectionOutput,
585}
586
587impl GetResult {
588    pub fn size_bytes(&self) -> u64 {
589        self.result
590            .records
591            .iter()
592            .map(ProjectionRecord::size_bytes)
593            .sum()
594    }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598    type Error = QueryConversionError;
599
600    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601        Ok(Self {
602            pulled_log_bytes: value.pulled_log_bytes,
603            result: ProjectionOutput {
604                records: value
605                    .records
606                    .into_iter()
607                    .map(TryInto::try_into)
608                    .collect::<Result<_, _>>()?,
609            },
610        })
611    }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615    type Error = QueryConversionError;
616
617    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618        Ok(Self {
619            pulled_log_bytes: value.pulled_log_bytes,
620            records: value
621                .result
622                .records
623                .into_iter()
624                .map(TryInto::try_into)
625                .collect::<Result<_, _>>()?,
626        })
627    }
628}
629
630/// The `KnnProjection` operator retrieves record content by offset ids
631/// It is based on `ProjectionOperator`, and it attaches the distance
632/// of the records to the target embedding to the record content
633///
634/// # Parameters
635/// - `projection`: The parameters of the `ProjectionOperator`
636/// - `distance`: Whether to attach distance information
637#[derive(Clone, Debug)]
638pub struct KnnProjection {
639    pub projection: Projection,
640    pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644    type Error = QueryConversionError;
645
646    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647        Ok(Self {
648            projection: value
649                .projection
650                .ok_or(QueryConversionError::field("projection"))?
651                .into(),
652            distance: value.distance,
653        })
654    }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658    fn from(value: KnnProjection) -> Self {
659        Self {
660            projection: Some(value.projection.into()),
661            distance: value.distance,
662        }
663    }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668    pub record: ProjectionRecord,
669    pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673    type Error = QueryConversionError;
674
675    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676        Ok(Self {
677            record: value
678                .record
679                .ok_or(QueryConversionError::field("record"))?
680                .try_into()?,
681            distance: value.distance,
682        })
683    }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687    type Error = QueryConversionError;
688
689    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690        Ok(Self {
691            record: Some(value.record.try_into()?),
692            distance: value.distance,
693        })
694    }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699    pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703    type Error = QueryConversionError;
704
705    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706        Ok(Self {
707            records: value
708                .records
709                .into_iter()
710                .map(TryInto::try_into)
711                .collect::<Result<_, _>>()?,
712        })
713    }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717    type Error = QueryConversionError;
718
719    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720        Ok(Self {
721            records: value
722                .records
723                .into_iter()
724                .map(TryInto::try_into)
725                .collect::<Result<_, _>>()?,
726        })
727    }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732    pub pulled_log_bytes: u64,
733    pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737    pub fn size_bytes(&self) -> u64 {
738        self.results
739            .iter()
740            .flat_map(|res| {
741                res.records
742                    .iter()
743                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744            })
745            .sum()
746    }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750    type Error = QueryConversionError;
751
752    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753        Ok(Self {
754            pulled_log_bytes: value.pulled_log_bytes,
755            results: value
756                .results
757                .into_iter()
758                .map(TryInto::try_into)
759                .collect::<Result<_, _>>()?,
760        })
761    }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765    type Error = QueryConversionError;
766
767    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768        Ok(Self {
769            pulled_log_bytes: value.pulled_log_bytes,
770            results: value
771                .results
772                .into_iter()
773                .map(TryInto::try_into)
774                .collect::<Result<_, _>>()?,
775        })
776    }
777}
778
779/// A query vector for KNN search.
780///
781/// Supports both dense and sparse vector formats.
782///
783/// # Variants
784///
785/// ## Dense
786///
787/// Standard dense embeddings as a vector of floats.
788///
789/// ```
790/// use chroma_types::operator::QueryVector;
791///
792/// let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3, 0.4]);
793/// ```
794///
795/// ## Sparse
796///
797/// Sparse vectors with explicit indices and values.
798///
799/// ```
800/// use chroma_types::operator::QueryVector;
801/// use chroma_types::SparseVector;
802///
803/// let sparse = QueryVector::Sparse(SparseVector::new(
804///     vec![0, 5, 10, 50],      // indices
805///     vec![0.5, 0.3, 0.8, 0.2], // values
806/// ).unwrap());
807/// ```
808///
809/// # Examples
810///
811/// ## Dense vector in KNN
812///
813/// ```
814/// use chroma_types::operator::{RankExpr, QueryVector, Key};
815///
816/// let rank = RankExpr::Knn {
817///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
818///     key: Key::Embedding,
819///     limit: 100,
820///     default: None,
821///     return_rank: false,
822/// };
823/// ```
824///
825/// ## Sparse vector in KNN
826///
827/// ```
828/// use chroma_types::operator::{RankExpr, QueryVector, Key};
829/// use chroma_types::SparseVector;
830///
831/// let rank = RankExpr::Knn {
832///     query: QueryVector::Sparse(SparseVector::new(
833///         vec![1, 5, 10],
834///         vec![0.5, 0.3, 0.8],
835///     ).unwrap()),
836///     key: Key::field("sparse_embedding"),
837///     limit: 100,
838///     default: None,
839///     return_rank: false,
840/// };
841/// ```
842#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845    Dense(Vec<f32>),
846    Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850    type Error = QueryConversionError;
851
852    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854        match vector {
855            chroma_proto::query_vector::Vector::Dense(dense) => {
856                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857            }
858            chroma_proto::query_vector::Vector::Sparse(sparse) => {
859                Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860                    QueryConversionError::validation("sparse vector length mismatch")
861                })?))
862            }
863        }
864    }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868    type Error = QueryConversionError;
869
870    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871        match value {
872            QueryVector::Dense(vec) => {
873                let dim = vec.len();
874                Ok(chroma_proto::QueryVector {
875                    vector: Some(chroma_proto::query_vector::Vector::Dense(
876                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877                    )),
878                })
879            }
880            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882            }),
883        }
884    }
885}
886
887impl From<Vec<f32>> for QueryVector {
888    fn from(vec: Vec<f32>) -> Self {
889        QueryVector::Dense(vec)
890    }
891}
892
893impl From<SparseVector> for QueryVector {
894    fn from(sparse: SparseVector) -> Self {
895        QueryVector::Sparse(sparse)
896    }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901    pub query: QueryVector,
902    pub key: Key,
903    pub limit: u32,
904}
905
906/// Wrapper for ranking expressions in search queries.
907///
908/// Contains an optional ranking expression. When None, results are returned in
909/// natural storage order without scoring.
910///
911/// # Fields
912///
913/// * `expr` - The ranking expression (None = no ranking)
914///
915/// # Examples
916///
917/// ```
918/// use chroma_types::operator::{Rank, RankExpr, QueryVector, Key};
919///
920/// // With ranking
921/// let rank = Rank {
922///     expr: Some(RankExpr::Knn {
923///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
924///         key: Key::Embedding,
925///         limit: 100,
926///         default: None,
927///         return_rank: false,
928///     }),
929/// };
930///
931/// // No ranking (natural order)
932/// let rank = Rank {
933///     expr: None,
934/// };
935/// ```
936#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939    pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943    pub fn knn_queries(&self) -> Vec<KnnQuery> {
944        self.expr
945            .as_ref()
946            .map(RankExpr::knn_queries)
947            .unwrap_or_default()
948    }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952    type Error = QueryConversionError;
953
954    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955        Ok(Rank {
956            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957        })
958    }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962    type Error = QueryConversionError;
963
964    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965        Ok(chroma_proto::RankOperator {
966            expr: rank.expr.map(TryInto::try_into).transpose()?,
967        })
968    }
969}
970
971/// A ranking expression for scoring and ordering search results.
972///
973/// Ranking expressions determine which documents appear in results and their order.
974/// Lower scores indicate better matches (distance-based scoring).
975///
976/// # Variants
977///
978/// ## Knn - K-Nearest Neighbor Search
979///
980/// The primary ranking method for vector similarity search.
981///
982/// ```
983/// use chroma_types::operator::{RankExpr, QueryVector, Key};
984///
985/// let rank = RankExpr::Knn {
986///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
987///     key: Key::Embedding,
988///     limit: 100,        // Consider top 100 candidates
989///     default: None,     // No default score for missing documents
990///     return_rank: false, // Return distances, not rank positions
991/// };
992/// ```
993///
994/// ## Value - Constant
995///
996/// Represents a constant score.
997///
998/// ```
999/// use chroma_types::operator::RankExpr;
1000///
1001/// let rank = RankExpr::Value(0.5);
1002/// ```
1003///
1004/// ## Arithmetic Operations
1005///
1006/// Combine ranking expressions using standard operators (+, -, *, /).
1007///
1008/// ```
1009/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1010///
1011/// let knn1 = RankExpr::Knn {
1012///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1013///     key: Key::Embedding,
1014///     limit: 100,
1015///     default: None,
1016///     return_rank: false,
1017/// };
1018///
1019/// let knn2 = RankExpr::Knn {
1020///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1021///     key: Key::field("other_embedding"),
1022///     limit: 100,
1023///     default: None,
1024///     return_rank: false,
1025/// };
1026///
1027/// // Weighted combination: 70% knn1 + 30% knn2
1028/// let combined = knn1 * 0.7 + knn2 * 0.3;
1029///
1030/// // Normalized
1031/// let normalized = combined / 2.0;
1032/// ```
1033///
1034/// ## Mathematical Functions
1035///
1036/// Apply mathematical transformations to scores.
1037///
1038/// ```
1039/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1040///
1041/// let knn = RankExpr::Knn {
1042///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1043///     key: Key::Embedding,
1044///     limit: 100,
1045///     default: None,
1046///     return_rank: false,
1047/// };
1048///
1049/// // Exponential - amplifies differences
1050/// let amplified = knn.clone().exp();
1051///
1052/// // Logarithm - compresses range (add constant to avoid log(0))
1053/// let compressed = (knn.clone() + 1.0).log();
1054///
1055/// // Absolute value
1056/// let absolute = knn.clone().abs();
1057///
1058/// // Min/Max - clamping
1059/// let clamped = knn.min(1.0).max(0.0);
1060/// ```
1061///
1062/// # Examples
1063///
1064/// ## Basic vector search
1065///
1066/// ```
1067/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1068///
1069/// let rank = RankExpr::Knn {
1070///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1071///     key: Key::Embedding,
1072///     limit: 100,
1073///     default: None,
1074///     return_rank: false,
1075/// };
1076/// ```
1077///
1078/// ## Hybrid search with weighted combination
1079///
1080/// ```
1081/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1082///
1083/// let dense = RankExpr::Knn {
1084///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1085///     key: Key::Embedding,
1086///     limit: 200,
1087///     default: None,
1088///     return_rank: false,
1089/// };
1090///
1091/// let sparse = RankExpr::Knn {
1092///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]), // Use sparse in practice
1093///     key: Key::field("sparse_embedding"),
1094///     limit: 200,
1095///     default: None,
1096///     return_rank: false,
1097/// };
1098///
1099/// // 70% semantic + 30% keyword
1100/// let hybrid = dense * 0.7 + sparse * 0.3;
1101/// ```
1102///
1103/// ## Reciprocal Rank Fusion (RRF)
1104///
1105/// Use the `rrf()` function for combining rankings with different score scales.
1106///
1107/// ```
1108/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
1109///
1110/// let dense = RankExpr::Knn {
1111///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1112///     key: Key::Embedding,
1113///     limit: 200,
1114///     default: None,
1115///     return_rank: true, // RRF requires rank positions
1116/// };
1117///
1118/// let sparse = RankExpr::Knn {
1119///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1120///     key: Key::field("sparse_embedding"),
1121///     limit: 200,
1122///     default: None,
1123///     return_rank: true, // RRF requires rank positions
1124/// };
1125///
1126/// let rrf_rank = rrf(
1127///     vec![dense, sparse],
1128///     Some(60),           // k parameter (smoothing)
1129///     Some(vec![0.7, 0.3]), // weights
1130///     false,              // normalize weights
1131/// ).unwrap();
1132/// ```
1133#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135    #[serde(rename = "$abs")]
1136    Absolute(Box<RankExpr>),
1137    #[serde(rename = "$div")]
1138    Division {
1139        left: Box<RankExpr>,
1140        right: Box<RankExpr>,
1141    },
1142    #[serde(rename = "$exp")]
1143    Exponentiation(Box<RankExpr>),
1144    #[serde(rename = "$knn")]
1145    Knn {
1146        query: QueryVector,
1147        #[serde(default = "RankExpr::default_knn_key")]
1148        key: Key,
1149        #[serde(default = "RankExpr::default_knn_limit")]
1150        limit: u32,
1151        #[serde(default)]
1152        default: Option<f32>,
1153        #[serde(default)]
1154        return_rank: bool,
1155    },
1156    #[serde(rename = "$log")]
1157    Logarithm(Box<RankExpr>),
1158    #[serde(rename = "$max")]
1159    Maximum(Vec<RankExpr>),
1160    #[serde(rename = "$min")]
1161    Minimum(Vec<RankExpr>),
1162    #[serde(rename = "$mul")]
1163    Multiplication(Vec<RankExpr>),
1164    #[serde(rename = "$sub")]
1165    Subtraction {
1166        left: Box<RankExpr>,
1167        right: Box<RankExpr>,
1168    },
1169    #[serde(rename = "$sum")]
1170    Summation(Vec<RankExpr>),
1171    #[serde(rename = "$val")]
1172    Value(f32),
1173}
1174
1175impl RankExpr {
1176    pub fn default_knn_key() -> Key {
1177        Key::Embedding
1178    }
1179
1180    pub fn default_knn_limit() -> u32 {
1181        16
1182    }
1183
1184    pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185        match self {
1186            RankExpr::Absolute(expr)
1187            | RankExpr::Exponentiation(expr)
1188            | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190                .knn_queries()
1191                .into_iter()
1192                .chain(right.knn_queries())
1193                .collect(),
1194            RankExpr::Maximum(exprs)
1195            | RankExpr::Minimum(exprs)
1196            | RankExpr::Multiplication(exprs)
1197            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198            RankExpr::Value(_) => Vec::new(),
1199            RankExpr::Knn {
1200                query,
1201                key,
1202                limit,
1203                default: _,
1204                return_rank: _,
1205            } => vec![KnnQuery {
1206                query: query.clone(),
1207                key: key.clone(),
1208                limit: *limit,
1209            }],
1210        }
1211    }
1212
1213    /// Applies exponential transformation: e^rank.
1214    ///
1215    /// Amplifies differences between scores.
1216    ///
1217    /// # Examples
1218    ///
1219    /// ```
1220    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1221    ///
1222    /// let knn = RankExpr::Knn {
1223    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1224    ///     key: Key::Embedding,
1225    ///     limit: 100,
1226    ///     default: None,
1227    ///     return_rank: false,
1228    /// };
1229    ///
1230    /// let amplified = knn.exp();
1231    /// ```
1232    pub fn exp(self) -> Self {
1233        RankExpr::Exponentiation(Box::new(self))
1234    }
1235
1236    /// Applies natural logarithm transformation: ln(rank).
1237    ///
1238    /// Compresses the score range. Add a constant to avoid log(0).
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1244    ///
1245    /// let knn = RankExpr::Knn {
1246    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1247    ///     key: Key::Embedding,
1248    ///     limit: 100,
1249    ///     default: None,
1250    ///     return_rank: false,
1251    /// };
1252    ///
1253    /// // Add constant to avoid log(0)
1254    /// let compressed = (knn + 1.0).log();
1255    /// ```
1256    pub fn log(self) -> Self {
1257        RankExpr::Logarithm(Box::new(self))
1258    }
1259
1260    /// Takes absolute value of the ranking expression.
1261    ///
1262    /// # Examples
1263    ///
1264    /// ```
1265    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1266    ///
1267    /// let knn1 = RankExpr::Knn {
1268    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1269    ///     key: Key::Embedding,
1270    ///     limit: 100,
1271    ///     default: None,
1272    ///     return_rank: false,
1273    /// };
1274    ///
1275    /// let knn2 = RankExpr::Knn {
1276    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1277    ///     key: Key::field("other"),
1278    ///     limit: 100,
1279    ///     default: None,
1280    ///     return_rank: false,
1281    /// };
1282    ///
1283    /// // Absolute difference
1284    /// let diff = (knn1 - knn2).abs();
1285    /// ```
1286    pub fn abs(self) -> Self {
1287        RankExpr::Absolute(Box::new(self))
1288    }
1289
1290    /// Returns maximum of this expression and another.
1291    ///
1292    /// Can be chained to clamp scores to a maximum value.
1293    ///
1294    /// # Examples
1295    ///
1296    /// ```
1297    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1298    ///
1299    /// let knn = RankExpr::Knn {
1300    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1301    ///     key: Key::Embedding,
1302    ///     limit: 100,
1303    ///     default: None,
1304    ///     return_rank: false,
1305    /// };
1306    ///
1307    /// // Clamp to maximum of 1.0
1308    /// let clamped = knn.clone().max(1.0);
1309    ///
1310    /// // Clamp to range [0.0, 1.0]
1311    /// let range_clamped = knn.min(0.0).max(1.0);
1312    /// ```
1313    pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314        let other = other.into();
1315
1316        match self {
1317            RankExpr::Maximum(mut exprs) => match other {
1318                RankExpr::Maximum(other_exprs) => {
1319                    exprs.extend(other_exprs);
1320                    RankExpr::Maximum(exprs)
1321                }
1322                _ => {
1323                    exprs.push(other);
1324                    RankExpr::Maximum(exprs)
1325                }
1326            },
1327            _ => match other {
1328                RankExpr::Maximum(mut exprs) => {
1329                    exprs.insert(0, self);
1330                    RankExpr::Maximum(exprs)
1331                }
1332                _ => RankExpr::Maximum(vec![self, other]),
1333            },
1334        }
1335    }
1336
1337    /// Returns minimum of this expression and another.
1338    ///
1339    /// Can be chained to clamp scores to a minimum value.
1340    ///
1341    /// # Examples
1342    ///
1343    /// ```
1344    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1345    ///
1346    /// let knn = RankExpr::Knn {
1347    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1348    ///     key: Key::Embedding,
1349    ///     limit: 100,
1350    ///     default: None,
1351    ///     return_rank: false,
1352    /// };
1353    ///
1354    /// // Clamp to minimum of 0.0 (ensure non-negative)
1355    /// let clamped = knn.clone().min(0.0);
1356    ///
1357    /// // Clamp to range [0.0, 1.0]
1358    /// let range_clamped = knn.min(0.0).max(1.0);
1359    /// ```
1360    pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361        let other = other.into();
1362
1363        match self {
1364            RankExpr::Minimum(mut exprs) => match other {
1365                RankExpr::Minimum(other_exprs) => {
1366                    exprs.extend(other_exprs);
1367                    RankExpr::Minimum(exprs)
1368                }
1369                _ => {
1370                    exprs.push(other);
1371                    RankExpr::Minimum(exprs)
1372                }
1373            },
1374            _ => match other {
1375                RankExpr::Minimum(mut exprs) => {
1376                    exprs.insert(0, self);
1377                    RankExpr::Minimum(exprs)
1378                }
1379                _ => RankExpr::Minimum(vec![self, other]),
1380            },
1381        }
1382    }
1383}
1384
1385impl Add for RankExpr {
1386    type Output = RankExpr;
1387
1388    fn add(self, rhs: Self) -> Self::Output {
1389        match self {
1390            RankExpr::Summation(mut exprs) => match rhs {
1391                RankExpr::Summation(rhs_exprs) => {
1392                    exprs.extend(rhs_exprs);
1393                    RankExpr::Summation(exprs)
1394                }
1395                _ => {
1396                    exprs.push(rhs);
1397                    RankExpr::Summation(exprs)
1398                }
1399            },
1400            _ => match rhs {
1401                RankExpr::Summation(mut exprs) => {
1402                    exprs.insert(0, self);
1403                    RankExpr::Summation(exprs)
1404                }
1405                _ => RankExpr::Summation(vec![self, rhs]),
1406            },
1407        }
1408    }
1409}
1410
1411impl Add<f32> for RankExpr {
1412    type Output = RankExpr;
1413
1414    fn add(self, rhs: f32) -> Self::Output {
1415        self + RankExpr::Value(rhs)
1416    }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420    type Output = RankExpr;
1421
1422    fn add(self, rhs: RankExpr) -> Self::Output {
1423        RankExpr::Value(self) + rhs
1424    }
1425}
1426
1427impl Sub for RankExpr {
1428    type Output = RankExpr;
1429
1430    fn sub(self, rhs: Self) -> Self::Output {
1431        RankExpr::Subtraction {
1432            left: Box::new(self),
1433            right: Box::new(rhs),
1434        }
1435    }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439    type Output = RankExpr;
1440
1441    fn sub(self, rhs: f32) -> Self::Output {
1442        self - RankExpr::Value(rhs)
1443    }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447    type Output = RankExpr;
1448
1449    fn sub(self, rhs: RankExpr) -> Self::Output {
1450        RankExpr::Value(self) - rhs
1451    }
1452}
1453
1454impl Mul for RankExpr {
1455    type Output = RankExpr;
1456
1457    fn mul(self, rhs: Self) -> Self::Output {
1458        match self {
1459            RankExpr::Multiplication(mut exprs) => match rhs {
1460                RankExpr::Multiplication(rhs_exprs) => {
1461                    exprs.extend(rhs_exprs);
1462                    RankExpr::Multiplication(exprs)
1463                }
1464                _ => {
1465                    exprs.push(rhs);
1466                    RankExpr::Multiplication(exprs)
1467                }
1468            },
1469            _ => match rhs {
1470                RankExpr::Multiplication(mut exprs) => {
1471                    exprs.insert(0, self);
1472                    RankExpr::Multiplication(exprs)
1473                }
1474                _ => RankExpr::Multiplication(vec![self, rhs]),
1475            },
1476        }
1477    }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481    type Output = RankExpr;
1482
1483    fn mul(self, rhs: f32) -> Self::Output {
1484        self * RankExpr::Value(rhs)
1485    }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489    type Output = RankExpr;
1490
1491    fn mul(self, rhs: RankExpr) -> Self::Output {
1492        RankExpr::Value(self) * rhs
1493    }
1494}
1495
1496impl Div for RankExpr {
1497    type Output = RankExpr;
1498
1499    fn div(self, rhs: Self) -> Self::Output {
1500        RankExpr::Division {
1501            left: Box::new(self),
1502            right: Box::new(rhs),
1503        }
1504    }
1505}
1506
1507impl Div<f32> for RankExpr {
1508    type Output = RankExpr;
1509
1510    fn div(self, rhs: f32) -> Self::Output {
1511        self / RankExpr::Value(rhs)
1512    }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516    type Output = RankExpr;
1517
1518    fn div(self, rhs: RankExpr) -> Self::Output {
1519        RankExpr::Value(self) / rhs
1520    }
1521}
1522
1523impl Neg for RankExpr {
1524    type Output = RankExpr;
1525
1526    fn neg(self) -> Self::Output {
1527        RankExpr::Value(-1.0) * self
1528    }
1529}
1530
1531impl From<f32> for RankExpr {
1532    fn from(v: f32) -> Self {
1533        RankExpr::Value(v)
1534    }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538    type Error = QueryConversionError;
1539
1540    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541        match proto_expr.rank {
1542            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544            }
1545            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548                Ok(RankExpr::Division {
1549                    left: Box::new(RankExpr::try_from(*left)?),
1550                    right: Box::new(RankExpr::try_from(*right)?),
1551                })
1552            }
1553            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555            ),
1556            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557                let query = knn
1558                    .query
1559                    .ok_or(QueryConversionError::field("query"))?
1560                    .try_into()?;
1561                Ok(RankExpr::Knn {
1562                    query,
1563                    key: Key::from(knn.key),
1564                    limit: knn.limit,
1565                    default: knn.default,
1566                    return_rank: knn.return_rank,
1567                })
1568            }
1569            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571            }
1572            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573                let exprs = max
1574                    .exprs
1575                    .into_iter()
1576                    .map(RankExpr::try_from)
1577                    .collect::<Result<Vec<_>, _>>()?;
1578                Ok(RankExpr::Maximum(exprs))
1579            }
1580            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581                let exprs = min
1582                    .exprs
1583                    .into_iter()
1584                    .map(RankExpr::try_from)
1585                    .collect::<Result<Vec<_>, _>>()?;
1586                Ok(RankExpr::Minimum(exprs))
1587            }
1588            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589                let exprs = mul
1590                    .exprs
1591                    .into_iter()
1592                    .map(RankExpr::try_from)
1593                    .collect::<Result<Vec<_>, _>>()?;
1594                Ok(RankExpr::Multiplication(exprs))
1595            }
1596            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599                Ok(RankExpr::Subtraction {
1600                    left: Box::new(RankExpr::try_from(*left)?),
1601                    right: Box::new(RankExpr::try_from(*right)?),
1602                })
1603            }
1604            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605                let exprs = sum
1606                    .exprs
1607                    .into_iter()
1608                    .map(RankExpr::try_from)
1609                    .collect::<Result<Vec<_>, _>>()?;
1610                Ok(RankExpr::Summation(exprs))
1611            }
1612            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613            None => Err(QueryConversionError::field("rank")),
1614        }
1615    }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619    type Error = QueryConversionError;
1620
1621    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622        let proto_rank = match rank_expr {
1623            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624                chroma_proto::RankExpr::try_from(*expr)?,
1625            )),
1626            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627                Box::new(chroma_proto::rank_expr::RankPair {
1628                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630                }),
1631            ),
1632            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634            ),
1635            RankExpr::Knn {
1636                query,
1637                key,
1638                limit,
1639                default,
1640                return_rank,
1641            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642                query: Some(query.try_into()?),
1643                key: key.to_string(),
1644                limit,
1645                default,
1646                return_rank,
1647            }),
1648            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649                chroma_proto::RankExpr::try_from(*expr)?,
1650            )),
1651            RankExpr::Maximum(exprs) => {
1652                let proto_exprs = exprs
1653                    .into_iter()
1654                    .map(chroma_proto::RankExpr::try_from)
1655                    .collect::<Result<Vec<_>, _>>()?;
1656                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657                    exprs: proto_exprs,
1658                })
1659            }
1660            RankExpr::Minimum(exprs) => {
1661                let proto_exprs = exprs
1662                    .into_iter()
1663                    .map(chroma_proto::RankExpr::try_from)
1664                    .collect::<Result<Vec<_>, _>>()?;
1665                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666                    exprs: proto_exprs,
1667                })
1668            }
1669            RankExpr::Multiplication(exprs) => {
1670                let proto_exprs = exprs
1671                    .into_iter()
1672                    .map(chroma_proto::RankExpr::try_from)
1673                    .collect::<Result<Vec<_>, _>>()?;
1674                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675                    exprs: proto_exprs,
1676                })
1677            }
1678            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679                Box::new(chroma_proto::rank_expr::RankPair {
1680                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682                }),
1683            ),
1684            RankExpr::Summation(exprs) => {
1685                let proto_exprs = exprs
1686                    .into_iter()
1687                    .map(chroma_proto::RankExpr::try_from)
1688                    .collect::<Result<Vec<_>, _>>()?;
1689                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690                    exprs: proto_exprs,
1691                })
1692            }
1693            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694        };
1695
1696        Ok(chroma_proto::RankExpr {
1697            rank: Some(proto_rank),
1698        })
1699    }
1700}
1701
1702/// Represents a field key in search queries.
1703///
1704/// Used for both selecting fields to return and building filter expressions.
1705/// Predefined keys access special fields, while custom keys access metadata.
1706///
1707/// # Predefined Keys
1708///
1709/// - `Key::Document` - Document text content (`#document`)
1710/// - `Key::Embedding` - Vector embeddings (`#embedding`)
1711/// - `Key::Metadata` - All metadata fields (`#metadata`)
1712/// - `Key::Score` - Search scores (`#score`)
1713///
1714/// # Custom Keys
1715///
1716/// Use `Key::field()` or `Key::from()` to reference metadata fields:
1717///
1718/// ```
1719/// use chroma_types::operator::Key;
1720///
1721/// let key = Key::field("author");
1722/// let key = Key::from("title");
1723/// ```
1724///
1725/// # Examples
1726///
1727/// ## Building filters
1728///
1729/// ```
1730/// use chroma_types::operator::Key;
1731///
1732/// // Equality
1733/// let filter = Key::field("status").eq("published");
1734///
1735/// // Comparisons
1736/// let filter = Key::field("year").gte(2020);
1737/// let filter = Key::field("score").lt(0.9);
1738///
1739/// // Set operations
1740/// let filter = Key::field("category").is_in(vec!["tech", "science"]);
1741/// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
1742///
1743/// // Document content
1744/// let filter = Key::Document.contains("machine learning");
1745/// let filter = Key::Document.regex(r"\bAPI\b");
1746///
1747/// // Combining filters
1748/// let filter = Key::field("status").eq("published")
1749///     & Key::field("year").gte(2020);
1750/// ```
1751///
1752/// ## Selecting fields
1753///
1754/// ```
1755/// use chroma_types::plan::SearchPayload;
1756/// use chroma_types::operator::Key;
1757///
1758/// let search = SearchPayload::default()
1759///     .select([
1760///         Key::Document,
1761///         Key::Score,
1762///         Key::field("title"),
1763///         Key::field("author"),
1764///     ]);
1765/// ```
1766#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769    // Predefined keys
1770    Document,
1771    Embedding,
1772    Metadata,
1773    Score,
1774    MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779    where
1780        S: serde::Serializer,
1781    {
1782        match self {
1783            Key::Document => serializer.serialize_str("#document"),
1784            Key::Embedding => serializer.serialize_str("#embedding"),
1785            Key::Metadata => serializer.serialize_str("#metadata"),
1786            Key::Score => serializer.serialize_str("#score"),
1787            Key::MetadataField(field) => serializer.serialize_str(field),
1788        }
1789    }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794    where
1795        D: Deserializer<'de>,
1796    {
1797        let s = String::deserialize(deserializer)?;
1798        Ok(Key::from(s))
1799    }
1800}
1801
1802impl fmt::Display for Key {
1803    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804        match self {
1805            Key::Document => write!(f, "#document"),
1806            Key::Embedding => write!(f, "#embedding"),
1807            Key::Metadata => write!(f, "#metadata"),
1808            Key::Score => write!(f, "#score"),
1809            Key::MetadataField(field) => write!(f, "{}", field),
1810        }
1811    }
1812}
1813
1814impl From<&str> for Key {
1815    fn from(s: &str) -> Self {
1816        match s {
1817            "#document" => Key::Document,
1818            "#embedding" => Key::Embedding,
1819            "#metadata" => Key::Metadata,
1820            "#score" => Key::Score,
1821            // Any other string is treated as a metadata field key
1822            field => Key::MetadataField(field.to_string()),
1823        }
1824    }
1825}
1826
1827impl From<String> for Key {
1828    fn from(s: String) -> Self {
1829        Key::from(s.as_str())
1830    }
1831}
1832
1833impl Key {
1834    /// Creates a Key for a custom metadata field.
1835    ///
1836    /// # Examples
1837    ///
1838    /// ```
1839    /// use chroma_types::operator::Key;
1840    ///
1841    /// let status = Key::field("status");
1842    /// let year = Key::field("year");
1843    /// let author = Key::field("author");
1844    /// ```
1845    pub fn field(name: impl Into<String>) -> Self {
1846        Key::MetadataField(name.into())
1847    }
1848
1849    /// Creates an equality filter: `field == value`.
1850    ///
1851    /// # Examples
1852    ///
1853    /// ```
1854    /// use chroma_types::operator::Key;
1855    ///
1856    /// // String equality
1857    /// let filter = Key::field("status").eq("published");
1858    ///
1859    /// // Numeric equality
1860    /// let filter = Key::field("count").eq(42);
1861    ///
1862    /// // Boolean equality
1863    /// let filter = Key::field("featured").eq(true);
1864    /// ```
1865    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866        Where::Metadata(MetadataExpression {
1867            key: self.to_string(),
1868            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869        })
1870    }
1871
1872    /// Creates an inequality filter: `field != value`.
1873    ///
1874    /// # Examples
1875    ///
1876    /// ```
1877    /// use chroma_types::operator::Key;
1878    ///
1879    /// let filter = Key::field("status").ne("deleted");
1880    /// let filter = Key::field("count").ne(0);
1881    /// ```
1882    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883        Where::Metadata(MetadataExpression {
1884            key: self.to_string(),
1885            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886        })
1887    }
1888
1889    /// Creates a greater-than filter: `field > value` (numeric only).
1890    ///
1891    /// # Examples
1892    ///
1893    /// ```
1894    /// use chroma_types::operator::Key;
1895    ///
1896    /// let filter = Key::field("score").gt(0.5);
1897    /// let filter = Key::field("year").gt(2020);
1898    /// ```
1899    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900        Where::Metadata(MetadataExpression {
1901            key: self.to_string(),
1902            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903        })
1904    }
1905
1906    /// Creates a greater-than-or-equal filter: `field >= value` (numeric only).
1907    ///
1908    /// # Examples
1909    ///
1910    /// ```
1911    /// use chroma_types::operator::Key;
1912    ///
1913    /// let filter = Key::field("score").gte(0.5);
1914    /// let filter = Key::field("year").gte(2020);
1915    /// ```
1916    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917        Where::Metadata(MetadataExpression {
1918            key: self.to_string(),
1919            comparison: MetadataComparison::Primitive(
1920                PrimitiveOperator::GreaterThanOrEqual,
1921                value.into(),
1922            ),
1923        })
1924    }
1925
1926    /// Creates a less-than filter: `field < value` (numeric only).
1927    ///
1928    /// # Examples
1929    ///
1930    /// ```
1931    /// use chroma_types::operator::Key;
1932    ///
1933    /// let filter = Key::field("score").lt(0.9);
1934    /// let filter = Key::field("year").lt(2025);
1935    /// ```
1936    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937        Where::Metadata(MetadataExpression {
1938            key: self.to_string(),
1939            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940        })
1941    }
1942
1943    /// Creates a less-than-or-equal filter: `field <= value` (numeric only).
1944    ///
1945    /// # Examples
1946    ///
1947    /// ```
1948    /// use chroma_types::operator::Key;
1949    ///
1950    /// let filter = Key::field("score").lte(0.9);
1951    /// let filter = Key::field("year").lte(2024);
1952    /// ```
1953    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954        Where::Metadata(MetadataExpression {
1955            key: self.to_string(),
1956            comparison: MetadataComparison::Primitive(
1957                PrimitiveOperator::LessThanOrEqual,
1958                value.into(),
1959            ),
1960        })
1961    }
1962
1963    /// Creates a set membership filter: `field IN values`.
1964    ///
1965    /// Accepts any iterator (Vec, array, slice, etc.).
1966    ///
1967    /// # Examples
1968    ///
1969    /// ```
1970    /// use chroma_types::operator::Key;
1971    ///
1972    /// // With Vec
1973    /// let filter = Key::field("year").is_in(vec![2023, 2024, 2025]);
1974    ///
1975    /// // With array
1976    /// let filter = Key::field("category").is_in(["tech", "science", "math"]);
1977    ///
1978    /// // With owned strings
1979    /// let categories = vec!["tech".to_string(), "science".to_string()];
1980    /// let filter = Key::field("category").is_in(categories);
1981    /// ```
1982    pub fn is_in<I, T>(self, values: I) -> Where
1983    where
1984        I: IntoIterator<Item = T>,
1985        Vec<T>: Into<MetadataSetValue>,
1986    {
1987        let vec: Vec<T> = values.into_iter().collect();
1988        Where::Metadata(MetadataExpression {
1989            key: self.to_string(),
1990            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991        })
1992    }
1993
1994    /// Creates a set exclusion filter: `field NOT IN values`.
1995    ///
1996    /// Accepts any iterator (Vec, array, slice, etc.).
1997    ///
1998    /// # Examples
1999    ///
2000    /// ```
2001    /// use chroma_types::operator::Key;
2002    ///
2003    /// // Exclude deleted and archived
2004    /// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
2005    ///
2006    /// // Exclude specific years
2007    /// let filter = Key::field("year").not_in(vec![2019, 2020]);
2008    /// ```
2009    pub fn not_in<I, T>(self, values: I) -> Where
2010    where
2011        I: IntoIterator<Item = T>,
2012        Vec<T>: Into<MetadataSetValue>,
2013    {
2014        let vec: Vec<T> = values.into_iter().collect();
2015        Where::Metadata(MetadataExpression {
2016            key: self.to_string(),
2017            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018        })
2019    }
2020
2021    /// Creates a substring filter (case-sensitive, document content only).
2022    ///
2023    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2024    /// 3 literal characters for accurate results.
2025    ///
2026    /// # Examples
2027    ///
2028    /// ```
2029    /// use chroma_types::operator::Key;
2030    ///
2031    /// let filter = Key::Document.contains("machine learning");
2032    /// let filter = Key::Document.contains("API");
2033    /// ```
2034    pub fn contains<S: Into<String>>(self, text: S) -> Where {
2035        Where::Document(DocumentExpression {
2036            operator: DocumentOperator::Contains,
2037            pattern: text.into(),
2038        })
2039    }
2040
2041    /// Creates a negative substring filter (case-sensitive, document content only).
2042    ///
2043    /// Note: Currently only works with `Key::Document`.
2044    ///
2045    /// # Examples
2046    ///
2047    /// ```
2048    /// use chroma_types::operator::Key;
2049    ///
2050    /// let filter = Key::Document.not_contains("deprecated");
2051    /// let filter = Key::Document.not_contains("beta");
2052    /// ```
2053    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2054        Where::Document(DocumentExpression {
2055            operator: DocumentOperator::NotContains,
2056            pattern: text.into(),
2057        })
2058    }
2059
2060    /// Creates a regex filter (case-sensitive, document content only).
2061    ///
2062    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2063    /// 3 literal characters for accurate results.
2064    ///
2065    /// # Examples
2066    ///
2067    /// ```
2068    /// use chroma_types::operator::Key;
2069    ///
2070    /// // Match whole word "API"
2071    /// let filter = Key::Document.regex(r"\bAPI\b");
2072    ///
2073    /// // Match version pattern
2074    /// let filter = Key::Document.regex(r"v\d+\.\d+\.\d+");
2075    /// ```
2076    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2077        Where::Document(DocumentExpression {
2078            operator: DocumentOperator::Regex,
2079            pattern: pattern.into(),
2080        })
2081    }
2082
2083    /// Creates a negative regex filter (case-sensitive, document content only).
2084    ///
2085    /// Note: Currently only works with `Key::Document`.
2086    ///
2087    /// # Examples
2088    ///
2089    /// ```
2090    /// use chroma_types::operator::Key;
2091    ///
2092    /// // Exclude beta versions
2093    /// let filter = Key::Document.not_regex(r"beta");
2094    ///
2095    /// // Exclude test documents
2096    /// let filter = Key::Document.not_regex(r"\btest\b");
2097    /// ```
2098    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2099        Where::Document(DocumentExpression {
2100            operator: DocumentOperator::NotRegex,
2101            pattern: pattern.into(),
2102        })
2103    }
2104}
2105
2106/// Field selection for search results.
2107///
2108/// Specifies which fields to include in the results. IDs are always included.
2109///
2110/// # Fields
2111///
2112/// * `keys` - Set of keys to include in results
2113///
2114/// # Available Keys
2115///
2116/// * `Key::Document` - Document text content
2117/// * `Key::Embedding` - Vector embeddings
2118/// * `Key::Metadata` - All metadata fields
2119/// * `Key::Score` - Search scores
2120/// * `Key::field("name")` - Specific metadata field
2121///
2122/// # Performance
2123///
2124/// Selecting fewer fields improves performance by reducing data transfer:
2125/// - Minimal: IDs only (default, fastest)
2126/// - Moderate: Scores + specific metadata fields
2127/// - Heavy: Documents + embeddings (larger payloads)
2128///
2129/// # Examples
2130///
2131/// ```
2132/// use chroma_types::operator::{Select, Key};
2133/// use std::collections::HashSet;
2134///
2135/// // Select predefined fields
2136/// let select = Select {
2137///     keys: [Key::Document, Key::Score].into_iter().collect(),
2138/// };
2139///
2140/// // Select specific metadata fields
2141/// let select = Select {
2142///     keys: [
2143///         Key::field("title"),
2144///         Key::field("author"),
2145///         Key::Score,
2146///     ].into_iter().collect(),
2147/// };
2148///
2149/// // Select everything
2150/// let select = Select {
2151///     keys: [
2152///         Key::Document,
2153///         Key::Embedding,
2154///         Key::Metadata,
2155///         Key::Score,
2156///     ].into_iter().collect(),
2157/// };
2158/// ```
2159#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2160pub struct Select {
2161    #[serde(default)]
2162    pub keys: HashSet<Key>,
2163}
2164
2165impl TryFrom<chroma_proto::SelectOperator> for Select {
2166    type Error = QueryConversionError;
2167
2168    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2169        let keys = value
2170            .keys
2171            .into_iter()
2172            .map(|key| {
2173                // Try to deserialize each string as a Key
2174                serde_json::from_value(serde_json::Value::String(key))
2175                    .map_err(|_| QueryConversionError::field("keys"))
2176            })
2177            .collect::<Result<HashSet<_>, _>>()?;
2178
2179        Ok(Self { keys })
2180    }
2181}
2182
2183impl TryFrom<Select> for chroma_proto::SelectOperator {
2184    type Error = QueryConversionError;
2185
2186    fn try_from(value: Select) -> Result<Self, Self::Error> {
2187        let keys = value
2188            .keys
2189            .into_iter()
2190            .map(|key| {
2191                // Serialize each Key back to string
2192                serde_json::to_value(&key)
2193                    .ok()
2194                    .and_then(|v| v.as_str().map(String::from))
2195                    .ok_or(QueryConversionError::field("keys"))
2196            })
2197            .collect::<Result<Vec<_>, _>>()?;
2198
2199        Ok(Self { keys })
2200    }
2201}
2202
2203/// Aggregation function applied within each group.
2204///
2205/// Determines which records to keep from each group and their ordering.
2206///
2207/// # Variants
2208///
2209/// * `MinK` - Returns k records with minimum values (ascending order).
2210///   Use with `Key::Score` to get best matches (lower score = better in Chroma).
2211/// * `MaxK` - Returns k records with maximum values (descending order).
2212///
2213/// # Multi-level Ordering
2214///
2215/// The `keys` field supports multi-level ordering. Records are sorted by
2216/// the first key, then by the second key for ties, and so on.
2217///
2218/// # Examples
2219///
2220/// ```
2221/// use chroma_types::operator::{Aggregate, Key};
2222///
2223/// // Best 3 by score per group
2224/// let agg = Aggregate::MinK {
2225///     keys: vec![Key::Score],
2226///     k: 3,
2227/// };
2228///
2229/// // Best 3 by score, then by date for ties
2230/// let agg = Aggregate::MinK {
2231///     keys: vec![Key::Score, Key::field("date")],
2232///     k: 3,
2233/// };
2234///
2235/// // Top 5 by recency (highest date first)
2236/// let agg = Aggregate::MaxK {
2237///     keys: vec![Key::field("date")],
2238///     k: 5,
2239/// };
2240/// ```
2241#[derive(Clone, Debug, Deserialize, Serialize)]
2242pub enum Aggregate {
2243    /// Returns k records with minimum values (ascending order)
2244    #[serde(rename = "$min_k")]
2245    MinK {
2246        /// Keys for multi-level ordering
2247        keys: Vec<Key>,
2248        /// Number of records to return per group
2249        k: u32,
2250    },
2251    /// Returns k records with maximum values (descending order)
2252    #[serde(rename = "$max_k")]
2253    MaxK {
2254        /// Keys for multi-level ordering
2255        keys: Vec<Key>,
2256        /// Number of records to return per group
2257        k: u32,
2258    },
2259}
2260
2261/// Groups results by metadata keys and aggregates within each group.
2262///
2263/// Results are grouped by the specified metadata keys (like SQL GROUP BY),
2264/// then aggregated within each group using MinK or MaxK ordering.
2265/// The final output is flattened and sorted by score.
2266///
2267/// # Fields
2268///
2269/// * `keys` - Metadata keys to group by (composite grouping)
2270/// * `aggregate` - Aggregation function to apply within each group
2271///
2272/// # Behavior
2273///
2274/// * Missing metadata keys are treated as Null (forming their own group)
2275/// * Empty groups are omitted from results
2276/// * Final output is flattened (group structure not preserved)
2277/// * Results are sorted by score after aggregation
2278///
2279/// # Examples
2280///
2281/// ```
2282/// use chroma_types::operator::{GroupBy, Aggregate, Key};
2283///
2284/// // Top 3 documents per category
2285/// let group_by = GroupBy {
2286///     keys: vec![Key::field("category")],
2287///     aggregate: Some(Aggregate::MinK {
2288///         keys: vec![Key::Score],
2289///         k: 3,
2290///     }),
2291/// };
2292///
2293/// // Top 2 per (category, author) combination
2294/// let group_by = GroupBy {
2295///     keys: vec![Key::field("category"), Key::field("author")],
2296///     aggregate: Some(Aggregate::MinK {
2297///         keys: vec![Key::Score, Key::field("date")],
2298///         k: 2,
2299///     }),
2300/// };
2301/// ```
2302#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2303pub struct GroupBy {
2304    /// Metadata keys to group by
2305    #[serde(default)]
2306    pub keys: Vec<Key>,
2307    /// Aggregation to apply within each group (required when keys is non-empty)
2308    #[serde(default)]
2309    pub aggregate: Option<Aggregate>,
2310}
2311
2312impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2313    type Error = QueryConversionError;
2314
2315    fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2316        match value
2317            .aggregate
2318            .ok_or(QueryConversionError::field("aggregate"))?
2319        {
2320            chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2321                let keys = min_k.keys.into_iter().map(Key::from).collect();
2322                Ok(Aggregate::MinK { keys, k: min_k.k })
2323            }
2324            chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2325                let keys = max_k.keys.into_iter().map(Key::from).collect();
2326                Ok(Aggregate::MaxK { keys, k: max_k.k })
2327            }
2328        }
2329    }
2330}
2331
2332impl From<Aggregate> for chroma_proto::Aggregate {
2333    fn from(value: Aggregate) -> Self {
2334        let aggregate = match value {
2335            Aggregate::MinK { keys, k } => {
2336                chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2337                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2338                    k,
2339                })
2340            }
2341            Aggregate::MaxK { keys, k } => {
2342                chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2343                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2344                    k,
2345                })
2346            }
2347        };
2348
2349        chroma_proto::Aggregate {
2350            aggregate: Some(aggregate),
2351        }
2352    }
2353}
2354
2355impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2356    type Error = QueryConversionError;
2357
2358    fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2359        let keys = value.keys.into_iter().map(Key::from).collect();
2360        let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2361
2362        Ok(Self { keys, aggregate })
2363    }
2364}
2365
2366impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2367    type Error = QueryConversionError;
2368
2369    fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2370        let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2371        let aggregate = value.aggregate.map(Into::into);
2372
2373        Ok(Self { keys, aggregate })
2374    }
2375}
2376
2377/// A single search result record.
2378///
2379/// Contains the document ID and optionally document content, embeddings, metadata,
2380/// and search score based on what was selected in the search query.
2381///
2382/// # Fields
2383///
2384/// * `id` - Document ID (always present)
2385/// * `document` - Document text content (if selected)
2386/// * `embedding` - Vector embedding (if selected)
2387/// * `metadata` - Document metadata (if selected)
2388/// * `score` - Search score (present when ranking is used, lower = better match)
2389///
2390/// # Examples
2391///
2392/// ```
2393/// use chroma_types::operator::SearchRecord;
2394///
2395/// fn process_results(records: Vec<SearchRecord>) {
2396///     for record in records {
2397///         println!("ID: {}", record.id);
2398///         
2399///         if let Some(score) = record.score {
2400///             println!("  Score: {:.3}", score);
2401///         }
2402///         
2403///         if let Some(doc) = record.document {
2404///             println!("  Document: {}", doc);
2405///         }
2406///         
2407///         if let Some(meta) = record.metadata {
2408///             println!("  Metadata: {:?}", meta);
2409///         }
2410///     }
2411/// }
2412/// ```
2413#[derive(Clone, Debug, Deserialize, Serialize)]
2414#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2415pub struct SearchRecord {
2416    pub id: String,
2417    pub document: Option<String>,
2418    pub embedding: Option<Vec<f32>>,
2419    pub metadata: Option<Metadata>,
2420    pub score: Option<f32>,
2421}
2422
2423impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2424    type Error = QueryConversionError;
2425
2426    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2427        Ok(Self {
2428            id: value.id,
2429            document: value.document,
2430            embedding: value
2431                .embedding
2432                .map(|vec| vec.try_into().map(|(v, _)| v))
2433                .transpose()?,
2434            metadata: value.metadata.map(TryInto::try_into).transpose()?,
2435            score: value.score,
2436        })
2437    }
2438}
2439
2440impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2441    type Error = QueryConversionError;
2442
2443    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2444        Ok(Self {
2445            id: value.id,
2446            document: value.document,
2447            embedding: value
2448                .embedding
2449                .map(|embedding| {
2450                    let embedding_dimension = embedding.len();
2451                    chroma_proto::Vector::try_from((
2452                        embedding,
2453                        ScalarEncoding::FLOAT32,
2454                        embedding_dimension,
2455                    ))
2456                })
2457                .transpose()?,
2458            metadata: value.metadata.map(Into::into),
2459            score: value.score,
2460        })
2461    }
2462}
2463
2464/// Results for a single search payload.
2465///
2466/// Contains all matching records for one search query.
2467///
2468/// # Fields
2469///
2470/// * `records` - Vector of search records, ordered by score (ascending)
2471///
2472/// # Examples
2473///
2474/// ```
2475/// use chroma_types::operator::{SearchPayloadResult, SearchRecord};
2476///
2477/// fn process_search_result(result: SearchPayloadResult) {
2478///     println!("Found {} results", result.records.len());
2479///     
2480///     for (i, record) in result.records.iter().enumerate() {
2481///         println!("{}. {} (score: {:?})", i + 1, record.id, record.score);
2482///     }
2483/// }
2484/// ```
2485#[derive(Clone, Debug, Default)]
2486pub struct SearchPayloadResult {
2487    pub records: Vec<SearchRecord>,
2488}
2489
2490impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2491    type Error = QueryConversionError;
2492
2493    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2494        Ok(Self {
2495            records: value
2496                .records
2497                .into_iter()
2498                .map(TryInto::try_into)
2499                .collect::<Result<_, _>>()?,
2500        })
2501    }
2502}
2503
2504impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2505    type Error = QueryConversionError;
2506
2507    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2508        Ok(Self {
2509            records: value
2510                .records
2511                .into_iter()
2512                .map(TryInto::try_into)
2513                .collect::<Result<Vec<_>, _>>()?,
2514        })
2515    }
2516}
2517
2518/// Results from a batch search operation.
2519///
2520/// Contains results for each search payload in the batch, maintaining the same order
2521/// as the input searches.
2522///
2523/// # Fields
2524///
2525/// * `results` - Results for each search payload (indexed by search position)
2526/// * `pulled_log_bytes` - Total bytes pulled from log (for internal metrics)
2527///
2528/// # Examples
2529///
2530/// ## Single search
2531///
2532/// ```
2533/// use chroma_types::operator::SearchResult;
2534///
2535/// fn process_single_search(result: SearchResult) {
2536///     // Single search, so results[0] contains our records
2537///     let records = &result.results[0].records;
2538///     
2539///     for record in records {
2540///         println!("{}: score={:?}", record.id, record.score);
2541///     }
2542/// }
2543/// ```
2544///
2545/// ## Batch search
2546///
2547/// ```
2548/// use chroma_types::operator::SearchResult;
2549///
2550/// fn process_batch_search(result: SearchResult) {
2551///     // Multiple searches in batch
2552///     for (i, search_result) in result.results.iter().enumerate() {
2553///         println!("\nSearch {}:", i + 1);
2554///         for record in &search_result.records {
2555///             println!("  {}: score={:?}", record.id, record.score);
2556///         }
2557///     }
2558/// }
2559/// ```
2560#[derive(Clone, Debug)]
2561pub struct SearchResult {
2562    pub results: Vec<SearchPayloadResult>,
2563    pub pulled_log_bytes: u64,
2564}
2565
2566impl SearchResult {
2567    pub fn size_bytes(&self) -> u64 {
2568        self.results
2569            .iter()
2570            .flat_map(|result| {
2571                result.records.iter().map(|record| {
2572                    (record.id.len()
2573                        + record
2574                            .document
2575                            .as_ref()
2576                            .map(|doc| doc.len())
2577                            .unwrap_or_default()
2578                        + record
2579                            .embedding
2580                            .as_ref()
2581                            .map(|emb| size_of_val(&emb[..]))
2582                            .unwrap_or_default()
2583                        + record
2584                            .metadata
2585                            .as_ref()
2586                            .map(logical_size_of_metadata)
2587                            .unwrap_or_default()
2588                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
2589                        as u64
2590                })
2591            })
2592            .sum()
2593    }
2594}
2595
2596impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2597    type Error = QueryConversionError;
2598
2599    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2600        Ok(Self {
2601            results: value
2602                .results
2603                .into_iter()
2604                .map(TryInto::try_into)
2605                .collect::<Result<_, _>>()?,
2606            pulled_log_bytes: value.pulled_log_bytes,
2607        })
2608    }
2609}
2610
2611impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2612    type Error = QueryConversionError;
2613
2614    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2615        Ok(Self {
2616            results: value
2617                .results
2618                .into_iter()
2619                .map(TryInto::try_into)
2620                .collect::<Result<Vec<_>, _>>()?,
2621            pulled_log_bytes: value.pulled_log_bytes,
2622        })
2623    }
2624}
2625
2626/// Reciprocal Rank Fusion (RRF) - combines multiple ranking strategies.
2627///
2628/// RRF is ideal for hybrid search where you want to merge results from different
2629/// ranking methods (e.g., dense and sparse embeddings) with different score scales.
2630/// It uses rank positions instead of raw scores, making it scale-agnostic.
2631///
2632/// # Formula
2633///
2634/// ```text
2635/// score = -Σ(weight_i / (k + rank_i))
2636/// ```
2637///
2638/// Where:
2639/// - `weight_i` = weight for ranking i (default: 1.0)
2640/// - `rank_i` = rank position from ranking i (0, 1, 2...)
2641/// - `k` = smoothing parameter (default: 60)
2642///
2643/// Score is negative because Chroma uses ascending order (lower = better).
2644///
2645/// # Arguments
2646///
2647/// * `ranks` - List of ranking expressions (must have `return_rank=true`)
2648/// * `k` - Smoothing parameter (None = 60). Higher values reduce emphasis on top ranks.
2649/// * `weights` - Weight for each ranking (None = all 1.0)
2650/// * `normalize` - If true, normalize weights to sum to 1.0
2651///
2652/// # Returns
2653///
2654/// A combined RankExpr or an error if:
2655/// - `ranks` is empty
2656/// - `weights` length doesn't match `ranks` length
2657/// - `weights` sum to zero when normalizing
2658///
2659/// # Examples
2660///
2661/// ## Basic RRF with default parameters
2662///
2663/// ```
2664/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2665///
2666/// let dense = RankExpr::Knn {
2667///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2668///     key: Key::Embedding,
2669///     limit: 200,
2670///     default: None,
2671///     return_rank: true, // Required for RRF
2672/// };
2673///
2674/// let sparse = RankExpr::Knn {
2675///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2676///     key: Key::field("sparse_embedding"),
2677///     limit: 200,
2678///     default: None,
2679///     return_rank: true, // Required for RRF
2680/// };
2681///
2682/// // Equal weights, k=60 (defaults)
2683/// let combined = rrf(vec![dense, sparse], None, None, false).unwrap();
2684/// ```
2685///
2686/// ## RRF with custom weights
2687///
2688/// ```
2689/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2690///
2691/// # let dense = RankExpr::Knn {
2692/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2693/// #     key: Key::Embedding,
2694/// #     limit: 200,
2695/// #     default: None,
2696/// #     return_rank: true,
2697/// # };
2698/// # let sparse = RankExpr::Knn {
2699/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2700/// #     key: Key::field("sparse_embedding"),
2701/// #     limit: 200,
2702/// #     default: None,
2703/// #     return_rank: true,
2704/// # };
2705/// // 70% dense, 30% sparse
2706/// let combined = rrf(
2707///     vec![dense, sparse],
2708///     Some(60),
2709///     Some(vec![0.7, 0.3]),
2710///     false,
2711/// ).unwrap();
2712/// ```
2713///
2714/// ## RRF with normalized weights
2715///
2716/// ```
2717/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2718///
2719/// # let dense = RankExpr::Knn {
2720/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2721/// #     key: Key::Embedding,
2722/// #     limit: 200,
2723/// #     default: None,
2724/// #     return_rank: true,
2725/// # };
2726/// # let sparse = RankExpr::Knn {
2727/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2728/// #     key: Key::field("sparse_embedding"),
2729/// #     limit: 200,
2730/// #     default: None,
2731/// #     return_rank: true,
2732/// # };
2733/// // Weights [75, 25] normalized to [0.75, 0.25]
2734/// let combined = rrf(
2735///     vec![dense, sparse],
2736///     Some(60),
2737///     Some(vec![75.0, 25.0]),
2738///     true, // normalize
2739/// ).unwrap();
2740/// ```
2741///
2742/// ## Adjusting the k parameter
2743///
2744/// ```
2745/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2746///
2747/// # let dense = RankExpr::Knn {
2748/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2749/// #     key: Key::Embedding,
2750/// #     limit: 200,
2751/// #     default: None,
2752/// #     return_rank: true,
2753/// # };
2754/// # let sparse = RankExpr::Knn {
2755/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2756/// #     key: Key::field("sparse_embedding"),
2757/// #     limit: 200,
2758/// #     default: None,
2759/// #     return_rank: true,
2760/// # };
2761/// // Small k (10) = heavy emphasis on top ranks
2762/// let top_heavy = rrf(vec![dense.clone(), sparse.clone()], Some(10), None, false).unwrap();
2763///
2764/// // Default k (60) = balanced
2765/// let balanced = rrf(vec![dense.clone(), sparse.clone()], Some(60), None, false).unwrap();
2766///
2767/// // Large k (200) = more uniform weighting
2768/// let uniform = rrf(vec![dense, sparse], Some(200), None, false).unwrap();
2769/// ```
2770pub fn rrf(
2771    ranks: Vec<RankExpr>,
2772    k: Option<u32>,
2773    weights: Option<Vec<f32>>,
2774    normalize: bool,
2775) -> Result<RankExpr, QueryConversionError> {
2776    let k = k.unwrap_or(60);
2777
2778    if ranks.is_empty() {
2779        return Err(QueryConversionError::validation(
2780            "RRF requires at least one rank expression",
2781        ));
2782    }
2783
2784    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2785
2786    if weights.len() != ranks.len() {
2787        return Err(QueryConversionError::validation(format!(
2788            "RRF weights length ({}) must match ranks length ({})",
2789            weights.len(),
2790            ranks.len()
2791        )));
2792    }
2793
2794    let weights = if normalize {
2795        let sum: f32 = weights.iter().sum();
2796        if sum == 0.0 {
2797            return Err(QueryConversionError::validation(
2798                "RRF weights sum to zero, cannot normalize",
2799            ));
2800        }
2801        weights.into_iter().map(|w| w / sum).collect()
2802    } else {
2803        weights
2804    };
2805
2806    let terms: Vec<RankExpr> = weights
2807        .into_iter()
2808        .zip(ranks)
2809        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2810        .collect();
2811
2812    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
2813    // Using unwrap_or_else as defensive programming to avoid panic.
2814    let sum = terms
2815        .into_iter()
2816        .reduce(|a, b| a + b)
2817        .unwrap_or(RankExpr::Value(0.0));
2818    Ok(-sum)
2819}
2820
2821#[cfg(test)]
2822mod tests {
2823    use super::*;
2824
2825    #[test]
2826    fn test_key_from_string() {
2827        // Test predefined keys
2828        assert_eq!(Key::from("#document"), Key::Document);
2829        assert_eq!(Key::from("#embedding"), Key::Embedding);
2830        assert_eq!(Key::from("#metadata"), Key::Metadata);
2831        assert_eq!(Key::from("#score"), Key::Score);
2832
2833        // Test metadata field keys
2834        assert_eq!(
2835            Key::from("custom_field"),
2836            Key::MetadataField("custom_field".to_string())
2837        );
2838        assert_eq!(
2839            Key::from("author"),
2840            Key::MetadataField("author".to_string())
2841        );
2842
2843        // Test String variant
2844        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2845        assert_eq!(
2846            Key::from("year".to_string()),
2847            Key::MetadataField("year".to_string())
2848        );
2849    }
2850
2851    #[test]
2852    fn test_query_vector_dense_proto_conversion() {
2853        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2854        let query_vector = QueryVector::Dense(dense_vec.clone());
2855
2856        // Convert to proto
2857        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2858
2859        // Convert back
2860        let converted: QueryVector = proto.try_into().unwrap();
2861
2862        assert_eq!(converted, query_vector);
2863        if let QueryVector::Dense(v) = converted {
2864            assert_eq!(v, dense_vec);
2865        } else {
2866            panic!("Expected dense vector");
2867        }
2868    }
2869
2870    #[test]
2871    fn test_query_vector_sparse_proto_conversion() {
2872        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2873        let query_vector = QueryVector::Sparse(sparse.clone());
2874
2875        // Convert to proto
2876        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2877
2878        // Convert back
2879        let converted: QueryVector = proto.try_into().unwrap();
2880
2881        assert_eq!(converted, query_vector);
2882        if let QueryVector::Sparse(s) = converted {
2883            assert_eq!(s, sparse);
2884        } else {
2885            panic!("Expected sparse vector");
2886        }
2887    }
2888
2889    #[test]
2890    fn test_filter_json_deserialization() {
2891        // For the new search API, deserialization treats the entire JSON as a where clause
2892
2893        // Test 1: Simple direct metadata comparison
2894        let simple_where = r#"{"author": "John Doe"}"#;
2895        let filter: Filter = serde_json::from_str(simple_where).unwrap();
2896        assert_eq!(filter.query_ids, None);
2897        assert!(filter.where_clause.is_some());
2898
2899        // Test 2: ID filter using #id with $in operator
2900        let id_filter_json = serde_json::json!({
2901            "#id": {
2902                "$in": ["doc1", "doc2", "doc3"]
2903            }
2904        });
2905        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2906        assert_eq!(filter.query_ids, None);
2907        assert!(filter.where_clause.is_some());
2908
2909        // Test 3: Complex nested expression with AND, OR, and various operators
2910        let complex_json = serde_json::json!({
2911            "$and": [
2912                {
2913                    "#id": {
2914                        "$in": ["doc1", "doc2", "doc3"]
2915                    }
2916                },
2917                {
2918                    "$or": [
2919                        {
2920                            "author": {
2921                                "$eq": "John Doe"
2922                            }
2923                        },
2924                        {
2925                            "author": {
2926                                "$eq": "Jane Smith"
2927                            }
2928                        }
2929                    ]
2930                },
2931                {
2932                    "year": {
2933                        "$gte": 2020
2934                    }
2935                },
2936                {
2937                    "tags": {
2938                        "$contains": "machine-learning"
2939                    }
2940                }
2941            ]
2942        });
2943
2944        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2945        assert_eq!(filter.query_ids, None);
2946        assert!(filter.where_clause.is_some());
2947
2948        // Verify the structure
2949        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2950            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2951            assert_eq!(composite.children.len(), 4);
2952
2953            // Check that the second child is an OR
2954            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2955                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2956                assert_eq!(or_composite.children.len(), 2);
2957            } else {
2958                panic!("Expected OR composite in second child");
2959            }
2960        } else {
2961            panic!("Expected AND composite where clause");
2962        }
2963
2964        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
2965        let mixed_operators_json = serde_json::json!({
2966            "$and": [
2967                {
2968                    "status": {
2969                        "$ne": "deleted"
2970                    }
2971                },
2972                {
2973                    "score": {
2974                        "$gt": 0.5
2975                    }
2976                },
2977                {
2978                    "score": {
2979                        "$lt": 0.9
2980                    }
2981                },
2982                {
2983                    "priority": {
2984                        "$lte": 10
2985                    }
2986                }
2987            ]
2988        });
2989
2990        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2991        assert_eq!(filter.query_ids, None);
2992        assert!(filter.where_clause.is_some());
2993
2994        // Test 5: Deeply nested expression
2995        let deeply_nested_json = serde_json::json!({
2996            "$or": [
2997                {
2998                    "$and": [
2999                        {
3000                            "#id": {
3001                                "$in": ["id1", "id2"]
3002                            }
3003                        },
3004                        {
3005                            "$or": [
3006                                {
3007                                    "category": "tech"
3008                                },
3009                                {
3010                                    "category": "science"
3011                                }
3012                            ]
3013                        }
3014                    ]
3015                },
3016                {
3017                    "$and": [
3018                        {
3019                            "author": "Admin"
3020                        },
3021                        {
3022                            "published": true
3023                        }
3024                    ]
3025                }
3026            ]
3027        });
3028
3029        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3030        assert_eq!(filter.query_ids, None);
3031        assert!(filter.where_clause.is_some());
3032
3033        // Verify it's an OR at the top level
3034        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3035            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3036            assert_eq!(composite.children.len(), 2);
3037
3038            // Both children should be AND composites
3039            for child in &composite.children {
3040                if let crate::metadata::Where::Composite(and_composite) = child {
3041                    assert_eq!(
3042                        and_composite.operator,
3043                        crate::metadata::BooleanOperator::And
3044                    );
3045                } else {
3046                    panic!("Expected AND composite in OR children");
3047                }
3048            }
3049        } else {
3050            panic!("Expected OR composite at top level");
3051        }
3052
3053        // Test 6: Single ID filter (edge case)
3054        let single_id_json = serde_json::json!({
3055            "#id": {
3056                "$eq": "single-doc-id"
3057            }
3058        });
3059
3060        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3061        assert_eq!(filter.query_ids, None);
3062        assert!(filter.where_clause.is_some());
3063
3064        // Test 7: Empty object should create empty filter
3065        let empty_json = serde_json::json!({});
3066        let filter: Filter = serde_json::from_value(empty_json).unwrap();
3067        assert_eq!(filter.query_ids, None);
3068        // Empty object results in None where_clause
3069        assert_eq!(filter.where_clause, None);
3070
3071        // Test 8: Combining #id filter with $not_contains and numeric comparisons
3072        let advanced_json = serde_json::json!({
3073            "$and": [
3074                {
3075                    "#id": {
3076                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3077                    }
3078                },
3079                {
3080                    "tags": {
3081                        "$not_contains": "deprecated"
3082                    }
3083                },
3084                {
3085                    "$or": [
3086                        {
3087                            "$and": [
3088                                {
3089                                    "confidence": {
3090                                        "$gte": 0.8
3091                                    }
3092                                },
3093                                {
3094                                    "verified": true
3095                                }
3096                            ]
3097                        },
3098                        {
3099                            "$and": [
3100                                {
3101                                    "confidence": {
3102                                        "$gte": 0.6
3103                                    }
3104                                },
3105                                {
3106                                    "confidence": {
3107                                        "$lt": 0.8
3108                                    }
3109                                },
3110                                {
3111                                    "reviews": {
3112                                        "$gte": 5
3113                                    }
3114                                }
3115                            ]
3116                        }
3117                    ]
3118                }
3119            ]
3120        });
3121
3122        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3123        assert_eq!(filter.query_ids, None);
3124        assert!(filter.where_clause.is_some());
3125
3126        // Verify top-level structure
3127        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3128            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3129            assert_eq!(composite.children.len(), 3);
3130        } else {
3131            panic!("Expected AND composite at top level");
3132        }
3133    }
3134
3135    #[test]
3136    fn test_limit_json_serialization() {
3137        let limit = Limit {
3138            offset: 10,
3139            limit: Some(20),
3140        };
3141
3142        let json = serde_json::to_string(&limit).unwrap();
3143        let deserialized: Limit = serde_json::from_str(&json).unwrap();
3144
3145        assert_eq!(deserialized.offset, limit.offset);
3146        assert_eq!(deserialized.limit, limit.limit);
3147    }
3148
3149    #[test]
3150    fn test_query_vector_json_serialization() {
3151        // Test dense vector
3152        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3153        let json = serde_json::to_string(&dense).unwrap();
3154        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3155        assert_eq!(deserialized, dense);
3156
3157        // Test sparse vector
3158        let sparse =
3159            QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3160        let json = serde_json::to_string(&sparse).unwrap();
3161        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3162        assert_eq!(deserialized, sparse);
3163    }
3164
3165    #[test]
3166    fn test_select_key_json_serialization() {
3167        use std::collections::HashSet;
3168
3169        // Test predefined keys
3170        let doc_key = Key::Document;
3171        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3172
3173        let embed_key = Key::Embedding;
3174        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3175
3176        let meta_key = Key::Metadata;
3177        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3178
3179        let score_key = Key::Score;
3180        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3181
3182        // Test metadata key
3183        let custom_key = Key::MetadataField("custom_key".to_string());
3184        assert_eq!(
3185            serde_json::to_string(&custom_key).unwrap(),
3186            "\"custom_key\""
3187        );
3188
3189        // Test deserialization
3190        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3191        assert!(matches!(deserialized, Key::Document));
3192
3193        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3194        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3195
3196        // Test Select struct with multiple keys
3197        let mut keys = HashSet::new();
3198        keys.insert(Key::Document);
3199        keys.insert(Key::Embedding);
3200        keys.insert(Key::MetadataField("author".to_string()));
3201
3202        let select = Select { keys };
3203        let json = serde_json::to_string(&select).unwrap();
3204        let deserialized: Select = serde_json::from_str(&json).unwrap();
3205
3206        assert_eq!(deserialized.keys.len(), 3);
3207        assert!(deserialized.keys.contains(&Key::Document));
3208        assert!(deserialized.keys.contains(&Key::Embedding));
3209        assert!(deserialized
3210            .keys
3211            .contains(&Key::MetadataField("author".to_string())));
3212    }
3213
3214    #[test]
3215    fn test_merge_basic_integers() {
3216        use std::cmp::Reverse;
3217
3218        let merge = Merge { k: 5 };
3219
3220        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
3221        let input = vec![
3222            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3223            vec![Reverse(2), Reverse(5), Reverse(8)],
3224            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3225        ];
3226
3227        let result = merge.merge(input);
3228
3229        // Should get top-5 smallest values (largest Reverse values)
3230        assert_eq!(result.len(), 5);
3231        assert_eq!(
3232            result,
3233            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3234        );
3235    }
3236
3237    #[test]
3238    fn test_merge_u32_descending() {
3239        let merge = Merge { k: 6 };
3240
3241        // Regular u32 in descending order (largest first)
3242        let input = vec![
3243            vec![100u32, 75, 50, 25],
3244            vec![90, 60, 30],
3245            vec![95, 85, 70, 40, 10],
3246        ];
3247
3248        let result = merge.merge(input);
3249
3250        // Should get top-6 largest u32 values
3251        assert_eq!(result.len(), 6);
3252        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3253    }
3254
3255    #[test]
3256    fn test_merge_i32_descending() {
3257        let merge = Merge { k: 5 };
3258
3259        // i32 values in descending order (including negatives)
3260        let input = vec![
3261            vec![50i32, 10, -10, -50],
3262            vec![30, 0, -30],
3263            vec![40, 20, -20, -40],
3264        ];
3265
3266        let result = merge.merge(input);
3267
3268        // Should get top-5 largest i32 values
3269        assert_eq!(result.len(), 5);
3270        assert_eq!(result, vec![50, 40, 30, 20, 10]);
3271    }
3272
3273    #[test]
3274    fn test_merge_with_duplicates() {
3275        let merge = Merge { k: 10 };
3276
3277        // Input with duplicates using regular u32 in descending order
3278        let input = vec![
3279            vec![100u32, 80, 80, 60, 40],
3280            vec![90, 80, 50, 30],
3281            vec![100, 70, 60, 20],
3282        ];
3283
3284        let result = merge.merge(input);
3285
3286        // Duplicates should be removed
3287        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3288    }
3289
3290    #[test]
3291    fn test_merge_empty_vectors() {
3292        let merge = Merge { k: 5 };
3293
3294        // All empty with u32
3295        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3296        let result = merge.merge(input);
3297        assert_eq!(result.len(), 0);
3298
3299        // Some empty, some with data (u64)
3300        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3301        let result = merge.merge(input);
3302        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3303
3304        // Single non-empty vector (i32)
3305        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3306        let result = merge.merge(input);
3307        assert_eq!(result, vec![100, 50, 25]);
3308    }
3309
3310    #[test]
3311    fn test_merge_k_boundary_conditions() {
3312        // k = 0 with u32
3313        let merge = Merge { k: 0 };
3314        let input = vec![vec![100u32, 50], vec![75, 25]];
3315        let result = merge.merge(input);
3316        assert_eq!(result.len(), 0);
3317
3318        // k = 1 with i64
3319        let merge = Merge { k: 1 };
3320        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3321        let result = merge.merge(input);
3322        assert_eq!(result, vec![1000]);
3323
3324        // k larger than total unique elements with u128
3325        let merge = Merge { k: 100 };
3326        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3327        let result = merge.merge(input);
3328        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3329    }
3330
3331    #[test]
3332    fn test_merge_with_strings() {
3333        let merge = Merge { k: 4 };
3334
3335        // Strings must be sorted in descending order (largest first) for the max heap merge
3336        let input = vec![
3337            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3338            vec!["elephant".to_string(), "banana".to_string()],
3339            vec!["fish".to_string(), "cat".to_string()],
3340        ];
3341
3342        let result = merge.merge(input);
3343
3344        // Should get top-4 lexicographically largest strings
3345        assert_eq!(
3346            result,
3347            vec![
3348                "zebra".to_string(),
3349                "fish".to_string(),
3350                "elephant".to_string(),
3351                "dog".to_string()
3352            ]
3353        );
3354    }
3355
3356    #[test]
3357    fn test_merge_with_custom_struct() {
3358        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3359        struct Score {
3360            value: i32,
3361            id: String,
3362        }
3363
3364        let merge = Merge { k: 3 };
3365
3366        // Custom structs sorted by value (descending), then by id
3367        let input = vec![
3368            vec![
3369                Score {
3370                    value: 100,
3371                    id: "a".to_string(),
3372                },
3373                Score {
3374                    value: 80,
3375                    id: "b".to_string(),
3376                },
3377                Score {
3378                    value: 60,
3379                    id: "c".to_string(),
3380                },
3381            ],
3382            vec![
3383                Score {
3384                    value: 90,
3385                    id: "d".to_string(),
3386                },
3387                Score {
3388                    value: 70,
3389                    id: "e".to_string(),
3390                },
3391            ],
3392            vec![
3393                Score {
3394                    value: 95,
3395                    id: "f".to_string(),
3396                },
3397                Score {
3398                    value: 85,
3399                    id: "g".to_string(),
3400                },
3401            ],
3402        ];
3403
3404        let result = merge.merge(input);
3405
3406        assert_eq!(result.len(), 3);
3407        assert_eq!(
3408            result[0],
3409            Score {
3410                value: 100,
3411                id: "a".to_string()
3412            }
3413        );
3414        assert_eq!(
3415            result[1],
3416            Score {
3417                value: 95,
3418                id: "f".to_string()
3419            }
3420        );
3421        assert_eq!(
3422            result[2],
3423            Score {
3424                value: 90,
3425                id: "d".to_string()
3426            }
3427        );
3428    }
3429
3430    #[test]
3431    fn test_merge_preserves_order() {
3432        use std::cmp::Reverse;
3433
3434        let merge = Merge { k: 10 };
3435
3436        // For Reverse, smaller inner values are "larger" in ordering
3437        // So vectors should be sorted with smallest inner values first
3438        let input = vec![
3439            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3440            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3441            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3442        ];
3443
3444        let result = merge.merge(input);
3445
3446        // Verify output maintains order - should be sorted by Reverse ordering
3447        // which means ascending inner values
3448        for i in 1..result.len() {
3449            assert!(
3450                result[i - 1] >= result[i],
3451                "Output should be in descending Reverse order"
3452            );
3453            assert!(
3454                result[i - 1].0 <= result[i].0,
3455                "Inner values should be in ascending order"
3456            );
3457        }
3458
3459        // Check we got the right elements
3460        assert_eq!(
3461            result,
3462            vec![
3463                Reverse(1),
3464                Reverse(2),
3465                Reverse(3),
3466                Reverse(4),
3467                Reverse(5),
3468                Reverse(6),
3469                Reverse(7),
3470                Reverse(8),
3471                Reverse(9),
3472                Reverse(10)
3473            ]
3474        );
3475    }
3476
3477    #[test]
3478    fn test_merge_single_vector() {
3479        let merge = Merge { k: 3 };
3480
3481        // Single vector input with u64
3482        let input = vec![vec![1000u64, 800, 600, 400, 200]];
3483
3484        let result = merge.merge(input);
3485
3486        assert_eq!(result, vec![1000, 800, 600]);
3487    }
3488
3489    #[test]
3490    fn test_merge_all_same_values() {
3491        let merge = Merge { k: 5 };
3492
3493        // All vectors contain the same value (using i16)
3494        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3495
3496        let result = merge.merge(input);
3497
3498        // Should deduplicate to single value
3499        assert_eq!(result, vec![42]);
3500    }
3501
3502    #[test]
3503    fn test_merge_mixed_types_sizes() {
3504        // Test with usize (common in real usage)
3505        let merge = Merge { k: 4 };
3506        let input = vec![
3507            vec![1000usize, 500, 100],
3508            vec![800, 300],
3509            vec![900, 600, 200],
3510        ];
3511        let result = merge.merge(input);
3512        assert_eq!(result, vec![1000, 900, 800, 600]);
3513
3514        // Test with negative integers (i32)
3515        let merge = Merge { k: 5 };
3516        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3517        let result = merge.merge(input);
3518        assert_eq!(result, vec![15, 10, 5, 0, -5]);
3519    }
3520
3521    #[test]
3522    fn test_aggregate_json_serialization() {
3523        // Test MinK serialization
3524        let min_k = Aggregate::MinK {
3525            keys: vec![Key::Score, Key::field("date")],
3526            k: 3,
3527        };
3528        let json = serde_json::to_value(&min_k).unwrap();
3529        assert!(json.get("$min_k").is_some());
3530        assert_eq!(json["$min_k"]["k"], 3);
3531
3532        // Test MinK deserialization
3533        let min_k_json = serde_json::json!({
3534            "$min_k": {
3535                "keys": ["#score", "date"],
3536                "k": 5
3537            }
3538        });
3539        let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3540        match deserialized {
3541            Aggregate::MinK { keys, k } => {
3542                assert_eq!(k, 5);
3543                assert_eq!(keys.len(), 2);
3544                assert_eq!(keys[0], Key::Score);
3545                assert_eq!(keys[1], Key::field("date"));
3546            }
3547            _ => panic!("Expected MinK"),
3548        }
3549
3550        // Test MaxK serialization
3551        let max_k = Aggregate::MaxK {
3552            keys: vec![Key::field("timestamp")],
3553            k: 10,
3554        };
3555        let json = serde_json::to_value(&max_k).unwrap();
3556        assert!(json.get("$max_k").is_some());
3557        assert_eq!(json["$max_k"]["k"], 10);
3558
3559        // Test MaxK deserialization
3560        let max_k_json = serde_json::json!({
3561            "$max_k": {
3562                "keys": ["timestamp"],
3563                "k": 2
3564            }
3565        });
3566        let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3567        match deserialized {
3568            Aggregate::MaxK { keys, k } => {
3569                assert_eq!(k, 2);
3570                assert_eq!(keys.len(), 1);
3571                assert_eq!(keys[0], Key::field("timestamp"));
3572            }
3573            _ => panic!("Expected MaxK"),
3574        }
3575    }
3576
3577    #[test]
3578    fn test_group_by_json_serialization() {
3579        // Test GroupBy with MinK
3580        let group_by = GroupBy {
3581            keys: vec![Key::field("category"), Key::field("author")],
3582            aggregate: Some(Aggregate::MinK {
3583                keys: vec![Key::Score],
3584                k: 3,
3585            }),
3586        };
3587
3588        let json = serde_json::to_value(&group_by).unwrap();
3589        assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3590        assert!(json["aggregate"]["$min_k"].is_object());
3591
3592        // Test roundtrip
3593        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3594        assert_eq!(deserialized.keys.len(), 2);
3595        assert_eq!(deserialized.keys[0], Key::field("category"));
3596        assert_eq!(deserialized.keys[1], Key::field("author"));
3597        assert!(deserialized.aggregate.is_some());
3598
3599        // Test empty GroupBy
3600        let empty_group_by = GroupBy::default();
3601        let json = serde_json::to_value(&empty_group_by).unwrap();
3602        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3603        assert!(deserialized.keys.is_empty());
3604        assert!(deserialized.aggregate.is_none());
3605
3606        // Test deserialization from JSON
3607        let json = serde_json::json!({
3608            "keys": ["category"],
3609            "aggregate": {
3610                "$max_k": {
3611                    "keys": ["#score", "priority"],
3612                    "k": 5
3613                }
3614            }
3615        });
3616        let group_by: GroupBy = serde_json::from_value(json).unwrap();
3617        assert_eq!(group_by.keys.len(), 1);
3618        assert_eq!(group_by.keys[0], Key::field("category"));
3619        match group_by.aggregate {
3620            Some(Aggregate::MaxK { keys, k }) => {
3621                assert_eq!(k, 5);
3622                assert_eq!(keys.len(), 2);
3623                assert_eq!(keys[0], Key::Score);
3624            }
3625            _ => panic!("Expected MaxK aggregate"),
3626        }
3627    }
3628}