chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::Hash,
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15    MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16    Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// Filter the search results.
122///
123/// Combines document ID filtering with metadata and document content predicates.
124/// For the Search API, use `where_clause` with Key expressions.
125///
126/// # Fields
127///
128/// * `query_ids` - Optional list of document IDs to filter (legacy, prefer Where expressions)
129/// * `where_clause` - Predicate on document metadata, content, or IDs
130///
131/// # Examples
132///
133/// ## Simple metadata filter
134///
135/// ```
136/// use chroma_types::operator::{Filter, Key};
137///
138/// let filter = Filter {
139///     query_ids: None,
140///     where_clause: Some(Key::field("status").eq("published")),
141/// };
142/// ```
143///
144/// ## Combined filters
145///
146/// ```
147/// use chroma_types::operator::{Filter, Key};
148///
149/// let filter = Filter {
150///     query_ids: None,
151///     where_clause: Some(
152///         Key::field("status").eq("published")
153///             & Key::field("year").gte(2020)
154///             & Key::field("category").is_in(vec!["tech", "science"])
155///     ),
156/// };
157/// ```
158///
159/// ## Document content filter
160///
161/// ```
162/// use chroma_types::operator::{Filter, Key};
163///
164/// let filter = Filter {
165///     query_ids: None,
166///     where_clause: Some(Key::Document.contains("machine learning")),
167/// };
168/// ```
169#[derive(Clone, Debug, Default)]
170pub struct Filter {
171    pub query_ids: Option<Vec<String>>,
172    pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        // For the search API, serialize directly as the where clause (or empty object if None)
181        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
182
183        match (&self.query_ids, &self.where_clause) {
184            (None, None) => {
185                // No filter at all - serialize empty object
186                let map = serializer.serialize_map(Some(0))?;
187                map.end()
188            }
189            (None, Some(where_clause)) => {
190                // Only where clause - serialize it directly
191                where_clause.serialize(serializer)
192            }
193            (Some(ids), None) => {
194                // Only query_ids - create Where clause: Key::ID.is_in(ids)
195                let id_where = Where::Metadata(MetadataExpression {
196                    key: "#id".to_string(),
197                    comparison: MetadataComparison::Set(
198                        SetOperator::In,
199                        MetadataSetValue::Str(ids.clone()),
200                    ),
201                });
202                id_where.serialize(serializer)
203            }
204            (Some(ids), Some(where_clause)) => {
205                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
206                let id_where = Where::Metadata(MetadataExpression {
207                    key: "#id".to_string(),
208                    comparison: MetadataComparison::Set(
209                        SetOperator::In,
210                        MetadataSetValue::Str(ids.clone()),
211                    ),
212                });
213                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214                combined.serialize(serializer)
215            }
216        }
217    }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        // For the new search API, the entire JSON is the where clause
226        let where_json = Value::deserialize(deserializer)?;
227        let where_clause =
228            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229                None
230            } else {
231                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232            };
233
234        Ok(Filter {
235            query_ids: None, // Always None for new search API
236            where_clause,
237        })
238    }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242    type Error = QueryConversionError;
243
244    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247        let where_clause = match (where_metadata, where_document) {
248            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249            (Some(w), None) | (None, Some(w)) => Some(w),
250            _ => None,
251        };
252
253        Ok(Self {
254            query_ids: value.ids.map(|uids| uids.ids),
255            where_clause,
256        })
257    }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261    type Error = QueryConversionError;
262
263    fn try_from(value: Filter) -> Result<Self, Self::Error> {
264        Ok(Self {
265            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267            where_document: None,
268        })
269    }
270}
271
272/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
273///
274/// # Parameters
275/// - `embedding`: The target embedding to search around
276/// - `fetch`: The number of records to fetch around the target
277#[derive(Clone, Debug)]
278pub struct Knn {
279    pub embedding: Vec<f32>,
280    pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284    fn from(value: KnnBatch) -> Self {
285        value
286            .embeddings
287            .into_iter()
288            .map(|embedding| Knn {
289                embedding,
290                fetch: value.fetch,
291            })
292            .collect()
293    }
294}
295
296/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
297///
298/// # Parameters
299/// - `embedding`: The target embedding to search around
300/// - `fetch`: The number of records to fetch around the target
301#[derive(Clone, Debug)]
302pub struct KnnBatch {
303    pub embeddings: Vec<Vec<f32>>,
304    pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308    type Error = QueryConversionError;
309
310    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311        Ok(Self {
312            embeddings: value
313                .embeddings
314                .into_iter()
315                .map(|vec| vec.try_into().map(|(v, _)| v))
316                .collect::<Result<_, _>>()?,
317            fetch: value.fetch,
318        })
319    }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323    type Error = QueryConversionError;
324
325    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326        Ok(Self {
327            embeddings: value
328                .embeddings
329                .into_iter()
330                .map(|embedding| {
331                    let dim = embedding.len();
332                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333                })
334                .collect::<Result<_, _>>()?,
335            fetch: value.fetch,
336        })
337    }
338}
339
340/// Pagination control for search results.
341///
342/// Controls how many results to return and how many to skip for pagination.
343///
344/// # Fields
345///
346/// * `offset` - Number of results to skip (default: 0)
347/// * `limit` - Maximum results to return (None = no limit)
348///
349/// # Examples
350///
351/// ```
352/// use chroma_types::operator::Limit;
353///
354/// // First page: results 0-9
355/// let limit = Limit {
356///     offset: 0,
357///     limit: Some(10),
358/// };
359///
360/// // Second page: results 10-19
361/// let limit = Limit {
362///     offset: 10,
363///     limit: Some(10),
364/// };
365///
366/// // No limit: all results
367/// let limit = Limit {
368///     offset: 0,
369///     limit: None,
370/// };
371/// ```
372#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374    #[serde(default)]
375    pub offset: u32,
376    #[serde(default)]
377    pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381    fn from(value: chroma_proto::LimitOperator) -> Self {
382        Self {
383            offset: value.offset,
384            limit: value.limit,
385        }
386    }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390    fn from(value: Limit) -> Self {
391        Self {
392            offset: value.offset,
393            limit: value.limit,
394        }
395    }
396}
397
398/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
399#[derive(Clone, Debug)]
400pub struct RecordMeasure {
401    pub offset_id: u32,
402    pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406    fn eq(&self, other: &Self) -> bool {
407        self.offset_id.eq(&other.offset_id)
408    }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414    fn cmp(&self, other: &Self) -> Ordering {
415        self.measure
416            .total_cmp(&other.measure)
417            .then_with(|| self.offset_id.cmp(&other.offset_id))
418    }
419}
420
421impl PartialOrd for RecordMeasure {
422    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423        Some(self.cmp(other))
424    }
425}
426
427#[derive(Debug, Default)]
428pub struct KnnOutput {
429    pub distances: Vec<RecordMeasure>,
430}
431
432/// The `Merge` operator selects the top records from the batch vectors of records
433/// which are all sorted in descending order. If the same record occurs multiple times
434/// only one copy will remain in the final result.
435///
436/// # Parameters
437/// - `k`: The total number of records to take after merge
438///
439/// # Usage
440/// It can be used to merge the query results from different operators
441#[derive(Clone, Debug)]
442pub struct Merge {
443    pub k: u32,
444}
445
446impl Merge {
447    pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
448        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
449
450        let mut max_heap = batch_iters
451            .iter_mut()
452            .enumerate()
453            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
454            .collect::<BinaryHeap<_>>();
455
456        let mut fusion = Vec::with_capacity(self.k as usize);
457        while let Some((m, idx)) = max_heap.pop() {
458            if self.k <= fusion.len() as u32 {
459                break;
460            }
461            if let Some(next_m) = batch_iters[idx].next() {
462                max_heap.push((next_m, idx));
463            }
464            if fusion.last().is_some_and(|tail| tail == &m) {
465                continue;
466            }
467            fusion.push(m);
468        }
469        fusion
470    }
471}
472
473/// The `Projection` operator retrieves record content by offset ids
474///
475/// # Parameters
476/// - `document`: Whether to retrieve document
477/// - `embedding`: Whether to retrieve embedding
478/// - `metadata`: Whether to retrieve metadata
479#[derive(Clone, Debug, Default)]
480pub struct Projection {
481    pub document: bool,
482    pub embedding: bool,
483    pub metadata: bool,
484}
485
486impl From<chroma_proto::ProjectionOperator> for Projection {
487    fn from(value: chroma_proto::ProjectionOperator) -> Self {
488        Self {
489            document: value.document,
490            embedding: value.embedding,
491            metadata: value.metadata,
492        }
493    }
494}
495
496impl From<Projection> for chroma_proto::ProjectionOperator {
497    fn from(value: Projection) -> Self {
498        Self {
499            document: value.document,
500            embedding: value.embedding,
501            metadata: value.metadata,
502        }
503    }
504}
505
506#[derive(Clone, Debug, PartialEq)]
507pub struct ProjectionRecord {
508    pub id: String,
509    pub document: Option<String>,
510    pub embedding: Option<Vec<f32>>,
511    pub metadata: Option<Metadata>,
512}
513
514impl ProjectionRecord {
515    pub fn size_bytes(&self) -> u64 {
516        (self.id.len()
517            + self
518                .document
519                .as_ref()
520                .map(|doc| doc.len())
521                .unwrap_or_default()
522            + self
523                .embedding
524                .as_ref()
525                .map(|emb| size_of_val(&emb[..]))
526                .unwrap_or_default()
527            + self
528                .metadata
529                .as_ref()
530                .map(logical_size_of_metadata)
531                .unwrap_or_default()) as u64
532    }
533}
534
535impl Eq for ProjectionRecord {}
536
537impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
538    type Error = QueryConversionError;
539
540    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
541        Ok(Self {
542            id: value.id,
543            document: value.document,
544            embedding: value
545                .embedding
546                .map(|vec| vec.try_into().map(|(v, _)| v))
547                .transpose()?,
548            metadata: value.metadata.map(TryInto::try_into).transpose()?,
549        })
550    }
551}
552
553impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
554    type Error = QueryConversionError;
555
556    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
557        Ok(Self {
558            id: value.id,
559            document: value.document,
560            embedding: value
561                .embedding
562                .map(|embedding| {
563                    let embedding_dimension = embedding.len();
564                    chroma_proto::Vector::try_from((
565                        embedding,
566                        ScalarEncoding::FLOAT32,
567                        embedding_dimension,
568                    ))
569                })
570                .transpose()?,
571            metadata: value.metadata.map(|metadata| metadata.into()),
572        })
573    }
574}
575
576#[derive(Clone, Debug, Eq, PartialEq)]
577pub struct ProjectionOutput {
578    pub records: Vec<ProjectionRecord>,
579}
580
581#[derive(Clone, Debug, Eq, PartialEq)]
582pub struct GetResult {
583    pub pulled_log_bytes: u64,
584    pub result: ProjectionOutput,
585}
586
587impl GetResult {
588    pub fn size_bytes(&self) -> u64 {
589        self.result
590            .records
591            .iter()
592            .map(ProjectionRecord::size_bytes)
593            .sum()
594    }
595}
596
597impl TryFrom<chroma_proto::GetResult> for GetResult {
598    type Error = QueryConversionError;
599
600    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
601        Ok(Self {
602            pulled_log_bytes: value.pulled_log_bytes,
603            result: ProjectionOutput {
604                records: value
605                    .records
606                    .into_iter()
607                    .map(TryInto::try_into)
608                    .collect::<Result<_, _>>()?,
609            },
610        })
611    }
612}
613
614impl TryFrom<GetResult> for chroma_proto::GetResult {
615    type Error = QueryConversionError;
616
617    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
618        Ok(Self {
619            pulled_log_bytes: value.pulled_log_bytes,
620            records: value
621                .result
622                .records
623                .into_iter()
624                .map(TryInto::try_into)
625                .collect::<Result<_, _>>()?,
626        })
627    }
628}
629
630/// The `KnnProjection` operator retrieves record content by offset ids
631/// It is based on `ProjectionOperator`, and it attaches the distance
632/// of the records to the target embedding to the record content
633///
634/// # Parameters
635/// - `projection`: The parameters of the `ProjectionOperator`
636/// - `distance`: Whether to attach distance information
637#[derive(Clone, Debug)]
638pub struct KnnProjection {
639    pub projection: Projection,
640    pub distance: bool,
641}
642
643impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
644    type Error = QueryConversionError;
645
646    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
647        Ok(Self {
648            projection: value
649                .projection
650                .ok_or(QueryConversionError::field("projection"))?
651                .into(),
652            distance: value.distance,
653        })
654    }
655}
656
657impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
658    fn from(value: KnnProjection) -> Self {
659        Self {
660            projection: Some(value.projection.into()),
661            distance: value.distance,
662        }
663    }
664}
665
666#[derive(Clone, Debug)]
667pub struct KnnProjectionRecord {
668    pub record: ProjectionRecord,
669    pub distance: Option<f32>,
670}
671
672impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
673    type Error = QueryConversionError;
674
675    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
676        Ok(Self {
677            record: value
678                .record
679                .ok_or(QueryConversionError::field("record"))?
680                .try_into()?,
681            distance: value.distance,
682        })
683    }
684}
685
686impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
687    type Error = QueryConversionError;
688
689    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
690        Ok(Self {
691            record: Some(value.record.try_into()?),
692            distance: value.distance,
693        })
694    }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct KnnProjectionOutput {
699    pub records: Vec<KnnProjectionRecord>,
700}
701
702impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
703    type Error = QueryConversionError;
704
705    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
706        Ok(Self {
707            records: value
708                .records
709                .into_iter()
710                .map(TryInto::try_into)
711                .collect::<Result<_, _>>()?,
712        })
713    }
714}
715
716impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
717    type Error = QueryConversionError;
718
719    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
720        Ok(Self {
721            records: value
722                .records
723                .into_iter()
724                .map(TryInto::try_into)
725                .collect::<Result<_, _>>()?,
726        })
727    }
728}
729
730#[derive(Clone, Debug, Default)]
731pub struct KnnBatchResult {
732    pub pulled_log_bytes: u64,
733    pub results: Vec<KnnProjectionOutput>,
734}
735
736impl KnnBatchResult {
737    pub fn size_bytes(&self) -> u64 {
738        self.results
739            .iter()
740            .flat_map(|res| {
741                res.records
742                    .iter()
743                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
744            })
745            .sum()
746    }
747}
748
749impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
750    type Error = QueryConversionError;
751
752    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
753        Ok(Self {
754            pulled_log_bytes: value.pulled_log_bytes,
755            results: value
756                .results
757                .into_iter()
758                .map(TryInto::try_into)
759                .collect::<Result<_, _>>()?,
760        })
761    }
762}
763
764impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
765    type Error = QueryConversionError;
766
767    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
768        Ok(Self {
769            pulled_log_bytes: value.pulled_log_bytes,
770            results: value
771                .results
772                .into_iter()
773                .map(TryInto::try_into)
774                .collect::<Result<_, _>>()?,
775        })
776    }
777}
778
779/// A query vector for KNN search.
780///
781/// Supports both dense and sparse vector formats.
782///
783/// # Variants
784///
785/// ## Dense
786///
787/// Standard dense embeddings as a vector of floats.
788///
789/// ```
790/// use chroma_types::operator::QueryVector;
791///
792/// let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3, 0.4]);
793/// ```
794///
795/// ## Sparse
796///
797/// Sparse vectors with explicit indices and values.
798///
799/// ```
800/// use chroma_types::operator::QueryVector;
801/// use chroma_types::SparseVector;
802///
803/// let sparse = QueryVector::Sparse(SparseVector::new(
804///     vec![0, 5, 10, 50],      // indices
805///     vec![0.5, 0.3, 0.8, 0.2], // values
806/// ).unwrap());
807/// ```
808///
809/// # Examples
810///
811/// ## Dense vector in KNN
812///
813/// ```
814/// use chroma_types::operator::{RankExpr, QueryVector, Key};
815///
816/// let rank = RankExpr::Knn {
817///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
818///     key: Key::Embedding,
819///     limit: 100,
820///     default: None,
821///     return_rank: false,
822/// };
823/// ```
824///
825/// ## Sparse vector in KNN
826///
827/// ```
828/// use chroma_types::operator::{RankExpr, QueryVector, Key};
829/// use chroma_types::SparseVector;
830///
831/// let rank = RankExpr::Knn {
832///     query: QueryVector::Sparse(SparseVector::new(
833///         vec![1, 5, 10],
834///         vec![0.5, 0.3, 0.8],
835///     ).unwrap()),
836///     key: Key::field("sparse_embedding"),
837///     limit: 100,
838///     default: None,
839///     return_rank: false,
840/// };
841/// ```
842#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
843#[serde(untagged)]
844pub enum QueryVector {
845    Dense(Vec<f32>),
846    Sparse(SparseVector),
847}
848
849impl TryFrom<chroma_proto::QueryVector> for QueryVector {
850    type Error = QueryConversionError;
851
852    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
853        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
854        match vector {
855            chroma_proto::query_vector::Vector::Dense(dense) => {
856                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
857            }
858            chroma_proto::query_vector::Vector::Sparse(sparse) => {
859                Ok(QueryVector::Sparse(sparse.try_into().map_err(|_| {
860                    QueryConversionError::validation("sparse vector length mismatch")
861                })?))
862            }
863        }
864    }
865}
866
867impl TryFrom<QueryVector> for chroma_proto::QueryVector {
868    type Error = QueryConversionError;
869
870    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
871        match value {
872            QueryVector::Dense(vec) => {
873                let dim = vec.len();
874                Ok(chroma_proto::QueryVector {
875                    vector: Some(chroma_proto::query_vector::Vector::Dense(
876                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
877                    )),
878                })
879            }
880            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
881                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
882            }),
883        }
884    }
885}
886
887impl From<Vec<f32>> for QueryVector {
888    fn from(vec: Vec<f32>) -> Self {
889        QueryVector::Dense(vec)
890    }
891}
892
893impl From<SparseVector> for QueryVector {
894    fn from(sparse: SparseVector) -> Self {
895        QueryVector::Sparse(sparse)
896    }
897}
898
899#[derive(Clone, Debug, PartialEq)]
900pub struct KnnQuery {
901    pub query: QueryVector,
902    pub key: Key,
903    pub limit: u32,
904}
905
906/// Wrapper for ranking expressions in search queries.
907///
908/// Contains an optional ranking expression. When None, results are returned in
909/// natural storage order without scoring.
910///
911/// # Fields
912///
913/// * `expr` - The ranking expression (None = no ranking)
914///
915/// # Examples
916///
917/// ```
918/// use chroma_types::operator::{Rank, RankExpr, QueryVector, Key};
919///
920/// // With ranking
921/// let rank = Rank {
922///     expr: Some(RankExpr::Knn {
923///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
924///         key: Key::Embedding,
925///         limit: 100,
926///         default: None,
927///         return_rank: false,
928///     }),
929/// };
930///
931/// // No ranking (natural order)
932/// let rank = Rank {
933///     expr: None,
934/// };
935/// ```
936#[derive(Clone, Debug, Default, Deserialize, Serialize)]
937#[serde(transparent)]
938pub struct Rank {
939    pub expr: Option<RankExpr>,
940}
941
942impl Rank {
943    pub fn knn_queries(&self) -> Vec<KnnQuery> {
944        self.expr
945            .as_ref()
946            .map(RankExpr::knn_queries)
947            .unwrap_or_default()
948    }
949}
950
951impl TryFrom<chroma_proto::RankOperator> for Rank {
952    type Error = QueryConversionError;
953
954    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
955        Ok(Rank {
956            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
957        })
958    }
959}
960
961impl TryFrom<Rank> for chroma_proto::RankOperator {
962    type Error = QueryConversionError;
963
964    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
965        Ok(chroma_proto::RankOperator {
966            expr: rank.expr.map(TryInto::try_into).transpose()?,
967        })
968    }
969}
970
971/// A ranking expression for scoring and ordering search results.
972///
973/// Ranking expressions determine which documents appear in results and their order.
974/// Lower scores indicate better matches (distance-based scoring).
975///
976/// # Variants
977///
978/// ## Knn - K-Nearest Neighbor Search
979///
980/// The primary ranking method for vector similarity search.
981///
982/// ```
983/// use chroma_types::operator::{RankExpr, QueryVector, Key};
984///
985/// let rank = RankExpr::Knn {
986///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
987///     key: Key::Embedding,
988///     limit: 100,        // Consider top 100 candidates
989///     default: None,     // No default score for missing documents
990///     return_rank: false, // Return distances, not rank positions
991/// };
992/// ```
993///
994/// ## Value - Constant
995///
996/// Represents a constant score.
997///
998/// ```
999/// use chroma_types::operator::RankExpr;
1000///
1001/// let rank = RankExpr::Value(0.5);
1002/// ```
1003///
1004/// ## Arithmetic Operations
1005///
1006/// Combine ranking expressions using standard operators (+, -, *, /).
1007///
1008/// ```
1009/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1010///
1011/// let knn1 = RankExpr::Knn {
1012///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1013///     key: Key::Embedding,
1014///     limit: 100,
1015///     default: None,
1016///     return_rank: false,
1017/// };
1018///
1019/// let knn2 = RankExpr::Knn {
1020///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1021///     key: Key::field("other_embedding"),
1022///     limit: 100,
1023///     default: None,
1024///     return_rank: false,
1025/// };
1026///
1027/// // Weighted combination: 70% knn1 + 30% knn2
1028/// let combined = knn1 * 0.7 + knn2 * 0.3;
1029///
1030/// // Normalized
1031/// let normalized = combined / 2.0;
1032/// ```
1033///
1034/// ## Mathematical Functions
1035///
1036/// Apply mathematical transformations to scores.
1037///
1038/// ```
1039/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1040///
1041/// let knn = RankExpr::Knn {
1042///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1043///     key: Key::Embedding,
1044///     limit: 100,
1045///     default: None,
1046///     return_rank: false,
1047/// };
1048///
1049/// // Exponential - amplifies differences
1050/// let amplified = knn.clone().exp();
1051///
1052/// // Logarithm - compresses range (add constant to avoid log(0))
1053/// let compressed = (knn.clone() + 1.0).log();
1054///
1055/// // Absolute value
1056/// let absolute = knn.clone().abs();
1057///
1058/// // Min/Max - clamping
1059/// let clamped = knn.min(1.0).max(0.0);
1060/// ```
1061///
1062/// # Examples
1063///
1064/// ## Basic vector search
1065///
1066/// ```
1067/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1068///
1069/// let rank = RankExpr::Knn {
1070///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1071///     key: Key::Embedding,
1072///     limit: 100,
1073///     default: None,
1074///     return_rank: false,
1075/// };
1076/// ```
1077///
1078/// ## Hybrid search with weighted combination
1079///
1080/// ```
1081/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1082///
1083/// let dense = RankExpr::Knn {
1084///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1085///     key: Key::Embedding,
1086///     limit: 200,
1087///     default: None,
1088///     return_rank: false,
1089/// };
1090///
1091/// let sparse = RankExpr::Knn {
1092///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]), // Use sparse in practice
1093///     key: Key::field("sparse_embedding"),
1094///     limit: 200,
1095///     default: None,
1096///     return_rank: false,
1097/// };
1098///
1099/// // 70% semantic + 30% keyword
1100/// let hybrid = dense * 0.7 + sparse * 0.3;
1101/// ```
1102///
1103/// ## Reciprocal Rank Fusion (RRF)
1104///
1105/// Use the `rrf()` function for combining rankings with different score scales.
1106///
1107/// ```
1108/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
1109///
1110/// let dense = RankExpr::Knn {
1111///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1112///     key: Key::Embedding,
1113///     limit: 200,
1114///     default: None,
1115///     return_rank: true, // RRF requires rank positions
1116/// };
1117///
1118/// let sparse = RankExpr::Knn {
1119///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1120///     key: Key::field("sparse_embedding"),
1121///     limit: 200,
1122///     default: None,
1123///     return_rank: true, // RRF requires rank positions
1124/// };
1125///
1126/// let rrf_rank = rrf(
1127///     vec![dense, sparse],
1128///     Some(60),           // k parameter (smoothing)
1129///     Some(vec![0.7, 0.3]), // weights
1130///     false,              // normalize weights
1131/// ).unwrap();
1132/// ```
1133#[derive(Clone, Debug, Deserialize, Serialize)]
1134pub enum RankExpr {
1135    #[serde(rename = "$abs")]
1136    Absolute(Box<RankExpr>),
1137    #[serde(rename = "$div")]
1138    Division {
1139        left: Box<RankExpr>,
1140        right: Box<RankExpr>,
1141    },
1142    #[serde(rename = "$exp")]
1143    Exponentiation(Box<RankExpr>),
1144    #[serde(rename = "$knn")]
1145    Knn {
1146        query: QueryVector,
1147        #[serde(default = "RankExpr::default_knn_key")]
1148        key: Key,
1149        #[serde(default = "RankExpr::default_knn_limit")]
1150        limit: u32,
1151        #[serde(default)]
1152        default: Option<f32>,
1153        #[serde(default)]
1154        return_rank: bool,
1155    },
1156    #[serde(rename = "$log")]
1157    Logarithm(Box<RankExpr>),
1158    #[serde(rename = "$max")]
1159    Maximum(Vec<RankExpr>),
1160    #[serde(rename = "$min")]
1161    Minimum(Vec<RankExpr>),
1162    #[serde(rename = "$mul")]
1163    Multiplication(Vec<RankExpr>),
1164    #[serde(rename = "$sub")]
1165    Subtraction {
1166        left: Box<RankExpr>,
1167        right: Box<RankExpr>,
1168    },
1169    #[serde(rename = "$sum")]
1170    Summation(Vec<RankExpr>),
1171    #[serde(rename = "$val")]
1172    Value(f32),
1173}
1174
1175impl RankExpr {
1176    pub fn default_knn_key() -> Key {
1177        Key::Embedding
1178    }
1179
1180    pub fn default_knn_limit() -> u32 {
1181        16
1182    }
1183
1184    pub fn knn_queries(&self) -> Vec<KnnQuery> {
1185        match self {
1186            RankExpr::Absolute(expr)
1187            | RankExpr::Exponentiation(expr)
1188            | RankExpr::Logarithm(expr) => expr.knn_queries(),
1189            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1190                .knn_queries()
1191                .into_iter()
1192                .chain(right.knn_queries())
1193                .collect(),
1194            RankExpr::Maximum(exprs)
1195            | RankExpr::Minimum(exprs)
1196            | RankExpr::Multiplication(exprs)
1197            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1198            RankExpr::Value(_) => Vec::new(),
1199            RankExpr::Knn {
1200                query,
1201                key,
1202                limit,
1203                default: _,
1204                return_rank: _,
1205            } => vec![KnnQuery {
1206                query: query.clone(),
1207                key: key.clone(),
1208                limit: *limit,
1209            }],
1210        }
1211    }
1212
1213    /// Applies exponential transformation: e^rank.
1214    ///
1215    /// Amplifies differences between scores.
1216    ///
1217    /// # Examples
1218    ///
1219    /// ```
1220    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1221    ///
1222    /// let knn = RankExpr::Knn {
1223    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1224    ///     key: Key::Embedding,
1225    ///     limit: 100,
1226    ///     default: None,
1227    ///     return_rank: false,
1228    /// };
1229    ///
1230    /// let amplified = knn.exp();
1231    /// ```
1232    pub fn exp(self) -> Self {
1233        RankExpr::Exponentiation(Box::new(self))
1234    }
1235
1236    /// Applies natural logarithm transformation: ln(rank).
1237    ///
1238    /// Compresses the score range. Add a constant to avoid log(0).
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1244    ///
1245    /// let knn = RankExpr::Knn {
1246    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1247    ///     key: Key::Embedding,
1248    ///     limit: 100,
1249    ///     default: None,
1250    ///     return_rank: false,
1251    /// };
1252    ///
1253    /// // Add constant to avoid log(0)
1254    /// let compressed = (knn + 1.0).log();
1255    /// ```
1256    pub fn log(self) -> Self {
1257        RankExpr::Logarithm(Box::new(self))
1258    }
1259
1260    /// Takes absolute value of the ranking expression.
1261    ///
1262    /// # Examples
1263    ///
1264    /// ```
1265    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1266    ///
1267    /// let knn1 = RankExpr::Knn {
1268    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1269    ///     key: Key::Embedding,
1270    ///     limit: 100,
1271    ///     default: None,
1272    ///     return_rank: false,
1273    /// };
1274    ///
1275    /// let knn2 = RankExpr::Knn {
1276    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1277    ///     key: Key::field("other"),
1278    ///     limit: 100,
1279    ///     default: None,
1280    ///     return_rank: false,
1281    /// };
1282    ///
1283    /// // Absolute difference
1284    /// let diff = (knn1 - knn2).abs();
1285    /// ```
1286    pub fn abs(self) -> Self {
1287        RankExpr::Absolute(Box::new(self))
1288    }
1289
1290    /// Returns maximum of this expression and another.
1291    ///
1292    /// Can be chained to clamp scores to a maximum value.
1293    ///
1294    /// # Examples
1295    ///
1296    /// ```
1297    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1298    ///
1299    /// let knn = RankExpr::Knn {
1300    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1301    ///     key: Key::Embedding,
1302    ///     limit: 100,
1303    ///     default: None,
1304    ///     return_rank: false,
1305    /// };
1306    ///
1307    /// // Clamp to maximum of 1.0
1308    /// let clamped = knn.clone().max(1.0);
1309    ///
1310    /// // Clamp to range [0.0, 1.0]
1311    /// let range_clamped = knn.min(0.0).max(1.0);
1312    /// ```
1313    pub fn max(self, other: impl Into<RankExpr>) -> Self {
1314        let other = other.into();
1315
1316        match self {
1317            RankExpr::Maximum(mut exprs) => match other {
1318                RankExpr::Maximum(other_exprs) => {
1319                    exprs.extend(other_exprs);
1320                    RankExpr::Maximum(exprs)
1321                }
1322                _ => {
1323                    exprs.push(other);
1324                    RankExpr::Maximum(exprs)
1325                }
1326            },
1327            _ => match other {
1328                RankExpr::Maximum(mut exprs) => {
1329                    exprs.insert(0, self);
1330                    RankExpr::Maximum(exprs)
1331                }
1332                _ => RankExpr::Maximum(vec![self, other]),
1333            },
1334        }
1335    }
1336
1337    /// Returns minimum of this expression and another.
1338    ///
1339    /// Can be chained to clamp scores to a minimum value.
1340    ///
1341    /// # Examples
1342    ///
1343    /// ```
1344    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1345    ///
1346    /// let knn = RankExpr::Knn {
1347    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1348    ///     key: Key::Embedding,
1349    ///     limit: 100,
1350    ///     default: None,
1351    ///     return_rank: false,
1352    /// };
1353    ///
1354    /// // Clamp to minimum of 0.0 (ensure non-negative)
1355    /// let clamped = knn.clone().min(0.0);
1356    ///
1357    /// // Clamp to range [0.0, 1.0]
1358    /// let range_clamped = knn.min(0.0).max(1.0);
1359    /// ```
1360    pub fn min(self, other: impl Into<RankExpr>) -> Self {
1361        let other = other.into();
1362
1363        match self {
1364            RankExpr::Minimum(mut exprs) => match other {
1365                RankExpr::Minimum(other_exprs) => {
1366                    exprs.extend(other_exprs);
1367                    RankExpr::Minimum(exprs)
1368                }
1369                _ => {
1370                    exprs.push(other);
1371                    RankExpr::Minimum(exprs)
1372                }
1373            },
1374            _ => match other {
1375                RankExpr::Minimum(mut exprs) => {
1376                    exprs.insert(0, self);
1377                    RankExpr::Minimum(exprs)
1378                }
1379                _ => RankExpr::Minimum(vec![self, other]),
1380            },
1381        }
1382    }
1383}
1384
1385impl Add for RankExpr {
1386    type Output = RankExpr;
1387
1388    fn add(self, rhs: Self) -> Self::Output {
1389        match self {
1390            RankExpr::Summation(mut exprs) => match rhs {
1391                RankExpr::Summation(rhs_exprs) => {
1392                    exprs.extend(rhs_exprs);
1393                    RankExpr::Summation(exprs)
1394                }
1395                _ => {
1396                    exprs.push(rhs);
1397                    RankExpr::Summation(exprs)
1398                }
1399            },
1400            _ => match rhs {
1401                RankExpr::Summation(mut exprs) => {
1402                    exprs.insert(0, self);
1403                    RankExpr::Summation(exprs)
1404                }
1405                _ => RankExpr::Summation(vec![self, rhs]),
1406            },
1407        }
1408    }
1409}
1410
1411impl Add<f32> for RankExpr {
1412    type Output = RankExpr;
1413
1414    fn add(self, rhs: f32) -> Self::Output {
1415        self + RankExpr::Value(rhs)
1416    }
1417}
1418
1419impl Add<RankExpr> for f32 {
1420    type Output = RankExpr;
1421
1422    fn add(self, rhs: RankExpr) -> Self::Output {
1423        RankExpr::Value(self) + rhs
1424    }
1425}
1426
1427impl Sub for RankExpr {
1428    type Output = RankExpr;
1429
1430    fn sub(self, rhs: Self) -> Self::Output {
1431        RankExpr::Subtraction {
1432            left: Box::new(self),
1433            right: Box::new(rhs),
1434        }
1435    }
1436}
1437
1438impl Sub<f32> for RankExpr {
1439    type Output = RankExpr;
1440
1441    fn sub(self, rhs: f32) -> Self::Output {
1442        self - RankExpr::Value(rhs)
1443    }
1444}
1445
1446impl Sub<RankExpr> for f32 {
1447    type Output = RankExpr;
1448
1449    fn sub(self, rhs: RankExpr) -> Self::Output {
1450        RankExpr::Value(self) - rhs
1451    }
1452}
1453
1454impl Mul for RankExpr {
1455    type Output = RankExpr;
1456
1457    fn mul(self, rhs: Self) -> Self::Output {
1458        match self {
1459            RankExpr::Multiplication(mut exprs) => match rhs {
1460                RankExpr::Multiplication(rhs_exprs) => {
1461                    exprs.extend(rhs_exprs);
1462                    RankExpr::Multiplication(exprs)
1463                }
1464                _ => {
1465                    exprs.push(rhs);
1466                    RankExpr::Multiplication(exprs)
1467                }
1468            },
1469            _ => match rhs {
1470                RankExpr::Multiplication(mut exprs) => {
1471                    exprs.insert(0, self);
1472                    RankExpr::Multiplication(exprs)
1473                }
1474                _ => RankExpr::Multiplication(vec![self, rhs]),
1475            },
1476        }
1477    }
1478}
1479
1480impl Mul<f32> for RankExpr {
1481    type Output = RankExpr;
1482
1483    fn mul(self, rhs: f32) -> Self::Output {
1484        self * RankExpr::Value(rhs)
1485    }
1486}
1487
1488impl Mul<RankExpr> for f32 {
1489    type Output = RankExpr;
1490
1491    fn mul(self, rhs: RankExpr) -> Self::Output {
1492        RankExpr::Value(self) * rhs
1493    }
1494}
1495
1496impl Div for RankExpr {
1497    type Output = RankExpr;
1498
1499    fn div(self, rhs: Self) -> Self::Output {
1500        RankExpr::Division {
1501            left: Box::new(self),
1502            right: Box::new(rhs),
1503        }
1504    }
1505}
1506
1507impl Div<f32> for RankExpr {
1508    type Output = RankExpr;
1509
1510    fn div(self, rhs: f32) -> Self::Output {
1511        self / RankExpr::Value(rhs)
1512    }
1513}
1514
1515impl Div<RankExpr> for f32 {
1516    type Output = RankExpr;
1517
1518    fn div(self, rhs: RankExpr) -> Self::Output {
1519        RankExpr::Value(self) / rhs
1520    }
1521}
1522
1523impl Neg for RankExpr {
1524    type Output = RankExpr;
1525
1526    fn neg(self) -> Self::Output {
1527        RankExpr::Value(-1.0) * self
1528    }
1529}
1530
1531impl From<f32> for RankExpr {
1532    fn from(v: f32) -> Self {
1533        RankExpr::Value(v)
1534    }
1535}
1536
1537impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1538    type Error = QueryConversionError;
1539
1540    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1541        match proto_expr.rank {
1542            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1543                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1544            }
1545            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1546                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1547                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1548                Ok(RankExpr::Division {
1549                    left: Box::new(RankExpr::try_from(*left)?),
1550                    right: Box::new(RankExpr::try_from(*right)?),
1551                })
1552            }
1553            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1554                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1555            ),
1556            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1557                let query = knn
1558                    .query
1559                    .ok_or(QueryConversionError::field("query"))?
1560                    .try_into()?;
1561                Ok(RankExpr::Knn {
1562                    query,
1563                    key: Key::from(knn.key),
1564                    limit: knn.limit,
1565                    default: knn.default,
1566                    return_rank: knn.return_rank,
1567                })
1568            }
1569            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1570                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1571            }
1572            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1573                let exprs = max
1574                    .exprs
1575                    .into_iter()
1576                    .map(RankExpr::try_from)
1577                    .collect::<Result<Vec<_>, _>>()?;
1578                Ok(RankExpr::Maximum(exprs))
1579            }
1580            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1581                let exprs = min
1582                    .exprs
1583                    .into_iter()
1584                    .map(RankExpr::try_from)
1585                    .collect::<Result<Vec<_>, _>>()?;
1586                Ok(RankExpr::Minimum(exprs))
1587            }
1588            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1589                let exprs = mul
1590                    .exprs
1591                    .into_iter()
1592                    .map(RankExpr::try_from)
1593                    .collect::<Result<Vec<_>, _>>()?;
1594                Ok(RankExpr::Multiplication(exprs))
1595            }
1596            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1597                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1598                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1599                Ok(RankExpr::Subtraction {
1600                    left: Box::new(RankExpr::try_from(*left)?),
1601                    right: Box::new(RankExpr::try_from(*right)?),
1602                })
1603            }
1604            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1605                let exprs = sum
1606                    .exprs
1607                    .into_iter()
1608                    .map(RankExpr::try_from)
1609                    .collect::<Result<Vec<_>, _>>()?;
1610                Ok(RankExpr::Summation(exprs))
1611            }
1612            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1613            None => Err(QueryConversionError::field("rank")),
1614        }
1615    }
1616}
1617
1618impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1619    type Error = QueryConversionError;
1620
1621    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1622        let proto_rank = match rank_expr {
1623            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1624                chroma_proto::RankExpr::try_from(*expr)?,
1625            )),
1626            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1627                Box::new(chroma_proto::rank_expr::RankPair {
1628                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1629                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1630                }),
1631            ),
1632            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1633                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1634            ),
1635            RankExpr::Knn {
1636                query,
1637                key,
1638                limit,
1639                default,
1640                return_rank,
1641            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1642                query: Some(query.try_into()?),
1643                key: key.to_string(),
1644                limit,
1645                default,
1646                return_rank,
1647            }),
1648            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1649                chroma_proto::RankExpr::try_from(*expr)?,
1650            )),
1651            RankExpr::Maximum(exprs) => {
1652                let proto_exprs = exprs
1653                    .into_iter()
1654                    .map(chroma_proto::RankExpr::try_from)
1655                    .collect::<Result<Vec<_>, _>>()?;
1656                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1657                    exprs: proto_exprs,
1658                })
1659            }
1660            RankExpr::Minimum(exprs) => {
1661                let proto_exprs = exprs
1662                    .into_iter()
1663                    .map(chroma_proto::RankExpr::try_from)
1664                    .collect::<Result<Vec<_>, _>>()?;
1665                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1666                    exprs: proto_exprs,
1667                })
1668            }
1669            RankExpr::Multiplication(exprs) => {
1670                let proto_exprs = exprs
1671                    .into_iter()
1672                    .map(chroma_proto::RankExpr::try_from)
1673                    .collect::<Result<Vec<_>, _>>()?;
1674                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1675                    exprs: proto_exprs,
1676                })
1677            }
1678            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1679                Box::new(chroma_proto::rank_expr::RankPair {
1680                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1681                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1682                }),
1683            ),
1684            RankExpr::Summation(exprs) => {
1685                let proto_exprs = exprs
1686                    .into_iter()
1687                    .map(chroma_proto::RankExpr::try_from)
1688                    .collect::<Result<Vec<_>, _>>()?;
1689                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1690                    exprs: proto_exprs,
1691                })
1692            }
1693            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1694        };
1695
1696        Ok(chroma_proto::RankExpr {
1697            rank: Some(proto_rank),
1698        })
1699    }
1700}
1701
1702/// Represents a field key in search queries.
1703///
1704/// Used for both selecting fields to return and building filter expressions.
1705/// Predefined keys access special fields, while custom keys access metadata.
1706///
1707/// # Predefined Keys
1708///
1709/// - `Key::Document` - Document text content (`#document`)
1710/// - `Key::Embedding` - Vector embeddings (`#embedding`)
1711/// - `Key::Metadata` - All metadata fields (`#metadata`)
1712/// - `Key::Score` - Search scores (`#score`)
1713///
1714/// # Custom Keys
1715///
1716/// Use `Key::field()` or `Key::from()` to reference metadata fields:
1717///
1718/// ```
1719/// use chroma_types::operator::Key;
1720///
1721/// let key = Key::field("author");
1722/// let key = Key::from("title");
1723/// ```
1724///
1725/// # Examples
1726///
1727/// ## Building filters
1728///
1729/// ```
1730/// use chroma_types::operator::Key;
1731///
1732/// // Equality
1733/// let filter = Key::field("status").eq("published");
1734///
1735/// // Comparisons
1736/// let filter = Key::field("year").gte(2020);
1737/// let filter = Key::field("score").lt(0.9);
1738///
1739/// // Set operations
1740/// let filter = Key::field("category").is_in(vec!["tech", "science"]);
1741/// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
1742///
1743/// // Document content
1744/// let filter = Key::Document.contains("machine learning");
1745/// let filter = Key::Document.regex(r"\bAPI\b");
1746///
1747/// // Combining filters
1748/// let filter = Key::field("status").eq("published")
1749///     & Key::field("year").gte(2020);
1750/// ```
1751///
1752/// ## Selecting fields
1753///
1754/// ```
1755/// use chroma_types::plan::SearchPayload;
1756/// use chroma_types::operator::Key;
1757///
1758/// let search = SearchPayload::default()
1759///     .select([
1760///         Key::Document,
1761///         Key::Score,
1762///         Key::field("title"),
1763///         Key::field("author"),
1764///     ]);
1765/// ```
1766#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1767#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1768pub enum Key {
1769    // Predefined keys
1770    Document,
1771    Embedding,
1772    Metadata,
1773    Score,
1774    MetadataField(String),
1775}
1776
1777impl Serialize for Key {
1778    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1779    where
1780        S: serde::Serializer,
1781    {
1782        match self {
1783            Key::Document => serializer.serialize_str("#document"),
1784            Key::Embedding => serializer.serialize_str("#embedding"),
1785            Key::Metadata => serializer.serialize_str("#metadata"),
1786            Key::Score => serializer.serialize_str("#score"),
1787            Key::MetadataField(field) => serializer.serialize_str(field),
1788        }
1789    }
1790}
1791
1792impl<'de> Deserialize<'de> for Key {
1793    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1794    where
1795        D: Deserializer<'de>,
1796    {
1797        let s = String::deserialize(deserializer)?;
1798        Ok(Key::from(s))
1799    }
1800}
1801
1802impl fmt::Display for Key {
1803    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1804        match self {
1805            Key::Document => write!(f, "#document"),
1806            Key::Embedding => write!(f, "#embedding"),
1807            Key::Metadata => write!(f, "#metadata"),
1808            Key::Score => write!(f, "#score"),
1809            Key::MetadataField(field) => write!(f, "{}", field),
1810        }
1811    }
1812}
1813
1814impl From<&str> for Key {
1815    fn from(s: &str) -> Self {
1816        match s {
1817            "#document" => Key::Document,
1818            "#embedding" => Key::Embedding,
1819            "#metadata" => Key::Metadata,
1820            "#score" => Key::Score,
1821            // Any other string is treated as a metadata field key
1822            field => Key::MetadataField(field.to_string()),
1823        }
1824    }
1825}
1826
1827impl From<String> for Key {
1828    fn from(s: String) -> Self {
1829        Key::from(s.as_str())
1830    }
1831}
1832
1833impl Key {
1834    /// Creates a Key for a custom metadata field.
1835    ///
1836    /// # Examples
1837    ///
1838    /// ```
1839    /// use chroma_types::operator::Key;
1840    ///
1841    /// let status = Key::field("status");
1842    /// let year = Key::field("year");
1843    /// let author = Key::field("author");
1844    /// ```
1845    pub fn field(name: impl Into<String>) -> Self {
1846        Key::MetadataField(name.into())
1847    }
1848
1849    /// Creates an equality filter: `field == value`.
1850    ///
1851    /// # Examples
1852    ///
1853    /// ```
1854    /// use chroma_types::operator::Key;
1855    ///
1856    /// // String equality
1857    /// let filter = Key::field("status").eq("published");
1858    ///
1859    /// // Numeric equality
1860    /// let filter = Key::field("count").eq(42);
1861    ///
1862    /// // Boolean equality
1863    /// let filter = Key::field("featured").eq(true);
1864    /// ```
1865    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1866        Where::Metadata(MetadataExpression {
1867            key: self.to_string(),
1868            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1869        })
1870    }
1871
1872    /// Creates an inequality filter: `field != value`.
1873    ///
1874    /// # Examples
1875    ///
1876    /// ```
1877    /// use chroma_types::operator::Key;
1878    ///
1879    /// let filter = Key::field("status").ne("deleted");
1880    /// let filter = Key::field("count").ne(0);
1881    /// ```
1882    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1883        Where::Metadata(MetadataExpression {
1884            key: self.to_string(),
1885            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1886        })
1887    }
1888
1889    /// Creates a greater-than filter: `field > value` (numeric only).
1890    ///
1891    /// # Examples
1892    ///
1893    /// ```
1894    /// use chroma_types::operator::Key;
1895    ///
1896    /// let filter = Key::field("score").gt(0.5);
1897    /// let filter = Key::field("year").gt(2020);
1898    /// ```
1899    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1900        Where::Metadata(MetadataExpression {
1901            key: self.to_string(),
1902            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1903        })
1904    }
1905
1906    /// Creates a greater-than-or-equal filter: `field >= value` (numeric only).
1907    ///
1908    /// # Examples
1909    ///
1910    /// ```
1911    /// use chroma_types::operator::Key;
1912    ///
1913    /// let filter = Key::field("score").gte(0.5);
1914    /// let filter = Key::field("year").gte(2020);
1915    /// ```
1916    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1917        Where::Metadata(MetadataExpression {
1918            key: self.to_string(),
1919            comparison: MetadataComparison::Primitive(
1920                PrimitiveOperator::GreaterThanOrEqual,
1921                value.into(),
1922            ),
1923        })
1924    }
1925
1926    /// Creates a less-than filter: `field < value` (numeric only).
1927    ///
1928    /// # Examples
1929    ///
1930    /// ```
1931    /// use chroma_types::operator::Key;
1932    ///
1933    /// let filter = Key::field("score").lt(0.9);
1934    /// let filter = Key::field("year").lt(2025);
1935    /// ```
1936    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1937        Where::Metadata(MetadataExpression {
1938            key: self.to_string(),
1939            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1940        })
1941    }
1942
1943    /// Creates a less-than-or-equal filter: `field <= value` (numeric only).
1944    ///
1945    /// # Examples
1946    ///
1947    /// ```
1948    /// use chroma_types::operator::Key;
1949    ///
1950    /// let filter = Key::field("score").lte(0.9);
1951    /// let filter = Key::field("year").lte(2024);
1952    /// ```
1953    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1954        Where::Metadata(MetadataExpression {
1955            key: self.to_string(),
1956            comparison: MetadataComparison::Primitive(
1957                PrimitiveOperator::LessThanOrEqual,
1958                value.into(),
1959            ),
1960        })
1961    }
1962
1963    /// Creates a set membership filter: `field IN values`.
1964    ///
1965    /// Accepts any iterator (Vec, array, slice, etc.).
1966    ///
1967    /// # Examples
1968    ///
1969    /// ```
1970    /// use chroma_types::operator::Key;
1971    ///
1972    /// // With Vec
1973    /// let filter = Key::field("year").is_in(vec![2023, 2024, 2025]);
1974    ///
1975    /// // With array
1976    /// let filter = Key::field("category").is_in(["tech", "science", "math"]);
1977    ///
1978    /// // With owned strings
1979    /// let categories = vec!["tech".to_string(), "science".to_string()];
1980    /// let filter = Key::field("category").is_in(categories);
1981    /// ```
1982    pub fn is_in<I, T>(self, values: I) -> Where
1983    where
1984        I: IntoIterator<Item = T>,
1985        Vec<T>: Into<MetadataSetValue>,
1986    {
1987        let vec: Vec<T> = values.into_iter().collect();
1988        Where::Metadata(MetadataExpression {
1989            key: self.to_string(),
1990            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1991        })
1992    }
1993
1994    /// Creates a set exclusion filter: `field NOT IN values`.
1995    ///
1996    /// Accepts any iterator (Vec, array, slice, etc.).
1997    ///
1998    /// # Examples
1999    ///
2000    /// ```
2001    /// use chroma_types::operator::Key;
2002    ///
2003    /// // Exclude deleted and archived
2004    /// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
2005    ///
2006    /// // Exclude specific years
2007    /// let filter = Key::field("year").not_in(vec![2019, 2020]);
2008    /// ```
2009    pub fn not_in<I, T>(self, values: I) -> Where
2010    where
2011        I: IntoIterator<Item = T>,
2012        Vec<T>: Into<MetadataSetValue>,
2013    {
2014        let vec: Vec<T> = values.into_iter().collect();
2015        Where::Metadata(MetadataExpression {
2016            key: self.to_string(),
2017            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2018        })
2019    }
2020
2021    /// Creates a substring filter (case-sensitive, document content only).
2022    ///
2023    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2024    /// 3 literal characters for accurate results.
2025    ///
2026    /// # Examples
2027    ///
2028    /// ```
2029    /// use chroma_types::operator::Key;
2030    ///
2031    /// let filter = Key::Document.contains("machine learning");
2032    /// let filter = Key::Document.contains("API");
2033    /// ```
2034    pub fn contains<S: Into<String>>(self, text: S) -> Where {
2035        Where::Document(DocumentExpression {
2036            operator: DocumentOperator::Contains,
2037            pattern: text.into(),
2038        })
2039    }
2040
2041    /// Creates a negative substring filter (case-sensitive, document content only).
2042    ///
2043    /// Note: Currently only works with `Key::Document`.
2044    ///
2045    /// # Examples
2046    ///
2047    /// ```
2048    /// use chroma_types::operator::Key;
2049    ///
2050    /// let filter = Key::Document.not_contains("deprecated");
2051    /// let filter = Key::Document.not_contains("beta");
2052    /// ```
2053    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2054        Where::Document(DocumentExpression {
2055            operator: DocumentOperator::NotContains,
2056            pattern: text.into(),
2057        })
2058    }
2059
2060    /// Creates a regex filter (case-sensitive, document content only).
2061    ///
2062    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2063    /// 3 literal characters for accurate results.
2064    ///
2065    /// # Examples
2066    ///
2067    /// ```
2068    /// use chroma_types::operator::Key;
2069    ///
2070    /// // Match whole word "API"
2071    /// let filter = Key::Document.regex(r"\bAPI\b");
2072    ///
2073    /// // Match version pattern
2074    /// let filter = Key::Document.regex(r"v\d+\.\d+\.\d+");
2075    /// ```
2076    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2077        Where::Document(DocumentExpression {
2078            operator: DocumentOperator::Regex,
2079            pattern: pattern.into(),
2080        })
2081    }
2082
2083    /// Creates a negative regex filter (case-sensitive, document content only).
2084    ///
2085    /// Note: Currently only works with `Key::Document`.
2086    ///
2087    /// # Examples
2088    ///
2089    /// ```
2090    /// use chroma_types::operator::Key;
2091    ///
2092    /// // Exclude beta versions
2093    /// let filter = Key::Document.not_regex(r"beta");
2094    ///
2095    /// // Exclude test documents
2096    /// let filter = Key::Document.not_regex(r"\btest\b");
2097    /// ```
2098    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2099        Where::Document(DocumentExpression {
2100            operator: DocumentOperator::NotRegex,
2101            pattern: pattern.into(),
2102        })
2103    }
2104}
2105
2106/// Field selection for search results.
2107///
2108/// Specifies which fields to include in the results. IDs are always included.
2109///
2110/// # Fields
2111///
2112/// * `keys` - Set of keys to include in results
2113///
2114/// # Available Keys
2115///
2116/// * `Key::Document` - Document text content
2117/// * `Key::Embedding` - Vector embeddings
2118/// * `Key::Metadata` - All metadata fields
2119/// * `Key::Score` - Search scores
2120/// * `Key::field("name")` - Specific metadata field
2121///
2122/// # Performance
2123///
2124/// Selecting fewer fields improves performance by reducing data transfer:
2125/// - Minimal: IDs only (default, fastest)
2126/// - Moderate: Scores + specific metadata fields
2127/// - Heavy: Documents + embeddings (larger payloads)
2128///
2129/// # Examples
2130///
2131/// ```
2132/// use chroma_types::operator::{Select, Key};
2133/// use std::collections::HashSet;
2134///
2135/// // Select predefined fields
2136/// let select = Select {
2137///     keys: [Key::Document, Key::Score].into_iter().collect(),
2138/// };
2139///
2140/// // Select specific metadata fields
2141/// let select = Select {
2142///     keys: [
2143///         Key::field("title"),
2144///         Key::field("author"),
2145///         Key::Score,
2146///     ].into_iter().collect(),
2147/// };
2148///
2149/// // Select everything
2150/// let select = Select {
2151///     keys: [
2152///         Key::Document,
2153///         Key::Embedding,
2154///         Key::Metadata,
2155///         Key::Score,
2156///     ].into_iter().collect(),
2157/// };
2158/// ```
2159#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2160pub struct Select {
2161    #[serde(default)]
2162    pub keys: HashSet<Key>,
2163}
2164
2165impl TryFrom<chroma_proto::SelectOperator> for Select {
2166    type Error = QueryConversionError;
2167
2168    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2169        let keys = value
2170            .keys
2171            .into_iter()
2172            .map(|key| {
2173                // Try to deserialize each string as a Key
2174                serde_json::from_value(serde_json::Value::String(key))
2175                    .map_err(|_| QueryConversionError::field("keys"))
2176            })
2177            .collect::<Result<HashSet<_>, _>>()?;
2178
2179        Ok(Self { keys })
2180    }
2181}
2182
2183impl TryFrom<Select> for chroma_proto::SelectOperator {
2184    type Error = QueryConversionError;
2185
2186    fn try_from(value: Select) -> Result<Self, Self::Error> {
2187        let keys = value
2188            .keys
2189            .into_iter()
2190            .map(|key| {
2191                // Serialize each Key back to string
2192                serde_json::to_value(&key)
2193                    .ok()
2194                    .and_then(|v| v.as_str().map(String::from))
2195                    .ok_or(QueryConversionError::field("keys"))
2196            })
2197            .collect::<Result<Vec<_>, _>>()?;
2198
2199        Ok(Self { keys })
2200    }
2201}
2202
2203/// A single search result record.
2204///
2205/// Contains the document ID and optionally document content, embeddings, metadata,
2206/// and search score based on what was selected in the search query.
2207///
2208/// # Fields
2209///
2210/// * `id` - Document ID (always present)
2211/// * `document` - Document text content (if selected)
2212/// * `embedding` - Vector embedding (if selected)
2213/// * `metadata` - Document metadata (if selected)
2214/// * `score` - Search score (present when ranking is used, lower = better match)
2215///
2216/// # Examples
2217///
2218/// ```
2219/// use chroma_types::operator::SearchRecord;
2220///
2221/// fn process_results(records: Vec<SearchRecord>) {
2222///     for record in records {
2223///         println!("ID: {}", record.id);
2224///         
2225///         if let Some(score) = record.score {
2226///             println!("  Score: {:.3}", score);
2227///         }
2228///         
2229///         if let Some(doc) = record.document {
2230///             println!("  Document: {}", doc);
2231///         }
2232///         
2233///         if let Some(meta) = record.metadata {
2234///             println!("  Metadata: {:?}", meta);
2235///         }
2236///     }
2237/// }
2238/// ```
2239#[derive(Clone, Debug, Deserialize, Serialize)]
2240#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2241pub struct SearchRecord {
2242    pub id: String,
2243    pub document: Option<String>,
2244    pub embedding: Option<Vec<f32>>,
2245    pub metadata: Option<Metadata>,
2246    pub score: Option<f32>,
2247}
2248
2249impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2250    type Error = QueryConversionError;
2251
2252    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2253        Ok(Self {
2254            id: value.id,
2255            document: value.document,
2256            embedding: value
2257                .embedding
2258                .map(|vec| vec.try_into().map(|(v, _)| v))
2259                .transpose()?,
2260            metadata: value.metadata.map(TryInto::try_into).transpose()?,
2261            score: value.score,
2262        })
2263    }
2264}
2265
2266impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2267    type Error = QueryConversionError;
2268
2269    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2270        Ok(Self {
2271            id: value.id,
2272            document: value.document,
2273            embedding: value
2274                .embedding
2275                .map(|embedding| {
2276                    let embedding_dimension = embedding.len();
2277                    chroma_proto::Vector::try_from((
2278                        embedding,
2279                        ScalarEncoding::FLOAT32,
2280                        embedding_dimension,
2281                    ))
2282                })
2283                .transpose()?,
2284            metadata: value.metadata.map(Into::into),
2285            score: value.score,
2286        })
2287    }
2288}
2289
2290/// Results for a single search payload.
2291///
2292/// Contains all matching records for one search query.
2293///
2294/// # Fields
2295///
2296/// * `records` - Vector of search records, ordered by score (ascending)
2297///
2298/// # Examples
2299///
2300/// ```
2301/// use chroma_types::operator::{SearchPayloadResult, SearchRecord};
2302///
2303/// fn process_search_result(result: SearchPayloadResult) {
2304///     println!("Found {} results", result.records.len());
2305///     
2306///     for (i, record) in result.records.iter().enumerate() {
2307///         println!("{}. {} (score: {:?})", i + 1, record.id, record.score);
2308///     }
2309/// }
2310/// ```
2311#[derive(Clone, Debug, Default)]
2312pub struct SearchPayloadResult {
2313    pub records: Vec<SearchRecord>,
2314}
2315
2316impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2317    type Error = QueryConversionError;
2318
2319    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2320        Ok(Self {
2321            records: value
2322                .records
2323                .into_iter()
2324                .map(TryInto::try_into)
2325                .collect::<Result<_, _>>()?,
2326        })
2327    }
2328}
2329
2330impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2331    type Error = QueryConversionError;
2332
2333    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2334        Ok(Self {
2335            records: value
2336                .records
2337                .into_iter()
2338                .map(TryInto::try_into)
2339                .collect::<Result<Vec<_>, _>>()?,
2340        })
2341    }
2342}
2343
2344/// Results from a batch search operation.
2345///
2346/// Contains results for each search payload in the batch, maintaining the same order
2347/// as the input searches.
2348///
2349/// # Fields
2350///
2351/// * `results` - Results for each search payload (indexed by search position)
2352/// * `pulled_log_bytes` - Total bytes pulled from log (for internal metrics)
2353///
2354/// # Examples
2355///
2356/// ## Single search
2357///
2358/// ```
2359/// use chroma_types::operator::SearchResult;
2360///
2361/// fn process_single_search(result: SearchResult) {
2362///     // Single search, so results[0] contains our records
2363///     let records = &result.results[0].records;
2364///     
2365///     for record in records {
2366///         println!("{}: score={:?}", record.id, record.score);
2367///     }
2368/// }
2369/// ```
2370///
2371/// ## Batch search
2372///
2373/// ```
2374/// use chroma_types::operator::SearchResult;
2375///
2376/// fn process_batch_search(result: SearchResult) {
2377///     // Multiple searches in batch
2378///     for (i, search_result) in result.results.iter().enumerate() {
2379///         println!("\nSearch {}:", i + 1);
2380///         for record in &search_result.records {
2381///             println!("  {}: score={:?}", record.id, record.score);
2382///         }
2383///     }
2384/// }
2385/// ```
2386#[derive(Clone, Debug)]
2387pub struct SearchResult {
2388    pub results: Vec<SearchPayloadResult>,
2389    pub pulled_log_bytes: u64,
2390}
2391
2392impl SearchResult {
2393    pub fn size_bytes(&self) -> u64 {
2394        self.results
2395            .iter()
2396            .flat_map(|result| {
2397                result.records.iter().map(|record| {
2398                    (record.id.len()
2399                        + record
2400                            .document
2401                            .as_ref()
2402                            .map(|doc| doc.len())
2403                            .unwrap_or_default()
2404                        + record
2405                            .embedding
2406                            .as_ref()
2407                            .map(|emb| size_of_val(&emb[..]))
2408                            .unwrap_or_default()
2409                        + record
2410                            .metadata
2411                            .as_ref()
2412                            .map(logical_size_of_metadata)
2413                            .unwrap_or_default()
2414                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
2415                        as u64
2416                })
2417            })
2418            .sum()
2419    }
2420}
2421
2422impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2423    type Error = QueryConversionError;
2424
2425    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2426        Ok(Self {
2427            results: value
2428                .results
2429                .into_iter()
2430                .map(TryInto::try_into)
2431                .collect::<Result<_, _>>()?,
2432            pulled_log_bytes: value.pulled_log_bytes,
2433        })
2434    }
2435}
2436
2437impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2438    type Error = QueryConversionError;
2439
2440    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2441        Ok(Self {
2442            results: value
2443                .results
2444                .into_iter()
2445                .map(TryInto::try_into)
2446                .collect::<Result<Vec<_>, _>>()?,
2447            pulled_log_bytes: value.pulled_log_bytes,
2448        })
2449    }
2450}
2451
2452/// Reciprocal Rank Fusion (RRF) - combines multiple ranking strategies.
2453///
2454/// RRF is ideal for hybrid search where you want to merge results from different
2455/// ranking methods (e.g., dense and sparse embeddings) with different score scales.
2456/// It uses rank positions instead of raw scores, making it scale-agnostic.
2457///
2458/// # Formula
2459///
2460/// ```text
2461/// score = -Σ(weight_i / (k + rank_i))
2462/// ```
2463///
2464/// Where:
2465/// - `weight_i` = weight for ranking i (default: 1.0)
2466/// - `rank_i` = rank position from ranking i (0, 1, 2...)
2467/// - `k` = smoothing parameter (default: 60)
2468///
2469/// Score is negative because Chroma uses ascending order (lower = better).
2470///
2471/// # Arguments
2472///
2473/// * `ranks` - List of ranking expressions (must have `return_rank=true`)
2474/// * `k` - Smoothing parameter (None = 60). Higher values reduce emphasis on top ranks.
2475/// * `weights` - Weight for each ranking (None = all 1.0)
2476/// * `normalize` - If true, normalize weights to sum to 1.0
2477///
2478/// # Returns
2479///
2480/// A combined RankExpr or an error if:
2481/// - `ranks` is empty
2482/// - `weights` length doesn't match `ranks` length
2483/// - `weights` sum to zero when normalizing
2484///
2485/// # Examples
2486///
2487/// ## Basic RRF with default parameters
2488///
2489/// ```
2490/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2491///
2492/// let dense = RankExpr::Knn {
2493///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2494///     key: Key::Embedding,
2495///     limit: 200,
2496///     default: None,
2497///     return_rank: true, // Required for RRF
2498/// };
2499///
2500/// let sparse = RankExpr::Knn {
2501///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2502///     key: Key::field("sparse_embedding"),
2503///     limit: 200,
2504///     default: None,
2505///     return_rank: true, // Required for RRF
2506/// };
2507///
2508/// // Equal weights, k=60 (defaults)
2509/// let combined = rrf(vec![dense, sparse], None, None, false).unwrap();
2510/// ```
2511///
2512/// ## RRF with custom weights
2513///
2514/// ```
2515/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2516///
2517/// # let dense = RankExpr::Knn {
2518/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2519/// #     key: Key::Embedding,
2520/// #     limit: 200,
2521/// #     default: None,
2522/// #     return_rank: true,
2523/// # };
2524/// # let sparse = RankExpr::Knn {
2525/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2526/// #     key: Key::field("sparse_embedding"),
2527/// #     limit: 200,
2528/// #     default: None,
2529/// #     return_rank: true,
2530/// # };
2531/// // 70% dense, 30% sparse
2532/// let combined = rrf(
2533///     vec![dense, sparse],
2534///     Some(60),
2535///     Some(vec![0.7, 0.3]),
2536///     false,
2537/// ).unwrap();
2538/// ```
2539///
2540/// ## RRF with normalized weights
2541///
2542/// ```
2543/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2544///
2545/// # let dense = RankExpr::Knn {
2546/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2547/// #     key: Key::Embedding,
2548/// #     limit: 200,
2549/// #     default: None,
2550/// #     return_rank: true,
2551/// # };
2552/// # let sparse = RankExpr::Knn {
2553/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2554/// #     key: Key::field("sparse_embedding"),
2555/// #     limit: 200,
2556/// #     default: None,
2557/// #     return_rank: true,
2558/// # };
2559/// // Weights [75, 25] normalized to [0.75, 0.25]
2560/// let combined = rrf(
2561///     vec![dense, sparse],
2562///     Some(60),
2563///     Some(vec![75.0, 25.0]),
2564///     true, // normalize
2565/// ).unwrap();
2566/// ```
2567///
2568/// ## Adjusting the k parameter
2569///
2570/// ```
2571/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2572///
2573/// # let dense = RankExpr::Knn {
2574/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2575/// #     key: Key::Embedding,
2576/// #     limit: 200,
2577/// #     default: None,
2578/// #     return_rank: true,
2579/// # };
2580/// # let sparse = RankExpr::Knn {
2581/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2582/// #     key: Key::field("sparse_embedding"),
2583/// #     limit: 200,
2584/// #     default: None,
2585/// #     return_rank: true,
2586/// # };
2587/// // Small k (10) = heavy emphasis on top ranks
2588/// let top_heavy = rrf(vec![dense.clone(), sparse.clone()], Some(10), None, false).unwrap();
2589///
2590/// // Default k (60) = balanced
2591/// let balanced = rrf(vec![dense.clone(), sparse.clone()], Some(60), None, false).unwrap();
2592///
2593/// // Large k (200) = more uniform weighting
2594/// let uniform = rrf(vec![dense, sparse], Some(200), None, false).unwrap();
2595/// ```
2596pub fn rrf(
2597    ranks: Vec<RankExpr>,
2598    k: Option<u32>,
2599    weights: Option<Vec<f32>>,
2600    normalize: bool,
2601) -> Result<RankExpr, QueryConversionError> {
2602    let k = k.unwrap_or(60);
2603
2604    if ranks.is_empty() {
2605        return Err(QueryConversionError::validation(
2606            "RRF requires at least one rank expression",
2607        ));
2608    }
2609
2610    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2611
2612    if weights.len() != ranks.len() {
2613        return Err(QueryConversionError::validation(format!(
2614            "RRF weights length ({}) must match ranks length ({})",
2615            weights.len(),
2616            ranks.len()
2617        )));
2618    }
2619
2620    let weights = if normalize {
2621        let sum: f32 = weights.iter().sum();
2622        if sum == 0.0 {
2623            return Err(QueryConversionError::validation(
2624                "RRF weights sum to zero, cannot normalize",
2625            ));
2626        }
2627        weights.into_iter().map(|w| w / sum).collect()
2628    } else {
2629        weights
2630    };
2631
2632    let terms: Vec<RankExpr> = weights
2633        .into_iter()
2634        .zip(ranks)
2635        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2636        .collect();
2637
2638    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
2639    // Using unwrap_or_else as defensive programming to avoid panic.
2640    let sum = terms
2641        .into_iter()
2642        .reduce(|a, b| a + b)
2643        .unwrap_or(RankExpr::Value(0.0));
2644    Ok(-sum)
2645}
2646
2647#[cfg(test)]
2648mod tests {
2649    use super::*;
2650
2651    #[test]
2652    fn test_key_from_string() {
2653        // Test predefined keys
2654        assert_eq!(Key::from("#document"), Key::Document);
2655        assert_eq!(Key::from("#embedding"), Key::Embedding);
2656        assert_eq!(Key::from("#metadata"), Key::Metadata);
2657        assert_eq!(Key::from("#score"), Key::Score);
2658
2659        // Test metadata field keys
2660        assert_eq!(
2661            Key::from("custom_field"),
2662            Key::MetadataField("custom_field".to_string())
2663        );
2664        assert_eq!(
2665            Key::from("author"),
2666            Key::MetadataField("author".to_string())
2667        );
2668
2669        // Test String variant
2670        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2671        assert_eq!(
2672            Key::from("year".to_string()),
2673            Key::MetadataField("year".to_string())
2674        );
2675    }
2676
2677    #[test]
2678    fn test_query_vector_dense_proto_conversion() {
2679        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2680        let query_vector = QueryVector::Dense(dense_vec.clone());
2681
2682        // Convert to proto
2683        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2684
2685        // Convert back
2686        let converted: QueryVector = proto.try_into().unwrap();
2687
2688        assert_eq!(converted, query_vector);
2689        if let QueryVector::Dense(v) = converted {
2690            assert_eq!(v, dense_vec);
2691        } else {
2692            panic!("Expected dense vector");
2693        }
2694    }
2695
2696    #[test]
2697    fn test_query_vector_sparse_proto_conversion() {
2698        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap();
2699        let query_vector = QueryVector::Sparse(sparse.clone());
2700
2701        // Convert to proto
2702        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2703
2704        // Convert back
2705        let converted: QueryVector = proto.try_into().unwrap();
2706
2707        assert_eq!(converted, query_vector);
2708        if let QueryVector::Sparse(s) = converted {
2709            assert_eq!(s, sparse);
2710        } else {
2711            panic!("Expected sparse vector");
2712        }
2713    }
2714
2715    #[test]
2716    fn test_filter_json_deserialization() {
2717        // For the new search API, deserialization treats the entire JSON as a where clause
2718
2719        // Test 1: Simple direct metadata comparison
2720        let simple_where = r#"{"author": "John Doe"}"#;
2721        let filter: Filter = serde_json::from_str(simple_where).unwrap();
2722        assert_eq!(filter.query_ids, None);
2723        assert!(filter.where_clause.is_some());
2724
2725        // Test 2: ID filter using #id with $in operator
2726        let id_filter_json = serde_json::json!({
2727            "#id": {
2728                "$in": ["doc1", "doc2", "doc3"]
2729            }
2730        });
2731        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2732        assert_eq!(filter.query_ids, None);
2733        assert!(filter.where_clause.is_some());
2734
2735        // Test 3: Complex nested expression with AND, OR, and various operators
2736        let complex_json = serde_json::json!({
2737            "$and": [
2738                {
2739                    "#id": {
2740                        "$in": ["doc1", "doc2", "doc3"]
2741                    }
2742                },
2743                {
2744                    "$or": [
2745                        {
2746                            "author": {
2747                                "$eq": "John Doe"
2748                            }
2749                        },
2750                        {
2751                            "author": {
2752                                "$eq": "Jane Smith"
2753                            }
2754                        }
2755                    ]
2756                },
2757                {
2758                    "year": {
2759                        "$gte": 2020
2760                    }
2761                },
2762                {
2763                    "tags": {
2764                        "$contains": "machine-learning"
2765                    }
2766                }
2767            ]
2768        });
2769
2770        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2771        assert_eq!(filter.query_ids, None);
2772        assert!(filter.where_clause.is_some());
2773
2774        // Verify the structure
2775        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2776            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2777            assert_eq!(composite.children.len(), 4);
2778
2779            // Check that the second child is an OR
2780            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2781                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2782                assert_eq!(or_composite.children.len(), 2);
2783            } else {
2784                panic!("Expected OR composite in second child");
2785            }
2786        } else {
2787            panic!("Expected AND composite where clause");
2788        }
2789
2790        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
2791        let mixed_operators_json = serde_json::json!({
2792            "$and": [
2793                {
2794                    "status": {
2795                        "$ne": "deleted"
2796                    }
2797                },
2798                {
2799                    "score": {
2800                        "$gt": 0.5
2801                    }
2802                },
2803                {
2804                    "score": {
2805                        "$lt": 0.9
2806                    }
2807                },
2808                {
2809                    "priority": {
2810                        "$lte": 10
2811                    }
2812                }
2813            ]
2814        });
2815
2816        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2817        assert_eq!(filter.query_ids, None);
2818        assert!(filter.where_clause.is_some());
2819
2820        // Test 5: Deeply nested expression
2821        let deeply_nested_json = serde_json::json!({
2822            "$or": [
2823                {
2824                    "$and": [
2825                        {
2826                            "#id": {
2827                                "$in": ["id1", "id2"]
2828                            }
2829                        },
2830                        {
2831                            "$or": [
2832                                {
2833                                    "category": "tech"
2834                                },
2835                                {
2836                                    "category": "science"
2837                                }
2838                            ]
2839                        }
2840                    ]
2841                },
2842                {
2843                    "$and": [
2844                        {
2845                            "author": "Admin"
2846                        },
2847                        {
2848                            "published": true
2849                        }
2850                    ]
2851                }
2852            ]
2853        });
2854
2855        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
2856        assert_eq!(filter.query_ids, None);
2857        assert!(filter.where_clause.is_some());
2858
2859        // Verify it's an OR at the top level
2860        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2861            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
2862            assert_eq!(composite.children.len(), 2);
2863
2864            // Both children should be AND composites
2865            for child in &composite.children {
2866                if let crate::metadata::Where::Composite(and_composite) = child {
2867                    assert_eq!(
2868                        and_composite.operator,
2869                        crate::metadata::BooleanOperator::And
2870                    );
2871                } else {
2872                    panic!("Expected AND composite in OR children");
2873                }
2874            }
2875        } else {
2876            panic!("Expected OR composite at top level");
2877        }
2878
2879        // Test 6: Single ID filter (edge case)
2880        let single_id_json = serde_json::json!({
2881            "#id": {
2882                "$eq": "single-doc-id"
2883            }
2884        });
2885
2886        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
2887        assert_eq!(filter.query_ids, None);
2888        assert!(filter.where_clause.is_some());
2889
2890        // Test 7: Empty object should create empty filter
2891        let empty_json = serde_json::json!({});
2892        let filter: Filter = serde_json::from_value(empty_json).unwrap();
2893        assert_eq!(filter.query_ids, None);
2894        // Empty object results in None where_clause
2895        assert_eq!(filter.where_clause, None);
2896
2897        // Test 8: Combining #id filter with $not_contains and numeric comparisons
2898        let advanced_json = serde_json::json!({
2899            "$and": [
2900                {
2901                    "#id": {
2902                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
2903                    }
2904                },
2905                {
2906                    "tags": {
2907                        "$not_contains": "deprecated"
2908                    }
2909                },
2910                {
2911                    "$or": [
2912                        {
2913                            "$and": [
2914                                {
2915                                    "confidence": {
2916                                        "$gte": 0.8
2917                                    }
2918                                },
2919                                {
2920                                    "verified": true
2921                                }
2922                            ]
2923                        },
2924                        {
2925                            "$and": [
2926                                {
2927                                    "confidence": {
2928                                        "$gte": 0.6
2929                                    }
2930                                },
2931                                {
2932                                    "confidence": {
2933                                        "$lt": 0.8
2934                                    }
2935                                },
2936                                {
2937                                    "reviews": {
2938                                        "$gte": 5
2939                                    }
2940                                }
2941                            ]
2942                        }
2943                    ]
2944                }
2945            ]
2946        });
2947
2948        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2949        assert_eq!(filter.query_ids, None);
2950        assert!(filter.where_clause.is_some());
2951
2952        // Verify top-level structure
2953        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2954            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2955            assert_eq!(composite.children.len(), 3);
2956        } else {
2957            panic!("Expected AND composite at top level");
2958        }
2959    }
2960
2961    #[test]
2962    fn test_limit_json_serialization() {
2963        let limit = Limit {
2964            offset: 10,
2965            limit: Some(20),
2966        };
2967
2968        let json = serde_json::to_string(&limit).unwrap();
2969        let deserialized: Limit = serde_json::from_str(&json).unwrap();
2970
2971        assert_eq!(deserialized.offset, limit.offset);
2972        assert_eq!(deserialized.limit, limit.limit);
2973    }
2974
2975    #[test]
2976    fn test_query_vector_json_serialization() {
2977        // Test dense vector
2978        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2979        let json = serde_json::to_string(&dense).unwrap();
2980        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2981        assert_eq!(deserialized, dense);
2982
2983        // Test sparse vector
2984        let sparse =
2985            QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]).unwrap());
2986        let json = serde_json::to_string(&sparse).unwrap();
2987        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2988        assert_eq!(deserialized, sparse);
2989    }
2990
2991    #[test]
2992    fn test_select_key_json_serialization() {
2993        use std::collections::HashSet;
2994
2995        // Test predefined keys
2996        let doc_key = Key::Document;
2997        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2998
2999        let embed_key = Key::Embedding;
3000        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
3001
3002        let meta_key = Key::Metadata;
3003        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
3004
3005        let score_key = Key::Score;
3006        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3007
3008        // Test metadata key
3009        let custom_key = Key::MetadataField("custom_key".to_string());
3010        assert_eq!(
3011            serde_json::to_string(&custom_key).unwrap(),
3012            "\"custom_key\""
3013        );
3014
3015        // Test deserialization
3016        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3017        assert!(matches!(deserialized, Key::Document));
3018
3019        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3020        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3021
3022        // Test Select struct with multiple keys
3023        let mut keys = HashSet::new();
3024        keys.insert(Key::Document);
3025        keys.insert(Key::Embedding);
3026        keys.insert(Key::MetadataField("author".to_string()));
3027
3028        let select = Select { keys };
3029        let json = serde_json::to_string(&select).unwrap();
3030        let deserialized: Select = serde_json::from_str(&json).unwrap();
3031
3032        assert_eq!(deserialized.keys.len(), 3);
3033        assert!(deserialized.keys.contains(&Key::Document));
3034        assert!(deserialized.keys.contains(&Key::Embedding));
3035        assert!(deserialized
3036            .keys
3037            .contains(&Key::MetadataField("author".to_string())));
3038    }
3039
3040    #[test]
3041    fn test_merge_basic_integers() {
3042        use std::cmp::Reverse;
3043
3044        let merge = Merge { k: 5 };
3045
3046        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
3047        let input = vec![
3048            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3049            vec![Reverse(2), Reverse(5), Reverse(8)],
3050            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3051        ];
3052
3053        let result = merge.merge(input);
3054
3055        // Should get top-5 smallest values (largest Reverse values)
3056        assert_eq!(result.len(), 5);
3057        assert_eq!(
3058            result,
3059            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3060        );
3061    }
3062
3063    #[test]
3064    fn test_merge_u32_descending() {
3065        let merge = Merge { k: 6 };
3066
3067        // Regular u32 in descending order (largest first)
3068        let input = vec![
3069            vec![100u32, 75, 50, 25],
3070            vec![90, 60, 30],
3071            vec![95, 85, 70, 40, 10],
3072        ];
3073
3074        let result = merge.merge(input);
3075
3076        // Should get top-6 largest u32 values
3077        assert_eq!(result.len(), 6);
3078        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3079    }
3080
3081    #[test]
3082    fn test_merge_i32_descending() {
3083        let merge = Merge { k: 5 };
3084
3085        // i32 values in descending order (including negatives)
3086        let input = vec![
3087            vec![50i32, 10, -10, -50],
3088            vec![30, 0, -30],
3089            vec![40, 20, -20, -40],
3090        ];
3091
3092        let result = merge.merge(input);
3093
3094        // Should get top-5 largest i32 values
3095        assert_eq!(result.len(), 5);
3096        assert_eq!(result, vec![50, 40, 30, 20, 10]);
3097    }
3098
3099    #[test]
3100    fn test_merge_with_duplicates() {
3101        let merge = Merge { k: 10 };
3102
3103        // Input with duplicates using regular u32 in descending order
3104        let input = vec![
3105            vec![100u32, 80, 80, 60, 40],
3106            vec![90, 80, 50, 30],
3107            vec![100, 70, 60, 20],
3108        ];
3109
3110        let result = merge.merge(input);
3111
3112        // Duplicates should be removed
3113        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3114    }
3115
3116    #[test]
3117    fn test_merge_empty_vectors() {
3118        let merge = Merge { k: 5 };
3119
3120        // All empty with u32
3121        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3122        let result = merge.merge(input);
3123        assert_eq!(result.len(), 0);
3124
3125        // Some empty, some with data (u64)
3126        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3127        let result = merge.merge(input);
3128        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3129
3130        // Single non-empty vector (i32)
3131        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3132        let result = merge.merge(input);
3133        assert_eq!(result, vec![100, 50, 25]);
3134    }
3135
3136    #[test]
3137    fn test_merge_k_boundary_conditions() {
3138        // k = 0 with u32
3139        let merge = Merge { k: 0 };
3140        let input = vec![vec![100u32, 50], vec![75, 25]];
3141        let result = merge.merge(input);
3142        assert_eq!(result.len(), 0);
3143
3144        // k = 1 with i64
3145        let merge = Merge { k: 1 };
3146        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3147        let result = merge.merge(input);
3148        assert_eq!(result, vec![1000]);
3149
3150        // k larger than total unique elements with u128
3151        let merge = Merge { k: 100 };
3152        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3153        let result = merge.merge(input);
3154        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3155    }
3156
3157    #[test]
3158    fn test_merge_with_strings() {
3159        let merge = Merge { k: 4 };
3160
3161        // Strings must be sorted in descending order (largest first) for the max heap merge
3162        let input = vec![
3163            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3164            vec!["elephant".to_string(), "banana".to_string()],
3165            vec!["fish".to_string(), "cat".to_string()],
3166        ];
3167
3168        let result = merge.merge(input);
3169
3170        // Should get top-4 lexicographically largest strings
3171        assert_eq!(
3172            result,
3173            vec![
3174                "zebra".to_string(),
3175                "fish".to_string(),
3176                "elephant".to_string(),
3177                "dog".to_string()
3178            ]
3179        );
3180    }
3181
3182    #[test]
3183    fn test_merge_with_custom_struct() {
3184        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3185        struct Score {
3186            value: i32,
3187            id: String,
3188        }
3189
3190        let merge = Merge { k: 3 };
3191
3192        // Custom structs sorted by value (descending), then by id
3193        let input = vec![
3194            vec![
3195                Score {
3196                    value: 100,
3197                    id: "a".to_string(),
3198                },
3199                Score {
3200                    value: 80,
3201                    id: "b".to_string(),
3202                },
3203                Score {
3204                    value: 60,
3205                    id: "c".to_string(),
3206                },
3207            ],
3208            vec![
3209                Score {
3210                    value: 90,
3211                    id: "d".to_string(),
3212                },
3213                Score {
3214                    value: 70,
3215                    id: "e".to_string(),
3216                },
3217            ],
3218            vec![
3219                Score {
3220                    value: 95,
3221                    id: "f".to_string(),
3222                },
3223                Score {
3224                    value: 85,
3225                    id: "g".to_string(),
3226                },
3227            ],
3228        ];
3229
3230        let result = merge.merge(input);
3231
3232        assert_eq!(result.len(), 3);
3233        assert_eq!(
3234            result[0],
3235            Score {
3236                value: 100,
3237                id: "a".to_string()
3238            }
3239        );
3240        assert_eq!(
3241            result[1],
3242            Score {
3243                value: 95,
3244                id: "f".to_string()
3245            }
3246        );
3247        assert_eq!(
3248            result[2],
3249            Score {
3250                value: 90,
3251                id: "d".to_string()
3252            }
3253        );
3254    }
3255
3256    #[test]
3257    fn test_merge_preserves_order() {
3258        use std::cmp::Reverse;
3259
3260        let merge = Merge { k: 10 };
3261
3262        // For Reverse, smaller inner values are "larger" in ordering
3263        // So vectors should be sorted with smallest inner values first
3264        let input = vec![
3265            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3266            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3267            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3268        ];
3269
3270        let result = merge.merge(input);
3271
3272        // Verify output maintains order - should be sorted by Reverse ordering
3273        // which means ascending inner values
3274        for i in 1..result.len() {
3275            assert!(
3276                result[i - 1] >= result[i],
3277                "Output should be in descending Reverse order"
3278            );
3279            assert!(
3280                result[i - 1].0 <= result[i].0,
3281                "Inner values should be in ascending order"
3282            );
3283        }
3284
3285        // Check we got the right elements
3286        assert_eq!(
3287            result,
3288            vec![
3289                Reverse(1),
3290                Reverse(2),
3291                Reverse(3),
3292                Reverse(4),
3293                Reverse(5),
3294                Reverse(6),
3295                Reverse(7),
3296                Reverse(8),
3297                Reverse(9),
3298                Reverse(10)
3299            ]
3300        );
3301    }
3302
3303    #[test]
3304    fn test_merge_single_vector() {
3305        let merge = Merge { k: 3 };
3306
3307        // Single vector input with u64
3308        let input = vec![vec![1000u64, 800, 600, 400, 200]];
3309
3310        let result = merge.merge(input);
3311
3312        assert_eq!(result, vec![1000, 800, 600]);
3313    }
3314
3315    #[test]
3316    fn test_merge_all_same_values() {
3317        let merge = Merge { k: 5 };
3318
3319        // All vectors contain the same value (using i16)
3320        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3321
3322        let result = merge.merge(input);
3323
3324        // Should deduplicate to single value
3325        assert_eq!(result, vec![42]);
3326    }
3327
3328    #[test]
3329    fn test_merge_mixed_types_sizes() {
3330        // Test with usize (common in real usage)
3331        let merge = Merge { k: 4 };
3332        let input = vec![
3333            vec![1000usize, 500, 100],
3334            vec![800, 300],
3335            vec![900, 600, 200],
3336        ];
3337        let result = merge.merge(input);
3338        assert_eq!(result, vec![1000, 900, 800, 600]);
3339
3340        // Test with negative integers (i32)
3341        let merge = Merge { k: 5 };
3342        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3343        let result = merge.merge(input);
3344        assert_eq!(result, vec![15, 10, 5, 0, -5]);
3345    }
3346}