chroma_types/execution/
operator.rs

1use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
2use serde_json::Value;
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashSet},
6    fmt,
7    hash::Hash,
8    ops::{Add, Div, Mul, Neg, Sub},
9};
10use thiserror::Error;
11
12use crate::{
13    chroma_proto, logical_size_of_metadata, parse_where, CollectionAndSegments, CollectionUuid,
14    DocumentExpression, DocumentOperator, Metadata, MetadataComparison, MetadataExpression,
15    MetadataSetValue, MetadataValue, PrimitiveOperator, ScalarEncoding, SetOperator, SparseVector,
16    Where,
17};
18
19use super::error::QueryConversionError;
20
21pub type InitialInput = ();
22
23/// The `Scan` opeartor pins the data used by all downstream operators
24///
25/// # Parameters
26/// - `collection_and_segments`: The consistent snapshot of collection
27#[derive(Clone, Debug)]
28pub struct Scan {
29    pub collection_and_segments: CollectionAndSegments,
30}
31
32impl TryFrom<chroma_proto::ScanOperator> for Scan {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::ScanOperator) -> Result<Self, Self::Error> {
36        Ok(Self {
37            collection_and_segments: CollectionAndSegments {
38                collection: value
39                    .collection
40                    .ok_or(QueryConversionError::field("collection"))?
41                    .try_into()?,
42                metadata_segment: value
43                    .metadata
44                    .ok_or(QueryConversionError::field("metadata segment"))?
45                    .try_into()?,
46                record_segment: value
47                    .record
48                    .ok_or(QueryConversionError::field("record segment"))?
49                    .try_into()?,
50                vector_segment: value
51                    .knn
52                    .ok_or(QueryConversionError::field("vector segment"))?
53                    .try_into()?,
54            },
55        })
56    }
57}
58
59#[derive(Debug, Error)]
60pub enum ScanToProtoError {
61    #[error("Could not convert collection to proto")]
62    CollectionToProto(#[from] crate::CollectionToProtoError),
63}
64
65impl TryFrom<Scan> for chroma_proto::ScanOperator {
66    type Error = ScanToProtoError;
67
68    fn try_from(value: Scan) -> Result<Self, Self::Error> {
69        Ok(Self {
70            collection: Some(value.collection_and_segments.collection.try_into()?),
71            knn: Some(value.collection_and_segments.vector_segment.into()),
72            metadata: Some(value.collection_and_segments.metadata_segment.into()),
73            record: Some(value.collection_and_segments.record_segment.into()),
74        })
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct CountResult {
80    pub count: u32,
81    pub pulled_log_bytes: u64,
82}
83
84impl CountResult {
85    pub fn size_bytes(&self) -> u64 {
86        size_of_val(&self.count) as u64
87    }
88}
89
90impl From<chroma_proto::CountResult> for CountResult {
91    fn from(value: chroma_proto::CountResult) -> Self {
92        Self {
93            count: value.count,
94            pulled_log_bytes: value.pulled_log_bytes,
95        }
96    }
97}
98
99impl From<CountResult> for chroma_proto::CountResult {
100    fn from(value: CountResult) -> Self {
101        Self {
102            count: value.count,
103            pulled_log_bytes: value.pulled_log_bytes,
104        }
105    }
106}
107
108/// The `FetchLog` operator fetches logs from the log service
109///
110/// # Parameters
111/// - `start_log_offset_id`: The offset id of the first log to read
112/// - `maximum_fetch_count`: The maximum number of logs to fetch in total
113/// - `collection_uuid`: The uuid of the collection where the fetched logs should belong
114#[derive(Clone, Debug)]
115pub struct FetchLog {
116    pub collection_uuid: CollectionUuid,
117    pub maximum_fetch_count: Option<u32>,
118    pub start_log_offset_id: u32,
119}
120
121/// Filter the search results.
122///
123/// Combines document ID filtering with metadata and document content predicates.
124/// For the Search API, use `where_clause` with Key expressions.
125///
126/// # Fields
127///
128/// * `query_ids` - Optional list of document IDs to filter (legacy, prefer Where expressions)
129/// * `where_clause` - Predicate on document metadata, content, or IDs
130///
131/// # Examples
132///
133/// ## Simple metadata filter
134///
135/// ```
136/// use chroma_types::operator::{Filter, Key};
137///
138/// let filter = Filter {
139///     query_ids: None,
140///     where_clause: Some(Key::field("status").eq("published")),
141/// };
142/// ```
143///
144/// ## Combined filters
145///
146/// ```
147/// use chroma_types::operator::{Filter, Key};
148///
149/// let filter = Filter {
150///     query_ids: None,
151///     where_clause: Some(
152///         Key::field("status").eq("published")
153///             & Key::field("year").gte(2020)
154///             & Key::field("category").is_in(vec!["tech", "science"])
155///     ),
156/// };
157/// ```
158///
159/// ## Document content filter
160///
161/// ```
162/// use chroma_types::operator::{Filter, Key};
163///
164/// let filter = Filter {
165///     query_ids: None,
166///     where_clause: Some(Key::Document.contains("machine learning")),
167/// };
168/// ```
169#[derive(Clone, Debug, Default)]
170pub struct Filter {
171    pub query_ids: Option<Vec<String>>,
172    pub where_clause: Option<Where>,
173}
174
175impl Serialize for Filter {
176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        // For the search API, serialize directly as the where clause (or empty object if None)
181        // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...])
182
183        match (&self.query_ids, &self.where_clause) {
184            (None, None) => {
185                // No filter at all - serialize empty object
186                let map = serializer.serialize_map(Some(0))?;
187                map.end()
188            }
189            (None, Some(where_clause)) => {
190                // Only where clause - serialize it directly
191                where_clause.serialize(serializer)
192            }
193            (Some(ids), None) => {
194                // Only query_ids - create Where clause: Key::ID.is_in(ids)
195                let id_where = Where::Metadata(MetadataExpression {
196                    key: "#id".to_string(),
197                    comparison: MetadataComparison::Set(
198                        SetOperator::In,
199                        MetadataSetValue::Str(ids.clone()),
200                    ),
201                });
202                id_where.serialize(serializer)
203            }
204            (Some(ids), Some(where_clause)) => {
205                // Both present - combine with AND: Key::ID.is_in(ids) & where_clause
206                let id_where = Where::Metadata(MetadataExpression {
207                    key: "#id".to_string(),
208                    comparison: MetadataComparison::Set(
209                        SetOperator::In,
210                        MetadataSetValue::Str(ids.clone()),
211                    ),
212                });
213                let combined = Where::conjunction(vec![id_where, where_clause.clone()]);
214                combined.serialize(serializer)
215            }
216        }
217    }
218}
219
220impl<'de> Deserialize<'de> for Filter {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        // For the new search API, the entire JSON is the where clause
226        let where_json = Value::deserialize(deserializer)?;
227        let where_clause =
228            if where_json.is_null() || where_json.as_object().is_some_and(|obj| obj.is_empty()) {
229                None
230            } else {
231                Some(parse_where(&where_json).map_err(|e| D::Error::custom(e.to_string()))?)
232            };
233
234        Ok(Filter {
235            query_ids: None, // Always None for new search API
236            where_clause,
237        })
238    }
239}
240
241impl TryFrom<chroma_proto::FilterOperator> for Filter {
242    type Error = QueryConversionError;
243
244    fn try_from(value: chroma_proto::FilterOperator) -> Result<Self, Self::Error> {
245        let where_metadata = value.r#where.map(TryInto::try_into).transpose()?;
246        let where_document = value.where_document.map(TryInto::try_into).transpose()?;
247        let where_clause = match (where_metadata, where_document) {
248            (Some(w), Some(wd)) => Some(Where::conjunction(vec![w, wd])),
249            (Some(w), None) | (None, Some(w)) => Some(w),
250            _ => None,
251        };
252
253        Ok(Self {
254            query_ids: value.ids.map(|uids| uids.ids),
255            where_clause,
256        })
257    }
258}
259
260impl TryFrom<Filter> for chroma_proto::FilterOperator {
261    type Error = QueryConversionError;
262
263    fn try_from(value: Filter) -> Result<Self, Self::Error> {
264        Ok(Self {
265            ids: value.query_ids.map(|ids| chroma_proto::UserIds { ids }),
266            r#where: value.where_clause.map(TryInto::try_into).transpose()?,
267            where_document: None,
268        })
269    }
270}
271
272/// The `Knn` operator searches for the nearest neighbours of the specified embedding. This is intended to use by executor
273///
274/// # Parameters
275/// - `embedding`: The target embedding to search around
276/// - `fetch`: The number of records to fetch around the target
277#[derive(Clone, Debug)]
278pub struct Knn {
279    pub embedding: Vec<f32>,
280    pub fetch: u32,
281}
282
283impl From<KnnBatch> for Vec<Knn> {
284    fn from(value: KnnBatch) -> Self {
285        value
286            .embeddings
287            .into_iter()
288            .map(|embedding| Knn {
289                embedding,
290                fetch: value.fetch,
291            })
292            .collect()
293    }
294}
295
296/// The `KnnBatch` operator searches for the nearest neighbours of the specified embedding. This is intended to use by frontend
297///
298/// # Parameters
299/// - `embedding`: The target embedding to search around
300/// - `fetch`: The number of records to fetch around the target
301#[derive(Clone, Debug)]
302pub struct KnnBatch {
303    pub embeddings: Vec<Vec<f32>>,
304    pub fetch: u32,
305}
306
307impl TryFrom<chroma_proto::KnnOperator> for KnnBatch {
308    type Error = QueryConversionError;
309
310    fn try_from(value: chroma_proto::KnnOperator) -> Result<Self, Self::Error> {
311        Ok(Self {
312            embeddings: value
313                .embeddings
314                .into_iter()
315                .map(|vec| vec.try_into().map(|(v, _)| v))
316                .collect::<Result<_, _>>()?,
317            fetch: value.fetch,
318        })
319    }
320}
321
322impl TryFrom<KnnBatch> for chroma_proto::KnnOperator {
323    type Error = QueryConversionError;
324
325    fn try_from(value: KnnBatch) -> Result<Self, Self::Error> {
326        Ok(Self {
327            embeddings: value
328                .embeddings
329                .into_iter()
330                .map(|embedding| {
331                    let dim = embedding.len();
332                    chroma_proto::Vector::try_from((embedding, ScalarEncoding::FLOAT32, dim))
333                })
334                .collect::<Result<_, _>>()?,
335            fetch: value.fetch,
336        })
337    }
338}
339
340/// Pagination control for search results.
341///
342/// Controls how many results to return and how many to skip for pagination.
343///
344/// # Fields
345///
346/// * `offset` - Number of results to skip (default: 0)
347/// * `limit` - Maximum results to return (None = no limit)
348///
349/// # Examples
350///
351/// ```
352/// use chroma_types::operator::Limit;
353///
354/// // First page: results 0-9
355/// let limit = Limit {
356///     offset: 0,
357///     limit: Some(10),
358/// };
359///
360/// // Second page: results 10-19
361/// let limit = Limit {
362///     offset: 10,
363///     limit: Some(10),
364/// };
365///
366/// // No limit: all results
367/// let limit = Limit {
368///     offset: 0,
369///     limit: None,
370/// };
371/// ```
372#[derive(Clone, Debug, Default, Deserialize, Serialize)]
373pub struct Limit {
374    #[serde(default)]
375    pub offset: u32,
376    #[serde(default)]
377    pub limit: Option<u32>,
378}
379
380impl From<chroma_proto::LimitOperator> for Limit {
381    fn from(value: chroma_proto::LimitOperator) -> Self {
382        Self {
383            offset: value.offset,
384            limit: value.limit,
385        }
386    }
387}
388
389impl From<Limit> for chroma_proto::LimitOperator {
390    fn from(value: Limit) -> Self {
391        Self {
392            offset: value.offset,
393            limit: value.limit,
394        }
395    }
396}
397
398/// The `RecordDistance` represents a measure of embedding (identified by `offset_id`) with respect to query embedding
399#[derive(Clone, Debug)]
400pub struct RecordMeasure {
401    pub offset_id: u32,
402    pub measure: f32,
403}
404
405impl PartialEq for RecordMeasure {
406    fn eq(&self, other: &Self) -> bool {
407        self.offset_id.eq(&other.offset_id)
408    }
409}
410
411impl Eq for RecordMeasure {}
412
413impl Ord for RecordMeasure {
414    fn cmp(&self, other: &Self) -> Ordering {
415        self.measure.total_cmp(&other.measure)
416    }
417}
418
419impl PartialOrd for RecordMeasure {
420    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
421        Some(self.cmp(other))
422    }
423}
424
425#[derive(Debug, Default)]
426pub struct KnnOutput {
427    pub distances: Vec<RecordMeasure>,
428}
429
430/// The `Merge` operator selects the top records from the batch vectors of records
431/// which are all sorted in descending order. If the same record occurs multiple times
432/// only one copy will remain in the final result.
433///
434/// # Parameters
435/// - `k`: The total number of records to take after merge
436///
437/// # Usage
438/// It can be used to merge the query results from different operators
439#[derive(Clone, Debug)]
440pub struct Merge {
441    pub k: u32,
442}
443
444impl Merge {
445    pub fn merge<M: Eq + Ord>(&self, input: Vec<Vec<M>>) -> Vec<M> {
446        let mut batch_iters = input.into_iter().map(Vec::into_iter).collect::<Vec<_>>();
447
448        let mut max_heap = batch_iters
449            .iter_mut()
450            .enumerate()
451            .filter_map(|(idx, itr)| itr.next().map(|rec| (rec, idx)))
452            .collect::<BinaryHeap<_>>();
453
454        let mut fusion = Vec::with_capacity(self.k as usize);
455        while let Some((m, idx)) = max_heap.pop() {
456            if self.k <= fusion.len() as u32 {
457                break;
458            }
459            if let Some(next_m) = batch_iters[idx].next() {
460                max_heap.push((next_m, idx));
461            }
462            if fusion.last().is_some_and(|tail| tail == &m) {
463                continue;
464            }
465            fusion.push(m);
466        }
467        fusion
468    }
469}
470
471/// The `Projection` operator retrieves record content by offset ids
472///
473/// # Parameters
474/// - `document`: Whether to retrieve document
475/// - `embedding`: Whether to retrieve embedding
476/// - `metadata`: Whether to retrieve metadata
477#[derive(Clone, Debug, Default)]
478pub struct Projection {
479    pub document: bool,
480    pub embedding: bool,
481    pub metadata: bool,
482}
483
484impl From<chroma_proto::ProjectionOperator> for Projection {
485    fn from(value: chroma_proto::ProjectionOperator) -> Self {
486        Self {
487            document: value.document,
488            embedding: value.embedding,
489            metadata: value.metadata,
490        }
491    }
492}
493
494impl From<Projection> for chroma_proto::ProjectionOperator {
495    fn from(value: Projection) -> Self {
496        Self {
497            document: value.document,
498            embedding: value.embedding,
499            metadata: value.metadata,
500        }
501    }
502}
503
504#[derive(Clone, Debug, PartialEq)]
505pub struct ProjectionRecord {
506    pub id: String,
507    pub document: Option<String>,
508    pub embedding: Option<Vec<f32>>,
509    pub metadata: Option<Metadata>,
510}
511
512impl ProjectionRecord {
513    pub fn size_bytes(&self) -> u64 {
514        (self.id.len()
515            + self
516                .document
517                .as_ref()
518                .map(|doc| doc.len())
519                .unwrap_or_default()
520            + self
521                .embedding
522                .as_ref()
523                .map(|emb| size_of_val(&emb[..]))
524                .unwrap_or_default()
525            + self
526                .metadata
527                .as_ref()
528                .map(logical_size_of_metadata)
529                .unwrap_or_default()) as u64
530    }
531}
532
533impl Eq for ProjectionRecord {}
534
535impl TryFrom<chroma_proto::ProjectionRecord> for ProjectionRecord {
536    type Error = QueryConversionError;
537
538    fn try_from(value: chroma_proto::ProjectionRecord) -> Result<Self, Self::Error> {
539        Ok(Self {
540            id: value.id,
541            document: value.document,
542            embedding: value
543                .embedding
544                .map(|vec| vec.try_into().map(|(v, _)| v))
545                .transpose()?,
546            metadata: value.metadata.map(TryInto::try_into).transpose()?,
547        })
548    }
549}
550
551impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
552    type Error = QueryConversionError;
553
554    fn try_from(value: ProjectionRecord) -> Result<Self, Self::Error> {
555        Ok(Self {
556            id: value.id,
557            document: value.document,
558            embedding: value
559                .embedding
560                .map(|embedding| {
561                    let embedding_dimension = embedding.len();
562                    chroma_proto::Vector::try_from((
563                        embedding,
564                        ScalarEncoding::FLOAT32,
565                        embedding_dimension,
566                    ))
567                })
568                .transpose()?,
569            metadata: value.metadata.map(|metadata| metadata.into()),
570        })
571    }
572}
573
574#[derive(Clone, Debug, Eq, PartialEq)]
575pub struct ProjectionOutput {
576    pub records: Vec<ProjectionRecord>,
577}
578
579#[derive(Clone, Debug, Eq, PartialEq)]
580pub struct GetResult {
581    pub pulled_log_bytes: u64,
582    pub result: ProjectionOutput,
583}
584
585impl GetResult {
586    pub fn size_bytes(&self) -> u64 {
587        self.result
588            .records
589            .iter()
590            .map(ProjectionRecord::size_bytes)
591            .sum()
592    }
593}
594
595impl TryFrom<chroma_proto::GetResult> for GetResult {
596    type Error = QueryConversionError;
597
598    fn try_from(value: chroma_proto::GetResult) -> Result<Self, Self::Error> {
599        Ok(Self {
600            pulled_log_bytes: value.pulled_log_bytes,
601            result: ProjectionOutput {
602                records: value
603                    .records
604                    .into_iter()
605                    .map(TryInto::try_into)
606                    .collect::<Result<_, _>>()?,
607            },
608        })
609    }
610}
611
612impl TryFrom<GetResult> for chroma_proto::GetResult {
613    type Error = QueryConversionError;
614
615    fn try_from(value: GetResult) -> Result<Self, Self::Error> {
616        Ok(Self {
617            pulled_log_bytes: value.pulled_log_bytes,
618            records: value
619                .result
620                .records
621                .into_iter()
622                .map(TryInto::try_into)
623                .collect::<Result<_, _>>()?,
624        })
625    }
626}
627
628/// The `KnnProjection` operator retrieves record content by offset ids
629/// It is based on `ProjectionOperator`, and it attaches the distance
630/// of the records to the target embedding to the record content
631///
632/// # Parameters
633/// - `projection`: The parameters of the `ProjectionOperator`
634/// - `distance`: Whether to attach distance information
635#[derive(Clone, Debug)]
636pub struct KnnProjection {
637    pub projection: Projection,
638    pub distance: bool,
639}
640
641impl TryFrom<chroma_proto::KnnProjectionOperator> for KnnProjection {
642    type Error = QueryConversionError;
643
644    fn try_from(value: chroma_proto::KnnProjectionOperator) -> Result<Self, Self::Error> {
645        Ok(Self {
646            projection: value
647                .projection
648                .ok_or(QueryConversionError::field("projection"))?
649                .into(),
650            distance: value.distance,
651        })
652    }
653}
654
655impl From<KnnProjection> for chroma_proto::KnnProjectionOperator {
656    fn from(value: KnnProjection) -> Self {
657        Self {
658            projection: Some(value.projection.into()),
659            distance: value.distance,
660        }
661    }
662}
663
664#[derive(Clone, Debug)]
665pub struct KnnProjectionRecord {
666    pub record: ProjectionRecord,
667    pub distance: Option<f32>,
668}
669
670impl TryFrom<chroma_proto::KnnProjectionRecord> for KnnProjectionRecord {
671    type Error = QueryConversionError;
672
673    fn try_from(value: chroma_proto::KnnProjectionRecord) -> Result<Self, Self::Error> {
674        Ok(Self {
675            record: value
676                .record
677                .ok_or(QueryConversionError::field("record"))?
678                .try_into()?,
679            distance: value.distance,
680        })
681    }
682}
683
684impl TryFrom<KnnProjectionRecord> for chroma_proto::KnnProjectionRecord {
685    type Error = QueryConversionError;
686
687    fn try_from(value: KnnProjectionRecord) -> Result<Self, Self::Error> {
688        Ok(Self {
689            record: Some(value.record.try_into()?),
690            distance: value.distance,
691        })
692    }
693}
694
695#[derive(Clone, Debug, Default)]
696pub struct KnnProjectionOutput {
697    pub records: Vec<KnnProjectionRecord>,
698}
699
700impl TryFrom<chroma_proto::KnnResult> for KnnProjectionOutput {
701    type Error = QueryConversionError;
702
703    fn try_from(value: chroma_proto::KnnResult) -> Result<Self, Self::Error> {
704        Ok(Self {
705            records: value
706                .records
707                .into_iter()
708                .map(TryInto::try_into)
709                .collect::<Result<_, _>>()?,
710        })
711    }
712}
713
714impl TryFrom<KnnProjectionOutput> for chroma_proto::KnnResult {
715    type Error = QueryConversionError;
716
717    fn try_from(value: KnnProjectionOutput) -> Result<Self, Self::Error> {
718        Ok(Self {
719            records: value
720                .records
721                .into_iter()
722                .map(TryInto::try_into)
723                .collect::<Result<_, _>>()?,
724        })
725    }
726}
727
728#[derive(Clone, Debug, Default)]
729pub struct KnnBatchResult {
730    pub pulled_log_bytes: u64,
731    pub results: Vec<KnnProjectionOutput>,
732}
733
734impl KnnBatchResult {
735    pub fn size_bytes(&self) -> u64 {
736        self.results
737            .iter()
738            .flat_map(|res| {
739                res.records
740                    .iter()
741                    .map(|rec| rec.record.size_bytes() + size_of_val(&rec.distance) as u64)
742            })
743            .sum()
744    }
745}
746
747impl TryFrom<chroma_proto::KnnBatchResult> for KnnBatchResult {
748    type Error = QueryConversionError;
749
750    fn try_from(value: chroma_proto::KnnBatchResult) -> Result<Self, Self::Error> {
751        Ok(Self {
752            pulled_log_bytes: value.pulled_log_bytes,
753            results: value
754                .results
755                .into_iter()
756                .map(TryInto::try_into)
757                .collect::<Result<_, _>>()?,
758        })
759    }
760}
761
762impl TryFrom<KnnBatchResult> for chroma_proto::KnnBatchResult {
763    type Error = QueryConversionError;
764
765    fn try_from(value: KnnBatchResult) -> Result<Self, Self::Error> {
766        Ok(Self {
767            pulled_log_bytes: value.pulled_log_bytes,
768            results: value
769                .results
770                .into_iter()
771                .map(TryInto::try_into)
772                .collect::<Result<_, _>>()?,
773        })
774    }
775}
776
777/// A query vector for KNN search.
778///
779/// Supports both dense and sparse vector formats.
780///
781/// # Variants
782///
783/// ## Dense
784///
785/// Standard dense embeddings as a vector of floats.
786///
787/// ```
788/// use chroma_types::operator::QueryVector;
789///
790/// let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3, 0.4]);
791/// ```
792///
793/// ## Sparse
794///
795/// Sparse vectors with explicit indices and values.
796///
797/// ```
798/// use chroma_types::operator::QueryVector;
799/// use chroma_types::SparseVector;
800///
801/// let sparse = QueryVector::Sparse(SparseVector::new(
802///     vec![0, 5, 10, 50],      // indices
803///     vec![0.5, 0.3, 0.8, 0.2] // values
804/// ));
805/// ```
806///
807/// # Examples
808///
809/// ## Dense vector in KNN
810///
811/// ```
812/// use chroma_types::operator::{RankExpr, QueryVector, Key};
813///
814/// let rank = RankExpr::Knn {
815///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
816///     key: Key::Embedding,
817///     limit: 100,
818///     default: None,
819///     return_rank: false,
820/// };
821/// ```
822///
823/// ## Sparse vector in KNN
824///
825/// ```
826/// use chroma_types::operator::{RankExpr, QueryVector, Key};
827/// use chroma_types::SparseVector;
828///
829/// let rank = RankExpr::Knn {
830///     query: QueryVector::Sparse(SparseVector::new(
831///         vec![1, 5, 10],
832///         vec![0.5, 0.3, 0.8]
833///     )),
834///     key: Key::field("sparse_embedding"),
835///     limit: 100,
836///     default: None,
837///     return_rank: false,
838/// };
839/// ```
840#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
841#[serde(untagged)]
842pub enum QueryVector {
843    Dense(Vec<f32>),
844    Sparse(SparseVector),
845}
846
847impl TryFrom<chroma_proto::QueryVector> for QueryVector {
848    type Error = QueryConversionError;
849
850    fn try_from(value: chroma_proto::QueryVector) -> Result<Self, Self::Error> {
851        let vector = value.vector.ok_or(QueryConversionError::field("vector"))?;
852        match vector {
853            chroma_proto::query_vector::Vector::Dense(dense) => {
854                Ok(QueryVector::Dense(dense.try_into().map(|(v, _)| v)?))
855            }
856            chroma_proto::query_vector::Vector::Sparse(sparse) => {
857                Ok(QueryVector::Sparse(sparse.into()))
858            }
859        }
860    }
861}
862
863impl TryFrom<QueryVector> for chroma_proto::QueryVector {
864    type Error = QueryConversionError;
865
866    fn try_from(value: QueryVector) -> Result<Self, Self::Error> {
867        match value {
868            QueryVector::Dense(vec) => {
869                let dim = vec.len();
870                Ok(chroma_proto::QueryVector {
871                    vector: Some(chroma_proto::query_vector::Vector::Dense(
872                        chroma_proto::Vector::try_from((vec, ScalarEncoding::FLOAT32, dim))?,
873                    )),
874                })
875            }
876            QueryVector::Sparse(sparse) => Ok(chroma_proto::QueryVector {
877                vector: Some(chroma_proto::query_vector::Vector::Sparse(sparse.into())),
878            }),
879        }
880    }
881}
882
883impl From<Vec<f32>> for QueryVector {
884    fn from(vec: Vec<f32>) -> Self {
885        QueryVector::Dense(vec)
886    }
887}
888
889impl From<SparseVector> for QueryVector {
890    fn from(sparse: SparseVector) -> Self {
891        QueryVector::Sparse(sparse)
892    }
893}
894
895#[derive(Clone, Debug, PartialEq)]
896pub struct KnnQuery {
897    pub query: QueryVector,
898    pub key: Key,
899    pub limit: u32,
900}
901
902/// Wrapper for ranking expressions in search queries.
903///
904/// Contains an optional ranking expression. When None, results are returned in
905/// natural storage order without scoring.
906///
907/// # Fields
908///
909/// * `expr` - The ranking expression (None = no ranking)
910///
911/// # Examples
912///
913/// ```
914/// use chroma_types::operator::{Rank, RankExpr, QueryVector, Key};
915///
916/// // With ranking
917/// let rank = Rank {
918///     expr: Some(RankExpr::Knn {
919///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
920///         key: Key::Embedding,
921///         limit: 100,
922///         default: None,
923///         return_rank: false,
924///     }),
925/// };
926///
927/// // No ranking (natural order)
928/// let rank = Rank {
929///     expr: None,
930/// };
931/// ```
932#[derive(Clone, Debug, Default, Deserialize, Serialize)]
933#[serde(transparent)]
934pub struct Rank {
935    pub expr: Option<RankExpr>,
936}
937
938impl Rank {
939    pub fn knn_queries(&self) -> Vec<KnnQuery> {
940        self.expr
941            .as_ref()
942            .map(RankExpr::knn_queries)
943            .unwrap_or_default()
944    }
945}
946
947impl TryFrom<chroma_proto::RankOperator> for Rank {
948    type Error = QueryConversionError;
949
950    fn try_from(proto_rank: chroma_proto::RankOperator) -> Result<Self, Self::Error> {
951        Ok(Rank {
952            expr: proto_rank.expr.map(TryInto::try_into).transpose()?,
953        })
954    }
955}
956
957impl TryFrom<Rank> for chroma_proto::RankOperator {
958    type Error = QueryConversionError;
959
960    fn try_from(rank: Rank) -> Result<Self, Self::Error> {
961        Ok(chroma_proto::RankOperator {
962            expr: rank.expr.map(TryInto::try_into).transpose()?,
963        })
964    }
965}
966
967/// A ranking expression for scoring and ordering search results.
968///
969/// Ranking expressions determine which documents appear in results and their order.
970/// Lower scores indicate better matches (distance-based scoring).
971///
972/// # Variants
973///
974/// ## Knn - K-Nearest Neighbor Search
975///
976/// The primary ranking method for vector similarity search.
977///
978/// ```
979/// use chroma_types::operator::{RankExpr, QueryVector, Key};
980///
981/// let rank = RankExpr::Knn {
982///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
983///     key: Key::Embedding,
984///     limit: 100,        // Consider top 100 candidates
985///     default: None,     // No default score for missing documents
986///     return_rank: false, // Return distances, not rank positions
987/// };
988/// ```
989///
990/// ## Value - Constant
991///
992/// Represents a constant score.
993///
994/// ```
995/// use chroma_types::operator::RankExpr;
996///
997/// let rank = RankExpr::Value(0.5);
998/// ```
999///
1000/// ## Arithmetic Operations
1001///
1002/// Combine ranking expressions using standard operators (+, -, *, /).
1003///
1004/// ```
1005/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1006///
1007/// let knn1 = RankExpr::Knn {
1008///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1009///     key: Key::Embedding,
1010///     limit: 100,
1011///     default: None,
1012///     return_rank: false,
1013/// };
1014///
1015/// let knn2 = RankExpr::Knn {
1016///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1017///     key: Key::field("other_embedding"),
1018///     limit: 100,
1019///     default: None,
1020///     return_rank: false,
1021/// };
1022///
1023/// // Weighted combination: 70% knn1 + 30% knn2
1024/// let combined = knn1 * 0.7 + knn2 * 0.3;
1025///
1026/// // Normalized
1027/// let normalized = combined / 2.0;
1028/// ```
1029///
1030/// ## Mathematical Functions
1031///
1032/// Apply mathematical transformations to scores.
1033///
1034/// ```
1035/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1036///
1037/// let knn = RankExpr::Knn {
1038///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1039///     key: Key::Embedding,
1040///     limit: 100,
1041///     default: None,
1042///     return_rank: false,
1043/// };
1044///
1045/// // Exponential - amplifies differences
1046/// let amplified = knn.clone().exp();
1047///
1048/// // Logarithm - compresses range (add constant to avoid log(0))
1049/// let compressed = (knn.clone() + 1.0).log();
1050///
1051/// // Absolute value
1052/// let absolute = knn.clone().abs();
1053///
1054/// // Min/Max - clamping
1055/// let clamped = knn.min(1.0).max(0.0);
1056/// ```
1057///
1058/// # Examples
1059///
1060/// ## Basic vector search
1061///
1062/// ```
1063/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1064///
1065/// let rank = RankExpr::Knn {
1066///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1067///     key: Key::Embedding,
1068///     limit: 100,
1069///     default: None,
1070///     return_rank: false,
1071/// };
1072/// ```
1073///
1074/// ## Hybrid search with weighted combination
1075///
1076/// ```
1077/// use chroma_types::operator::{RankExpr, QueryVector, Key};
1078///
1079/// let dense = RankExpr::Knn {
1080///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1081///     key: Key::Embedding,
1082///     limit: 200,
1083///     default: None,
1084///     return_rank: false,
1085/// };
1086///
1087/// let sparse = RankExpr::Knn {
1088///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]), // Use sparse in practice
1089///     key: Key::field("sparse_embedding"),
1090///     limit: 200,
1091///     default: None,
1092///     return_rank: false,
1093/// };
1094///
1095/// // 70% semantic + 30% keyword
1096/// let hybrid = dense * 0.7 + sparse * 0.3;
1097/// ```
1098///
1099/// ## Reciprocal Rank Fusion (RRF)
1100///
1101/// Use the `rrf()` function for combining rankings with different score scales.
1102///
1103/// ```
1104/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
1105///
1106/// let dense = RankExpr::Knn {
1107///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1108///     key: Key::Embedding,
1109///     limit: 200,
1110///     default: None,
1111///     return_rank: true, // RRF requires rank positions
1112/// };
1113///
1114/// let sparse = RankExpr::Knn {
1115///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1116///     key: Key::field("sparse_embedding"),
1117///     limit: 200,
1118///     default: None,
1119///     return_rank: true, // RRF requires rank positions
1120/// };
1121///
1122/// let rrf_rank = rrf(
1123///     vec![dense, sparse],
1124///     Some(60),           // k parameter (smoothing)
1125///     Some(vec![0.7, 0.3]), // weights
1126///     false,              // normalize weights
1127/// ).unwrap();
1128/// ```
1129#[derive(Clone, Debug, Deserialize, Serialize)]
1130pub enum RankExpr {
1131    #[serde(rename = "$abs")]
1132    Absolute(Box<RankExpr>),
1133    #[serde(rename = "$div")]
1134    Division {
1135        left: Box<RankExpr>,
1136        right: Box<RankExpr>,
1137    },
1138    #[serde(rename = "$exp")]
1139    Exponentiation(Box<RankExpr>),
1140    #[serde(rename = "$knn")]
1141    Knn {
1142        query: QueryVector,
1143        #[serde(default = "RankExpr::default_knn_key")]
1144        key: Key,
1145        #[serde(default = "RankExpr::default_knn_limit")]
1146        limit: u32,
1147        #[serde(default)]
1148        default: Option<f32>,
1149        #[serde(default)]
1150        return_rank: bool,
1151    },
1152    #[serde(rename = "$log")]
1153    Logarithm(Box<RankExpr>),
1154    #[serde(rename = "$max")]
1155    Maximum(Vec<RankExpr>),
1156    #[serde(rename = "$min")]
1157    Minimum(Vec<RankExpr>),
1158    #[serde(rename = "$mul")]
1159    Multiplication(Vec<RankExpr>),
1160    #[serde(rename = "$sub")]
1161    Subtraction {
1162        left: Box<RankExpr>,
1163        right: Box<RankExpr>,
1164    },
1165    #[serde(rename = "$sum")]
1166    Summation(Vec<RankExpr>),
1167    #[serde(rename = "$val")]
1168    Value(f32),
1169}
1170
1171impl RankExpr {
1172    pub fn default_knn_key() -> Key {
1173        Key::Embedding
1174    }
1175
1176    pub fn default_knn_limit() -> u32 {
1177        16
1178    }
1179
1180    pub fn knn_queries(&self) -> Vec<KnnQuery> {
1181        match self {
1182            RankExpr::Absolute(expr)
1183            | RankExpr::Exponentiation(expr)
1184            | RankExpr::Logarithm(expr) => expr.knn_queries(),
1185            RankExpr::Division { left, right } | RankExpr::Subtraction { left, right } => left
1186                .knn_queries()
1187                .into_iter()
1188                .chain(right.knn_queries())
1189                .collect(),
1190            RankExpr::Maximum(exprs)
1191            | RankExpr::Minimum(exprs)
1192            | RankExpr::Multiplication(exprs)
1193            | RankExpr::Summation(exprs) => exprs.iter().flat_map(RankExpr::knn_queries).collect(),
1194            RankExpr::Value(_) => Vec::new(),
1195            RankExpr::Knn {
1196                query,
1197                key,
1198                limit,
1199                default: _,
1200                return_rank: _,
1201            } => vec![KnnQuery {
1202                query: query.clone(),
1203                key: key.clone(),
1204                limit: *limit,
1205            }],
1206        }
1207    }
1208
1209    /// Applies exponential transformation: e^rank.
1210    ///
1211    /// Amplifies differences between scores.
1212    ///
1213    /// # Examples
1214    ///
1215    /// ```
1216    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1217    ///
1218    /// let knn = RankExpr::Knn {
1219    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1220    ///     key: Key::Embedding,
1221    ///     limit: 100,
1222    ///     default: None,
1223    ///     return_rank: false,
1224    /// };
1225    ///
1226    /// let amplified = knn.exp();
1227    /// ```
1228    pub fn exp(self) -> Self {
1229        RankExpr::Exponentiation(Box::new(self))
1230    }
1231
1232    /// Applies natural logarithm transformation: ln(rank).
1233    ///
1234    /// Compresses the score range. Add a constant to avoid log(0).
1235    ///
1236    /// # Examples
1237    ///
1238    /// ```
1239    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1240    ///
1241    /// let knn = RankExpr::Knn {
1242    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1243    ///     key: Key::Embedding,
1244    ///     limit: 100,
1245    ///     default: None,
1246    ///     return_rank: false,
1247    /// };
1248    ///
1249    /// // Add constant to avoid log(0)
1250    /// let compressed = (knn + 1.0).log();
1251    /// ```
1252    pub fn log(self) -> Self {
1253        RankExpr::Logarithm(Box::new(self))
1254    }
1255
1256    /// Takes absolute value of the ranking expression.
1257    ///
1258    /// # Examples
1259    ///
1260    /// ```
1261    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1262    ///
1263    /// let knn1 = RankExpr::Knn {
1264    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1265    ///     key: Key::Embedding,
1266    ///     limit: 100,
1267    ///     default: None,
1268    ///     return_rank: false,
1269    /// };
1270    ///
1271    /// let knn2 = RankExpr::Knn {
1272    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
1273    ///     key: Key::field("other"),
1274    ///     limit: 100,
1275    ///     default: None,
1276    ///     return_rank: false,
1277    /// };
1278    ///
1279    /// // Absolute difference
1280    /// let diff = (knn1 - knn2).abs();
1281    /// ```
1282    pub fn abs(self) -> Self {
1283        RankExpr::Absolute(Box::new(self))
1284    }
1285
1286    /// Returns maximum of this expression and another.
1287    ///
1288    /// Can be chained to clamp scores to a maximum value.
1289    ///
1290    /// # Examples
1291    ///
1292    /// ```
1293    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1294    ///
1295    /// let knn = RankExpr::Knn {
1296    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1297    ///     key: Key::Embedding,
1298    ///     limit: 100,
1299    ///     default: None,
1300    ///     return_rank: false,
1301    /// };
1302    ///
1303    /// // Clamp to maximum of 1.0
1304    /// let clamped = knn.clone().max(1.0);
1305    ///
1306    /// // Clamp to range [0.0, 1.0]
1307    /// let range_clamped = knn.min(0.0).max(1.0);
1308    /// ```
1309    pub fn max(self, other: impl Into<RankExpr>) -> Self {
1310        let other = other.into();
1311
1312        match self {
1313            RankExpr::Maximum(mut exprs) => match other {
1314                RankExpr::Maximum(other_exprs) => {
1315                    exprs.extend(other_exprs);
1316                    RankExpr::Maximum(exprs)
1317                }
1318                _ => {
1319                    exprs.push(other);
1320                    RankExpr::Maximum(exprs)
1321                }
1322            },
1323            _ => match other {
1324                RankExpr::Maximum(mut exprs) => {
1325                    exprs.insert(0, self);
1326                    RankExpr::Maximum(exprs)
1327                }
1328                _ => RankExpr::Maximum(vec![self, other]),
1329            },
1330        }
1331    }
1332
1333    /// Returns minimum of this expression and another.
1334    ///
1335    /// Can be chained to clamp scores to a minimum value.
1336    ///
1337    /// # Examples
1338    ///
1339    /// ```
1340    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
1341    ///
1342    /// let knn = RankExpr::Knn {
1343    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
1344    ///     key: Key::Embedding,
1345    ///     limit: 100,
1346    ///     default: None,
1347    ///     return_rank: false,
1348    /// };
1349    ///
1350    /// // Clamp to minimum of 0.0 (ensure non-negative)
1351    /// let clamped = knn.clone().min(0.0);
1352    ///
1353    /// // Clamp to range [0.0, 1.0]
1354    /// let range_clamped = knn.min(0.0).max(1.0);
1355    /// ```
1356    pub fn min(self, other: impl Into<RankExpr>) -> Self {
1357        let other = other.into();
1358
1359        match self {
1360            RankExpr::Minimum(mut exprs) => match other {
1361                RankExpr::Minimum(other_exprs) => {
1362                    exprs.extend(other_exprs);
1363                    RankExpr::Minimum(exprs)
1364                }
1365                _ => {
1366                    exprs.push(other);
1367                    RankExpr::Minimum(exprs)
1368                }
1369            },
1370            _ => match other {
1371                RankExpr::Minimum(mut exprs) => {
1372                    exprs.insert(0, self);
1373                    RankExpr::Minimum(exprs)
1374                }
1375                _ => RankExpr::Minimum(vec![self, other]),
1376            },
1377        }
1378    }
1379}
1380
1381impl Add for RankExpr {
1382    type Output = RankExpr;
1383
1384    fn add(self, rhs: Self) -> Self::Output {
1385        match self {
1386            RankExpr::Summation(mut exprs) => match rhs {
1387                RankExpr::Summation(rhs_exprs) => {
1388                    exprs.extend(rhs_exprs);
1389                    RankExpr::Summation(exprs)
1390                }
1391                _ => {
1392                    exprs.push(rhs);
1393                    RankExpr::Summation(exprs)
1394                }
1395            },
1396            _ => match rhs {
1397                RankExpr::Summation(mut exprs) => {
1398                    exprs.insert(0, self);
1399                    RankExpr::Summation(exprs)
1400                }
1401                _ => RankExpr::Summation(vec![self, rhs]),
1402            },
1403        }
1404    }
1405}
1406
1407impl Add<f32> for RankExpr {
1408    type Output = RankExpr;
1409
1410    fn add(self, rhs: f32) -> Self::Output {
1411        self + RankExpr::Value(rhs)
1412    }
1413}
1414
1415impl Add<RankExpr> for f32 {
1416    type Output = RankExpr;
1417
1418    fn add(self, rhs: RankExpr) -> Self::Output {
1419        RankExpr::Value(self) + rhs
1420    }
1421}
1422
1423impl Sub for RankExpr {
1424    type Output = RankExpr;
1425
1426    fn sub(self, rhs: Self) -> Self::Output {
1427        RankExpr::Subtraction {
1428            left: Box::new(self),
1429            right: Box::new(rhs),
1430        }
1431    }
1432}
1433
1434impl Sub<f32> for RankExpr {
1435    type Output = RankExpr;
1436
1437    fn sub(self, rhs: f32) -> Self::Output {
1438        self - RankExpr::Value(rhs)
1439    }
1440}
1441
1442impl Sub<RankExpr> for f32 {
1443    type Output = RankExpr;
1444
1445    fn sub(self, rhs: RankExpr) -> Self::Output {
1446        RankExpr::Value(self) - rhs
1447    }
1448}
1449
1450impl Mul for RankExpr {
1451    type Output = RankExpr;
1452
1453    fn mul(self, rhs: Self) -> Self::Output {
1454        match self {
1455            RankExpr::Multiplication(mut exprs) => match rhs {
1456                RankExpr::Multiplication(rhs_exprs) => {
1457                    exprs.extend(rhs_exprs);
1458                    RankExpr::Multiplication(exprs)
1459                }
1460                _ => {
1461                    exprs.push(rhs);
1462                    RankExpr::Multiplication(exprs)
1463                }
1464            },
1465            _ => match rhs {
1466                RankExpr::Multiplication(mut exprs) => {
1467                    exprs.insert(0, self);
1468                    RankExpr::Multiplication(exprs)
1469                }
1470                _ => RankExpr::Multiplication(vec![self, rhs]),
1471            },
1472        }
1473    }
1474}
1475
1476impl Mul<f32> for RankExpr {
1477    type Output = RankExpr;
1478
1479    fn mul(self, rhs: f32) -> Self::Output {
1480        self * RankExpr::Value(rhs)
1481    }
1482}
1483
1484impl Mul<RankExpr> for f32 {
1485    type Output = RankExpr;
1486
1487    fn mul(self, rhs: RankExpr) -> Self::Output {
1488        RankExpr::Value(self) * rhs
1489    }
1490}
1491
1492impl Div for RankExpr {
1493    type Output = RankExpr;
1494
1495    fn div(self, rhs: Self) -> Self::Output {
1496        RankExpr::Division {
1497            left: Box::new(self),
1498            right: Box::new(rhs),
1499        }
1500    }
1501}
1502
1503impl Div<f32> for RankExpr {
1504    type Output = RankExpr;
1505
1506    fn div(self, rhs: f32) -> Self::Output {
1507        self / RankExpr::Value(rhs)
1508    }
1509}
1510
1511impl Div<RankExpr> for f32 {
1512    type Output = RankExpr;
1513
1514    fn div(self, rhs: RankExpr) -> Self::Output {
1515        RankExpr::Value(self) / rhs
1516    }
1517}
1518
1519impl Neg for RankExpr {
1520    type Output = RankExpr;
1521
1522    fn neg(self) -> Self::Output {
1523        RankExpr::Value(-1.0) * self
1524    }
1525}
1526
1527impl From<f32> for RankExpr {
1528    fn from(v: f32) -> Self {
1529        RankExpr::Value(v)
1530    }
1531}
1532
1533impl TryFrom<chroma_proto::RankExpr> for RankExpr {
1534    type Error = QueryConversionError;
1535
1536    fn try_from(proto_expr: chroma_proto::RankExpr) -> Result<Self, Self::Error> {
1537        match proto_expr.rank {
1538            Some(chroma_proto::rank_expr::Rank::Absolute(expr)) => {
1539                Ok(RankExpr::Absolute(Box::new(RankExpr::try_from(*expr)?)))
1540            }
1541            Some(chroma_proto::rank_expr::Rank::Division(div)) => {
1542                let left = div.left.ok_or(QueryConversionError::field("left"))?;
1543                let right = div.right.ok_or(QueryConversionError::field("right"))?;
1544                Ok(RankExpr::Division {
1545                    left: Box::new(RankExpr::try_from(*left)?),
1546                    right: Box::new(RankExpr::try_from(*right)?),
1547                })
1548            }
1549            Some(chroma_proto::rank_expr::Rank::Exponentiation(expr)) => Ok(
1550                RankExpr::Exponentiation(Box::new(RankExpr::try_from(*expr)?)),
1551            ),
1552            Some(chroma_proto::rank_expr::Rank::Knn(knn)) => {
1553                let query = knn
1554                    .query
1555                    .ok_or(QueryConversionError::field("query"))?
1556                    .try_into()?;
1557                Ok(RankExpr::Knn {
1558                    query,
1559                    key: Key::from(knn.key),
1560                    limit: knn.limit,
1561                    default: knn.default,
1562                    return_rank: knn.return_rank,
1563                })
1564            }
1565            Some(chroma_proto::rank_expr::Rank::Logarithm(expr)) => {
1566                Ok(RankExpr::Logarithm(Box::new(RankExpr::try_from(*expr)?)))
1567            }
1568            Some(chroma_proto::rank_expr::Rank::Maximum(max)) => {
1569                let exprs = max
1570                    .exprs
1571                    .into_iter()
1572                    .map(RankExpr::try_from)
1573                    .collect::<Result<Vec<_>, _>>()?;
1574                Ok(RankExpr::Maximum(exprs))
1575            }
1576            Some(chroma_proto::rank_expr::Rank::Minimum(min)) => {
1577                let exprs = min
1578                    .exprs
1579                    .into_iter()
1580                    .map(RankExpr::try_from)
1581                    .collect::<Result<Vec<_>, _>>()?;
1582                Ok(RankExpr::Minimum(exprs))
1583            }
1584            Some(chroma_proto::rank_expr::Rank::Multiplication(mul)) => {
1585                let exprs = mul
1586                    .exprs
1587                    .into_iter()
1588                    .map(RankExpr::try_from)
1589                    .collect::<Result<Vec<_>, _>>()?;
1590                Ok(RankExpr::Multiplication(exprs))
1591            }
1592            Some(chroma_proto::rank_expr::Rank::Subtraction(sub)) => {
1593                let left = sub.left.ok_or(QueryConversionError::field("left"))?;
1594                let right = sub.right.ok_or(QueryConversionError::field("right"))?;
1595                Ok(RankExpr::Subtraction {
1596                    left: Box::new(RankExpr::try_from(*left)?),
1597                    right: Box::new(RankExpr::try_from(*right)?),
1598                })
1599            }
1600            Some(chroma_proto::rank_expr::Rank::Summation(sum)) => {
1601                let exprs = sum
1602                    .exprs
1603                    .into_iter()
1604                    .map(RankExpr::try_from)
1605                    .collect::<Result<Vec<_>, _>>()?;
1606                Ok(RankExpr::Summation(exprs))
1607            }
1608            Some(chroma_proto::rank_expr::Rank::Value(value)) => Ok(RankExpr::Value(value)),
1609            None => Err(QueryConversionError::field("rank")),
1610        }
1611    }
1612}
1613
1614impl TryFrom<RankExpr> for chroma_proto::RankExpr {
1615    type Error = QueryConversionError;
1616
1617    fn try_from(rank_expr: RankExpr) -> Result<Self, Self::Error> {
1618        let proto_rank = match rank_expr {
1619            RankExpr::Absolute(expr) => chroma_proto::rank_expr::Rank::Absolute(Box::new(
1620                chroma_proto::RankExpr::try_from(*expr)?,
1621            )),
1622            RankExpr::Division { left, right } => chroma_proto::rank_expr::Rank::Division(
1623                Box::new(chroma_proto::rank_expr::RankPair {
1624                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1625                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1626                }),
1627            ),
1628            RankExpr::Exponentiation(expr) => chroma_proto::rank_expr::Rank::Exponentiation(
1629                Box::new(chroma_proto::RankExpr::try_from(*expr)?),
1630            ),
1631            RankExpr::Knn {
1632                query,
1633                key,
1634                limit,
1635                default,
1636                return_rank,
1637            } => chroma_proto::rank_expr::Rank::Knn(chroma_proto::rank_expr::Knn {
1638                query: Some(query.try_into()?),
1639                key: key.to_string(),
1640                limit,
1641                default,
1642                return_rank,
1643            }),
1644            RankExpr::Logarithm(expr) => chroma_proto::rank_expr::Rank::Logarithm(Box::new(
1645                chroma_proto::RankExpr::try_from(*expr)?,
1646            )),
1647            RankExpr::Maximum(exprs) => {
1648                let proto_exprs = exprs
1649                    .into_iter()
1650                    .map(chroma_proto::RankExpr::try_from)
1651                    .collect::<Result<Vec<_>, _>>()?;
1652                chroma_proto::rank_expr::Rank::Maximum(chroma_proto::rank_expr::RankList {
1653                    exprs: proto_exprs,
1654                })
1655            }
1656            RankExpr::Minimum(exprs) => {
1657                let proto_exprs = exprs
1658                    .into_iter()
1659                    .map(chroma_proto::RankExpr::try_from)
1660                    .collect::<Result<Vec<_>, _>>()?;
1661                chroma_proto::rank_expr::Rank::Minimum(chroma_proto::rank_expr::RankList {
1662                    exprs: proto_exprs,
1663                })
1664            }
1665            RankExpr::Multiplication(exprs) => {
1666                let proto_exprs = exprs
1667                    .into_iter()
1668                    .map(chroma_proto::RankExpr::try_from)
1669                    .collect::<Result<Vec<_>, _>>()?;
1670                chroma_proto::rank_expr::Rank::Multiplication(chroma_proto::rank_expr::RankList {
1671                    exprs: proto_exprs,
1672                })
1673            }
1674            RankExpr::Subtraction { left, right } => chroma_proto::rank_expr::Rank::Subtraction(
1675                Box::new(chroma_proto::rank_expr::RankPair {
1676                    left: Some(Box::new(chroma_proto::RankExpr::try_from(*left)?)),
1677                    right: Some(Box::new(chroma_proto::RankExpr::try_from(*right)?)),
1678                }),
1679            ),
1680            RankExpr::Summation(exprs) => {
1681                let proto_exprs = exprs
1682                    .into_iter()
1683                    .map(chroma_proto::RankExpr::try_from)
1684                    .collect::<Result<Vec<_>, _>>()?;
1685                chroma_proto::rank_expr::Rank::Summation(chroma_proto::rank_expr::RankList {
1686                    exprs: proto_exprs,
1687                })
1688            }
1689            RankExpr::Value(value) => chroma_proto::rank_expr::Rank::Value(value),
1690        };
1691
1692        Ok(chroma_proto::RankExpr {
1693            rank: Some(proto_rank),
1694        })
1695    }
1696}
1697
1698/// Represents a field key in search queries.
1699///
1700/// Used for both selecting fields to return and building filter expressions.
1701/// Predefined keys access special fields, while custom keys access metadata.
1702///
1703/// # Predefined Keys
1704///
1705/// - `Key::Document` - Document text content (`#document`)
1706/// - `Key::Embedding` - Vector embeddings (`#embedding`)
1707/// - `Key::Metadata` - All metadata fields (`#metadata`)
1708/// - `Key::Score` - Search scores (`#score`)
1709///
1710/// # Custom Keys
1711///
1712/// Use `Key::field()` or `Key::from()` to reference metadata fields:
1713///
1714/// ```
1715/// use chroma_types::operator::Key;
1716///
1717/// let key = Key::field("author");
1718/// let key = Key::from("title");
1719/// ```
1720///
1721/// # Examples
1722///
1723/// ## Building filters
1724///
1725/// ```
1726/// use chroma_types::operator::Key;
1727///
1728/// // Equality
1729/// let filter = Key::field("status").eq("published");
1730///
1731/// // Comparisons
1732/// let filter = Key::field("year").gte(2020);
1733/// let filter = Key::field("score").lt(0.9);
1734///
1735/// // Set operations
1736/// let filter = Key::field("category").is_in(vec!["tech", "science"]);
1737/// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
1738///
1739/// // Document content
1740/// let filter = Key::Document.contains("machine learning");
1741/// let filter = Key::Document.regex(r"\bAPI\b");
1742///
1743/// // Combining filters
1744/// let filter = Key::field("status").eq("published")
1745///     & Key::field("year").gte(2020);
1746/// ```
1747///
1748/// ## Selecting fields
1749///
1750/// ```
1751/// use chroma_types::plan::SearchPayload;
1752/// use chroma_types::operator::Key;
1753///
1754/// let search = SearchPayload::default()
1755///     .select([
1756///         Key::Document,
1757///         Key::Score,
1758///         Key::field("title"),
1759///         Key::field("author"),
1760///     ]);
1761/// ```
1762#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
1763#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
1764pub enum Key {
1765    // Predefined keys
1766    Document,
1767    Embedding,
1768    Metadata,
1769    Score,
1770    MetadataField(String),
1771}
1772
1773impl Serialize for Key {
1774    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1775    where
1776        S: serde::Serializer,
1777    {
1778        match self {
1779            Key::Document => serializer.serialize_str("#document"),
1780            Key::Embedding => serializer.serialize_str("#embedding"),
1781            Key::Metadata => serializer.serialize_str("#metadata"),
1782            Key::Score => serializer.serialize_str("#score"),
1783            Key::MetadataField(field) => serializer.serialize_str(field),
1784        }
1785    }
1786}
1787
1788impl<'de> Deserialize<'de> for Key {
1789    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1790    where
1791        D: Deserializer<'de>,
1792    {
1793        let s = String::deserialize(deserializer)?;
1794        Ok(Key::from(s))
1795    }
1796}
1797
1798impl fmt::Display for Key {
1799    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1800        match self {
1801            Key::Document => write!(f, "#document"),
1802            Key::Embedding => write!(f, "#embedding"),
1803            Key::Metadata => write!(f, "#metadata"),
1804            Key::Score => write!(f, "#score"),
1805            Key::MetadataField(field) => write!(f, "{}", field),
1806        }
1807    }
1808}
1809
1810impl From<&str> for Key {
1811    fn from(s: &str) -> Self {
1812        match s {
1813            "#document" => Key::Document,
1814            "#embedding" => Key::Embedding,
1815            "#metadata" => Key::Metadata,
1816            "#score" => Key::Score,
1817            // Any other string is treated as a metadata field key
1818            field => Key::MetadataField(field.to_string()),
1819        }
1820    }
1821}
1822
1823impl From<String> for Key {
1824    fn from(s: String) -> Self {
1825        Key::from(s.as_str())
1826    }
1827}
1828
1829impl Key {
1830    /// Creates a Key for a custom metadata field.
1831    ///
1832    /// # Examples
1833    ///
1834    /// ```
1835    /// use chroma_types::operator::Key;
1836    ///
1837    /// let status = Key::field("status");
1838    /// let year = Key::field("year");
1839    /// let author = Key::field("author");
1840    /// ```
1841    pub fn field(name: impl Into<String>) -> Self {
1842        Key::MetadataField(name.into())
1843    }
1844
1845    /// Creates an equality filter: `field == value`.
1846    ///
1847    /// # Examples
1848    ///
1849    /// ```
1850    /// use chroma_types::operator::Key;
1851    ///
1852    /// // String equality
1853    /// let filter = Key::field("status").eq("published");
1854    ///
1855    /// // Numeric equality
1856    /// let filter = Key::field("count").eq(42);
1857    ///
1858    /// // Boolean equality
1859    /// let filter = Key::field("featured").eq(true);
1860    /// ```
1861    pub fn eq<T: Into<MetadataValue>>(self, value: T) -> Where {
1862        Where::Metadata(MetadataExpression {
1863            key: self.to_string(),
1864            comparison: MetadataComparison::Primitive(PrimitiveOperator::Equal, value.into()),
1865        })
1866    }
1867
1868    /// Creates an inequality filter: `field != value`.
1869    ///
1870    /// # Examples
1871    ///
1872    /// ```
1873    /// use chroma_types::operator::Key;
1874    ///
1875    /// let filter = Key::field("status").ne("deleted");
1876    /// let filter = Key::field("count").ne(0);
1877    /// ```
1878    pub fn ne<T: Into<MetadataValue>>(self, value: T) -> Where {
1879        Where::Metadata(MetadataExpression {
1880            key: self.to_string(),
1881            comparison: MetadataComparison::Primitive(PrimitiveOperator::NotEqual, value.into()),
1882        })
1883    }
1884
1885    /// Creates a greater-than filter: `field > value` (numeric only).
1886    ///
1887    /// # Examples
1888    ///
1889    /// ```
1890    /// use chroma_types::operator::Key;
1891    ///
1892    /// let filter = Key::field("score").gt(0.5);
1893    /// let filter = Key::field("year").gt(2020);
1894    /// ```
1895    pub fn gt<T: Into<MetadataValue>>(self, value: T) -> Where {
1896        Where::Metadata(MetadataExpression {
1897            key: self.to_string(),
1898            comparison: MetadataComparison::Primitive(PrimitiveOperator::GreaterThan, value.into()),
1899        })
1900    }
1901
1902    /// Creates a greater-than-or-equal filter: `field >= value` (numeric only).
1903    ///
1904    /// # Examples
1905    ///
1906    /// ```
1907    /// use chroma_types::operator::Key;
1908    ///
1909    /// let filter = Key::field("score").gte(0.5);
1910    /// let filter = Key::field("year").gte(2020);
1911    /// ```
1912    pub fn gte<T: Into<MetadataValue>>(self, value: T) -> Where {
1913        Where::Metadata(MetadataExpression {
1914            key: self.to_string(),
1915            comparison: MetadataComparison::Primitive(
1916                PrimitiveOperator::GreaterThanOrEqual,
1917                value.into(),
1918            ),
1919        })
1920    }
1921
1922    /// Creates a less-than filter: `field < value` (numeric only).
1923    ///
1924    /// # Examples
1925    ///
1926    /// ```
1927    /// use chroma_types::operator::Key;
1928    ///
1929    /// let filter = Key::field("score").lt(0.9);
1930    /// let filter = Key::field("year").lt(2025);
1931    /// ```
1932    pub fn lt<T: Into<MetadataValue>>(self, value: T) -> Where {
1933        Where::Metadata(MetadataExpression {
1934            key: self.to_string(),
1935            comparison: MetadataComparison::Primitive(PrimitiveOperator::LessThan, value.into()),
1936        })
1937    }
1938
1939    /// Creates a less-than-or-equal filter: `field <= value` (numeric only).
1940    ///
1941    /// # Examples
1942    ///
1943    /// ```
1944    /// use chroma_types::operator::Key;
1945    ///
1946    /// let filter = Key::field("score").lte(0.9);
1947    /// let filter = Key::field("year").lte(2024);
1948    /// ```
1949    pub fn lte<T: Into<MetadataValue>>(self, value: T) -> Where {
1950        Where::Metadata(MetadataExpression {
1951            key: self.to_string(),
1952            comparison: MetadataComparison::Primitive(
1953                PrimitiveOperator::LessThanOrEqual,
1954                value.into(),
1955            ),
1956        })
1957    }
1958
1959    /// Creates a set membership filter: `field IN values`.
1960    ///
1961    /// Accepts any iterator (Vec, array, slice, etc.).
1962    ///
1963    /// # Examples
1964    ///
1965    /// ```
1966    /// use chroma_types::operator::Key;
1967    ///
1968    /// // With Vec
1969    /// let filter = Key::field("year").is_in(vec![2023, 2024, 2025]);
1970    ///
1971    /// // With array
1972    /// let filter = Key::field("category").is_in(["tech", "science", "math"]);
1973    ///
1974    /// // With owned strings
1975    /// let categories = vec!["tech".to_string(), "science".to_string()];
1976    /// let filter = Key::field("category").is_in(categories);
1977    /// ```
1978    pub fn is_in<I, T>(self, values: I) -> Where
1979    where
1980        I: IntoIterator<Item = T>,
1981        Vec<T>: Into<MetadataSetValue>,
1982    {
1983        let vec: Vec<T> = values.into_iter().collect();
1984        Where::Metadata(MetadataExpression {
1985            key: self.to_string(),
1986            comparison: MetadataComparison::Set(SetOperator::In, vec.into()),
1987        })
1988    }
1989
1990    /// Creates a set exclusion filter: `field NOT IN values`.
1991    ///
1992    /// Accepts any iterator (Vec, array, slice, etc.).
1993    ///
1994    /// # Examples
1995    ///
1996    /// ```
1997    /// use chroma_types::operator::Key;
1998    ///
1999    /// // Exclude deleted and archived
2000    /// let filter = Key::field("status").not_in(vec!["deleted", "archived"]);
2001    ///
2002    /// // Exclude specific years
2003    /// let filter = Key::field("year").not_in(vec![2019, 2020]);
2004    /// ```
2005    pub fn not_in<I, T>(self, values: I) -> Where
2006    where
2007        I: IntoIterator<Item = T>,
2008        Vec<T>: Into<MetadataSetValue>,
2009    {
2010        let vec: Vec<T> = values.into_iter().collect();
2011        Where::Metadata(MetadataExpression {
2012            key: self.to_string(),
2013            comparison: MetadataComparison::Set(SetOperator::NotIn, vec.into()),
2014        })
2015    }
2016
2017    /// Creates a substring filter (case-sensitive, document content only).
2018    ///
2019    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2020    /// 3 literal characters for accurate results.
2021    ///
2022    /// # Examples
2023    ///
2024    /// ```
2025    /// use chroma_types::operator::Key;
2026    ///
2027    /// let filter = Key::Document.contains("machine learning");
2028    /// let filter = Key::Document.contains("API");
2029    /// ```
2030    pub fn contains<S: Into<String>>(self, text: S) -> Where {
2031        Where::Document(DocumentExpression {
2032            operator: DocumentOperator::Contains,
2033            pattern: text.into(),
2034        })
2035    }
2036
2037    /// Creates a negative substring filter (case-sensitive, document content only).
2038    ///
2039    /// Note: Currently only works with `Key::Document`.
2040    ///
2041    /// # Examples
2042    ///
2043    /// ```
2044    /// use chroma_types::operator::Key;
2045    ///
2046    /// let filter = Key::Document.not_contains("deprecated");
2047    /// let filter = Key::Document.not_contains("beta");
2048    /// ```
2049    pub fn not_contains<S: Into<String>>(self, text: S) -> Where {
2050        Where::Document(DocumentExpression {
2051            operator: DocumentOperator::NotContains,
2052            pattern: text.into(),
2053        })
2054    }
2055
2056    /// Creates a regex filter (case-sensitive, document content only).
2057    ///
2058    /// Note: Currently only works with `Key::Document`. Pattern must have at least
2059    /// 3 literal characters for accurate results.
2060    ///
2061    /// # Examples
2062    ///
2063    /// ```
2064    /// use chroma_types::operator::Key;
2065    ///
2066    /// // Match whole word "API"
2067    /// let filter = Key::Document.regex(r"\bAPI\b");
2068    ///
2069    /// // Match version pattern
2070    /// let filter = Key::Document.regex(r"v\d+\.\d+\.\d+");
2071    /// ```
2072    pub fn regex<S: Into<String>>(self, pattern: S) -> Where {
2073        Where::Document(DocumentExpression {
2074            operator: DocumentOperator::Regex,
2075            pattern: pattern.into(),
2076        })
2077    }
2078
2079    /// Creates a negative regex filter (case-sensitive, document content only).
2080    ///
2081    /// Note: Currently only works with `Key::Document`.
2082    ///
2083    /// # Examples
2084    ///
2085    /// ```
2086    /// use chroma_types::operator::Key;
2087    ///
2088    /// // Exclude beta versions
2089    /// let filter = Key::Document.not_regex(r"beta");
2090    ///
2091    /// // Exclude test documents
2092    /// let filter = Key::Document.not_regex(r"\btest\b");
2093    /// ```
2094    pub fn not_regex<S: Into<String>>(self, pattern: S) -> Where {
2095        Where::Document(DocumentExpression {
2096            operator: DocumentOperator::NotRegex,
2097            pattern: pattern.into(),
2098        })
2099    }
2100}
2101
2102/// Field selection for search results.
2103///
2104/// Specifies which fields to include in the results. IDs are always included.
2105///
2106/// # Fields
2107///
2108/// * `keys` - Set of keys to include in results
2109///
2110/// # Available Keys
2111///
2112/// * `Key::Document` - Document text content
2113/// * `Key::Embedding` - Vector embeddings
2114/// * `Key::Metadata` - All metadata fields
2115/// * `Key::Score` - Search scores
2116/// * `Key::field("name")` - Specific metadata field
2117///
2118/// # Performance
2119///
2120/// Selecting fewer fields improves performance by reducing data transfer:
2121/// - Minimal: IDs only (default, fastest)
2122/// - Moderate: Scores + specific metadata fields
2123/// - Heavy: Documents + embeddings (larger payloads)
2124///
2125/// # Examples
2126///
2127/// ```
2128/// use chroma_types::operator::{Select, Key};
2129/// use std::collections::HashSet;
2130///
2131/// // Select predefined fields
2132/// let select = Select {
2133///     keys: [Key::Document, Key::Score].into_iter().collect(),
2134/// };
2135///
2136/// // Select specific metadata fields
2137/// let select = Select {
2138///     keys: [
2139///         Key::field("title"),
2140///         Key::field("author"),
2141///         Key::Score,
2142///     ].into_iter().collect(),
2143/// };
2144///
2145/// // Select everything
2146/// let select = Select {
2147///     keys: [
2148///         Key::Document,
2149///         Key::Embedding,
2150///         Key::Metadata,
2151///         Key::Score,
2152///     ].into_iter().collect(),
2153/// };
2154/// ```
2155#[derive(Clone, Debug, Default, Deserialize, Serialize)]
2156pub struct Select {
2157    #[serde(default)]
2158    pub keys: HashSet<Key>,
2159}
2160
2161impl TryFrom<chroma_proto::SelectOperator> for Select {
2162    type Error = QueryConversionError;
2163
2164    fn try_from(value: chroma_proto::SelectOperator) -> Result<Self, Self::Error> {
2165        let keys = value
2166            .keys
2167            .into_iter()
2168            .map(|key| {
2169                // Try to deserialize each string as a Key
2170                serde_json::from_value(serde_json::Value::String(key))
2171                    .map_err(|_| QueryConversionError::field("keys"))
2172            })
2173            .collect::<Result<HashSet<_>, _>>()?;
2174
2175        Ok(Self { keys })
2176    }
2177}
2178
2179impl TryFrom<Select> for chroma_proto::SelectOperator {
2180    type Error = QueryConversionError;
2181
2182    fn try_from(value: Select) -> Result<Self, Self::Error> {
2183        let keys = value
2184            .keys
2185            .into_iter()
2186            .map(|key| {
2187                // Serialize each Key back to string
2188                serde_json::to_value(&key)
2189                    .ok()
2190                    .and_then(|v| v.as_str().map(String::from))
2191                    .ok_or(QueryConversionError::field("keys"))
2192            })
2193            .collect::<Result<Vec<_>, _>>()?;
2194
2195        Ok(Self { keys })
2196    }
2197}
2198
2199/// A single search result record.
2200///
2201/// Contains the document ID and optionally document content, embeddings, metadata,
2202/// and search score based on what was selected in the search query.
2203///
2204/// # Fields
2205///
2206/// * `id` - Document ID (always present)
2207/// * `document` - Document text content (if selected)
2208/// * `embedding` - Vector embedding (if selected)
2209/// * `metadata` - Document metadata (if selected)
2210/// * `score` - Search score (present when ranking is used, lower = better match)
2211///
2212/// # Examples
2213///
2214/// ```
2215/// use chroma_types::operator::SearchRecord;
2216///
2217/// fn process_results(records: Vec<SearchRecord>) {
2218///     for record in records {
2219///         println!("ID: {}", record.id);
2220///         
2221///         if let Some(score) = record.score {
2222///             println!("  Score: {:.3}", score);
2223///         }
2224///         
2225///         if let Some(doc) = record.document {
2226///             println!("  Document: {}", doc);
2227///         }
2228///         
2229///         if let Some(meta) = record.metadata {
2230///             println!("  Metadata: {:?}", meta);
2231///         }
2232///     }
2233/// }
2234/// ```
2235#[derive(Clone, Debug, Deserialize, Serialize)]
2236#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
2237pub struct SearchRecord {
2238    pub id: String,
2239    pub document: Option<String>,
2240    pub embedding: Option<Vec<f32>>,
2241    pub metadata: Option<Metadata>,
2242    pub score: Option<f32>,
2243}
2244
2245impl TryFrom<chroma_proto::SearchRecord> for SearchRecord {
2246    type Error = QueryConversionError;
2247
2248    fn try_from(value: chroma_proto::SearchRecord) -> Result<Self, Self::Error> {
2249        Ok(Self {
2250            id: value.id,
2251            document: value.document,
2252            embedding: value
2253                .embedding
2254                .map(|vec| vec.try_into().map(|(v, _)| v))
2255                .transpose()?,
2256            metadata: value.metadata.map(TryInto::try_into).transpose()?,
2257            score: value.score,
2258        })
2259    }
2260}
2261
2262impl TryFrom<SearchRecord> for chroma_proto::SearchRecord {
2263    type Error = QueryConversionError;
2264
2265    fn try_from(value: SearchRecord) -> Result<Self, Self::Error> {
2266        Ok(Self {
2267            id: value.id,
2268            document: value.document,
2269            embedding: value
2270                .embedding
2271                .map(|embedding| {
2272                    let embedding_dimension = embedding.len();
2273                    chroma_proto::Vector::try_from((
2274                        embedding,
2275                        ScalarEncoding::FLOAT32,
2276                        embedding_dimension,
2277                    ))
2278                })
2279                .transpose()?,
2280            metadata: value.metadata.map(Into::into),
2281            score: value.score,
2282        })
2283    }
2284}
2285
2286/// Results for a single search payload.
2287///
2288/// Contains all matching records for one search query.
2289///
2290/// # Fields
2291///
2292/// * `records` - Vector of search records, ordered by score (ascending)
2293///
2294/// # Examples
2295///
2296/// ```
2297/// use chroma_types::operator::{SearchPayloadResult, SearchRecord};
2298///
2299/// fn process_search_result(result: SearchPayloadResult) {
2300///     println!("Found {} results", result.records.len());
2301///     
2302///     for (i, record) in result.records.iter().enumerate() {
2303///         println!("{}. {} (score: {:?})", i + 1, record.id, record.score);
2304///     }
2305/// }
2306/// ```
2307#[derive(Clone, Debug, Default)]
2308pub struct SearchPayloadResult {
2309    pub records: Vec<SearchRecord>,
2310}
2311
2312impl TryFrom<chroma_proto::SearchPayloadResult> for SearchPayloadResult {
2313    type Error = QueryConversionError;
2314
2315    fn try_from(value: chroma_proto::SearchPayloadResult) -> Result<Self, Self::Error> {
2316        Ok(Self {
2317            records: value
2318                .records
2319                .into_iter()
2320                .map(TryInto::try_into)
2321                .collect::<Result<_, _>>()?,
2322        })
2323    }
2324}
2325
2326impl TryFrom<SearchPayloadResult> for chroma_proto::SearchPayloadResult {
2327    type Error = QueryConversionError;
2328
2329    fn try_from(value: SearchPayloadResult) -> Result<Self, Self::Error> {
2330        Ok(Self {
2331            records: value
2332                .records
2333                .into_iter()
2334                .map(TryInto::try_into)
2335                .collect::<Result<Vec<_>, _>>()?,
2336        })
2337    }
2338}
2339
2340/// Results from a batch search operation.
2341///
2342/// Contains results for each search payload in the batch, maintaining the same order
2343/// as the input searches.
2344///
2345/// # Fields
2346///
2347/// * `results` - Results for each search payload (indexed by search position)
2348/// * `pulled_log_bytes` - Total bytes pulled from log (for internal metrics)
2349///
2350/// # Examples
2351///
2352/// ## Single search
2353///
2354/// ```
2355/// use chroma_types::operator::SearchResult;
2356///
2357/// fn process_single_search(result: SearchResult) {
2358///     // Single search, so results[0] contains our records
2359///     let records = &result.results[0].records;
2360///     
2361///     for record in records {
2362///         println!("{}: score={:?}", record.id, record.score);
2363///     }
2364/// }
2365/// ```
2366///
2367/// ## Batch search
2368///
2369/// ```
2370/// use chroma_types::operator::SearchResult;
2371///
2372/// fn process_batch_search(result: SearchResult) {
2373///     // Multiple searches in batch
2374///     for (i, search_result) in result.results.iter().enumerate() {
2375///         println!("\nSearch {}:", i + 1);
2376///         for record in &search_result.records {
2377///             println!("  {}: score={:?}", record.id, record.score);
2378///         }
2379///     }
2380/// }
2381/// ```
2382#[derive(Clone, Debug)]
2383pub struct SearchResult {
2384    pub results: Vec<SearchPayloadResult>,
2385    pub pulled_log_bytes: u64,
2386}
2387
2388impl SearchResult {
2389    pub fn size_bytes(&self) -> u64 {
2390        self.results
2391            .iter()
2392            .flat_map(|result| {
2393                result.records.iter().map(|record| {
2394                    (record.id.len()
2395                        + record
2396                            .document
2397                            .as_ref()
2398                            .map(|doc| doc.len())
2399                            .unwrap_or_default()
2400                        + record
2401                            .embedding
2402                            .as_ref()
2403                            .map(|emb| size_of_val(&emb[..]))
2404                            .unwrap_or_default()
2405                        + record
2406                            .metadata
2407                            .as_ref()
2408                            .map(logical_size_of_metadata)
2409                            .unwrap_or_default()
2410                        + record.score.as_ref().map(size_of_val).unwrap_or_default())
2411                        as u64
2412                })
2413            })
2414            .sum()
2415    }
2416}
2417
2418impl TryFrom<chroma_proto::SearchResult> for SearchResult {
2419    type Error = QueryConversionError;
2420
2421    fn try_from(value: chroma_proto::SearchResult) -> Result<Self, Self::Error> {
2422        Ok(Self {
2423            results: value
2424                .results
2425                .into_iter()
2426                .map(TryInto::try_into)
2427                .collect::<Result<_, _>>()?,
2428            pulled_log_bytes: value.pulled_log_bytes,
2429        })
2430    }
2431}
2432
2433impl TryFrom<SearchResult> for chroma_proto::SearchResult {
2434    type Error = QueryConversionError;
2435
2436    fn try_from(value: SearchResult) -> Result<Self, Self::Error> {
2437        Ok(Self {
2438            results: value
2439                .results
2440                .into_iter()
2441                .map(TryInto::try_into)
2442                .collect::<Result<Vec<_>, _>>()?,
2443            pulled_log_bytes: value.pulled_log_bytes,
2444        })
2445    }
2446}
2447
2448/// Reciprocal Rank Fusion (RRF) - combines multiple ranking strategies.
2449///
2450/// RRF is ideal for hybrid search where you want to merge results from different
2451/// ranking methods (e.g., dense and sparse embeddings) with different score scales.
2452/// It uses rank positions instead of raw scores, making it scale-agnostic.
2453///
2454/// # Formula
2455///
2456/// ```text
2457/// score = -Σ(weight_i / (k + rank_i))
2458/// ```
2459///
2460/// Where:
2461/// - `weight_i` = weight for ranking i (default: 1.0)
2462/// - `rank_i` = rank position from ranking i (0, 1, 2...)
2463/// - `k` = smoothing parameter (default: 60)
2464///
2465/// Score is negative because Chroma uses ascending order (lower = better).
2466///
2467/// # Arguments
2468///
2469/// * `ranks` - List of ranking expressions (must have `return_rank=true`)
2470/// * `k` - Smoothing parameter (None = 60). Higher values reduce emphasis on top ranks.
2471/// * `weights` - Weight for each ranking (None = all 1.0)
2472/// * `normalize` - If true, normalize weights to sum to 1.0
2473///
2474/// # Returns
2475///
2476/// A combined RankExpr or an error if:
2477/// - `ranks` is empty
2478/// - `weights` length doesn't match `ranks` length
2479/// - `weights` sum to zero when normalizing
2480///
2481/// # Examples
2482///
2483/// ## Basic RRF with default parameters
2484///
2485/// ```
2486/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2487///
2488/// let dense = RankExpr::Knn {
2489///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2490///     key: Key::Embedding,
2491///     limit: 200,
2492///     default: None,
2493///     return_rank: true, // Required for RRF
2494/// };
2495///
2496/// let sparse = RankExpr::Knn {
2497///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2498///     key: Key::field("sparse_embedding"),
2499///     limit: 200,
2500///     default: None,
2501///     return_rank: true, // Required for RRF
2502/// };
2503///
2504/// // Equal weights, k=60 (defaults)
2505/// let combined = rrf(vec![dense, sparse], None, None, false).unwrap();
2506/// ```
2507///
2508/// ## RRF with custom weights
2509///
2510/// ```
2511/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2512///
2513/// # let dense = RankExpr::Knn {
2514/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2515/// #     key: Key::Embedding,
2516/// #     limit: 200,
2517/// #     default: None,
2518/// #     return_rank: true,
2519/// # };
2520/// # let sparse = RankExpr::Knn {
2521/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2522/// #     key: Key::field("sparse_embedding"),
2523/// #     limit: 200,
2524/// #     default: None,
2525/// #     return_rank: true,
2526/// # };
2527/// // 70% dense, 30% sparse
2528/// let combined = rrf(
2529///     vec![dense, sparse],
2530///     Some(60),
2531///     Some(vec![0.7, 0.3]),
2532///     false,
2533/// ).unwrap();
2534/// ```
2535///
2536/// ## RRF with normalized weights
2537///
2538/// ```
2539/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2540///
2541/// # let dense = RankExpr::Knn {
2542/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2543/// #     key: Key::Embedding,
2544/// #     limit: 200,
2545/// #     default: None,
2546/// #     return_rank: true,
2547/// # };
2548/// # let sparse = RankExpr::Knn {
2549/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2550/// #     key: Key::field("sparse_embedding"),
2551/// #     limit: 200,
2552/// #     default: None,
2553/// #     return_rank: true,
2554/// # };
2555/// // Weights [75, 25] normalized to [0.75, 0.25]
2556/// let combined = rrf(
2557///     vec![dense, sparse],
2558///     Some(60),
2559///     Some(vec![75.0, 25.0]),
2560///     true, // normalize
2561/// ).unwrap();
2562/// ```
2563///
2564/// ## Adjusting the k parameter
2565///
2566/// ```
2567/// use chroma_types::operator::{RankExpr, QueryVector, Key, rrf};
2568///
2569/// # let dense = RankExpr::Knn {
2570/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2571/// #     key: Key::Embedding,
2572/// #     limit: 200,
2573/// #     default: None,
2574/// #     return_rank: true,
2575/// # };
2576/// # let sparse = RankExpr::Knn {
2577/// #     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
2578/// #     key: Key::field("sparse_embedding"),
2579/// #     limit: 200,
2580/// #     default: None,
2581/// #     return_rank: true,
2582/// # };
2583/// // Small k (10) = heavy emphasis on top ranks
2584/// let top_heavy = rrf(vec![dense.clone(), sparse.clone()], Some(10), None, false).unwrap();
2585///
2586/// // Default k (60) = balanced
2587/// let balanced = rrf(vec![dense.clone(), sparse.clone()], Some(60), None, false).unwrap();
2588///
2589/// // Large k (200) = more uniform weighting
2590/// let uniform = rrf(vec![dense, sparse], Some(200), None, false).unwrap();
2591/// ```
2592pub fn rrf(
2593    ranks: Vec<RankExpr>,
2594    k: Option<u32>,
2595    weights: Option<Vec<f32>>,
2596    normalize: bool,
2597) -> Result<RankExpr, QueryConversionError> {
2598    let k = k.unwrap_or(60);
2599
2600    if ranks.is_empty() {
2601        return Err(QueryConversionError::validation(
2602            "RRF requires at least one rank expression",
2603        ));
2604    }
2605
2606    let weights = weights.unwrap_or_else(|| vec![1.0; ranks.len()]);
2607
2608    if weights.len() != ranks.len() {
2609        return Err(QueryConversionError::validation(format!(
2610            "RRF weights length ({}) must match ranks length ({})",
2611            weights.len(),
2612            ranks.len()
2613        )));
2614    }
2615
2616    let weights = if normalize {
2617        let sum: f32 = weights.iter().sum();
2618        if sum == 0.0 {
2619            return Err(QueryConversionError::validation(
2620                "RRF weights sum to zero, cannot normalize",
2621            ));
2622        }
2623        weights.into_iter().map(|w| w / sum).collect()
2624    } else {
2625        weights
2626    };
2627
2628    let terms: Vec<RankExpr> = weights
2629        .into_iter()
2630        .zip(ranks)
2631        .map(|(w, rank)| RankExpr::Value(w) / (RankExpr::Value(k as f32) + rank))
2632        .collect();
2633
2634    // Safe: ranks is validated as non-empty above, so terms cannot be empty.
2635    // Using unwrap_or_else as defensive programming to avoid panic.
2636    let sum = terms
2637        .into_iter()
2638        .reduce(|a, b| a + b)
2639        .unwrap_or(RankExpr::Value(0.0));
2640    Ok(-sum)
2641}
2642
2643#[cfg(test)]
2644mod tests {
2645    use super::*;
2646
2647    #[test]
2648    fn test_key_from_string() {
2649        // Test predefined keys
2650        assert_eq!(Key::from("#document"), Key::Document);
2651        assert_eq!(Key::from("#embedding"), Key::Embedding);
2652        assert_eq!(Key::from("#metadata"), Key::Metadata);
2653        assert_eq!(Key::from("#score"), Key::Score);
2654
2655        // Test metadata field keys
2656        assert_eq!(
2657            Key::from("custom_field"),
2658            Key::MetadataField("custom_field".to_string())
2659        );
2660        assert_eq!(
2661            Key::from("author"),
2662            Key::MetadataField("author".to_string())
2663        );
2664
2665        // Test String variant
2666        assert_eq!(Key::from("#embedding".to_string()), Key::Embedding);
2667        assert_eq!(
2668            Key::from("year".to_string()),
2669            Key::MetadataField("year".to_string())
2670        );
2671    }
2672
2673    #[test]
2674    fn test_query_vector_dense_proto_conversion() {
2675        let dense_vec = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2676        let query_vector = QueryVector::Dense(dense_vec.clone());
2677
2678        // Convert to proto
2679        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2680
2681        // Convert back
2682        let converted: QueryVector = proto.try_into().unwrap();
2683
2684        assert_eq!(converted, query_vector);
2685        if let QueryVector::Dense(v) = converted {
2686            assert_eq!(v, dense_vec);
2687        } else {
2688            panic!("Expected dense vector");
2689        }
2690    }
2691
2692    #[test]
2693    fn test_query_vector_sparse_proto_conversion() {
2694        let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]);
2695        let query_vector = QueryVector::Sparse(sparse.clone());
2696
2697        // Convert to proto
2698        let proto: chroma_proto::QueryVector = query_vector.clone().try_into().unwrap();
2699
2700        // Convert back
2701        let converted: QueryVector = proto.try_into().unwrap();
2702
2703        assert_eq!(converted, query_vector);
2704        if let QueryVector::Sparse(s) = converted {
2705            assert_eq!(s, sparse);
2706        } else {
2707            panic!("Expected sparse vector");
2708        }
2709    }
2710
2711    #[test]
2712    fn test_filter_json_deserialization() {
2713        // For the new search API, deserialization treats the entire JSON as a where clause
2714
2715        // Test 1: Simple direct metadata comparison
2716        let simple_where = r#"{"author": "John Doe"}"#;
2717        let filter: Filter = serde_json::from_str(simple_where).unwrap();
2718        assert_eq!(filter.query_ids, None);
2719        assert!(filter.where_clause.is_some());
2720
2721        // Test 2: ID filter using #id with $in operator
2722        let id_filter_json = serde_json::json!({
2723            "#id": {
2724                "$in": ["doc1", "doc2", "doc3"]
2725            }
2726        });
2727        let filter: Filter = serde_json::from_value(id_filter_json).unwrap();
2728        assert_eq!(filter.query_ids, None);
2729        assert!(filter.where_clause.is_some());
2730
2731        // Test 3: Complex nested expression with AND, OR, and various operators
2732        let complex_json = serde_json::json!({
2733            "$and": [
2734                {
2735                    "#id": {
2736                        "$in": ["doc1", "doc2", "doc3"]
2737                    }
2738                },
2739                {
2740                    "$or": [
2741                        {
2742                            "author": {
2743                                "$eq": "John Doe"
2744                            }
2745                        },
2746                        {
2747                            "author": {
2748                                "$eq": "Jane Smith"
2749                            }
2750                        }
2751                    ]
2752                },
2753                {
2754                    "year": {
2755                        "$gte": 2020
2756                    }
2757                },
2758                {
2759                    "tags": {
2760                        "$contains": "machine-learning"
2761                    }
2762                }
2763            ]
2764        });
2765
2766        let filter: Filter = serde_json::from_value(complex_json.clone()).unwrap();
2767        assert_eq!(filter.query_ids, None);
2768        assert!(filter.where_clause.is_some());
2769
2770        // Verify the structure
2771        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2772            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2773            assert_eq!(composite.children.len(), 4);
2774
2775            // Check that the second child is an OR
2776            if let crate::metadata::Where::Composite(or_composite) = &composite.children[1] {
2777                assert_eq!(or_composite.operator, crate::metadata::BooleanOperator::Or);
2778                assert_eq!(or_composite.children.len(), 2);
2779            } else {
2780                panic!("Expected OR composite in second child");
2781            }
2782        } else {
2783            panic!("Expected AND composite where clause");
2784        }
2785
2786        // Test 4: Mixed operators - $ne, $lt, $gt, $lte
2787        let mixed_operators_json = serde_json::json!({
2788            "$and": [
2789                {
2790                    "status": {
2791                        "$ne": "deleted"
2792                    }
2793                },
2794                {
2795                    "score": {
2796                        "$gt": 0.5
2797                    }
2798                },
2799                {
2800                    "score": {
2801                        "$lt": 0.9
2802                    }
2803                },
2804                {
2805                    "priority": {
2806                        "$lte": 10
2807                    }
2808                }
2809            ]
2810        });
2811
2812        let filter: Filter = serde_json::from_value(mixed_operators_json).unwrap();
2813        assert_eq!(filter.query_ids, None);
2814        assert!(filter.where_clause.is_some());
2815
2816        // Test 5: Deeply nested expression
2817        let deeply_nested_json = serde_json::json!({
2818            "$or": [
2819                {
2820                    "$and": [
2821                        {
2822                            "#id": {
2823                                "$in": ["id1", "id2"]
2824                            }
2825                        },
2826                        {
2827                            "$or": [
2828                                {
2829                                    "category": "tech"
2830                                },
2831                                {
2832                                    "category": "science"
2833                                }
2834                            ]
2835                        }
2836                    ]
2837                },
2838                {
2839                    "$and": [
2840                        {
2841                            "author": "Admin"
2842                        },
2843                        {
2844                            "published": true
2845                        }
2846                    ]
2847                }
2848            ]
2849        });
2850
2851        let filter: Filter = serde_json::from_value(deeply_nested_json).unwrap();
2852        assert_eq!(filter.query_ids, None);
2853        assert!(filter.where_clause.is_some());
2854
2855        // Verify it's an OR at the top level
2856        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2857            assert_eq!(composite.operator, crate::metadata::BooleanOperator::Or);
2858            assert_eq!(composite.children.len(), 2);
2859
2860            // Both children should be AND composites
2861            for child in &composite.children {
2862                if let crate::metadata::Where::Composite(and_composite) = child {
2863                    assert_eq!(
2864                        and_composite.operator,
2865                        crate::metadata::BooleanOperator::And
2866                    );
2867                } else {
2868                    panic!("Expected AND composite in OR children");
2869                }
2870            }
2871        } else {
2872            panic!("Expected OR composite at top level");
2873        }
2874
2875        // Test 6: Single ID filter (edge case)
2876        let single_id_json = serde_json::json!({
2877            "#id": {
2878                "$eq": "single-doc-id"
2879            }
2880        });
2881
2882        let filter: Filter = serde_json::from_value(single_id_json).unwrap();
2883        assert_eq!(filter.query_ids, None);
2884        assert!(filter.where_clause.is_some());
2885
2886        // Test 7: Empty object should create empty filter
2887        let empty_json = serde_json::json!({});
2888        let filter: Filter = serde_json::from_value(empty_json).unwrap();
2889        assert_eq!(filter.query_ids, None);
2890        // Empty object results in None where_clause
2891        assert_eq!(filter.where_clause, None);
2892
2893        // Test 8: Combining #id filter with $not_contains and numeric comparisons
2894        let advanced_json = serde_json::json!({
2895            "$and": [
2896                {
2897                    "#id": {
2898                        "$in": ["doc1", "doc2", "doc3", "doc4", "doc5"]
2899                    }
2900                },
2901                {
2902                    "tags": {
2903                        "$not_contains": "deprecated"
2904                    }
2905                },
2906                {
2907                    "$or": [
2908                        {
2909                            "$and": [
2910                                {
2911                                    "confidence": {
2912                                        "$gte": 0.8
2913                                    }
2914                                },
2915                                {
2916                                    "verified": true
2917                                }
2918                            ]
2919                        },
2920                        {
2921                            "$and": [
2922                                {
2923                                    "confidence": {
2924                                        "$gte": 0.6
2925                                    }
2926                                },
2927                                {
2928                                    "confidence": {
2929                                        "$lt": 0.8
2930                                    }
2931                                },
2932                                {
2933                                    "reviews": {
2934                                        "$gte": 5
2935                                    }
2936                                }
2937                            ]
2938                        }
2939                    ]
2940                }
2941            ]
2942        });
2943
2944        let filter: Filter = serde_json::from_value(advanced_json).unwrap();
2945        assert_eq!(filter.query_ids, None);
2946        assert!(filter.where_clause.is_some());
2947
2948        // Verify top-level structure
2949        if let crate::metadata::Where::Composite(composite) = filter.where_clause.unwrap() {
2950            assert_eq!(composite.operator, crate::metadata::BooleanOperator::And);
2951            assert_eq!(composite.children.len(), 3);
2952        } else {
2953            panic!("Expected AND composite at top level");
2954        }
2955    }
2956
2957    #[test]
2958    fn test_limit_json_serialization() {
2959        let limit = Limit {
2960            offset: 10,
2961            limit: Some(20),
2962        };
2963
2964        let json = serde_json::to_string(&limit).unwrap();
2965        let deserialized: Limit = serde_json::from_str(&json).unwrap();
2966
2967        assert_eq!(deserialized.offset, limit.offset);
2968        assert_eq!(deserialized.limit, limit.limit);
2969    }
2970
2971    #[test]
2972    fn test_query_vector_json_serialization() {
2973        // Test dense vector
2974        let dense = QueryVector::Dense(vec![0.1, 0.2, 0.3]);
2975        let json = serde_json::to_string(&dense).unwrap();
2976        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2977        assert_eq!(deserialized, dense);
2978
2979        // Test sparse vector
2980        let sparse = QueryVector::Sparse(SparseVector::new(vec![0, 5, 10], vec![0.1, 0.5, 0.9]));
2981        let json = serde_json::to_string(&sparse).unwrap();
2982        let deserialized: QueryVector = serde_json::from_str(&json).unwrap();
2983        assert_eq!(deserialized, sparse);
2984    }
2985
2986    #[test]
2987    fn test_select_key_json_serialization() {
2988        use std::collections::HashSet;
2989
2990        // Test predefined keys
2991        let doc_key = Key::Document;
2992        assert_eq!(serde_json::to_string(&doc_key).unwrap(), "\"#document\"");
2993
2994        let embed_key = Key::Embedding;
2995        assert_eq!(serde_json::to_string(&embed_key).unwrap(), "\"#embedding\"");
2996
2997        let meta_key = Key::Metadata;
2998        assert_eq!(serde_json::to_string(&meta_key).unwrap(), "\"#metadata\"");
2999
3000        let score_key = Key::Score;
3001        assert_eq!(serde_json::to_string(&score_key).unwrap(), "\"#score\"");
3002
3003        // Test metadata key
3004        let custom_key = Key::MetadataField("custom_key".to_string());
3005        assert_eq!(
3006            serde_json::to_string(&custom_key).unwrap(),
3007            "\"custom_key\""
3008        );
3009
3010        // Test deserialization
3011        let deserialized: Key = serde_json::from_str("\"#document\"").unwrap();
3012        assert!(matches!(deserialized, Key::Document));
3013
3014        let deserialized: Key = serde_json::from_str("\"custom_field\"").unwrap();
3015        assert!(matches!(deserialized, Key::MetadataField(s) if s == "custom_field"));
3016
3017        // Test Select struct with multiple keys
3018        let mut keys = HashSet::new();
3019        keys.insert(Key::Document);
3020        keys.insert(Key::Embedding);
3021        keys.insert(Key::MetadataField("author".to_string()));
3022
3023        let select = Select { keys };
3024        let json = serde_json::to_string(&select).unwrap();
3025        let deserialized: Select = serde_json::from_str(&json).unwrap();
3026
3027        assert_eq!(deserialized.keys.len(), 3);
3028        assert!(deserialized.keys.contains(&Key::Document));
3029        assert!(deserialized.keys.contains(&Key::Embedding));
3030        assert!(deserialized
3031            .keys
3032            .contains(&Key::MetadataField("author".to_string())));
3033    }
3034
3035    #[test]
3036    fn test_merge_basic_integers() {
3037        use std::cmp::Reverse;
3038
3039        let merge = Merge { k: 5 };
3040
3041        // Input: sorted vectors of Reverse(u32) - ascending order of inner values
3042        let input = vec![
3043            vec![Reverse(1), Reverse(4), Reverse(7), Reverse(10)],
3044            vec![Reverse(2), Reverse(5), Reverse(8)],
3045            vec![Reverse(3), Reverse(6), Reverse(9), Reverse(11), Reverse(12)],
3046        ];
3047
3048        let result = merge.merge(input);
3049
3050        // Should get top-5 smallest values (largest Reverse values)
3051        assert_eq!(result.len(), 5);
3052        assert_eq!(
3053            result,
3054            vec![Reverse(1), Reverse(2), Reverse(3), Reverse(4), Reverse(5)]
3055        );
3056    }
3057
3058    #[test]
3059    fn test_merge_u32_descending() {
3060        let merge = Merge { k: 6 };
3061
3062        // Regular u32 in descending order (largest first)
3063        let input = vec![
3064            vec![100u32, 75, 50, 25],
3065            vec![90, 60, 30],
3066            vec![95, 85, 70, 40, 10],
3067        ];
3068
3069        let result = merge.merge(input);
3070
3071        // Should get top-6 largest u32 values
3072        assert_eq!(result.len(), 6);
3073        assert_eq!(result, vec![100, 95, 90, 85, 75, 70]);
3074    }
3075
3076    #[test]
3077    fn test_merge_i32_descending() {
3078        let merge = Merge { k: 5 };
3079
3080        // i32 values in descending order (including negatives)
3081        let input = vec![
3082            vec![50i32, 10, -10, -50],
3083            vec![30, 0, -30],
3084            vec![40, 20, -20, -40],
3085        ];
3086
3087        let result = merge.merge(input);
3088
3089        // Should get top-5 largest i32 values
3090        assert_eq!(result.len(), 5);
3091        assert_eq!(result, vec![50, 40, 30, 20, 10]);
3092    }
3093
3094    #[test]
3095    fn test_merge_with_duplicates() {
3096        let merge = Merge { k: 10 };
3097
3098        // Input with duplicates using regular u32 in descending order
3099        let input = vec![
3100            vec![100u32, 80, 80, 60, 40],
3101            vec![90, 80, 50, 30],
3102            vec![100, 70, 60, 20],
3103        ];
3104
3105        let result = merge.merge(input);
3106
3107        // Duplicates should be removed
3108        assert_eq!(result, vec![100, 90, 80, 70, 60, 50, 40, 30, 20]);
3109    }
3110
3111    #[test]
3112    fn test_merge_empty_vectors() {
3113        let merge = Merge { k: 5 };
3114
3115        // All empty with u32
3116        let input: Vec<Vec<u32>> = vec![vec![], vec![], vec![]];
3117        let result = merge.merge(input);
3118        assert_eq!(result.len(), 0);
3119
3120        // Some empty, some with data (u64)
3121        let input = vec![vec![], vec![1000u64, 750, 500], vec![], vec![850, 600]];
3122        let result = merge.merge(input);
3123        assert_eq!(result, vec![1000, 850, 750, 600, 500]);
3124
3125        // Single non-empty vector (i32)
3126        let input = vec![vec![], vec![100i32, 50, 25], vec![]];
3127        let result = merge.merge(input);
3128        assert_eq!(result, vec![100, 50, 25]);
3129    }
3130
3131    #[test]
3132    fn test_merge_k_boundary_conditions() {
3133        // k = 0 with u32
3134        let merge = Merge { k: 0 };
3135        let input = vec![vec![100u32, 50], vec![75, 25]];
3136        let result = merge.merge(input);
3137        assert_eq!(result.len(), 0);
3138
3139        // k = 1 with i64
3140        let merge = Merge { k: 1 };
3141        let input = vec![vec![1000i64, 500], vec![750, 250], vec![900, 100]];
3142        let result = merge.merge(input);
3143        assert_eq!(result, vec![1000]);
3144
3145        // k larger than total unique elements with u128
3146        let merge = Merge { k: 100 };
3147        let input = vec![vec![10000u128, 5000], vec![8000, 3000]];
3148        let result = merge.merge(input);
3149        assert_eq!(result, vec![10000, 8000, 5000, 3000]);
3150    }
3151
3152    #[test]
3153    fn test_merge_with_strings() {
3154        let merge = Merge { k: 4 };
3155
3156        // Strings must be sorted in descending order (largest first) for the max heap merge
3157        let input = vec![
3158            vec!["zebra".to_string(), "dog".to_string(), "apple".to_string()],
3159            vec!["elephant".to_string(), "banana".to_string()],
3160            vec!["fish".to_string(), "cat".to_string()],
3161        ];
3162
3163        let result = merge.merge(input);
3164
3165        // Should get top-4 lexicographically largest strings
3166        assert_eq!(
3167            result,
3168            vec![
3169                "zebra".to_string(),
3170                "fish".to_string(),
3171                "elephant".to_string(),
3172                "dog".to_string()
3173            ]
3174        );
3175    }
3176
3177    #[test]
3178    fn test_merge_with_custom_struct() {
3179        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
3180        struct Score {
3181            value: i32,
3182            id: String,
3183        }
3184
3185        let merge = Merge { k: 3 };
3186
3187        // Custom structs sorted by value (descending), then by id
3188        let input = vec![
3189            vec![
3190                Score {
3191                    value: 100,
3192                    id: "a".to_string(),
3193                },
3194                Score {
3195                    value: 80,
3196                    id: "b".to_string(),
3197                },
3198                Score {
3199                    value: 60,
3200                    id: "c".to_string(),
3201                },
3202            ],
3203            vec![
3204                Score {
3205                    value: 90,
3206                    id: "d".to_string(),
3207                },
3208                Score {
3209                    value: 70,
3210                    id: "e".to_string(),
3211                },
3212            ],
3213            vec![
3214                Score {
3215                    value: 95,
3216                    id: "f".to_string(),
3217                },
3218                Score {
3219                    value: 85,
3220                    id: "g".to_string(),
3221                },
3222            ],
3223        ];
3224
3225        let result = merge.merge(input);
3226
3227        assert_eq!(result.len(), 3);
3228        assert_eq!(
3229            result[0],
3230            Score {
3231                value: 100,
3232                id: "a".to_string()
3233            }
3234        );
3235        assert_eq!(
3236            result[1],
3237            Score {
3238                value: 95,
3239                id: "f".to_string()
3240            }
3241        );
3242        assert_eq!(
3243            result[2],
3244            Score {
3245                value: 90,
3246                id: "d".to_string()
3247            }
3248        );
3249    }
3250
3251    #[test]
3252    fn test_merge_preserves_order() {
3253        use std::cmp::Reverse;
3254
3255        let merge = Merge { k: 10 };
3256
3257        // For Reverse, smaller inner values are "larger" in ordering
3258        // So vectors should be sorted with smallest inner values first
3259        let input = vec![
3260            vec![Reverse(2), Reverse(6), Reverse(10), Reverse(14)],
3261            vec![Reverse(4), Reverse(8), Reverse(12), Reverse(16)],
3262            vec![Reverse(1), Reverse(3), Reverse(5), Reverse(7), Reverse(9)],
3263        ];
3264
3265        let result = merge.merge(input);
3266
3267        // Verify output maintains order - should be sorted by Reverse ordering
3268        // which means ascending inner values
3269        for i in 1..result.len() {
3270            assert!(
3271                result[i - 1] >= result[i],
3272                "Output should be in descending Reverse order"
3273            );
3274            assert!(
3275                result[i - 1].0 <= result[i].0,
3276                "Inner values should be in ascending order"
3277            );
3278        }
3279
3280        // Check we got the right elements
3281        assert_eq!(
3282            result,
3283            vec![
3284                Reverse(1),
3285                Reverse(2),
3286                Reverse(3),
3287                Reverse(4),
3288                Reverse(5),
3289                Reverse(6),
3290                Reverse(7),
3291                Reverse(8),
3292                Reverse(9),
3293                Reverse(10)
3294            ]
3295        );
3296    }
3297
3298    #[test]
3299    fn test_merge_single_vector() {
3300        let merge = Merge { k: 3 };
3301
3302        // Single vector input with u64
3303        let input = vec![vec![1000u64, 800, 600, 400, 200]];
3304
3305        let result = merge.merge(input);
3306
3307        assert_eq!(result, vec![1000, 800, 600]);
3308    }
3309
3310    #[test]
3311    fn test_merge_all_same_values() {
3312        let merge = Merge { k: 5 };
3313
3314        // All vectors contain the same value (using i16)
3315        let input = vec![vec![42i16, 42, 42], vec![42, 42], vec![42, 42, 42, 42]];
3316
3317        let result = merge.merge(input);
3318
3319        // Should deduplicate to single value
3320        assert_eq!(result, vec![42]);
3321    }
3322
3323    #[test]
3324    fn test_merge_mixed_types_sizes() {
3325        // Test with usize (common in real usage)
3326        let merge = Merge { k: 4 };
3327        let input = vec![
3328            vec![1000usize, 500, 100],
3329            vec![800, 300],
3330            vec![900, 600, 200],
3331        ];
3332        let result = merge.merge(input);
3333        assert_eq!(result, vec![1000, 900, 800, 600]);
3334
3335        // Test with negative integers (i32)
3336        let merge = Merge { k: 5 };
3337        let input = vec![vec![10i32, 0, -10, -20], vec![5, -5, -15], vec![15, -25]];
3338        let result = merge.merge(input);
3339        assert_eq!(result, vec![15, 10, 5, 0, -5]);
3340    }
3341}