Skip to main content

chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::{Hash, Hasher},
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    ContainsOperator, DocumentExpression, DocumentOperator, Metadata, MetadataComparison,
15    MetadataExpression, MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding,
16    SetOperator, SparseVector, Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// Filter the search results.
122///
123/// Combines document ID filtering with metadata and document content predicates.
124/// For the Search API, use `where_clause` with Key expressions.
125///
126/// # Fields
127///
128/// * `query_ids` - Optional list of document IDs to filter (legacy, prefer Where expressions)
129/// * `where_clause` - Predicate on document metadata, content, or IDs
130///
131/// # Examples
132///
133/// ## Simple metadata filter
134///
135/// ```
136/// use chroma_types::operator::{Filter, Key};
137///
138/// let filter = Filter {
139///     query_ids: None,
140///     where_clause: Some(Key::field("status").eq("published")),
141/// };
142/// ```
143///
144/// ## Combined filters
145///
146/// ```
147/// use chroma_types::operator::{Filter, Key};
148///
149/// let filter = Filter {
150///     query_ids: None,
151///     where_clause: Some(
152///         Key::field("status").eq("published")
153///             & Key::field("year").gte(2020)
154///             & Key::field("category").is_in(vec!["tech", "science"])
155///     ),
156/// };
157/// ```
158///
159/// ## Document content filter
160///
161/// ```
162/// use chroma_types::operator::{Filter, Key};
163///
164/// let filter = Filter {
165///     query_ids: None,
166///     where_clause: Some(Key::Document.contains("machine learning")),
167/// };
168/// ```
169#[derive(Clone, Debug, Default)]
170pub struct Filter {
171    pub query_ids: Option<Vec<String>>,
172    pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        // For the search API, serialize directly as the where clause (or empty object if None)
181        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
182
183        match (&self.query_ids, &self.where_clause) {
184            (None, None) => {
185                // No filter at all - serialize empty object
186                let map = serializer.serialize_map(Some(0))?;
187                map.end()
188            }
189            (None, Some(where_clause)) => {
190                // Only where clause - serialize it directly
191                where_clause.serialize(serializer)
192            }
193            (Some(ids), None) => {
194                // Only query_ids - create Where clause: Key::ID.is_in(ids)
195                let id_where = Where::Metadata(MetadataExpression {
196                    key: "#id".to_string(),
197                    comparison: MetadataComparison::Set(
198                        SetOperator::In,
199                        MetadataSetValue::Str(ids.clone()),
200                    ),
201                });
202                id_where.serialize(serializer)
203            }
204            (Some(ids), Some(where_clause)) => {
205                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
206                let id_where = Where::Metadata(MetadataExpression {
207                    key: "#id".to_string(),
208                    comparison: MetadataComparison::Set(
209                        SetOperator::In,
210                        MetadataSetValue::Str(ids.clone()),
211                    ),
212                });
213                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214                combined.serialize(serializer)
215            }
216        }
217    }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        // For the new search API, the entire JSON is the where clause
226        let where_json = Value::deserialize(deserializer)?;
227        let where_clause =
228            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229                None
230            } else {
231                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232            };
233
234        Ok(Filter {
235            query_ids: None, // Always None for new search API
236            where_clause,
237        })
238    }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242    type Error = QueryConversionError;
243
244    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247        let where_clause = match (where_metadata, where_document) {
248            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249            (Some(w), None) | (None, Some(w)) => Some(w),
250            _ => None,
251        };
252
253        Ok(Self {
254            query_ids: value.ids.map(|uids| uids.ids),
255            where_clause,
256        })
257    }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261    type Error = QueryConversionError;
262
263    fn try_from(value: Filter) -> Result<Self, Self::Error> {
264        Ok(Self {
265            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267            where_document: None,
268        })
269    }
270}
271
272/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
273///
274/// # Parameters
275/// - `embedding`: The target embedding to search around
276/// - `fetch`: The number of records to fetch around the target
277#[derive(Clone, Debug)]
278pub struct Knn {
279    pub embedding: Vec<f32>,
280    pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284    fn from(value: KnnBatch) -> Self {
285        value
286            .embeddings
287            .into_iter()
288            .map(|embedding| Knn {
289                embedding,
290                fetch: value.fetch,
291            })
292            .collect()
293    }
294}
295
296/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
297///
298/// # Parameters
299/// - `embedding`: The target embedding to search around
300/// - `fetch`: The number of records to fetch around the target
301#[derive(Clone, Debug)]
302pub struct KnnBatch {
303    pub embeddings: Vec<Vec<f32>>,
304    pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308    type Error = QueryConversionError;
309
310    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311        Ok(Self {
312            embeddings: value
313                .embeddings
314                .into_iter()
315                .map(|vec| vec.try_into().map(|(v, _)| v))
316                .collect::<Result<_, _>>()?,
317            fetch: value.fetch,
318        })
319    }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323    type Error = QueryConversionError;
324
325    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326        Ok(Self {
327            embeddings: value
328                .embeddings
329                .into_iter()
330                .map(|embedding| {
331                    let dim = embedding.len();
332                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333                })
334                .collect::<Result<_, _>>()?,
335            fetch: value.fetch,
336        })
337    }
338}
339
340/// Pagination control for search results.
341///
342/// Controls how many results to return and how many to skip for pagination.
343///
344/// # Fields
345///
346/// * `offset` - Number of results to skip (default: 0)
347/// * `limit` - Maximum results to return (None = no limit)
348///
349/// # Examples
350///
351/// ```
352/// use chroma_types::operator::Limit;
353///
354/// // First page: results 0-9
355/// let limit = Limit {
356///     offset: 0,
357///     limit: Some(10),
358/// };
359///
360/// // Second page: results 10-19
361/// let limit = Limit {
362///     offset: 10,
363///     limit: Some(10),
364/// };
365///
366/// // No limit: all results
367/// let limit = Limit {
368///     offset: 0,
369///     limit: None,
370/// };
371/// ```
372#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374    #[serde(default)]
375    pub offset: u32,
376    #[serde(default)]
377    pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381    fn from(value: chroma_proto::LimitOperator) -> Self {
382        Self {
383            offset: value.offset,
384            limit: value.limit,
385        }
386    }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390    fn from(value: Limit) -> Self {
391        Self {
392            offset: value.offset,
393            limit: value.limit,
394        }
395    }
396}
397
398/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
399#[derive(Clone, Copy, Debug)]
400pub struct RecordMeasure {
401    pub offset_id: u32,
402    pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406    fn eq(&self, other: &Self) -> bool {
407        self.offset_id.eq(&other.offset_id)
408    }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Hash for RecordMeasure {
414    fn hash<H: Hasher>(&self, state: &mut H) {
415        self.offset_id.hash(state);
416    }
417}
418
419impl Ord for RecordMeasure {
420    fn cmp(&self, other: &Self) -> Ordering {
421        self.measure
422            .total_cmp(&other.measure)
423            .then_with(|| self.offset_id.cmp(&other.offset_id))
424    }
425}
426
427impl PartialOrd for RecordMeasure {
428    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
429        Some(self.cmp(other))
430    }
431}
432
433#[derive(Debug, Default)]
434pub struct KnnOutput {
435    pub distances: Vec<RecordMeasure>,
436}
437
438/// The `Merge` operator selects the top records from the batch vectors of records
439/// which are all sorted in descending order. If the same record occurs multiple times
440/// only one copy will remain in the final result.
441///
442/// # Parameters
443/// - `k`: The total number of records to take after merge
444///
445/// # Usage
446/// It can be used to merge the query results from different operators
447#[derive(Clone, Debug)]
448pub struct Merge {
449    pub k: u32,
450}
451
452impl Merge {
453    pub fn merge<M: Clone + Eq + Hash + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
454        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
455
456        let mut max_heap = batch_iters
457            .iter_mut()
458            .enumerate()
459            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
460            .collect::<BinaryHeap<_>>();
461
462        let mut seen = HashSet::with_capacity(self.k as usize);
463        let mut fusion = Vec::with_capacity(self.k as usize);
464        while let Some((m, idx)) = max_heap.pop() {
465            if self.k <= fusion.len() as u32 {
466                break;
467            }
468            if let Some(next_m) = batch_iters[idx].next() {
469                max_heap.push((next_m, idx));
470            }
471            if !seen.insert(m.clone()) {
472                continue;
473            }
474            fusion.push(m);
475        }
476        fusion
477    }
478}
479
480/// The `Projection` operator retrieves record content by offset ids
481///
482/// # Parameters
483/// - `document`: Whether to retrieve document
484/// - `embedding`: Whether to retrieve embedding
485/// - `metadata`: Whether to retrieve metadata
486#[derive(Clone, Debug, Default)]
487pub struct Projection {
488    pub document: bool,
489    pub embedding: bool,
490    pub metadata: bool,
491}
492
493impl From<chroma_proto::ProjectionOperator> for Projection {
494    fn from(value: chroma_proto::ProjectionOperator) -> Self {
495        Self {
496            document: value.document,
497            embedding: value.embedding,
498            metadata: value.metadata,
499        }
500    }
501}
502
503impl From<Projection> for chroma_proto::ProjectionOperator {
504    fn from(value: Projection) -> Self {
505        Self {
506            document: value.document,
507            embedding: value.embedding,
508            metadata: value.metadata,
509        }
510    }
511}
512
513#[derive(Clone, Debug, PartialEq)]
514pub struct ProjectionRecord {
515    pub id: String,
516    pub document: Option<String>,
517    pub embedding: Option<Vec<f32>>,
518    pub metadata: Option<Metadata>,
519}
520
521impl ProjectionRecord {
522    pub fn size_bytes(&self) -> u64 {
523        (self.id.len()
524            + self
525                .document
526                .as_ref()
527                .map(|doc| doc.len())
528                .unwrap_or_default()
529            + self
530                .embedding
531                .as_ref()
532                .map(|emb| size_of_val(&emb[..]))
533                .unwrap_or_default()
534            + self
535                .metadata
536                .as_ref()
537                .map(logical_size_of_metadata)
538                .unwrap_or_default()) as u64
539    }
540}
541
542impl Eq for ProjectionRecord {}
543
544impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
545    type Error = QueryConversionError;
546
547    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
548        Ok(Self {
549            id: value.id,
550            document: value.document,
551            embedding: value
552                .embedding
553                .map(|vec| vec.try_into().map(|(v, _)| v))
554                .transpose()?,
555            metadata: value.metadata.map(TryInto::try_into).transpose()?,
556        })
557    }
558}
559
560impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
561    type Error = QueryConversionError;
562
563    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
564        Ok(Self {
565            id: value.id,
566            document: value.document,
567            embedding: value
568                .embedding
569                .map(|embedding| {
570                    let embedding_dimension = embedding.len();
571                    chroma_proto::Vector::try_from((
572                        embedding,
573                        ScalarEncoding::FLOAT32,
574                        embedding_dimension,
575                    ))
576                })
577                .transpose()?,
578            metadata: value.metadata.map(|metadata| metadata.into()),
579        })
580    }
581}
582
583#[derive(Clone, Debug, Eq, PartialEq)]
584pub struct ProjectionOutput {
585    pub records: Vec<ProjectionRecord>,
586}
587
588#[derive(Clone, Debug, Eq, PartialEq)]
589pub struct GetResult {
590    pub pulled_log_bytes: u64,
591    pub result: ProjectionOutput,
592}
593
594impl GetResult {
595    pub fn size_bytes(&self) -> u64 {
596        self.result
597            .records
598            .iter()
599            .map(ProjectionRecord::size_bytes)
600            .sum()
601    }
602}
603
604impl TryFrom<chroma_proto::GetResult> for GetResult {
605    type Error = QueryConversionError;
606
607    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
608        Ok(Self {
609            pulled_log_bytes: value.pulled_log_bytes,
610            result: ProjectionOutput {
611                records: value
612                    .records
613                    .into_iter()
614                    .map(TryInto::try_into)
615                    .collect::<Result<_, _>>()?,
616            },
617        })
618    }
619}
620
621impl TryFrom<GetResult> for chroma_proto::GetResult {
622    type Error = QueryConversionError;
623
624    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
625        Ok(Self {
626            pulled_log_bytes: value.pulled_log_bytes,
627            records: value
628                .result
629                .records
630                .into_iter()
631                .map(TryInto::try_into)
632                .collect::<Result<_, _>>()?,
633        })
634    }
635}
636
637/// The `KnnProjection` operator retrieves record content by offset ids
638/// It is based on `ProjectionOperator`, and it attaches the distance
639/// of the records to the target embedding to the record content
640///
641/// # Parameters
642/// - `projection`: The parameters of the `ProjectionOperator`
643/// - `distance`: Whether to attach distance information
644#[derive(Clone, Debug)]
645pub struct KnnProjection {
646    pub projection: Projection,
647    pub distance: bool,
648}
649
650impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
651    type Error = QueryConversionError;
652
653    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
654        Ok(Self {
655            projection: value
656                .projection
657                .ok_or(QueryConversionError::field("projection"))?
658                .into(),
659            distance: value.distance,
660        })
661    }
662}
663
664impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
665    fn from(value: KnnProjection) -> Self {
666        Self {
667            projection: Some(value.projection.into()),
668            distance: value.distance,
669        }
670    }
671}
672
673#[derive(Clone, Debug)]
674pub struct KnnProjectionRecord {
675    pub record: ProjectionRecord,
676    pub distance: Option<f32>,
677}
678
679impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
680    type Error = QueryConversionError;
681
682    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
683        Ok(Self {
684            record: value
685                .record
686                .ok_or(QueryConversionError::field("record"))?
687                .try_into()?,
688            distance: value.distance,
689        })
690    }
691}
692
693impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
694    type Error = QueryConversionError;
695
696    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
697        Ok(Self {
698            record: Some(value.record.try_into()?),
699            distance: value.distance,
700        })
701    }
702}
703
704#[derive(Clone, Debug, Default)]
705pub struct KnnProjectionOutput {
706    pub records: Vec<KnnProjectionRecord>,
707}
708
709impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
710    type Error = QueryConversionError;
711
712    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
713        Ok(Self {
714            records: value
715                .records
716                .into_iter()
717                .map(TryInto::try_into)
718                .collect::<Result<_, _>>()?,
719        })
720    }
721}
722
723impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
724    type Error = QueryConversionError;
725
726    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
727        Ok(Self {
728            records: value
729                .records
730                .into_iter()
731                .map(TryInto::try_into)
732                .collect::<Result<_, _>>()?,
733        })
734    }
735}
736
737#[derive(Clone, Debug, Default)]
738pub struct KnnBatchResult {
739    pub pulled_log_bytes: u64,
740    pub results: Vec<KnnProjectionOutput>,
741}
742
743impl KnnBatchResult {
744    pub fn size_bytes(&self) -> u64 {
745        self.results
746            .iter()
747            .flat_map(|res| {
748                res.records
749                    .iter()
750                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
751            })
752            .sum()
753    }
754}
755
756impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
757    type Error = QueryConversionError;
758
759    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
760        Ok(Self {
761            pulled_log_bytes: value.pulled_log_bytes,
762            results: value
763                .results
764                .into_iter()
765                .map(TryInto::try_into)
766                .collect::<Result<_, _>>()?,
767        })
768    }
769}
770
771impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
772    type Error = QueryConversionError;
773
774    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
775        Ok(Self {
776            pulled_log_bytes: value.pulled_log_bytes,
777            results: value
778                .results
779                .into_iter()
780                .map(TryInto::try_into)
781                .collect::<Result<_, _>>()?,
782        })
783    }
784}
785
786/// A query vector for KNN search.
787///
788/// Supports both dense and sparse vector formats.
789///
790/// # Variants
791///
792/// ## Dense
793///
794/// Standard dense embeddings as a vector of floats.
795///
796/// ```
797/// use chroma_types::operator::QueryVector;
798///
799/// let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3, 0.4]);
800/// ```
801///
802/// ## Sparse
803///
804/// Sparse vectors with explicit indices and values.
805///
806/// ```
807/// use chroma_types::operator::QueryVector;
808/// use chroma_types::SparseVector;
809///
810/// let sparse = QueryVector::Sparse(SparseVector::new(
811///     vec![0, 5, 10, 50],      // indices
812///     vec![0.5, 0.3, 0.8, 0.2], // values
813/// ).unwrap());
814/// ```
815///
816/// # Examples
817///
818/// ## Dense vector in KNN
819///
820/// ```
821/// use chroma_types::operator::{RankExpr, QueryVector, Key};
822///
823/// let rank = RankExpr::Knn {
824///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
825///     key: Key::Embedding,
826///     limit: 100,
827///     default: None,
828///     return_rank: false,
829/// };
830/// ```
831///
832/// ## Sparse vector in KNN
833///
834/// ```
835/// use chroma_types::operator::{RankExpr, QueryVector, Key};
836/// use chroma_types::SparseVector;
837///
838/// let rank = RankExpr::Knn {
839///     query: QueryVector::Sparse(SparseVector::new(
840///         vec![1, 5, 10],
841///         vec![0.5, 0.3, 0.8],
842///     ).unwrap()),
843///     key: Key::field("sparse_embedding"),
844///     limit: 100,
845///     default: None,
846///     return_rank: false,
847/// };
848/// ```
849#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
850#[serde(untagged)]
851pub enum QueryVector {
852    Dense(Vec<f32>),
853    Sparse(SparseVector),
854}
855
856impl TryFrom<chroma_proto::QueryVector> for QueryVector {
857    type Error = QueryConversionError;
858
859    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
860        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
861        match vector {
862            chroma_proto::query_vector::Vector::Dense(dense) => {
863                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
864            }
865            chroma_proto::query_vector::Vector::Sparse(sparse) => {
866                Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
867                    QueryConversionError::validation("sparse vector length mismatch")
868                })?))
869            }
870        }
871    }
872}
873
874impl TryFrom<QueryVector> for chroma_proto::QueryVector {
875    type Error = QueryConversionError;
876
877    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
878        match value {
879            QueryVector::Dense(vec) => {
880                let dim = vec.len();
881                Ok(chroma_proto::QueryVector {
882                    vector: Some(chroma_proto::query_vector::Vector::Dense(
883                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
884                    )),
885                })
886            }
887            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
888                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
889            }),
890        }
891    }
892}
893
894impl From<Vec<f32>> for QueryVector {
895    fn from(vec: Vec<f32>) -> Self {
896        QueryVector::Dense(vec)
897    }
898}
899
900impl From<SparseVector> for QueryVector {
901    fn from(sparse: SparseVector) -> Self {
902        QueryVector::Sparse(sparse)
903    }
904}
905
906#[derive(Clone, Debug, PartialEq)]
907pub struct KnnQuery {
908    pub query: QueryVector,
909    pub key: Key,
910    pub limit: u32,
911}
912
913/// Wrapper for ranking expressions in search queries.
914///
915/// Contains an optional ranking expression. When None, results are returned in
916/// natural storage order without scoring.
917///
918/// # Fields
919///
920/// * `expr` - The ranking expression (None = no ranking)
921///
922/// # Examples
923///
924/// ```
925/// use chroma_types::operator::{Rank, RankExpr, QueryVector, Key};
926///
927/// // With ranking
928/// let rank = Rank {
929///     expr: Some(RankExpr::Knn {
930///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
931///         key: Key::Embedding,
932///         limit: 100,
933///         default: None,
934///         return_rank: false,
935///     }),
936/// };
937///
938/// // No ranking (natural order)
939/// let rank = Rank {
940///     expr: None,
941/// };
942/// ```
943#[derive(Clone, Debug, Default, Deserialize, Serialize)]
944#[serde(transparent)]
945pub struct Rank {
946    pub expr: Option<RankExpr>,
947}
948
949impl Rank {
950    pub fn knn_queries(&self) -> Vec<KnnQuery> {
951        self.expr
952            .as_ref()
953            .map(RankExpr::knn_queries)
954            .unwrap_or_default()
955    }
956}
957
958impl TryFrom<chroma_proto::RankOperator> for Rank {
959    type Error = QueryConversionError;
960
961    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
962        Ok(Rank {
963            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
964        })
965    }
966}
967
968impl TryFrom<Rank> for chroma_proto::RankOperator {
969    type Error = QueryConversionError;
970
971    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
972        Ok(chroma_proto::RankOperator {
973            expr: rank.expr.map(TryInto::try_into).transpose()?,
974        })
975    }
976}
977
978/// A ranking expression for scoring and ordering search results.
979///
980/// Ranking expressions determine which documents appear in results and their order.
981/// Lower scores indicate better matches (distance-based scoring).
982///
983/// # Variants
984///
985/// ## Knn - K-Nearest Neighbor Search
986///
987/// The primary ranking method for vector similarity search.
988///
989/// ```
990/// use chroma_types::operator::{RankExpr, QueryVector, Key};
991///
992/// let rank = RankExpr::Knn {
993///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
994///     key: Key::Embedding,
995///     limit: 100,        // Consider top 100 candidates
996///     default: None,     // No default score for missing documents
997///     return_rank: false, // Return distances, not rank positions
998/// };
999/// ```
1000///
1001/// ## Value - Constant
1002///
1003/// Represents a constant score.
1004///
1005/// ```
1006/// use chroma_types::operator::RankExpr;
1007///
1008/// let rank = RankExpr::Value(0.5);
1009/// ```
1010///
1011/// ## Arithmetic Operations
1012///
1013/// Combine ranking expressions using standard operators (+, -, *, /).
1014///
1015/// ```
1016/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1017///
1018/// let knn1 = RankExpr::Knn {
1019///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1020///     key: Key::Embedding,
1021///     limit: 100,
1022///     default: None,
1023///     return_rank: false,
1024/// };
1025///
1026/// let knn2 = RankExpr::Knn {
1027///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1028///     key: Key::field("other_embedding"),
1029///     limit: 100,
1030///     default: None,
1031///     return_rank: false,
1032/// };
1033///
1034/// // Weighted combination: 70% knn1 + 30% knn2
1035/// let combined = knn1 * 0.7 + knn2 * 0.3;
1036///
1037/// // Normalized
1038/// let normalized = combined / 2.0;
1039/// ```
1040///
1041/// ## Mathematical Functions
1042///
1043/// Apply mathematical transformations to scores.
1044///
1045/// ```
1046/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1047///
1048/// let knn = RankExpr::Knn {
1049///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1050///     key: Key::Embedding,
1051///     limit: 100,
1052///     default: None,
1053///     return_rank: false,
1054/// };
1055///
1056/// // Exponential - amplifies differences
1057/// let amplified = knn.clone().exp();
1058///
1059/// // Logarithm - compresses range (add constant to avoid log(0))
1060/// let compressed = (knn.clone() + 1.0).log();
1061///
1062/// // Absolute value
1063/// let absolute = knn.clone().abs();
1064///
1065/// // Min/Max - clamping
1066/// let clamped = knn.min(1.0).max(0.0);
1067/// ```
1068///
1069/// # Examples
1070///
1071/// ## Basic vector search
1072///
1073/// ```
1074/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1075///
1076/// let rank = RankExpr::Knn {
1077///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1078///     key: Key::Embedding,
1079///     limit: 100,
1080///     default: None,
1081///     return_rank: false,
1082/// };
1083/// ```
1084///
1085/// ## Hybrid search with weighted combination
1086///
1087/// ```
1088/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1089///
1090/// let dense = RankExpr::Knn {
1091///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1092///     key: Key::Embedding,
1093///     limit: 200,
1094///     default: None,
1095///     return_rank: false,
1096/// };
1097///
1098/// let sparse = RankExpr::Knn {
1099///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]), // Use sparse in practice
1100///     key: Key::field("sparse_embedding"),
1101///     limit: 200,
1102///     default: None,
1103///     return_rank: false,
1104/// };
1105///
1106/// // 70% semantic + 30% keyword
1107/// let hybrid = dense * 0.7 + sparse * 0.3;
1108/// ```
1109///
1110/// ## Reciprocal Rank Fusion (RRF)
1111///
1112/// Use the `rrf()` function for combining rankings with different score scales.
1113///
1114/// ```
1115/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
1116///
1117/// let dense = RankExpr::Knn {
1118///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1119///     key: Key::Embedding,
1120///     limit: 200,
1121///     default: None,
1122///     return_rank: true, // RRF requires rank positions
1123/// };
1124///
1125/// let sparse = RankExpr::Knn {
1126///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1127///     key: Key::field("sparse_embedding"),
1128///     limit: 200,
1129///     default: None,
1130///     return_rank: true, // RRF requires rank positions
1131/// };
1132///
1133/// let rrf_rank = rrf(
1134///     vec![dense, sparse],
1135///     Some(60),           // k parameter (smoothing)
1136///     Some(vec![0.7, 0.3]), // weights
1137///     false,              // normalize weights
1138/// ).unwrap();
1139/// ```
1140#[derive(Clone, Debug, Deserialize, Serialize)]
1141pub enum RankExpr {
1142    #[serde(rename = "$abs")]
1143    Absolute(Box<RankExpr>),
1144    #[serde(rename = "$div")]
1145    Division {
1146        left: Box<RankExpr>,
1147        right: Box<RankExpr>,
1148    },
1149    #[serde(rename = "$exp")]
1150    Exponentiation(Box<RankExpr>),
1151    #[serde(rename = "$knn")]
1152    Knn {
1153        query: QueryVector,
1154        #[serde(default = "RankExpr::default_knn_key")]
1155        key: Key,
1156        #[serde(default = "RankExpr::default_knn_limit")]
1157        limit: u32,
1158        #[serde(default)]
1159        default: Option<f32>,
1160        #[serde(default)]
1161        return_rank: bool,
1162    },
1163    #[serde(rename = "$log")]
1164    Logarithm(Box<RankExpr>),
1165    #[serde(rename = "$max")]
1166    Maximum(Vec<RankExpr>),
1167    #[serde(rename = "$min")]
1168    Minimum(Vec<RankExpr>),
1169    #[serde(rename = "$mul")]
1170    Multiplication(Vec<RankExpr>),
1171    #[serde(rename = "$sub")]
1172    Subtraction {
1173        left: Box<RankExpr>,
1174        right: Box<RankExpr>,
1175    },
1176    #[serde(rename = "$sum")]
1177    Summation(Vec<RankExpr>),
1178    #[serde(rename = "$val")]
1179    Value(f32),
1180}
1181
1182impl RankExpr {
1183    pub fn default_knn_key() -> Key {
1184        Key::Embedding
1185    }
1186
1187    pub fn default_knn_limit() -> u32 {
1188        16
1189    }
1190
1191    pub fn knn_queries(&self) -> Vec<KnnQuery> {
1192        match self {
1193            RankExpr::Absolute(expr)
1194            | RankExpr::Exponentiation(expr)
1195            | RankExpr::Logarithm(expr) => expr.knn_queries(),
1196            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1197                .knn_queries()
1198                .into_iter()
1199                .chain(right.knn_queries())
1200                .collect(),
1201            RankExpr::Maximum(exprs)
1202            | RankExpr::Minimum(exprs)
1203            | RankExpr::Multiplication(exprs)
1204            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1205            RankExpr::Value(_) => Vec::new(),
1206            RankExpr::Knn {
1207                query,
1208                key,
1209                limit,
1210                default: _,
1211                return_rank: _,
1212            } => vec![KnnQuery {
1213                query: query.clone(),
1214                key: key.clone(),
1215                limit: *limit,
1216            }],
1217        }
1218    }
1219
1220    /// Applies exponential transformation: e^rank.
1221    ///
1222    /// Amplifies differences between scores.
1223    ///
1224    /// # Examples
1225    ///
1226    /// ```
1227    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1228    ///
1229    /// let knn = RankExpr::Knn {
1230    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1231    ///     key: Key::Embedding,
1232    ///     limit: 100,
1233    ///     default: None,
1234    ///     return_rank: false,
1235    /// };
1236    ///
1237    /// let amplified = knn.exp();
1238    /// ```
1239    pub fn exp(self) -> Self {
1240        RankExpr::Exponentiation(Box::new(self))
1241    }
1242
1243    /// Applies natural logarithm transformation: ln(rank).
1244    ///
1245    /// Compresses the score range. Add a constant to avoid log(0).
1246    ///
1247    /// # Examples
1248    ///
1249    /// ```
1250    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1251    ///
1252    /// let knn = RankExpr::Knn {
1253    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1254    ///     key: Key::Embedding,
1255    ///     limit: 100,
1256    ///     default: None,
1257    ///     return_rank: false,
1258    /// };
1259    ///
1260    /// // Add constant to avoid log(0)
1261    /// let compressed = (knn + 1.0).log();
1262    /// ```
1263    pub fn log(self) -> Self {
1264        RankExpr::Logarithm(Box::new(self))
1265    }
1266
1267    /// Takes absolute value of the ranking expression.
1268    ///
1269    /// # Examples
1270    ///
1271    /// ```
1272    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1273    ///
1274    /// let knn1 = RankExpr::Knn {
1275    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1276    ///     key: Key::Embedding,
1277    ///     limit: 100,
1278    ///     default: None,
1279    ///     return_rank: false,
1280    /// };
1281    ///
1282    /// let knn2 = RankExpr::Knn {
1283    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1284    ///     key: Key::field("other"),
1285    ///     limit: 100,
1286    ///     default: None,
1287    ///     return_rank: false,
1288    /// };
1289    ///
1290    /// // Absolute difference
1291    /// let diff = (knn1 - knn2).abs();
1292    /// ```
1293    pub fn abs(self) -> Self {
1294        RankExpr::Absolute(Box::new(self))
1295    }
1296
1297    /// Returns maximum of this expression and another.
1298    ///
1299    /// Can be chained to clamp scores to a maximum value.
1300    ///
1301    /// # Examples
1302    ///
1303    /// ```
1304    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1305    ///
1306    /// let knn = RankExpr::Knn {
1307    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1308    ///     key: Key::Embedding,
1309    ///     limit: 100,
1310    ///     default: None,
1311    ///     return_rank: false,
1312    /// };
1313    ///
1314    /// // Clamp to maximum of 1.0
1315    /// let clamped = knn.clone().max(1.0);
1316    ///
1317    /// // Clamp to range [0.0, 1.0]
1318    /// let range_clamped = knn.min(0.0).max(1.0);
1319    /// ```
1320    pub fn max(self, other: impl Into<RankExpr>) -> Self {
1321        let other = other.into();
1322
1323        match self {
1324            RankExpr::Maximum(mut exprs) => match other {
1325                RankExpr::Maximum(other_exprs) => {
1326                    exprs.extend(other_exprs);
1327                    RankExpr::Maximum(exprs)
1328                }
1329                _ => {
1330                    exprs.push(other);
1331                    RankExpr::Maximum(exprs)
1332                }
1333            },
1334            _ => match other {
1335                RankExpr::Maximum(mut exprs) => {
1336                    exprs.insert(0, self);
1337                    RankExpr::Maximum(exprs)
1338                }
1339                _ => RankExpr::Maximum(vec![self, other]),
1340            },
1341        }
1342    }
1343
1344    /// Returns minimum of this expression and another.
1345    ///
1346    /// Can be chained to clamp scores to a minimum value.
1347    ///
1348    /// # Examples
1349    ///
1350    /// ```
1351    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1352    ///
1353    /// let knn = RankExpr::Knn {
1354    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1355    ///     key: Key::Embedding,
1356    ///     limit: 100,
1357    ///     default: None,
1358    ///     return_rank: false,
1359    /// };
1360    ///
1361    /// // Clamp to minimum of 0.0 (ensure non-negative)
1362    /// let clamped = knn.clone().min(0.0);
1363    ///
1364    /// // Clamp to range [0.0, 1.0]
1365    /// let range_clamped = knn.min(0.0).max(1.0);
1366    /// ```
1367    pub fn min(self, other: impl Into<RankExpr>) -> Self {
1368        let other = other.into();
1369
1370        match self {
1371            RankExpr::Minimum(mut exprs) => match other {
1372                RankExpr::Minimum(other_exprs) => {
1373                    exprs.extend(other_exprs);
1374                    RankExpr::Minimum(exprs)
1375                }
1376                _ => {
1377                    exprs.push(other);
1378                    RankExpr::Minimum(exprs)
1379                }
1380            },
1381            _ => match other {
1382                RankExpr::Minimum(mut exprs) => {
1383                    exprs.insert(0, self);
1384                    RankExpr::Minimum(exprs)
1385                }
1386                _ => RankExpr::Minimum(vec![self, other]),
1387            },
1388        }
1389    }
1390}
1391
1392impl Add for RankExpr {
1393    type Output = RankExpr;
1394
1395    fn add(self, rhs: Self) -> Self::Output {
1396        match self {
1397            RankExpr::Summation(mut exprs) => match rhs {
1398                RankExpr::Summation(rhs_exprs) => {
1399                    exprs.extend(rhs_exprs);
1400                    RankExpr::Summation(exprs)
1401                }
1402                _ => {
1403                    exprs.push(rhs);
1404                    RankExpr::Summation(exprs)
1405                }
1406            },
1407            _ => match rhs {
1408                RankExpr::Summation(mut exprs) => {
1409                    exprs.insert(0, self);
1410                    RankExpr::Summation(exprs)
1411                }
1412                _ => RankExpr::Summation(vec![self, rhs]),
1413            },
1414        }
1415    }
1416}
1417
1418impl Add<f32> for RankExpr {
1419    type Output = RankExpr;
1420
1421    fn add(self, rhs: f32) -> Self::Output {
1422        self + RankExpr::Value(rhs)
1423    }
1424}
1425
1426impl Add<RankExpr> for f32 {
1427    type Output = RankExpr;
1428
1429    fn add(self, rhs: RankExpr) -> Self::Output {
1430        RankExpr::Value(self) + rhs
1431    }
1432}
1433
1434impl Sub for RankExpr {
1435    type Output = RankExpr;
1436
1437    fn sub(self, rhs: Self) -> Self::Output {
1438        RankExpr::Subtraction {
1439            left: Box::new(self),
1440            right: Box::new(rhs),
1441        }
1442    }
1443}
1444
1445impl Sub<f32> for RankExpr {
1446    type Output = RankExpr;
1447
1448    fn sub(self, rhs: f32) -> Self::Output {
1449        self - RankExpr::Value(rhs)
1450    }
1451}
1452
1453impl Sub<RankExpr> for f32 {
1454    type Output = RankExpr;
1455
1456    fn sub(self, rhs: RankExpr) -> Self::Output {
1457        RankExpr::Value(self) - rhs
1458    }
1459}
1460
1461impl Mul for RankExpr {
1462    type Output = RankExpr;
1463
1464    fn mul(self, rhs: Self) -> Self::Output {
1465        match self {
1466            RankExpr::Multiplication(mut exprs) => match rhs {
1467                RankExpr::Multiplication(rhs_exprs) => {
1468                    exprs.extend(rhs_exprs);
1469                    RankExpr::Multiplication(exprs)
1470                }
1471                _ => {
1472                    exprs.push(rhs);
1473                    RankExpr::Multiplication(exprs)
1474                }
1475            },
1476            _ => match rhs {
1477                RankExpr::Multiplication(mut exprs) => {
1478                    exprs.insert(0, self);
1479                    RankExpr::Multiplication(exprs)
1480                }
1481                _ => RankExpr::Multiplication(vec![self, rhs]),
1482            },
1483        }
1484    }
1485}
1486
1487impl Mul<f32> for RankExpr {
1488    type Output = RankExpr;
1489
1490    fn mul(self, rhs: f32) -> Self::Output {
1491        self * RankExpr::Value(rhs)
1492    }
1493}
1494
1495impl Mul<RankExpr> for f32 {
1496    type Output = RankExpr;
1497
1498    fn mul(self, rhs: RankExpr) -> Self::Output {
1499        RankExpr::Value(self) * rhs
1500    }
1501}
1502
1503impl Div for RankExpr {
1504    type Output = RankExpr;
1505
1506    fn div(self, rhs: Self) -> Self::Output {
1507        RankExpr::Division {
1508            left: Box::new(self),
1509            right: Box::new(rhs),
1510        }
1511    }
1512}
1513
1514impl Div<f32> for RankExpr {
1515    type Output = RankExpr;
1516
1517    fn div(self, rhs: f32) -> Self::Output {
1518        self / RankExpr::Value(rhs)
1519    }
1520}
1521
1522impl Div<RankExpr> for f32 {
1523    type Output = RankExpr;
1524
1525    fn div(self, rhs: RankExpr) -> Self::Output {
1526        RankExpr::Value(self) / rhs
1527    }
1528}
1529
1530impl Neg for RankExpr {
1531    type Output = RankExpr;
1532
1533    fn neg(self) -> Self::Output {
1534        RankExpr::Value(-1.0) * self
1535    }
1536}
1537
1538impl From<f32> for RankExpr {
1539    fn from(v: f32) -> Self {
1540        RankExpr::Value(v)
1541    }
1542}
1543
1544impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1545    type Error = QueryConversionError;
1546
1547    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1548        match proto_expr.rank {
1549            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1550                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1551            }
1552            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1553                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1554                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1555                Ok(RankExpr::Division {
1556                    left: Box::new(RankExpr::try_from(*left)?),
1557                    right: Box::new(RankExpr::try_from(*right)?),
1558                })
1559            }
1560            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1561                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1562            ),
1563            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1564                let query = knn
1565                    .query
1566                    .ok_or(QueryConversionError::field("query"))?
1567                    .try_into()?;
1568                Ok(RankExpr::Knn {
1569                    query,
1570                    key: Key::from(knn.key),
1571                    limit: knn.limit,
1572                    default: knn.default,
1573                    return_rank: knn.return_rank,
1574                })
1575            }
1576            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1577                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1578            }
1579            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1580                let exprs = max
1581                    .exprs
1582                    .into_iter()
1583                    .map(RankExpr::try_from)
1584                    .collect::<Result<Vec<_>, _>>()?;
1585                Ok(RankExpr::Maximum(exprs))
1586            }
1587            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1588                let exprs = min
1589                    .exprs
1590                    .into_iter()
1591                    .map(RankExpr::try_from)
1592                    .collect::<Result<Vec<_>, _>>()?;
1593                Ok(RankExpr::Minimum(exprs))
1594            }
1595            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1596                let exprs = mul
1597                    .exprs
1598                    .into_iter()
1599                    .map(RankExpr::try_from)
1600                    .collect::<Result<Vec<_>, _>>()?;
1601                Ok(RankExpr::Multiplication(exprs))
1602            }
1603            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1604                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1605                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1606                Ok(RankExpr::Subtraction {
1607                    left: Box::new(RankExpr::try_from(*left)?),
1608                    right: Box::new(RankExpr::try_from(*right)?),
1609                })
1610            }
1611            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1612                let exprs = sum
1613                    .exprs
1614                    .into_iter()
1615                    .map(RankExpr::try_from)
1616                    .collect::<Result<Vec<_>, _>>()?;
1617                Ok(RankExpr::Summation(exprs))
1618            }
1619            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1620            None => Err(QueryConversionError::field("rank")),
1621        }
1622    }
1623}
1624
1625impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1626    type Error = QueryConversionError;
1627
1628    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1629        let proto_rank = match rank_expr {
1630            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1631                chroma_proto::RankExpr::try_from(*expr)?,
1632            )),
1633            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1634                Box::new(chroma_proto::rank_expr::RankPair {
1635                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1636                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1637                }),
1638            ),
1639            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1640                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1641            ),
1642            RankExpr::Knn {
1643                query,
1644                key,
1645                limit,
1646                default,
1647                return_rank,
1648            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1649                query: Some(query.try_into()?),
1650                key: key.to_string(),
1651                limit,
1652                default,
1653                return_rank,
1654            }),
1655            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1656                chroma_proto::RankExpr::try_from(*expr)?,
1657            )),
1658            RankExpr::Maximum(exprs) => {
1659                let proto_exprs = exprs
1660                    .into_iter()
1661                    .map(chroma_proto::RankExpr::try_from)
1662                    .collect::<Result<Vec<_>, _>>()?;
1663                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1664                    exprs: proto_exprs,
1665                })
1666            }
1667            RankExpr::Minimum(exprs) => {
1668                let proto_exprs = exprs
1669                    .into_iter()
1670                    .map(chroma_proto::RankExpr::try_from)
1671                    .collect::<Result<Vec<_>, _>>()?;
1672                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1673                    exprs: proto_exprs,
1674                })
1675            }
1676            RankExpr::Multiplication(exprs) => {
1677                let proto_exprs = exprs
1678                    .into_iter()
1679                    .map(chroma_proto::RankExpr::try_from)
1680                    .collect::<Result<Vec<_>, _>>()?;
1681                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1682                    exprs: proto_exprs,
1683                })
1684            }
1685            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1686                Box::new(chroma_proto::rank_expr::RankPair {
1687                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1688                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1689                }),
1690            ),
1691            RankExpr::Summation(exprs) => {
1692                let proto_exprs = exprs
1693                    .into_iter()
1694                    .map(chroma_proto::RankExpr::try_from)
1695                    .collect::<Result<Vec<_>, _>>()?;
1696                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1697                    exprs: proto_exprs,
1698                })
1699            }
1700            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1701        };
1702
1703        Ok(chroma_proto::RankExpr {
1704            rank: Some(proto_rank),
1705        })
1706    }
1707}
1708
1709/// Represents a field key in search queries.
1710///
1711/// Used for both selecting fields to return and building filter expressions.
1712/// Predefined keys access special fields, while custom keys access metadata.
1713///
1714/// # Predefined Keys
1715///
1716/// - `Key::Document` - Document text content (`#document`)
1717/// - `Key::Embedding` - Vector embeddings (`#embedding`)
1718/// - `Key::Metadata` - All metadata fields (`#metadata`)
1719/// - `Key::Score` - Search scores (`#score`)
1720///
1721/// # Custom Keys
1722///
1723/// Use `Key::field()` or `Key::from()` to reference metadata fields:
1724///
1725/// ```
1726/// use chroma_types::operator::Key;
1727///
1728/// let key = Key::field("author");
1729/// let key = Key::from("title");
1730/// ```
1731///
1732/// # Examples
1733///
1734/// ## Building filters
1735///
1736/// ```
1737/// use chroma_types::operator::Key;
1738///
1739/// // Equality
1740/// let filter = Key::field("status").eq("published");
1741///
1742/// // Comparisons
1743/// let filter = Key::field("year").gte(2020);
1744/// let filter = Key::field("score").lt(0.9);
1745///
1746/// // Set operations
1747/// let filter = Key::field("category").is_in(vec!["tech", "science"]);
1748/// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
1749///
1750/// // Document content
1751/// let filter = Key::Document.contains("machine learning");
1752/// let filter = Key::Document.regex(r"\bAPI\b");
1753///
1754/// // Combining filters
1755/// let filter = Key::field("status").eq("published")
1756///     & Key::field("year").gte(2020);
1757/// ```
1758///
1759/// ## Selecting fields
1760///
1761/// ```
1762/// use chroma_types::plan::SearchPayload;
1763/// use chroma_types::operator::Key;
1764///
1765/// let search = SearchPayload::default()
1766///     .select([
1767///         Key::Document,
1768///         Key::Score,
1769///         Key::field("title"),
1770///         Key::field("author"),
1771///     ]);
1772/// ```
1773#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1774#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1775pub enum Key {
1776    // Predefined keys
1777    Document,
1778    Embedding,
1779    Metadata,
1780    Score,
1781    MetadataField(String),
1782}
1783
1784impl Serialize for Key {
1785    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1786    where
1787        S: serde::Serializer,
1788    {
1789        match self {
1790            Key::Document => serializer.serialize_str("#document"),
1791            Key::Embedding => serializer.serialize_str("#embedding"),
1792            Key::Metadata => serializer.serialize_str("#metadata"),
1793            Key::Score => serializer.serialize_str("#score"),
1794            Key::MetadataField(field) => serializer.serialize_str(field),
1795        }
1796    }
1797}
1798
1799impl<'de> Deserialize<'de> for Key {
1800    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1801    where
1802        D: Deserializer<'de>,
1803    {
1804        let s = String::deserialize(deserializer)?;
1805        Ok(Key::from(s))
1806    }
1807}
1808
1809impl fmt::Display for Key {
1810    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1811        match self {
1812            Key::Document => write!(f, "#document"),
1813            Key::Embedding => write!(f, "#embedding"),
1814            Key::Metadata => write!(f, "#metadata"),
1815            Key::Score => write!(f, "#score"),
1816            Key::MetadataField(field) => write!(f, "{}", field),
1817        }
1818    }
1819}
1820
1821impl From<&str> for Key {
1822    fn from(s: &str) -> Self {
1823        match s {
1824            "#document" => Key::Document,
1825            "#embedding" => Key::Embedding,
1826            "#metadata" => Key::Metadata,
1827            "#score" => Key::Score,
1828            // Any other string is treated as a metadata field key
1829            field => Key::MetadataField(field.to_string()),
1830        }
1831    }
1832}
1833
1834impl From<String> for Key {
1835    fn from(s: String) -> Self {
1836        Key::from(s.as_str())
1837    }
1838}
1839
1840impl Key {
1841    /// Creates a Key for a custom metadata field.
1842    ///
1843    /// # Examples
1844    ///
1845    /// ```
1846    /// use chroma_types::operator::Key;
1847    ///
1848    /// let status = Key::field("status");
1849    /// let year = Key::field("year");
1850    /// let author = Key::field("author");
1851    /// ```
1852    pub fn field(name: impl Into<String>) -> Self {
1853        Key::MetadataField(name.into())
1854    }
1855
1856    /// Creates an equality filter: `field == value`.
1857    ///
1858    /// # Examples
1859    ///
1860    /// ```
1861    /// use chroma_types::operator::Key;
1862    ///
1863    /// // String equality
1864    /// let filter = Key::field("status").eq("published");
1865    ///
1866    /// // Numeric equality
1867    /// let filter = Key::field("count").eq(42);
1868    ///
1869    /// // Boolean equality
1870    /// let filter = Key::field("featured").eq(true);
1871    /// ```
1872    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1873        Where::Metadata(MetadataExpression {
1874            key: self.to_string(),
1875            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1876        })
1877    }
1878
1879    /// Creates an inequality filter: `field != value`.
1880    ///
1881    /// # Examples
1882    ///
1883    /// ```
1884    /// use chroma_types::operator::Key;
1885    ///
1886    /// let filter = Key::field("status").ne("deleted");
1887    /// let filter = Key::field("count").ne(0);
1888    /// ```
1889    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1890        Where::Metadata(MetadataExpression {
1891            key: self.to_string(),
1892            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1893        })
1894    }
1895
1896    /// Creates a greater-than filter: `field > value` (numeric only).
1897    ///
1898    /// # Examples
1899    ///
1900    /// ```
1901    /// use chroma_types::operator::Key;
1902    ///
1903    /// let filter = Key::field("score").gt(0.5);
1904    /// let filter = Key::field("year").gt(2020);
1905    /// ```
1906    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1907        Where::Metadata(MetadataExpression {
1908            key: self.to_string(),
1909            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1910        })
1911    }
1912
1913    /// Creates a greater-than-or-equal filter: `field >= value` (numeric only).
1914    ///
1915    /// # Examples
1916    ///
1917    /// ```
1918    /// use chroma_types::operator::Key;
1919    ///
1920    /// let filter = Key::field("score").gte(0.5);
1921    /// let filter = Key::field("year").gte(2020);
1922    /// ```
1923    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1924        Where::Metadata(MetadataExpression {
1925            key: self.to_string(),
1926            comparison: MetadataComparison::Primitive(
1927                PrimitiveOperator::GreaterThanOrEqual,
1928                value.into(),
1929            ),
1930        })
1931    }
1932
1933    /// Creates a less-than filter: `field < value` (numeric only).
1934    ///
1935    /// # Examples
1936    ///
1937    /// ```
1938    /// use chroma_types::operator::Key;
1939    ///
1940    /// let filter = Key::field("score").lt(0.9);
1941    /// let filter = Key::field("year").lt(2025);
1942    /// ```
1943    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1944        Where::Metadata(MetadataExpression {
1945            key: self.to_string(),
1946            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1947        })
1948    }
1949
1950    /// Creates a less-than-or-equal filter: `field <= value` (numeric only).
1951    ///
1952    /// # Examples
1953    ///
1954    /// ```
1955    /// use chroma_types::operator::Key;
1956    ///
1957    /// let filter = Key::field("score").lte(0.9);
1958    /// let filter = Key::field("year").lte(2024);
1959    /// ```
1960    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1961        Where::Metadata(MetadataExpression {
1962            key: self.to_string(),
1963            comparison: MetadataComparison::Primitive(
1964                PrimitiveOperator::LessThanOrEqual,
1965                value.into(),
1966            ),
1967        })
1968    }
1969
1970    /// Creates a set membership filter: `field IN values`.
1971    ///
1972    /// Accepts any iterator (Vec, array, slice, etc.).
1973    ///
1974    /// # Examples
1975    ///
1976    /// ```
1977    /// use chroma_types::operator::Key;
1978    ///
1979    /// // With Vec
1980    /// let filter = Key::field("year").is_in(vec![2023, 2024, 2025]);
1981    ///
1982    /// // With array
1983    /// let filter = Key::field("category").is_in(["tech", "science", "math"]);
1984    ///
1985    /// // With owned strings
1986    /// let categories = vec!["tech".to_string(), "science".to_string()];
1987    /// let filter = Key::field("category").is_in(categories);
1988    /// ```
1989    pub fn is_in<I, T>(self, values: I) -> Where
1990    where
1991        I: IntoIterator<Item = T>,
1992        Vec<T>: Into<MetadataSetValue>,
1993    {
1994        let vec: Vec<T> = values.into_iter().collect();
1995        Where::Metadata(MetadataExpression {
1996            key: self.to_string(),
1997            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1998        })
1999    }
2000
2001    /// Creates a set exclusion filter: `field NOT IN values`.
2002    ///
2003    /// Accepts any iterator (Vec, array, slice, etc.).
2004    ///
2005    /// # Examples
2006    ///
2007    /// ```
2008    /// use chroma_types::operator::Key;
2009    ///
2010    /// // Exclude deleted and archived
2011    /// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
2012    ///
2013    /// // Exclude specific years
2014    /// let filter = Key::field("year").not_in(vec![2019, 2020]);
2015    /// ```
2016    pub fn not_in<I, T>(self, values: I) -> Where
2017    where
2018        I: IntoIterator<Item = T>,
2019        Vec<T>: Into<MetadataSetValue>,
2020    {
2021        let vec: Vec<T> = values.into_iter().collect();
2022        Where::Metadata(MetadataExpression {
2023            key: self.to_string(),
2024            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2025        })
2026    }
2027
2028    /// Creates a document substring filter (case-sensitive).
2029    ///
2030    /// Only valid on `Key::Document`. Pattern must have at least 3 literal
2031    /// characters for accurate results.
2032    ///
2033    /// For metadata array contains, use [`contains_value`](Key::contains_value).
2034    ///
2035    /// # Examples
2036    ///
2037    /// ```
2038    /// use chroma_types::operator::Key;
2039    ///
2040    /// let filter = Key::Document.contains("machine learning");
2041    /// let filter = Key::Document.contains("API");
2042    /// ```
2043    pub fn contains<S: Into<String>>(self, text: S) -> Where {
2044        Where::Document(DocumentExpression {
2045            operator: DocumentOperator::Contains,
2046            pattern: text.into(),
2047        })
2048    }
2049
2050    /// Creates a negative document substring filter (case-sensitive).
2051    ///
2052    /// Only valid on `Key::Document`.
2053    ///
2054    /// For metadata array not-contains, use
2055    /// [`not_contains_value`](Key::not_contains_value).
2056    ///
2057    /// # Examples
2058    ///
2059    /// ```
2060    /// use chroma_types::operator::Key;
2061    ///
2062    /// let filter = Key::Document.not_contains("deprecated");
2063    /// let filter = Key::Document.not_contains("beta");
2064    /// ```
2065    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2066        Where::Document(DocumentExpression {
2067            operator: DocumentOperator::NotContains,
2068            pattern: text.into(),
2069        })
2070    }
2071
2072    /// Checks whether a metadata array field contains the given scalar value.
2073    ///
2074    /// # Examples
2075    ///
2076    /// ```
2077    /// use chroma_types::operator::Key;
2078    ///
2079    /// let filter = Key::field("tags").contains_value("action");
2080    /// let filter = Key::field("scores").contains_value(42);
2081    /// let filter = Key::field("ratings").contains_value(4.5);
2082    /// let filter = Key::field("flags").contains_value(true);
2083    /// ```
2084    pub fn contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2085        Where::Metadata(MetadataExpression {
2086            key: self.to_string(),
2087            comparison: MetadataComparison::ArrayContains(ContainsOperator::Contains, value.into()),
2088        })
2089    }
2090
2091    /// Checks that a metadata array field does **not** contain the given scalar
2092    /// value.
2093    ///
2094    /// # Examples
2095    ///
2096    /// ```
2097    /// use chroma_types::operator::Key;
2098    ///
2099    /// let filter = Key::field("tags").not_contains_value("draft");
2100    /// let filter = Key::field("scores").not_contains_value(0);
2101    /// ```
2102    pub fn not_contains_value<T: Into<MetadataValue>>(self, value: T) -> Where {
2103        Where::Metadata(MetadataExpression {
2104            key: self.to_string(),
2105            comparison: MetadataComparison::ArrayContains(
2106                ContainsOperator::NotContains,
2107                value.into(),
2108            ),
2109        })
2110    }
2111
2112    /// Creates a regex filter (case-sensitive, document content only).
2113    ///
2114    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2115    /// 3 literal characters for accurate results.
2116    ///
2117    /// # Examples
2118    ///
2119    /// ```
2120    /// use chroma_types::operator::Key;
2121    ///
2122    /// // Match whole word "API"
2123    /// let filter = Key::Document.regex(r"\bAPI\b");
2124    ///
2125    /// // Match version pattern
2126    /// let filter = Key::Document.regex(r"v\d+\.\d+\.\d+");
2127    /// ```
2128    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2129        Where::Document(DocumentExpression {
2130            operator: DocumentOperator::Regex,
2131            pattern: pattern.into(),
2132        })
2133    }
2134
2135    /// Creates a negative regex filter (case-sensitive, document content only).
2136    ///
2137    /// Note: Currently only works with `Key::Document`.
2138    ///
2139    /// # Examples
2140    ///
2141    /// ```
2142    /// use chroma_types::operator::Key;
2143    ///
2144    /// // Exclude beta versions
2145    /// let filter = Key::Document.not_regex(r"beta");
2146    ///
2147    /// // Exclude test documents
2148    /// let filter = Key::Document.not_regex(r"\btest\b");
2149    /// ```
2150    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2151        Where::Document(DocumentExpression {
2152            operator: DocumentOperator::NotRegex,
2153            pattern: pattern.into(),
2154        })
2155    }
2156}
2157
2158/// Field selection for search results.
2159///
2160/// Specifies which fields to include in the results. IDs are always included.
2161///
2162/// # Fields
2163///
2164/// * `keys` - Set of keys to include in results
2165///
2166/// # Available Keys
2167///
2168/// * `Key::Document` - Document text content
2169/// * `Key::Embedding` - Vector embeddings
2170/// * `Key::Metadata` - All metadata fields
2171/// * `Key::Score` - Search scores
2172/// * `Key::field("name")` - Specific metadata field
2173///
2174/// # Performance
2175///
2176/// Selecting fewer fields improves performance by reducing data transfer:
2177/// - Minimal: IDs only (default, fastest)
2178/// - Moderate: Scores + specific metadata fields
2179/// - Heavy: Documents + embeddings (larger payloads)
2180///
2181/// # Examples
2182///
2183/// ```
2184/// use chroma_types::operator::{Select, Key};
2185/// use std::collections::HashSet;
2186///
2187/// // Select predefined fields
2188/// let select = Select {
2189///     keys: [Key::Document, Key::Score].into_iter().collect(),
2190/// };
2191///
2192/// // Select specific metadata fields
2193/// let select = Select {
2194///     keys: [
2195///         Key::field("title"),
2196///         Key::field("author"),
2197///         Key::Score,
2198///     ].into_iter().collect(),
2199/// };
2200///
2201/// // Select everything
2202/// let select = Select {
2203///     keys: [
2204///         Key::Document,
2205///         Key::Embedding,
2206///         Key::Metadata,
2207///         Key::Score,
2208///     ].into_iter().collect(),
2209/// };
2210/// ```
2211#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2212pub struct Select {
2213    #[serde(default)]
2214    pub keys: HashSet<Key>,
2215}
2216
2217impl TryFrom<chroma_proto::SelectOperator> for Select {
2218    type Error = QueryConversionError;
2219
2220    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2221        let keys = value
2222            .keys
2223            .into_iter()
2224            .map(|key| {
2225                // Try to deserialize each string as a Key
2226                serde_json::from_value(serde_json::Value::String(key))
2227                    .map_err(|_| QueryConversionError::field("keys"))
2228            })
2229            .collect::<Result<HashSet<_>, _>>()?;
2230
2231        Ok(Self { keys })
2232    }
2233}
2234
2235impl TryFrom<Select> for chroma_proto::SelectOperator {
2236    type Error = QueryConversionError;
2237
2238    fn try_from(value: Select) -> Result<Self, Self::Error> {
2239        let keys = value
2240            .keys
2241            .into_iter()
2242            .map(|key| {
2243                // Serialize each Key back to string
2244                serde_json::to_value(&key)
2245                    .ok()
2246                    .and_then(|v| v.as_str().map(String::from))
2247                    .ok_or(QueryConversionError::field("keys"))
2248            })
2249            .collect::<Result<Vec<_>, _>>()?;
2250
2251        Ok(Self { keys })
2252    }
2253}
2254
2255/// Aggregation function applied within each group.
2256///
2257/// Determines which records to keep from each group and their ordering.
2258///
2259/// # Variants
2260///
2261/// * `MinK` - Returns k records with minimum values (ascending order).
2262///   Use with `Key::Score` to get best matches (lower score = better in Chroma).
2263/// * `MaxK` - Returns k records with maximum values (descending order).
2264///
2265/// # Multi-level Ordering
2266///
2267/// The `keys` field supports multi-level ordering. Records are sorted by
2268/// the first key, then by the second key for ties, and so on.
2269///
2270/// # Examples
2271///
2272/// ```
2273/// use chroma_types::operator::{Aggregate, Key};
2274///
2275/// // Best 3 by score per group
2276/// let agg = Aggregate::MinK {
2277///     keys: vec![Key::Score],
2278///     k: 3,
2279/// };
2280///
2281/// // Best 3 by score, then by date for ties
2282/// let agg = Aggregate::MinK {
2283///     keys: vec![Key::Score, Key::field("date")],
2284///     k: 3,
2285/// };
2286///
2287/// // Top 5 by recency (highest date first)
2288/// let agg = Aggregate::MaxK {
2289///     keys: vec![Key::field("date")],
2290///     k: 5,
2291/// };
2292/// ```
2293#[derive(Clone, Debug, Deserialize, Serialize)]
2294pub enum Aggregate {
2295    /// Returns k records with minimum values (ascending order)
2296    #[serde(rename = "$min_k")]
2297    MinK {
2298        /// Keys for multi-level ordering
2299        keys: Vec<Key>,
2300        /// Number of records to return per group
2301        k: u32,
2302    },
2303    /// Returns k records with maximum values (descending order)
2304    #[serde(rename = "$max_k")]
2305    MaxK {
2306        /// Keys for multi-level ordering
2307        keys: Vec<Key>,
2308        /// Number of records to return per group
2309        k: u32,
2310    },
2311}
2312
2313/// Groups results by metadata keys and aggregates within each group.
2314///
2315/// Results are grouped by the specified metadata keys (like SQL GROUP BY),
2316/// then aggregated within each group using MinK or MaxK ordering.
2317/// The final output is flattened and sorted by score.
2318///
2319/// # Fields
2320///
2321/// * `keys` - Metadata keys to group by (composite grouping)
2322/// * `aggregate` - Aggregation function to apply within each group
2323///
2324/// # Behavior
2325///
2326/// * Missing metadata keys are treated as Null (forming their own group)
2327/// * Empty groups are omitted from results
2328/// * Final output is flattened (group structure not preserved)
2329/// * Results are sorted by score after aggregation
2330///
2331/// # Examples
2332///
2333/// ```
2334/// use chroma_types::operator::{GroupBy, Aggregate, Key};
2335///
2336/// // Top 3 documents per category
2337/// let group_by = GroupBy {
2338///     keys: vec![Key::field("category")],
2339///     aggregate: Some(Aggregate::MinK {
2340///         keys: vec![Key::Score],
2341///         k: 3,
2342///     }),
2343/// };
2344///
2345/// // Top 2 per (category, author) combination
2346/// let group_by = GroupBy {
2347///     keys: vec![Key::field("category"), Key::field("author")],
2348///     aggregate: Some(Aggregate::MinK {
2349///         keys: vec![Key::Score, Key::field("date")],
2350///         k: 2,
2351///     }),
2352/// };
2353/// ```
2354#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2355pub struct GroupBy {
2356    /// Metadata keys to group by
2357    #[serde(default)]
2358    pub keys: Vec<Key>,
2359    /// Aggregation to apply within each group (required when keys is non-empty)
2360    #[serde(default)]
2361    pub aggregate: Option<Aggregate>,
2362}
2363
2364impl TryFrom<chroma_proto::Aggregate> for Aggregate {
2365    type Error = QueryConversionError;
2366
2367    fn try_from(value: chroma_proto::Aggregate) -> Result<Self, Self::Error> {
2368        match value
2369            .aggregate
2370            .ok_or(QueryConversionError::field("aggregate"))?
2371        {
2372            chroma_proto::aggregate::Aggregate::MinK(min_k) => {
2373                let keys = min_k.keys.into_iter().map(Key::from).collect();
2374                Ok(Aggregate::MinK { keys, k: min_k.k })
2375            }
2376            chroma_proto::aggregate::Aggregate::MaxK(max_k) => {
2377                let keys = max_k.keys.into_iter().map(Key::from).collect();
2378                Ok(Aggregate::MaxK { keys, k: max_k.k })
2379            }
2380        }
2381    }
2382}
2383
2384impl From<Aggregate> for chroma_proto::Aggregate {
2385    fn from(value: Aggregate) -> Self {
2386        let aggregate = match value {
2387            Aggregate::MinK { keys, k } => {
2388                chroma_proto::aggregate::Aggregate::MinK(chroma_proto::aggregate::MinK {
2389                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2390                    k,
2391                })
2392            }
2393            Aggregate::MaxK { keys, k } => {
2394                chroma_proto::aggregate::Aggregate::MaxK(chroma_proto::aggregate::MaxK {
2395                    keys: keys.into_iter().map(|k| k.to_string()).collect(),
2396                    k,
2397                })
2398            }
2399        };
2400
2401        chroma_proto::Aggregate {
2402            aggregate: Some(aggregate),
2403        }
2404    }
2405}
2406
2407impl TryFrom<chroma_proto::GroupByOperator> for GroupBy {
2408    type Error = QueryConversionError;
2409
2410    fn try_from(value: chroma_proto::GroupByOperator) -> Result<Self, Self::Error> {
2411        let keys = value.keys.into_iter().map(Key::from).collect();
2412        let aggregate = value.aggregate.map(TryInto::try_into).transpose()?;
2413
2414        Ok(Self { keys, aggregate })
2415    }
2416}
2417
2418impl TryFrom<GroupBy> for chroma_proto::GroupByOperator {
2419    type Error = QueryConversionError;
2420
2421    fn try_from(value: GroupBy) -> Result<Self, Self::Error> {
2422        let keys = value.keys.into_iter().map(|k| k.to_string()).collect();
2423        let aggregate = value.aggregate.map(Into::into);
2424
2425        Ok(Self { keys, aggregate })
2426    }
2427}
2428
2429/// A single search result record.
2430///
2431/// Contains the document ID and optionally document content, embeddings, metadata,
2432/// and search score based on what was selected in the search query.
2433///
2434/// # Fields
2435///
2436/// * `id` - Document ID (always present)
2437/// * `document` - Document text content (if selected)
2438/// * `embedding` - Vector embedding (if selected)
2439/// * `metadata` - Document metadata (if selected)
2440/// * `score` - Search score (present when ranking is used, lower = better match)
2441///
2442/// # Examples
2443///
2444/// ```
2445/// use chroma_types::operator::SearchRecord;
2446///
2447/// fn process_results(records: Vec<SearchRecord>) {
2448///     for record in records {
2449///         println!("ID: {}", record.id);
2450///
2451///         if let Some(score) = record.score {
2452///             println!("  Score: {:.3}", score);
2453///         }
2454///
2455///         if let Some(doc) = record.document {
2456///             println!("  Document: {}", doc);
2457///         }
2458///
2459///         if let Some(meta) = record.metadata {
2460///             println!("  Metadata: {:?}", meta);
2461///         }
2462///     }
2463/// }
2464/// ```
2465#[derive(Clone, Debug, Deserialize, Serialize)]
2466#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2467pub struct SearchRecord {
2468    pub id: String,
2469    pub document: Option<String>,
2470    pub embedding: Option<Vec<f32>>,
2471    pub metadata: Option<Metadata>,
2472    pub score: Option<f32>,
2473}
2474
2475impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2476    type Error = QueryConversionError;
2477
2478    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2479        Ok(Self {
2480            id: value.id,
2481            document: value.document,
2482            embedding: value
2483                .embedding
2484                .map(|vec| vec.try_into().map(|(v, _)| v))
2485                .transpose()?,
2486            metadata: value.metadata.map(TryInto::try_into).transpose()?,
2487            score: value.score,
2488        })
2489    }
2490}
2491
2492impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2493    type Error = QueryConversionError;
2494
2495    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2496        Ok(Self {
2497            id: value.id,
2498            document: value.document,
2499            embedding: value
2500                .embedding
2501                .map(|embedding| {
2502                    let embedding_dimension = embedding.len();
2503                    chroma_proto::Vector::try_from((
2504                        embedding,
2505                        ScalarEncoding::FLOAT32,
2506                        embedding_dimension,
2507                    ))
2508                })
2509                .transpose()?,
2510            metadata: value.metadata.map(Into::into),
2511            score: value.score,
2512        })
2513    }
2514}
2515
2516/// Results for a single search payload.
2517///
2518/// Contains all matching records for one search query.
2519///
2520/// # Fields
2521///
2522/// * `records` - Vector of search records, ordered by score (ascending)
2523///
2524/// # Examples
2525///
2526/// ```
2527/// use chroma_types::operator::{SearchPayloadResult, SearchRecord};
2528///
2529/// fn process_search_result(result: SearchPayloadResult) {
2530///     println!("Found {} results", result.records.len());
2531///
2532///     for (i, record) in result.records.iter().enumerate() {
2533///         println!("{}. {} (score: {:?})", i + 1, record.id, record.score);
2534///     }
2535/// }
2536/// ```
2537#[derive(Clone, Debug, Default)]
2538pub struct SearchPayloadResult {
2539    pub records: Vec<SearchRecord>,
2540}
2541
2542impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2543    type Error = QueryConversionError;
2544
2545    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2546        Ok(Self {
2547            records: value
2548                .records
2549                .into_iter()
2550                .map(TryInto::try_into)
2551                .collect::<Result<_, _>>()?,
2552        })
2553    }
2554}
2555
2556impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2557    type Error = QueryConversionError;
2558
2559    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2560        Ok(Self {
2561            records: value
2562                .records
2563                .into_iter()
2564                .map(TryInto::try_into)
2565                .collect::<Result<Vec<_>, _>>()?,
2566        })
2567    }
2568}
2569
2570/// Results from a batch search operation.
2571///
2572/// Contains results for each search payload in the batch, maintaining the same order
2573/// as the input searches.
2574///
2575/// # Fields
2576///
2577/// * `results` - Results for each search payload (indexed by search position)
2578/// * `pulled_log_bytes` - Total bytes pulled from log (for internal metrics)
2579///
2580/// # Examples
2581///
2582/// ## Single search
2583///
2584/// ```
2585/// use chroma_types::operator::SearchResult;
2586///
2587/// fn process_single_search(result: SearchResult) {
2588///     // Single search, so results[0] contains our records
2589///     let records = &result.results[0].records;
2590///
2591///     for record in records {
2592///         println!("{}: score={:?}", record.id, record.score);
2593///     }
2594/// }
2595/// ```
2596///
2597/// ## Batch search
2598///
2599/// ```
2600/// use chroma_types::operator::SearchResult;
2601///
2602/// fn process_batch_search(result: SearchResult) {
2603///     // Multiple searches in batch
2604///     for (i, search_result) in result.results.iter().enumerate() {
2605///         println!("\nSearch {}:", i + 1);
2606///         for record in &search_result.records {
2607///             println!("  {}: score={:?}", record.id, record.score);
2608///         }
2609///     }
2610/// }
2611/// ```
2612#[derive(Clone, Debug)]
2613pub struct SearchResult {
2614    pub results: Vec<SearchPayloadResult>,
2615    pub pulled_log_bytes: u64,
2616}
2617
2618impl SearchResult {
2619    pub fn size_bytes(&self) -> u64 {
2620        self.results
2621            .iter()
2622            .flat_map(|result| {
2623                result.records.iter().map(|record| {
2624                    (record.id.len()
2625                        + record
2626                            .document
2627                            .as_ref()
2628                            .map(|doc| doc.len())
2629                            .unwrap_or_default()
2630                        + record
2631                            .embedding
2632                            .as_ref()
2633                            .map(|emb| size_of_val(&emb[..]))
2634                            .unwrap_or_default()
2635                        + record
2636                            .metadata
2637                            .as_ref()
2638                            .map(logical_size_of_metadata)
2639                            .unwrap_or_default()
2640                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
2641                        as u64
2642                })
2643            })
2644            .sum()
2645    }
2646}
2647
2648impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2649    type Error = QueryConversionError;
2650
2651    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2652        Ok(Self {
2653            results: value
2654                .results
2655                .into_iter()
2656                .map(TryInto::try_into)
2657                .collect::<Result<_, _>>()?,
2658            pulled_log_bytes: value.pulled_log_bytes,
2659        })
2660    }
2661}
2662
2663impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2664    type Error = QueryConversionError;
2665
2666    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2667        Ok(Self {
2668            results: value
2669                .results
2670                .into_iter()
2671                .map(TryInto::try_into)
2672                .collect::<Result<Vec<_>, _>>()?,
2673            pulled_log_bytes: value.pulled_log_bytes,
2674        })
2675    }
2676}
2677
2678/// Reciprocal Rank Fusion (RRF) - combines multiple ranking strategies.
2679///
2680/// RRF is ideal for hybrid search where you want to merge results from different
2681/// ranking methods (e.g., dense and sparse embeddings) with different score scales.
2682/// It uses rank positions instead of raw scores, making it scale-agnostic.
2683///
2684/// # Formula
2685///
2686/// ```text
2687/// score = -Σ(weight_i / (k + rank_i))
2688/// ```
2689///
2690/// Where:
2691/// - `weight_i` = weight for ranking i (default: 1.0)
2692/// - `rank_i` = rank position from ranking i (0, 1, 2...)
2693/// - `k` = smoothing parameter (default: 60)
2694///
2695/// Score is negative because Chroma uses ascending order (lower = better).
2696///
2697/// # Arguments
2698///
2699/// * `ranks` - List of ranking expressions (must have `return_rank=true`)
2700/// * `k` - Smoothing parameter (None = 60). Higher values reduce emphasis on top ranks.
2701/// * `weights` - Weight for each ranking (None = all 1.0)
2702/// * `normalize` - If true, normalize weights to sum to 1.0
2703///
2704/// # Returns
2705///
2706/// A combined RankExpr or an error if:
2707/// - `ranks` is empty
2708/// - `weights` length doesn't match `ranks` length
2709/// - `weights` sum to zero when normalizing
2710///
2711/// # Examples
2712///
2713/// ## Basic RRF with default parameters
2714///
2715/// ```
2716/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2717///
2718/// let dense = RankExpr::Knn {
2719///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2720///     key: Key::Embedding,
2721///     limit: 200,
2722///     default: None,
2723///     return_rank: true, // Required for RRF
2724/// };
2725///
2726/// let sparse = RankExpr::Knn {
2727///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2728///     key: Key::field("sparse_embedding"),
2729///     limit: 200,
2730///     default: None,
2731///     return_rank: true, // Required for RRF
2732/// };
2733///
2734/// // Equal weights, k=60 (defaults)
2735/// let combined = rrf(vec![dense, sparse], None, None, false).unwrap();
2736/// ```
2737///
2738/// ## RRF with custom weights
2739///
2740/// ```
2741/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2742///
2743/// # let dense = RankExpr::Knn {
2744/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2745/// #     key: Key::Embedding,
2746/// #     limit: 200,
2747/// #     default: None,
2748/// #     return_rank: true,
2749/// # };
2750/// # let sparse = RankExpr::Knn {
2751/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2752/// #     key: Key::field("sparse_embedding"),
2753/// #     limit: 200,
2754/// #     default: None,
2755/// #     return_rank: true,
2756/// # };
2757/// // 70% dense, 30% sparse
2758/// let combined = rrf(
2759///     vec![dense, sparse],
2760///     Some(60),
2761///     Some(vec![0.7, 0.3]),
2762///     false,
2763/// ).unwrap();
2764/// ```
2765///
2766/// ## RRF with normalized weights
2767///
2768/// ```
2769/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2770///
2771/// # let dense = RankExpr::Knn {
2772/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2773/// #     key: Key::Embedding,
2774/// #     limit: 200,
2775/// #     default: None,
2776/// #     return_rank: true,
2777/// # };
2778/// # let sparse = RankExpr::Knn {
2779/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2780/// #     key: Key::field("sparse_embedding"),
2781/// #     limit: 200,
2782/// #     default: None,
2783/// #     return_rank: true,
2784/// # };
2785/// // Weights [75, 25] normalized to [0.75, 0.25]
2786/// let combined = rrf(
2787///     vec![dense, sparse],
2788///     Some(60),
2789///     Some(vec![75.0, 25.0]),
2790///     true, // normalize
2791/// ).unwrap();
2792/// ```
2793///
2794/// ## Adjusting the k parameter
2795///
2796/// ```
2797/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2798///
2799/// # let dense = RankExpr::Knn {
2800/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2801/// #     key: Key::Embedding,
2802/// #     limit: 200,
2803/// #     default: None,
2804/// #     return_rank: true,
2805/// # };
2806/// # let sparse = RankExpr::Knn {
2807/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2808/// #     key: Key::field("sparse_embedding"),
2809/// #     limit: 200,
2810/// #     default: None,
2811/// #     return_rank: true,
2812/// # };
2813/// // Small k (10) = heavy emphasis on top ranks
2814/// let top_heavy = rrf(vec![dense.clone(), sparse.clone()], Some(10), None, false).unwrap();
2815///
2816/// // Default k (60) = balanced
2817/// let balanced = rrf(vec![dense.clone(), sparse.clone()], Some(60), None, false).unwrap();
2818///
2819/// // Large k (200) = more uniform weighting
2820/// let uniform = rrf(vec![dense, sparse], Some(200), None, false).unwrap();
2821/// ```
2822pub fn rrf(
2823    ranks: Vec<RankExpr>,
2824    k: Option<u32>,
2825    weights: Option<Vec<f32>>,
2826    normalize: bool,
2827) -> Result<RankExpr, QueryConversionError> {
2828    let k = k.unwrap_or(60);
2829
2830    if ranks.is_empty() {
2831        return Err(QueryConversionError::validation(
2832            "RRF requires at least one rank expression",
2833        ));
2834    }
2835
2836    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2837
2838    if weights.len() != ranks.len() {
2839        return Err(QueryConversionError::validation(format!(
2840            "RRF weights length ({}) must match ranks length ({})",
2841            weights.len(),
2842            ranks.len()
2843        )));
2844    }
2845
2846    let weights = if normalize {
2847        let sum: f32 = weights.iter().sum();
2848        if sum == 0.0 {
2849            return Err(QueryConversionError::validation(
2850                "RRF weights sum to zero, cannot normalize",
2851            ));
2852        }
2853        weights.into_iter().map(|w| w / sum).collect()
2854    } else {
2855        weights
2856    };
2857
2858    let terms: Vec<RankExpr> = weights
2859        .into_iter()
2860        .zip(ranks)
2861        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2862        .collect();
2863
2864    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
2865    // Using unwrap_or_else as defensive programming to avoid panic.
2866    let sum = terms
2867        .into_iter()
2868        .reduce(|a, b| a + b)
2869        .unwrap_or(RankExpr::Value(0.0));
2870    Ok(-sum)
2871}
2872
2873#[cfg(test)]
2874mod tests {
2875    use super::*;
2876
2877    #[test]
2878    fn test_key_from_string() {
2879        // Test predefined keys
2880        assert_eq!(Key::from("#document"), Key::Document);
2881        assert_eq!(Key::from("#embedding"), Key::Embedding);
2882        assert_eq!(Key::from("#metadata"), Key::Metadata);
2883        assert_eq!(Key::from("#score"), Key::Score);
2884
2885        // Test metadata field keys
2886        assert_eq!(
2887            Key::from("custom_field"),
2888            Key::MetadataField("custom_field".to_string())
2889        );
2890        assert_eq!(
2891            Key::from("author"),
2892            Key::MetadataField("author".to_string())
2893        );
2894
2895        // Test String variant
2896        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2897        assert_eq!(
2898            Key::from("year".to_string()),
2899            Key::MetadataField("year".to_string())
2900        );
2901    }
2902
2903    #[test]
2904    fn test_query_vector_dense_proto_conversion() {
2905        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2906        let query_vector = QueryVector::Dense(dense_vec.clone());
2907
2908        // Convert to proto
2909        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2910
2911        // Convert back
2912        let converted: QueryVector = proto.try_into().unwrap();
2913
2914        assert_eq!(converted, query_vector);
2915        if let QueryVector::Dense(v) = converted {
2916            assert_eq!(v, dense_vec);
2917        } else {
2918            panic!("Expected dense vector");
2919        }
2920    }
2921
2922    #[test]
2923    fn test_query_vector_sparse_proto_conversion() {
2924        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2925        let query_vector = QueryVector::Sparse(sparse.clone());
2926
2927        // Convert to proto
2928        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2929
2930        // Convert back
2931        let converted: QueryVector = proto.try_into().unwrap();
2932
2933        assert_eq!(converted, query_vector);
2934        if let QueryVector::Sparse(s) = converted {
2935            assert_eq!(s, sparse);
2936        } else {
2937            panic!("Expected sparse vector");
2938        }
2939    }
2940
2941    #[test]
2942    fn test_filter_json_deserialization() {
2943        // For the new search API, deserialization treats the entire JSON as a where clause
2944
2945        // Test 1: Simple direct metadata comparison
2946        let simple_where = r#"{"author": "John Doe"}"#;
2947        let filter: Filter = serde_json::from_str(simple_where).unwrap();
2948        assert_eq!(filter.query_ids, None);
2949        assert!(filter.where_clause.is_some());
2950
2951        // Test 2: ID filter using #id with $in operator
2952        let id_filter_json = serde_json::json!({
2953            "#id": {
2954                "$in": ["doc1", "doc2", "doc3"]
2955            }
2956        });
2957        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2958        assert_eq!(filter.query_ids, None);
2959        assert!(filter.where_clause.is_some());
2960
2961        // Test 3: Complex nested expression with AND, OR, and various operators
2962        let complex_json = serde_json::json!({
2963            "$and": [
2964                {
2965                    "#id": {
2966                        "$in": ["doc1", "doc2", "doc3"]
2967                    }
2968                },
2969                {
2970                    "$or": [
2971                        {
2972                            "author": {
2973                                "$eq": "John Doe"
2974                            }
2975                        },
2976                        {
2977                            "author": {
2978                                "$eq": "Jane Smith"
2979                            }
2980                        }
2981                    ]
2982                },
2983                {
2984                    "year": {
2985                        "$gte": 2020
2986                    }
2987                },
2988                {
2989                    "tags": {
2990                        "$contains": "machine-learning"
2991                    }
2992                }
2993            ]
2994        });
2995
2996        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2997        assert_eq!(filter.query_ids, None);
2998        assert!(filter.where_clause.is_some());
2999
3000        // Verify the structure
3001        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3002            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3003            assert_eq!(composite.children.len(), 4);
3004
3005            // Check that the second child is an OR
3006            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
3007                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
3008                assert_eq!(or_composite.children.len(), 2);
3009            } else {
3010                panic!("Expected OR composite in second child");
3011            }
3012        } else {
3013            panic!("Expected AND composite where clause");
3014        }
3015
3016        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
3017        let mixed_operators_json = serde_json::json!({
3018            "$and": [
3019                {
3020                    "status": {
3021                        "$ne": "deleted"
3022                    }
3023                },
3024                {
3025                    "score": {
3026                        "$gt": 0.5
3027                    }
3028                },
3029                {
3030                    "score": {
3031                        "$lt": 0.9
3032                    }
3033                },
3034                {
3035                    "priority": {
3036                        "$lte": 10
3037                    }
3038                }
3039            ]
3040        });
3041
3042        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
3043        assert_eq!(filter.query_ids, None);
3044        assert!(filter.where_clause.is_some());
3045
3046        // Test 5: Deeply nested expression
3047        let deeply_nested_json = serde_json::json!({
3048            "$or": [
3049                {
3050                    "$and": [
3051                        {
3052                            "#id": {
3053                                "$in": ["id1", "id2"]
3054                            }
3055                        },
3056                        {
3057                            "$or": [
3058                                {
3059                                    "category": "tech"
3060                                },
3061                                {
3062                                    "category": "science"
3063                                }
3064                            ]
3065                        }
3066                    ]
3067                },
3068                {
3069                    "$and": [
3070                        {
3071                            "author": "Admin"
3072                        },
3073                        {
3074                            "published": true
3075                        }
3076                    ]
3077                }
3078            ]
3079        });
3080
3081        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
3082        assert_eq!(filter.query_ids, None);
3083        assert!(filter.where_clause.is_some());
3084
3085        // Verify it's an OR at the top level
3086        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3087            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
3088            assert_eq!(composite.children.len(), 2);
3089
3090            // Both children should be AND composites
3091            for child in &composite.children {
3092                if let crate::metadata::Where::Composite(and_composite) = child {
3093                    assert_eq!(
3094                        and_composite.operator,
3095                        crate::metadata::BooleanOperator::And
3096                    );
3097                } else {
3098                    panic!("Expected AND composite in OR children");
3099                }
3100            }
3101        } else {
3102            panic!("Expected OR composite at top level");
3103        }
3104
3105        // Test 6: Single ID filter (edge case)
3106        let single_id_json = serde_json::json!({
3107            "#id": {
3108                "$eq": "single-doc-id"
3109            }
3110        });
3111
3112        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
3113        assert_eq!(filter.query_ids, None);
3114        assert!(filter.where_clause.is_some());
3115
3116        // Test 7: Empty object should create empty filter
3117        let empty_json = serde_json::json!({});
3118        let filter: Filter = serde_json::from_value(empty_json).unwrap();
3119        assert_eq!(filter.query_ids, None);
3120        // Empty object results in None where_clause
3121        assert_eq!(filter.where_clause, None);
3122
3123        // Test 8: Combining #id filter with $not_contains and numeric comparisons
3124        let advanced_json = serde_json::json!({
3125            "$and": [
3126                {
3127                    "#id": {
3128                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
3129                    }
3130                },
3131                {
3132                    "tags": {
3133                        "$not_contains": "deprecated"
3134                    }
3135                },
3136                {
3137                    "$or": [
3138                        {
3139                            "$and": [
3140                                {
3141                                    "confidence": {
3142                                        "$gte": 0.8
3143                                    }
3144                                },
3145                                {
3146                                    "verified": true
3147                                }
3148                            ]
3149                        },
3150                        {
3151                            "$and": [
3152                                {
3153                                    "confidence": {
3154                                        "$gte": 0.6
3155                                    }
3156                                },
3157                                {
3158                                    "confidence": {
3159                                        "$lt": 0.8
3160                                    }
3161                                },
3162                                {
3163                                    "reviews": {
3164                                        "$gte": 5
3165                                    }
3166                                }
3167                            ]
3168                        }
3169                    ]
3170                }
3171            ]
3172        });
3173
3174        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
3175        assert_eq!(filter.query_ids, None);
3176        assert!(filter.where_clause.is_some());
3177
3178        // Verify top-level structure
3179        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
3180            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
3181            assert_eq!(composite.children.len(), 3);
3182        } else {
3183            panic!("Expected AND composite at top level");
3184        }
3185    }
3186
3187    #[test]
3188    fn test_limit_json_serialization() {
3189        let limit = Limit {
3190            offset: 10,
3191            limit: Some(20),
3192        };
3193
3194        let json = serde_json::to_string(&limit).unwrap();
3195        let deserialized: Limit = serde_json::from_str(&json).unwrap();
3196
3197        assert_eq!(deserialized.offset, limit.offset);
3198        assert_eq!(deserialized.limit, limit.limit);
3199    }
3200
3201    #[test]
3202    fn test_query_vector_json_serialization() {
3203        // Test dense vector
3204        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
3205        let json = serde_json::to_string(&dense).unwrap();
3206        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3207        assert_eq!(deserialized, dense);
3208
3209        // Test sparse vector
3210        let sparse =
3211            QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
3212        let json = serde_json::to_string(&sparse).unwrap();
3213        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
3214        assert_eq!(deserialized, sparse);
3215    }
3216
3217    #[test]
3218    fn test_select_key_json_serialization() {
3219        use std::collections::HashSet;
3220
3221        // Test predefined keys
3222        let doc_key = Key::Document;
3223        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
3224
3225        let embed_key = Key::Embedding;
3226        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3227
3228        let meta_key = Key::Metadata;
3229        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3230
3231        let score_key = Key::Score;
3232        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3233
3234        // Test metadata key
3235        let custom_key = Key::MetadataField("custom_key".to_string());
3236        assert_eq!(
3237            serde_json::to_string(&custom_key).unwrap(),
3238            "\"custom_key\""
3239        );
3240
3241        // Test deserialization
3242        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3243        assert!(matches!(deserialized, Key::Document));
3244
3245        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3246        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3247
3248        // Test Select struct with multiple keys
3249        let mut keys = HashSet::new();
3250        keys.insert(Key::Document);
3251        keys.insert(Key::Embedding);
3252        keys.insert(Key::MetadataField("author".to_string()));
3253
3254        let select = Select { keys };
3255        let json = serde_json::to_string(&select).unwrap();
3256        let deserialized: Select = serde_json::from_str(&json).unwrap();
3257
3258        assert_eq!(deserialized.keys.len(), 3);
3259        assert!(deserialized.keys.contains(&Key::Document));
3260        assert!(deserialized.keys.contains(&Key::Embedding));
3261        assert!(deserialized
3262            .keys
3263            .contains(&Key::MetadataField("author".to_string())));
3264    }
3265
3266    #[test]
3267    fn test_merge_basic_integers() {
3268        use std::cmp::Reverse;
3269
3270        let merge = Merge { k: 5 };
3271
3272        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
3273        let input = vec![
3274            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3275            vec![Reverse(2), Reverse(5), Reverse(8)],
3276            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3277        ];
3278
3279        let result = merge.merge(input);
3280
3281        // Should get top-5 smallest values (largest Reverse values)
3282        assert_eq!(result.len(), 5);
3283        assert_eq!(
3284            result,
3285            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3286        );
3287    }
3288
3289    #[test]
3290    fn test_merge_u32_descending() {
3291        let merge = Merge { k: 6 };
3292
3293        // Regular u32 in descending order (largest first)
3294        let input = vec![
3295            vec![100u32, 75, 50, 25],
3296            vec![90, 60, 30],
3297            vec![95, 85, 70, 40, 10],
3298        ];
3299
3300        let result = merge.merge(input);
3301
3302        // Should get top-6 largest u32 values
3303        assert_eq!(result.len(), 6);
3304        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3305    }
3306
3307    #[test]
3308    fn test_merge_i32_descending() {
3309        let merge = Merge { k: 5 };
3310
3311        // i32 values in descending order (including negatives)
3312        let input = vec![
3313            vec![50i32, 10, -10, -50],
3314            vec![30, 0, -30],
3315            vec![40, 20, -20, -40],
3316        ];
3317
3318        let result = merge.merge(input);
3319
3320        // Should get top-5 largest i32 values
3321        assert_eq!(result.len(), 5);
3322        assert_eq!(result, vec![50, 40, 30, 20, 10]);
3323    }
3324
3325    #[test]
3326    fn test_merge_with_duplicates() {
3327        let merge = Merge { k: 10 };
3328
3329        // Input with duplicates using regular u32 in descending order
3330        let input = vec![
3331            vec![100u32, 80, 80, 60, 40],
3332            vec![90, 80, 50, 30],
3333            vec![100, 70, 60, 20],
3334        ];
3335
3336        let result = merge.merge(input);
3337
3338        // Duplicates should be removed
3339        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3340    }
3341
3342    #[test]
3343    fn test_merge_empty_vectors() {
3344        let merge = Merge { k: 5 };
3345
3346        // All empty with u32
3347        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3348        let result = merge.merge(input);
3349        assert_eq!(result.len(), 0);
3350
3351        // Some empty, some with data (u64)
3352        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3353        let result = merge.merge(input);
3354        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3355
3356        // Single non-empty vector (i32)
3357        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3358        let result = merge.merge(input);
3359        assert_eq!(result, vec![100, 50, 25]);
3360    }
3361
3362    #[test]
3363    fn test_merge_k_boundary_conditions() {
3364        // k = 0 with u32
3365        let merge = Merge { k: 0 };
3366        let input = vec![vec![100u32, 50], vec![75, 25]];
3367        let result = merge.merge(input);
3368        assert_eq!(result.len(), 0);
3369
3370        // k = 1 with i64
3371        let merge = Merge { k: 1 };
3372        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3373        let result = merge.merge(input);
3374        assert_eq!(result, vec![1000]);
3375
3376        // k larger than total unique elements with u128
3377        let merge = Merge { k: 100 };
3378        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3379        let result = merge.merge(input);
3380        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3381    }
3382
3383    #[test]
3384    fn test_merge_with_strings() {
3385        let merge = Merge { k: 4 };
3386
3387        // Strings must be sorted in descending order (largest first) for the max heap merge
3388        let input = vec![
3389            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3390            vec!["elephant".to_string(), "banana".to_string()],
3391            vec!["fish".to_string(), "cat".to_string()],
3392        ];
3393
3394        let result = merge.merge(input);
3395
3396        // Should get top-4 lexicographically largest strings
3397        assert_eq!(
3398            result,
3399            vec![
3400                "zebra".to_string(),
3401                "fish".to_string(),
3402                "elephant".to_string(),
3403                "dog".to_string()
3404            ]
3405        );
3406    }
3407
3408    #[test]
3409    fn test_merge_with_custom_struct() {
3410        #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
3411        struct Score {
3412            value: i32,
3413            id: String,
3414        }
3415
3416        let merge = Merge { k: 3 };
3417
3418        // Custom structs sorted by value (descending), then by id
3419        let input = vec![
3420            vec![
3421                Score {
3422                    value: 100,
3423                    id: "a".to_string(),
3424                },
3425                Score {
3426                    value: 80,
3427                    id: "b".to_string(),
3428                },
3429                Score {
3430                    value: 60,
3431                    id: "c".to_string(),
3432                },
3433            ],
3434            vec![
3435                Score {
3436                    value: 90,
3437                    id: "d".to_string(),
3438                },
3439                Score {
3440                    value: 70,
3441                    id: "e".to_string(),
3442                },
3443            ],
3444            vec![
3445                Score {
3446                    value: 95,
3447                    id: "f".to_string(),
3448                },
3449                Score {
3450                    value: 85,
3451                    id: "g".to_string(),
3452                },
3453            ],
3454        ];
3455
3456        let result = merge.merge(input);
3457
3458        assert_eq!(result.len(), 3);
3459        assert_eq!(
3460            result[0],
3461            Score {
3462                value: 100,
3463                id: "a".to_string()
3464            }
3465        );
3466        assert_eq!(
3467            result[1],
3468            Score {
3469                value: 95,
3470                id: "f".to_string()
3471            }
3472        );
3473        assert_eq!(
3474            result[2],
3475            Score {
3476                value: 90,
3477                id: "d".to_string()
3478            }
3479        );
3480    }
3481
3482    #[test]
3483    fn test_merge_preserves_order() {
3484        use std::cmp::Reverse;
3485
3486        let merge = Merge { k: 10 };
3487
3488        // For Reverse, smaller inner values are "larger" in ordering
3489        // So vectors should be sorted with smallest inner values first
3490        let input = vec![
3491            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3492            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3493            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3494        ];
3495
3496        let result = merge.merge(input);
3497
3498        // Verify output maintains order - should be sorted by Reverse ordering
3499        // which means ascending inner values
3500        for i in 1..result.len() {
3501            assert!(
3502                result[i - 1] >= result[i],
3503                "Output should be in descending Reverse order"
3504            );
3505            assert!(
3506                result[i - 1].0 <= result[i].0,
3507                "Inner values should be in ascending order"
3508            );
3509        }
3510
3511        // Check we got the right elements
3512        assert_eq!(
3513            result,
3514            vec![
3515                Reverse(1),
3516                Reverse(2),
3517                Reverse(3),
3518                Reverse(4),
3519                Reverse(5),
3520                Reverse(6),
3521                Reverse(7),
3522                Reverse(8),
3523                Reverse(9),
3524                Reverse(10)
3525            ]
3526        );
3527    }
3528
3529    #[test]
3530    fn test_merge_single_vector() {
3531        let merge = Merge { k: 3 };
3532
3533        // Single vector input with u64
3534        let input = vec![vec![1000u64, 800, 600, 400, 200]];
3535
3536        let result = merge.merge(input);
3537
3538        assert_eq!(result, vec![1000, 800, 600]);
3539    }
3540
3541    #[test]
3542    fn test_merge_all_same_values() {
3543        let merge = Merge { k: 5 };
3544
3545        // All vectors contain the same value (using i16)
3546        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3547
3548        let result = merge.merge(input);
3549
3550        // Should deduplicate to single value
3551        assert_eq!(result, vec![42]);
3552    }
3553
3554    #[test]
3555    fn test_merge_mixed_types_sizes() {
3556        // Test with usize (common in real usage)
3557        let merge = Merge { k: 4 };
3558        let input = vec![
3559            vec![1000usize, 500, 100],
3560            vec![800, 300],
3561            vec![900, 600, 200],
3562        ];
3563        let result = merge.merge(input);
3564        assert_eq!(result, vec![1000, 900, 800, 600]);
3565
3566        // Test with negative integers (i32)
3567        let merge = Merge { k: 5 };
3568        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3569        let result = merge.merge(input);
3570        assert_eq!(result, vec![15, 10, 5, 0, -5]);
3571    }
3572
3573    #[test]
3574    fn test_merge_dedup_same_id_different_scores() {
3575        use std::cmp::Reverse;
3576
3577        // Simulates quantized distance estimation: the same offset_id appears
3578        // in multiple posting lists with different approximate distances.
3579        // Merge should keep only the first (best-scored) occurrence per id.
3580        let merge = Merge { k: 5 };
3581
3582        // Three posting lists with overlapping IDs and different estimated distances.
3583        // Using Reverse<RecordMeasure> to match the real knn_merge call site:
3584        // the heap acts as a min-heap on measure (smallest distance = best).
3585        let input: Vec<Vec<Reverse<RecordMeasure>>> = vec![
3586            vec![
3587                Reverse(RecordMeasure {
3588                    offset_id: 1,
3589                    measure: 0.10,
3590                }),
3591                Reverse(RecordMeasure {
3592                    offset_id: 4,
3593                    measure: 0.50,
3594                }),
3595                Reverse(RecordMeasure {
3596                    offset_id: 5,
3597                    measure: 0.70,
3598                }),
3599            ],
3600            vec![
3601                Reverse(RecordMeasure {
3602                    offset_id: 2,
3603                    measure: 0.20,
3604                }),
3605                Reverse(RecordMeasure {
3606                    offset_id: 1,
3607                    measure: 0.25,
3608                }), // dup id=1, worse
3609                Reverse(RecordMeasure {
3610                    offset_id: 3,
3611                    measure: 0.60,
3612                }),
3613            ],
3614            vec![
3615                Reverse(RecordMeasure {
3616                    offset_id: 3,
3617                    measure: 0.15,
3618                }), // dup id=3, better than 0.60
3619                Reverse(RecordMeasure {
3620                    offset_id: 4,
3621                    measure: 0.35,
3622                }), // dup id=4, better than 0.50
3623                Reverse(RecordMeasure {
3624                    offset_id: 2,
3625                    measure: 0.80,
3626                }), // dup id=2, worse
3627            ],
3628        ];
3629
3630        // Expected merge order (ascending distance via Reverse min-heap):
3631        //   pop 0.10 id=1 → keep
3632        //   pop 0.15 id=3 → keep
3633        //   pop 0.20 id=2 → keep
3634        //   pop 0.25 id=1 → dup, skip
3635        //   pop 0.35 id=4 → keep
3636        //   pop 0.50 id=4 → dup, skip
3637        //   pop 0.60 id=3 → dup, skip
3638        //   pop 0.70 id=5 → keep (5th unique)
3639        let result: Vec<Reverse<RecordMeasure>> = merge.merge(input);
3640        let ids: Vec<u32> = result.iter().map(|Reverse(r)| r.offset_id).collect();
3641        let measures: Vec<f32> = result.iter().map(|Reverse(r)| r.measure).collect();
3642
3643        assert_eq!(ids, vec![1, 3, 2, 4, 5]);
3644        assert_eq!(measures, vec![0.10, 0.15, 0.20, 0.35, 0.70]);
3645    }
3646
3647    #[test]
3648    fn test_aggregate_json_serialization() {
3649        // Test MinK serialization
3650        let min_k = Aggregate::MinK {
3651            keys: vec![Key::Score, Key::field("date")],
3652            k: 3,
3653        };
3654        let json = serde_json::to_value(&min_k).unwrap();
3655        assert!(json.get("$min_k").is_some());
3656        assert_eq!(json["$min_k"]["k"], 3);
3657
3658        // Test MinK deserialization
3659        let min_k_json = serde_json::json!({
3660            "$min_k": {
3661                "keys": ["#score", "date"],
3662                "k": 5
3663            }
3664        });
3665        let deserialized: Aggregate = serde_json::from_value(min_k_json).unwrap();
3666        match deserialized {
3667            Aggregate::MinK { keys, k } => {
3668                assert_eq!(k, 5);
3669                assert_eq!(keys.len(), 2);
3670                assert_eq!(keys[0], Key::Score);
3671                assert_eq!(keys[1], Key::field("date"));
3672            }
3673            _ => panic!("Expected MinK"),
3674        }
3675
3676        // Test MaxK serialization
3677        let max_k = Aggregate::MaxK {
3678            keys: vec![Key::field("timestamp")],
3679            k: 10,
3680        };
3681        let json = serde_json::to_value(&max_k).unwrap();
3682        assert!(json.get("$max_k").is_some());
3683        assert_eq!(json["$max_k"]["k"], 10);
3684
3685        // Test MaxK deserialization
3686        let max_k_json = serde_json::json!({
3687            "$max_k": {
3688                "keys": ["timestamp"],
3689                "k": 2
3690            }
3691        });
3692        let deserialized: Aggregate = serde_json::from_value(max_k_json).unwrap();
3693        match deserialized {
3694            Aggregate::MaxK { keys, k } => {
3695                assert_eq!(k, 2);
3696                assert_eq!(keys.len(), 1);
3697                assert_eq!(keys[0], Key::field("timestamp"));
3698            }
3699            _ => panic!("Expected MaxK"),
3700        }
3701    }
3702
3703    #[test]
3704    fn test_group_by_json_serialization() {
3705        // Test GroupBy with MinK
3706        let group_by = GroupBy {
3707            keys: vec![Key::field("category"), Key::field("author")],
3708            aggregate: Some(Aggregate::MinK {
3709                keys: vec![Key::Score],
3710                k: 3,
3711            }),
3712        };
3713
3714        let json = serde_json::to_value(&group_by).unwrap();
3715        assert_eq!(json["keys"].as_array().unwrap().len(), 2);
3716        assert!(json["aggregate"]["$min_k"].is_object());
3717
3718        // Test roundtrip
3719        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3720        assert_eq!(deserialized.keys.len(), 2);
3721        assert_eq!(deserialized.keys[0], Key::field("category"));
3722        assert_eq!(deserialized.keys[1], Key::field("author"));
3723        assert!(deserialized.aggregate.is_some());
3724
3725        // Test empty GroupBy
3726        let empty_group_by = GroupBy::default();
3727        let json = serde_json::to_value(&empty_group_by).unwrap();
3728        let deserialized: GroupBy = serde_json::from_value(json).unwrap();
3729        assert!(deserialized.keys.is_empty());
3730        assert!(deserialized.aggregate.is_none());
3731
3732        // Test deserialization from JSON
3733        let json = serde_json::json!({
3734            "keys": ["category"],
3735            "aggregate": {
3736                "$max_k": {
3737                    "keys": ["#score", "priority"],
3738                    "k": 5
3739                }
3740            }
3741        });
3742        let group_by: GroupBy = serde_json::from_value(json).unwrap();
3743        assert_eq!(group_by.keys.len(), 1);
3744        assert_eq!(group_by.keys[0], Key::field("category"));
3745        match group_by.aggregate {
3746            Some(Aggregate::MaxK { keys, k }) => {
3747                assert_eq!(k, 5);
3748                assert_eq!(keys.len(), 2);
3749                assert_eq!(keys[0], Key::Score);
3750            }
3751            _ => panic!("Expected MaxK aggregate"),
3752        }
3753    }
3754}