Skip to main content

chroma_types/execution/
plan.rs

1use super::{
2    error::QueryConversionError,
3    operator::{
4        Filter, GroupBy, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError,
5        Select,
6    },
7};
8use crate::{
9    chroma_proto,
10    operator::{Key, RankExpr},
11    validators::{validate_group_by, validate_rank, validate_search_payload},
12    Where,
13};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16#[cfg(feature = "utoipa")]
17use utoipa::{
18    openapi::{
19        schema::{Schema, SchemaType},
20        ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
21    },
22    PartialSchema,
23};
24use validator::Validate;
25
26#[derive(Error, Debug)]
27pub enum PlanToProtoError {
28    #[error("Failed to convert scan to proto: {0}")]
29    Scan(#[from] ScanToProtoError),
30}
31
32/// The `Count` plan shoud ouutput the total number of records in the collection
33#[derive(Clone)]
34pub struct Count {
35    pub scan: Scan,
36    pub read_level: ReadLevel,
37}
38
39impl TryFrom<chroma_proto::CountPlan> for Count {
40    type Error = QueryConversionError;
41
42    fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
43        let read_level = value.read_level().into();
44        Ok(Self {
45            scan: value
46                .scan
47                .ok_or(QueryConversionError::field("scan"))?
48                .try_into()?,
49            read_level,
50        })
51    }
52}
53
54impl TryFrom<Count> for chroma_proto::CountPlan {
55    type Error = PlanToProtoError;
56
57    fn try_from(value: Count) -> Result<Self, Self::Error> {
58        Ok(Self {
59            scan: Some(value.scan.try_into()?),
60            read_level: chroma_proto::ReadLevel::from(value.read_level).into(),
61        })
62    }
63}
64
65/// The `Get` plan should output records matching the specified filter and limit in the collection
66#[derive(Clone, Debug)]
67pub struct Get {
68    pub scan: Scan,
69    pub filter: Filter,
70    pub limit: Limit,
71    pub proj: Projection,
72}
73
74impl TryFrom<chroma_proto::GetPlan> for Get {
75    type Error = QueryConversionError;
76
77    fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
78        Ok(Self {
79            scan: value
80                .scan
81                .ok_or(QueryConversionError::field("scan"))?
82                .try_into()?,
83            filter: value
84                .filter
85                .ok_or(QueryConversionError::field("filter"))?
86                .try_into()?,
87            limit: value
88                .limit
89                .ok_or(QueryConversionError::field("limit"))?
90                .into(),
91            proj: value
92                .projection
93                .ok_or(QueryConversionError::field("projection"))?
94                .into(),
95        })
96    }
97}
98
99impl TryFrom<Get> for chroma_proto::GetPlan {
100    type Error = QueryConversionError;
101
102    fn try_from(value: Get) -> Result<Self, Self::Error> {
103        Ok(Self {
104            scan: Some(value.scan.try_into()?),
105            filter: Some(value.filter.try_into()?),
106            limit: Some(value.limit.into()),
107            projection: Some(value.proj.into()),
108        })
109    }
110}
111
112/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
113#[derive(Clone, Debug)]
114pub struct Knn {
115    pub scan: Scan,
116    pub filter: Filter,
117    pub knn: KnnBatch,
118    pub proj: KnnProjection,
119}
120
121impl TryFrom<chroma_proto::KnnPlan> for Knn {
122    type Error = QueryConversionError;
123
124    fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
125        Ok(Self {
126            scan: value
127                .scan
128                .ok_or(QueryConversionError::field("scan"))?
129                .try_into()?,
130            filter: value
131                .filter
132                .ok_or(QueryConversionError::field("filter"))?
133                .try_into()?,
134            knn: value
135                .knn
136                .ok_or(QueryConversionError::field("knn"))?
137                .try_into()?,
138            proj: value
139                .projection
140                .ok_or(QueryConversionError::field("projection"))?
141                .try_into()?,
142        })
143    }
144}
145
146impl TryFrom<Knn> for chroma_proto::KnnPlan {
147    type Error = QueryConversionError;
148
149    fn try_from(value: Knn) -> Result<Self, Self::Error> {
150        Ok(Self {
151            scan: Some(value.scan.try_into()?),
152            filter: Some(value.filter.try_into()?),
153            knn: Some(value.knn.try_into()?),
154            projection: Some(value.proj.into()),
155        })
156    }
157}
158
159/// A search payload for the hybrid search API.
160///
161/// Combines filtering, ranking, pagination, and field selection into a single query.
162/// Use the builder methods to construct complex searches with a fluent interface.
163///
164/// # Examples
165///
166/// ## Basic vector search
167///
168/// ```
169/// use chroma_types::plan::SearchPayload;
170/// use chroma_types::operator::{RankExpr, QueryVector, Key};
171///
172/// let search = SearchPayload::default()
173///     .rank(RankExpr::Knn {
174///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
175///         key: Key::Embedding,
176///         limit: 100,
177///         default: None,
178///         return_rank: false,
179///     })
180///     .limit(Some(10), 0)
181///     .select([Key::Document, Key::Score]);
182/// ```
183///
184/// ## Filtered search
185///
186/// ```
187/// use chroma_types::plan::SearchPayload;
188/// use chroma_types::operator::{RankExpr, QueryVector, Key};
189///
190/// let search = SearchPayload::default()
191///     .r#where(
192///         Key::field("status").eq("published")
193///             & Key::field("year").gte(2020)
194///     )
195///     .rank(RankExpr::Knn {
196///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
197///         key: Key::Embedding,
198///         limit: 200,
199///         default: None,
200///         return_rank: false,
201///     })
202///     .limit(Some(5), 0)
203///     .select([Key::Document, Key::Score, Key::field("title")]);
204/// ```
205///
206/// ## Hybrid search with custom ranking
207///
208/// ```
209/// use chroma_types::plan::SearchPayload;
210/// use chroma_types::operator::{RankExpr, QueryVector, Key};
211///
212/// let dense = RankExpr::Knn {
213///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
214///     key: Key::Embedding,
215///     limit: 200,
216///     default: None,
217///     return_rank: false,
218/// };
219///
220/// let sparse = RankExpr::Knn {
221///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
222///     key: Key::field("sparse_embedding"),
223///     limit: 200,
224///     default: None,
225///     return_rank: false,
226/// };
227///
228/// let search = SearchPayload::default()
229///     .rank(dense * 0.7 + sparse * 0.3)
230///     .limit(Some(10), 0)
231///     .select([Key::Document, Key::Score]);
232/// ```
233#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
234#[validate(schema(function = "validate_search_payload"))]
235pub struct SearchPayload {
236    #[serde(default)]
237    pub filter: Filter,
238    #[serde(default)]
239    #[validate(custom(function = "validate_rank"))]
240    pub rank: Rank,
241    #[serde(default)]
242    #[validate(custom(function = "validate_group_by"))]
243    pub group_by: GroupBy,
244    #[serde(default)]
245    pub limit: Limit,
246    #[serde(default)]
247    pub select: Select,
248}
249
250impl SearchPayload {
251    /// Sets pagination parameters.
252    ///
253    /// # Arguments
254    ///
255    /// * `limit` - Maximum number of results to return (None = no limit)
256    /// * `offset` - Number of results to skip
257    ///
258    /// # Examples
259    ///
260    /// ```
261    /// use chroma_types::plan::SearchPayload;
262    ///
263    /// // First page: results 0-9
264    /// let search = SearchPayload::default().limit(Some(10), 0);
265    ///
266    /// // Second page: results 10-19
267    /// let search = SearchPayload::default().limit(Some(10), 10);
268    /// ```
269    pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
270        self.limit.limit = limit;
271        self.limit.offset = offset;
272        self
273    }
274
275    /// Sets the ranking expression for scoring and ordering results.
276    ///
277    /// # Arguments
278    ///
279    /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
280    ///
281    /// # Examples
282    ///
283    /// ## Simple KNN ranking
284    ///
285    /// ```
286    /// use chroma_types::plan::SearchPayload;
287    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
288    ///
289    /// let search = SearchPayload::default()
290    ///     .rank(RankExpr::Knn {
291    ///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
292    ///         key: Key::Embedding,
293    ///         limit: 100,
294    ///         default: None,
295    ///         return_rank: false,
296    ///     });
297    /// ```
298    ///
299    /// ## Weighted combination
300    ///
301    /// ```
302    /// use chroma_types::plan::SearchPayload;
303    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
304    ///
305    /// let knn1 = RankExpr::Knn {
306    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
307    ///     key: Key::Embedding,
308    ///     limit: 100,
309    ///     default: None,
310    ///     return_rank: false,
311    /// };
312    ///
313    /// let knn2 = RankExpr::Knn {
314    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
315    ///     key: Key::field("other_embedding"),
316    ///     limit: 100,
317    ///     default: None,
318    ///     return_rank: false,
319    /// };
320    ///
321    /// let search = SearchPayload::default()
322    ///     .rank(knn1 * 0.8 + knn2 * 0.2);
323    /// ```
324    pub fn rank(mut self, expr: RankExpr) -> Self {
325        self.rank.expr = Some(expr);
326        self
327    }
328
329    /// Selects which fields to include in the results.
330    ///
331    /// # Arguments
332    ///
333    /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
334    ///
335    /// # Examples
336    ///
337    /// ```
338    /// use chroma_types::plan::SearchPayload;
339    /// use chroma_types::operator::Key;
340    ///
341    /// // Select predefined fields
342    /// let search = SearchPayload::default()
343    ///     .select([Key::Document, Key::Score]);
344    ///
345    /// // Select metadata fields
346    /// let search = SearchPayload::default()
347    ///     .select([Key::field("title"), Key::field("author")]);
348    ///
349    /// // Mix predefined and custom fields
350    /// let search = SearchPayload::default()
351    ///     .select([Key::Document, Key::Score, Key::field("title")]);
352    /// ```
353    pub fn select<I, T>(mut self, keys: I) -> Self
354    where
355        I: IntoIterator<Item = T>,
356        T: Into<Key>,
357    {
358        self.select.keys = keys.into_iter().map(Into::into).collect();
359        self
360    }
361
362    /// Sets the filter expression for narrowing results.
363    ///
364    /// # Arguments
365    ///
366    /// * `where` - A Where expression for filtering
367    ///
368    /// # Examples
369    ///
370    /// ## Simple equality filter
371    ///
372    /// ```
373    /// use chroma_types::plan::SearchPayload;
374    /// use chroma_types::operator::Key;
375    ///
376    /// let search = SearchPayload::default()
377    ///     .r#where(Key::field("status").eq("published"));
378    /// ```
379    ///
380    /// ## Numeric comparisons
381    ///
382    /// ```
383    /// use chroma_types::plan::SearchPayload;
384    /// use chroma_types::operator::Key;
385    ///
386    /// let search = SearchPayload::default()
387    ///     .r#where(Key::field("year").gte(2020));
388    /// ```
389    ///
390    /// ## Combining filters
391    ///
392    /// ```
393    /// use chroma_types::plan::SearchPayload;
394    /// use chroma_types::operator::Key;
395    ///
396    /// let search = SearchPayload::default()
397    ///     .r#where(
398    ///         Key::field("status").eq("published")
399    ///             & Key::field("year").gte(2020)
400    ///             & Key::field("category").is_in(vec!["tech", "science"])
401    ///     );
402    /// ```
403    ///
404    /// ## Document content filtering
405    ///
406    /// ```
407    /// use chroma_types::plan::SearchPayload;
408    /// use chroma_types::operator::Key;
409    ///
410    /// let search = SearchPayload::default()
411    ///     .r#where(Key::Document.contains("machine learning"));
412    /// ```
413    pub fn r#where(mut self, r#where: Where) -> Self {
414        self.filter.where_clause = Some(r#where);
415        self
416    }
417
418    /// Groups results by metadata keys and aggregates within each group.
419    ///
420    /// # Arguments
421    ///
422    /// * `group_by` - GroupBy configuration with keys and aggregation
423    ///
424    /// # Examples
425    ///
426    /// ```
427    /// use chroma_types::plan::SearchPayload;
428    /// use chroma_types::operator::{GroupBy, Aggregate, Key};
429    ///
430    /// // Top 3 best documents per category
431    /// let search = SearchPayload::default()
432    ///     .group_by(GroupBy {
433    ///         keys: vec![Key::field("category")],
434    ///         aggregate: Some(Aggregate::MinK {
435    ///             keys: vec![Key::Score],
436    ///             k: 3,
437    ///         }),
438    ///     });
439    /// ```
440    pub fn group_by(mut self, group_by: GroupBy) -> Self {
441        self.group_by = group_by;
442        self
443    }
444}
445
446#[cfg(feature = "utoipa")]
447impl PartialSchema for SearchPayload {
448    fn schema() -> RefOr<Schema> {
449        RefOr::T(Schema::Object(
450            ObjectBuilder::new()
451                .schema_type(SchemaType::Type(Type::Object))
452                .property(
453                    "filter",
454                    ObjectBuilder::new()
455                        .schema_type(SchemaType::Type(Type::Object))
456                        .property(
457                            "query_ids",
458                            ArrayBuilder::new()
459                                .items(Object::with_type(SchemaType::Type(Type::String))),
460                        )
461                        .property(
462                            "where_clause",
463                            Object::with_type(SchemaType::Type(Type::Object)),
464                        ),
465                )
466                .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
467                .property(
468                    "group_by",
469                    ObjectBuilder::new()
470                        .schema_type(SchemaType::Type(Type::Object))
471                        .property(
472                            "keys",
473                            ArrayBuilder::new()
474                                .items(Object::with_type(SchemaType::Type(Type::String))),
475                        )
476                        .property(
477                            "aggregate",
478                            Object::with_type(SchemaType::Type(Type::Object)),
479                        ),
480                )
481                .property(
482                    "limit",
483                    ObjectBuilder::new()
484                        .schema_type(SchemaType::Type(Type::Object))
485                        .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
486                        .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
487                )
488                .property(
489                    "select",
490                    ObjectBuilder::new()
491                        .schema_type(SchemaType::Type(Type::Object))
492                        .property(
493                            "keys",
494                            ArrayBuilder::new()
495                                .items(Object::with_type(SchemaType::Type(Type::String))),
496                        ),
497                )
498                .build(),
499        ))
500    }
501}
502
503#[cfg(feature = "utoipa")]
504impl utoipa::ToSchema for SearchPayload {}
505
506impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
507    type Error = QueryConversionError;
508
509    fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
510        Ok(Self {
511            filter: value
512                .filter
513                .ok_or(QueryConversionError::field("filter"))?
514                .try_into()?,
515            rank: value
516                .rank
517                .ok_or(QueryConversionError::field("rank"))?
518                .try_into()?,
519            group_by: value
520                .group_by
521                .map(TryInto::try_into)
522                .transpose()?
523                .unwrap_or_default(),
524            limit: value
525                .limit
526                .ok_or(QueryConversionError::field("limit"))?
527                .into(),
528            select: value
529                .select
530                .ok_or(QueryConversionError::field("select"))?
531                .try_into()?,
532        })
533    }
534}
535
536impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
537    type Error = QueryConversionError;
538
539    fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
540        Ok(Self {
541            filter: Some(value.filter.try_into()?),
542            rank: Some(value.rank.try_into()?),
543            group_by: Some(value.group_by.try_into()?),
544            limit: Some(value.limit.into()),
545            select: Some(value.select.try_into()?),
546        })
547    }
548}
549
550#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
551#[serde(rename_all = "snake_case")]
552#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
553pub enum ReadLevel {
554    /// Read from both the index and the write-ahead log (default).
555    /// Provides full consistency with all committed writes visible.
556    #[default]
557    IndexAndWal,
558    /// Read only from the index, skipping the write-ahead log.
559    /// Provides eventual consistency - recent uncommitted writes may not be visible.
560    IndexOnly,
561}
562
563impl From<chroma_proto::ReadLevel> for ReadLevel {
564    fn from(value: chroma_proto::ReadLevel) -> Self {
565        match value {
566            chroma_proto::ReadLevel::IndexAndWal => ReadLevel::IndexAndWal,
567            chroma_proto::ReadLevel::IndexOnly => ReadLevel::IndexOnly,
568        }
569    }
570}
571
572impl From<ReadLevel> for chroma_proto::ReadLevel {
573    fn from(value: ReadLevel) -> Self {
574        match value {
575            ReadLevel::IndexAndWal => chroma_proto::ReadLevel::IndexAndWal,
576            ReadLevel::IndexOnly => chroma_proto::ReadLevel::IndexOnly,
577        }
578    }
579}
580
581#[derive(Clone, Debug)]
582pub struct Search {
583    pub scan: Scan,
584    pub payloads: Vec<SearchPayload>,
585    pub read_level: ReadLevel,
586}
587
588impl TryFrom<chroma_proto::SearchPlan> for Search {
589    type Error = QueryConversionError;
590
591    fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
592        let read_level = value.read_level().into();
593        Ok(Self {
594            scan: value
595                .scan
596                .ok_or(QueryConversionError::field("scan"))?
597                .try_into()?,
598            payloads: value
599                .payloads
600                .into_iter()
601                .map(TryInto::try_into)
602                .collect::<Result<Vec<_>, _>>()?,
603            read_level,
604        })
605    }
606}
607
608impl TryFrom<Search> for chroma_proto::SearchPlan {
609    type Error = QueryConversionError;
610
611    fn try_from(value: Search) -> Result<Self, Self::Error> {
612        Ok(Self {
613            scan: Some(value.scan.try_into()?),
614            payloads: value
615                .payloads
616                .into_iter()
617                .map(TryInto::try_into)
618                .collect::<Result<Vec<_>, _>>()?,
619            read_level: chroma_proto::ReadLevel::from(value.read_level).into(),
620        })
621    }
622}