chroma_types/execution/
plan.rs

1use super::{
2    error::QueryConversionError,
3    operator::{
4        Filter, GroupBy, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError,
5        Select,
6    },
7};
8use crate::{
9    chroma_proto,
10    operator::{Key, RankExpr},
11    validators::{validate_group_by, validate_rank, validate_search_payload},
12    Where,
13};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16#[cfg(feature = "utoipa")]
17use utoipa::{
18    openapi::{
19        schema::{Schema, SchemaType},
20        ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
21    },
22    PartialSchema,
23};
24use validator::Validate;
25
26#[derive(Error, Debug)]
27pub enum PlanToProtoError {
28    #[error("Failed to convert scan to proto: {0}")]
29    Scan(#[from] ScanToProtoError),
30}
31
32/// The `Count` plan shoud ouutput the total number of records in the collection
33#[derive(Clone)]
34pub struct Count {
35    pub scan: Scan,
36}
37
38impl TryFrom<chroma_proto::CountPlan> for Count {
39    type Error = QueryConversionError;
40
41    fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
42        Ok(Self {
43            scan: value
44                .scan
45                .ok_or(QueryConversionError::field("scan"))?
46                .try_into()?,
47        })
48    }
49}
50
51impl TryFrom<Count> for chroma_proto::CountPlan {
52    type Error = PlanToProtoError;
53
54    fn try_from(value: Count) -> Result<Self, Self::Error> {
55        Ok(Self {
56            scan: Some(value.scan.try_into()?),
57        })
58    }
59}
60
61/// The `Get` plan should output records matching the specified filter and limit in the collection
62#[derive(Clone, Debug)]
63pub struct Get {
64    pub scan: Scan,
65    pub filter: Filter,
66    pub limit: Limit,
67    pub proj: Projection,
68}
69
70impl TryFrom<chroma_proto::GetPlan> for Get {
71    type Error = QueryConversionError;
72
73    fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
74        Ok(Self {
75            scan: value
76                .scan
77                .ok_or(QueryConversionError::field("scan"))?
78                .try_into()?,
79            filter: value
80                .filter
81                .ok_or(QueryConversionError::field("filter"))?
82                .try_into()?,
83            limit: value
84                .limit
85                .ok_or(QueryConversionError::field("limit"))?
86                .into(),
87            proj: value
88                .projection
89                .ok_or(QueryConversionError::field("projection"))?
90                .into(),
91        })
92    }
93}
94
95impl TryFrom<Get> for chroma_proto::GetPlan {
96    type Error = QueryConversionError;
97
98    fn try_from(value: Get) -> Result<Self, Self::Error> {
99        Ok(Self {
100            scan: Some(value.scan.try_into()?),
101            filter: Some(value.filter.try_into()?),
102            limit: Some(value.limit.into()),
103            projection: Some(value.proj.into()),
104        })
105    }
106}
107
108/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
109#[derive(Clone, Debug)]
110pub struct Knn {
111    pub scan: Scan,
112    pub filter: Filter,
113    pub knn: KnnBatch,
114    pub proj: KnnProjection,
115}
116
117impl TryFrom<chroma_proto::KnnPlan> for Knn {
118    type Error = QueryConversionError;
119
120    fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
121        Ok(Self {
122            scan: value
123                .scan
124                .ok_or(QueryConversionError::field("scan"))?
125                .try_into()?,
126            filter: value
127                .filter
128                .ok_or(QueryConversionError::field("filter"))?
129                .try_into()?,
130            knn: value
131                .knn
132                .ok_or(QueryConversionError::field("knn"))?
133                .try_into()?,
134            proj: value
135                .projection
136                .ok_or(QueryConversionError::field("projection"))?
137                .try_into()?,
138        })
139    }
140}
141
142impl TryFrom<Knn> for chroma_proto::KnnPlan {
143    type Error = QueryConversionError;
144
145    fn try_from(value: Knn) -> Result<Self, Self::Error> {
146        Ok(Self {
147            scan: Some(value.scan.try_into()?),
148            filter: Some(value.filter.try_into()?),
149            knn: Some(value.knn.try_into()?),
150            projection: Some(value.proj.into()),
151        })
152    }
153}
154
155/// A search payload for the hybrid search API.
156///
157/// Combines filtering, ranking, pagination, and field selection into a single query.
158/// Use the builder methods to construct complex searches with a fluent interface.
159///
160/// # Examples
161///
162/// ## Basic vector search
163///
164/// ```
165/// use chroma_types::plan::SearchPayload;
166/// use chroma_types::operator::{RankExpr, QueryVector, Key};
167///
168/// let search = SearchPayload::default()
169///     .rank(RankExpr::Knn {
170///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
171///         key: Key::Embedding,
172///         limit: 100,
173///         default: None,
174///         return_rank: false,
175///     })
176///     .limit(Some(10), 0)
177///     .select([Key::Document, Key::Score]);
178/// ```
179///
180/// ## Filtered search
181///
182/// ```
183/// use chroma_types::plan::SearchPayload;
184/// use chroma_types::operator::{RankExpr, QueryVector, Key};
185///
186/// let search = SearchPayload::default()
187///     .r#where(
188///         Key::field("status").eq("published")
189///             & Key::field("year").gte(2020)
190///     )
191///     .rank(RankExpr::Knn {
192///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
193///         key: Key::Embedding,
194///         limit: 200,
195///         default: None,
196///         return_rank: false,
197///     })
198///     .limit(Some(5), 0)
199///     .select([Key::Document, Key::Score, Key::field("title")]);
200/// ```
201///
202/// ## Hybrid search with custom ranking
203///
204/// ```
205/// use chroma_types::plan::SearchPayload;
206/// use chroma_types::operator::{RankExpr, QueryVector, Key};
207///
208/// let dense = RankExpr::Knn {
209///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
210///     key: Key::Embedding,
211///     limit: 200,
212///     default: None,
213///     return_rank: false,
214/// };
215///
216/// let sparse = RankExpr::Knn {
217///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
218///     key: Key::field("sparse_embedding"),
219///     limit: 200,
220///     default: None,
221///     return_rank: false,
222/// };
223///
224/// let search = SearchPayload::default()
225///     .rank(dense * 0.7 + sparse * 0.3)
226///     .limit(Some(10), 0)
227///     .select([Key::Document, Key::Score]);
228/// ```
229#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
230#[validate(schema(function = "validate_search_payload"))]
231pub struct SearchPayload {
232    #[serde(default)]
233    pub filter: Filter,
234    #[serde(default)]
235    #[validate(custom(function = "validate_rank"))]
236    pub rank: Rank,
237    #[serde(default)]
238    #[validate(custom(function = "validate_group_by"))]
239    pub group_by: GroupBy,
240    #[serde(default)]
241    pub limit: Limit,
242    #[serde(default)]
243    pub select: Select,
244}
245
246impl SearchPayload {
247    /// Sets pagination parameters.
248    ///
249    /// # Arguments
250    ///
251    /// * `limit` - Maximum number of results to return (None = no limit)
252    /// * `offset` - Number of results to skip
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use chroma_types::plan::SearchPayload;
258    ///
259    /// // First page: results 0-9
260    /// let search = SearchPayload::default().limit(Some(10), 0);
261    ///
262    /// // Second page: results 10-19
263    /// let search = SearchPayload::default().limit(Some(10), 10);
264    /// ```
265    pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
266        self.limit.limit = limit;
267        self.limit.offset = offset;
268        self
269    }
270
271    /// Sets the ranking expression for scoring and ordering results.
272    ///
273    /// # Arguments
274    ///
275    /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
276    ///
277    /// # Examples
278    ///
279    /// ## Simple KNN ranking
280    ///
281    /// ```
282    /// use chroma_types::plan::SearchPayload;
283    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
284    ///
285    /// let search = SearchPayload::default()
286    ///     .rank(RankExpr::Knn {
287    ///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
288    ///         key: Key::Embedding,
289    ///         limit: 100,
290    ///         default: None,
291    ///         return_rank: false,
292    ///     });
293    /// ```
294    ///
295    /// ## Weighted combination
296    ///
297    /// ```
298    /// use chroma_types::plan::SearchPayload;
299    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
300    ///
301    /// let knn1 = RankExpr::Knn {
302    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
303    ///     key: Key::Embedding,
304    ///     limit: 100,
305    ///     default: None,
306    ///     return_rank: false,
307    /// };
308    ///
309    /// let knn2 = RankExpr::Knn {
310    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
311    ///     key: Key::field("other_embedding"),
312    ///     limit: 100,
313    ///     default: None,
314    ///     return_rank: false,
315    /// };
316    ///
317    /// let search = SearchPayload::default()
318    ///     .rank(knn1 * 0.8 + knn2 * 0.2);
319    /// ```
320    pub fn rank(mut self, expr: RankExpr) -> Self {
321        self.rank.expr = Some(expr);
322        self
323    }
324
325    /// Selects which fields to include in the results.
326    ///
327    /// # Arguments
328    ///
329    /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// use chroma_types::plan::SearchPayload;
335    /// use chroma_types::operator::Key;
336    ///
337    /// // Select predefined fields
338    /// let search = SearchPayload::default()
339    ///     .select([Key::Document, Key::Score]);
340    ///
341    /// // Select metadata fields
342    /// let search = SearchPayload::default()
343    ///     .select([Key::field("title"), Key::field("author")]);
344    ///
345    /// // Mix predefined and custom fields
346    /// let search = SearchPayload::default()
347    ///     .select([Key::Document, Key::Score, Key::field("title")]);
348    /// ```
349    pub fn select<I, T>(mut self, keys: I) -> Self
350    where
351        I: IntoIterator<Item = T>,
352        T: Into<Key>,
353    {
354        self.select.keys = keys.into_iter().map(Into::into).collect();
355        self
356    }
357
358    /// Sets the filter expression for narrowing results.
359    ///
360    /// # Arguments
361    ///
362    /// * `where` - A Where expression for filtering
363    ///
364    /// # Examples
365    ///
366    /// ## Simple equality filter
367    ///
368    /// ```
369    /// use chroma_types::plan::SearchPayload;
370    /// use chroma_types::operator::Key;
371    ///
372    /// let search = SearchPayload::default()
373    ///     .r#where(Key::field("status").eq("published"));
374    /// ```
375    ///
376    /// ## Numeric comparisons
377    ///
378    /// ```
379    /// use chroma_types::plan::SearchPayload;
380    /// use chroma_types::operator::Key;
381    ///
382    /// let search = SearchPayload::default()
383    ///     .r#where(Key::field("year").gte(2020));
384    /// ```
385    ///
386    /// ## Combining filters
387    ///
388    /// ```
389    /// use chroma_types::plan::SearchPayload;
390    /// use chroma_types::operator::Key;
391    ///
392    /// let search = SearchPayload::default()
393    ///     .r#where(
394    ///         Key::field("status").eq("published")
395    ///             & Key::field("year").gte(2020)
396    ///             & Key::field("category").is_in(vec!["tech", "science"])
397    ///     );
398    /// ```
399    ///
400    /// ## Document content filtering
401    ///
402    /// ```
403    /// use chroma_types::plan::SearchPayload;
404    /// use chroma_types::operator::Key;
405    ///
406    /// let search = SearchPayload::default()
407    ///     .r#where(Key::Document.contains("machine learning"));
408    /// ```
409    pub fn r#where(mut self, r#where: Where) -> Self {
410        self.filter.where_clause = Some(r#where);
411        self
412    }
413
414    /// Groups results by metadata keys and aggregates within each group.
415    ///
416    /// # Arguments
417    ///
418    /// * `group_by` - GroupBy configuration with keys and aggregation
419    ///
420    /// # Examples
421    ///
422    /// ```
423    /// use chroma_types::plan::SearchPayload;
424    /// use chroma_types::operator::{GroupBy, Aggregate, Key};
425    ///
426    /// // Top 3 best documents per category
427    /// let search = SearchPayload::default()
428    ///     .group_by(GroupBy {
429    ///         keys: vec![Key::field("category")],
430    ///         aggregate: Some(Aggregate::MinK {
431    ///             keys: vec![Key::Score],
432    ///             k: 3,
433    ///         }),
434    ///     });
435    /// ```
436    pub fn group_by(mut self, group_by: GroupBy) -> Self {
437        self.group_by = group_by;
438        self
439    }
440}
441
442#[cfg(feature = "utoipa")]
443impl PartialSchema for SearchPayload {
444    fn schema() -> RefOr<Schema> {
445        RefOr::T(Schema::Object(
446            ObjectBuilder::new()
447                .schema_type(SchemaType::Type(Type::Object))
448                .property(
449                    "filter",
450                    ObjectBuilder::new()
451                        .schema_type(SchemaType::Type(Type::Object))
452                        .property(
453                            "query_ids",
454                            ArrayBuilder::new()
455                                .items(Object::with_type(SchemaType::Type(Type::String))),
456                        )
457                        .property(
458                            "where_clause",
459                            Object::with_type(SchemaType::Type(Type::Object)),
460                        ),
461                )
462                .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
463                .property(
464                    "group_by",
465                    ObjectBuilder::new()
466                        .schema_type(SchemaType::Type(Type::Object))
467                        .property(
468                            "keys",
469                            ArrayBuilder::new()
470                                .items(Object::with_type(SchemaType::Type(Type::String))),
471                        )
472                        .property(
473                            "aggregate",
474                            Object::with_type(SchemaType::Type(Type::Object)),
475                        ),
476                )
477                .property(
478                    "limit",
479                    ObjectBuilder::new()
480                        .schema_type(SchemaType::Type(Type::Object))
481                        .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
482                        .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
483                )
484                .property(
485                    "select",
486                    ObjectBuilder::new()
487                        .schema_type(SchemaType::Type(Type::Object))
488                        .property(
489                            "keys",
490                            ArrayBuilder::new()
491                                .items(Object::with_type(SchemaType::Type(Type::String))),
492                        ),
493                )
494                .build(),
495        ))
496    }
497}
498
499#[cfg(feature = "utoipa")]
500impl utoipa::ToSchema for SearchPayload {}
501
502impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
503    type Error = QueryConversionError;
504
505    fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
506        Ok(Self {
507            filter: value
508                .filter
509                .ok_or(QueryConversionError::field("filter"))?
510                .try_into()?,
511            rank: value
512                .rank
513                .ok_or(QueryConversionError::field("rank"))?
514                .try_into()?,
515            group_by: value
516                .group_by
517                .map(TryInto::try_into)
518                .transpose()?
519                .unwrap_or_default(),
520            limit: value
521                .limit
522                .ok_or(QueryConversionError::field("limit"))?
523                .into(),
524            select: value
525                .select
526                .ok_or(QueryConversionError::field("select"))?
527                .try_into()?,
528        })
529    }
530}
531
532impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
533    type Error = QueryConversionError;
534
535    fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
536        Ok(Self {
537            filter: Some(value.filter.try_into()?),
538            rank: Some(value.rank.try_into()?),
539            group_by: Some(value.group_by.try_into()?),
540            limit: Some(value.limit.into()),
541            select: Some(value.select.try_into()?),
542        })
543    }
544}
545
546#[derive(Clone, Debug)]
547pub struct Search {
548    pub scan: Scan,
549    pub payloads: Vec<SearchPayload>,
550}
551
552impl TryFrom<chroma_proto::SearchPlan> for Search {
553    type Error = QueryConversionError;
554
555    fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
556        Ok(Self {
557            scan: value
558                .scan
559                .ok_or(QueryConversionError::field("scan"))?
560                .try_into()?,
561            payloads: value
562                .payloads
563                .into_iter()
564                .map(TryInto::try_into)
565                .collect::<Result<Vec<_>, _>>()?,
566        })
567    }
568}
569
570impl TryFrom<Search> for chroma_proto::SearchPlan {
571    type Error = QueryConversionError;
572
573    fn try_from(value: Search) -> Result<Self, Self::Error> {
574        Ok(Self {
575            scan: Some(value.scan.try_into()?),
576            payloads: value
577                .payloads
578                .into_iter()
579                .map(TryInto::try_into)
580                .collect::<Result<Vec<_>, _>>()?,
581        })
582    }
583}