chroma_types/execution/
plan.rs

1use super::{
2    error::QueryConversionError,
3    operator::{
4        Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5    },
6};
7use crate::{
8    chroma_proto,
9    operator::{Key, RankExpr},
10    validators::validate_rank,
11    Where,
12};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15#[cfg(feature = "utoipa")]
16use utoipa::{
17    openapi::{
18        schema::{Schema, SchemaType},
19        ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
20    },
21    PartialSchema,
22};
23use validator::Validate;
24
25#[derive(Error, Debug)]
26pub enum PlanToProtoError {
27    #[error("Failed to convert scan to proto: {0}")]
28    Scan(#[from] ScanToProtoError),
29}
30
31/// The `Count` plan shoud ouutput the total number of records in the collection
32#[derive(Clone)]
33pub struct Count {
34    pub scan: Scan,
35}
36
37impl TryFrom<chroma_proto::CountPlan> for Count {
38    type Error = QueryConversionError;
39
40    fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
41        Ok(Self {
42            scan: value
43                .scan
44                .ok_or(QueryConversionError::field("scan"))?
45                .try_into()?,
46        })
47    }
48}
49
50impl TryFrom<Count> for chroma_proto::CountPlan {
51    type Error = PlanToProtoError;
52
53    fn try_from(value: Count) -> Result<Self, Self::Error> {
54        Ok(Self {
55            scan: Some(value.scan.try_into()?),
56        })
57    }
58}
59
60/// The `Get` plan should output records matching the specified filter and limit in the collection
61#[derive(Clone, Debug)]
62pub struct Get {
63    pub scan: Scan,
64    pub filter: Filter,
65    pub limit: Limit,
66    pub proj: Projection,
67}
68
69impl TryFrom<chroma_proto::GetPlan> for Get {
70    type Error = QueryConversionError;
71
72    fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
73        Ok(Self {
74            scan: value
75                .scan
76                .ok_or(QueryConversionError::field("scan"))?
77                .try_into()?,
78            filter: value
79                .filter
80                .ok_or(QueryConversionError::field("filter"))?
81                .try_into()?,
82            limit: value
83                .limit
84                .ok_or(QueryConversionError::field("limit"))?
85                .into(),
86            proj: value
87                .projection
88                .ok_or(QueryConversionError::field("projection"))?
89                .into(),
90        })
91    }
92}
93
94impl TryFrom<Get> for chroma_proto::GetPlan {
95    type Error = QueryConversionError;
96
97    fn try_from(value: Get) -> Result<Self, Self::Error> {
98        Ok(Self {
99            scan: Some(value.scan.try_into()?),
100            filter: Some(value.filter.try_into()?),
101            limit: Some(value.limit.into()),
102            projection: Some(value.proj.into()),
103        })
104    }
105}
106
107/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
108#[derive(Clone, Debug)]
109pub struct Knn {
110    pub scan: Scan,
111    pub filter: Filter,
112    pub knn: KnnBatch,
113    pub proj: KnnProjection,
114}
115
116impl TryFrom<chroma_proto::KnnPlan> for Knn {
117    type Error = QueryConversionError;
118
119    fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
120        Ok(Self {
121            scan: value
122                .scan
123                .ok_or(QueryConversionError::field("scan"))?
124                .try_into()?,
125            filter: value
126                .filter
127                .ok_or(QueryConversionError::field("filter"))?
128                .try_into()?,
129            knn: value
130                .knn
131                .ok_or(QueryConversionError::field("knn"))?
132                .try_into()?,
133            proj: value
134                .projection
135                .ok_or(QueryConversionError::field("projection"))?
136                .try_into()?,
137        })
138    }
139}
140
141impl TryFrom<Knn> for chroma_proto::KnnPlan {
142    type Error = QueryConversionError;
143
144    fn try_from(value: Knn) -> Result<Self, Self::Error> {
145        Ok(Self {
146            scan: Some(value.scan.try_into()?),
147            filter: Some(value.filter.try_into()?),
148            knn: Some(value.knn.try_into()?),
149            projection: Some(value.proj.into()),
150        })
151    }
152}
153
154/// A search payload for the hybrid search API.
155///
156/// Combines filtering, ranking, pagination, and field selection into a single query.
157/// Use the builder methods to construct complex searches with a fluent interface.
158///
159/// # Examples
160///
161/// ## Basic vector search
162///
163/// ```
164/// use chroma_types::plan::SearchPayload;
165/// use chroma_types::operator::{RankExpr, QueryVector, Key};
166///
167/// let search = SearchPayload::default()
168///     .rank(RankExpr::Knn {
169///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
170///         key: Key::Embedding,
171///         limit: 100,
172///         default: None,
173///         return_rank: false,
174///     })
175///     .limit(Some(10), 0)
176///     .select([Key::Document, Key::Score]);
177/// ```
178///
179/// ## Filtered search
180///
181/// ```
182/// use chroma_types::plan::SearchPayload;
183/// use chroma_types::operator::{RankExpr, QueryVector, Key};
184///
185/// let search = SearchPayload::default()
186///     .r#where(
187///         Key::field("status").eq("published")
188///             & Key::field("year").gte(2020)
189///     )
190///     .rank(RankExpr::Knn {
191///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
192///         key: Key::Embedding,
193///         limit: 200,
194///         default: None,
195///         return_rank: false,
196///     })
197///     .limit(Some(5), 0)
198///     .select([Key::Document, Key::Score, Key::field("title")]);
199/// ```
200///
201/// ## Hybrid search with custom ranking
202///
203/// ```
204/// use chroma_types::plan::SearchPayload;
205/// use chroma_types::operator::{RankExpr, QueryVector, Key};
206///
207/// let dense = RankExpr::Knn {
208///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
209///     key: Key::Embedding,
210///     limit: 200,
211///     default: None,
212///     return_rank: false,
213/// };
214///
215/// let sparse = RankExpr::Knn {
216///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
217///     key: Key::field("sparse_embedding"),
218///     limit: 200,
219///     default: None,
220///     return_rank: false,
221/// };
222///
223/// let search = SearchPayload::default()
224///     .rank(dense * 0.7 + sparse * 0.3)
225///     .limit(Some(10), 0)
226///     .select([Key::Document, Key::Score]);
227/// ```
228#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
229pub struct SearchPayload {
230    #[serde(default)]
231    pub filter: Filter,
232    #[serde(default)]
233    #[validate(custom(function = "validate_rank"))]
234    pub rank: Rank,
235    #[serde(default)]
236    pub limit: Limit,
237    #[serde(default)]
238    pub select: Select,
239}
240
241impl SearchPayload {
242    /// Sets pagination parameters.
243    ///
244    /// # Arguments
245    ///
246    /// * `limit` - Maximum number of results to return (None = no limit)
247    /// * `offset` - Number of results to skip
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use chroma_types::plan::SearchPayload;
253    ///
254    /// // First page: results 0-9
255    /// let search = SearchPayload::default().limit(Some(10), 0);
256    ///
257    /// // Second page: results 10-19
258    /// let search = SearchPayload::default().limit(Some(10), 10);
259    /// ```
260    pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
261        self.limit.limit = limit;
262        self.limit.offset = offset;
263        self
264    }
265
266    /// Sets the ranking expression for scoring and ordering results.
267    ///
268    /// # Arguments
269    ///
270    /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
271    ///
272    /// # Examples
273    ///
274    /// ## Simple KNN ranking
275    ///
276    /// ```
277    /// use chroma_types::plan::SearchPayload;
278    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
279    ///
280    /// let search = SearchPayload::default()
281    ///     .rank(RankExpr::Knn {
282    ///         query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
283    ///         key: Key::Embedding,
284    ///         limit: 100,
285    ///         default: None,
286    ///         return_rank: false,
287    ///     });
288    /// ```
289    ///
290    /// ## Weighted combination
291    ///
292    /// ```
293    /// use chroma_types::plan::SearchPayload;
294    /// use chroma_types::operator::{RankExpr, QueryVector, Key};
295    ///
296    /// let knn1 = RankExpr::Knn {
297    ///     query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
298    ///     key: Key::Embedding,
299    ///     limit: 100,
300    ///     default: None,
301    ///     return_rank: false,
302    /// };
303    ///
304    /// let knn2 = RankExpr::Knn {
305    ///     query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
306    ///     key: Key::field("other_embedding"),
307    ///     limit: 100,
308    ///     default: None,
309    ///     return_rank: false,
310    /// };
311    ///
312    /// let search = SearchPayload::default()
313    ///     .rank(knn1 * 0.8 + knn2 * 0.2);
314    /// ```
315    pub fn rank(mut self, expr: RankExpr) -> Self {
316        self.rank.expr = Some(expr);
317        self
318    }
319
320    /// Selects which fields to include in the results.
321    ///
322    /// # Arguments
323    ///
324    /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
325    ///
326    /// # Examples
327    ///
328    /// ```
329    /// use chroma_types::plan::SearchPayload;
330    /// use chroma_types::operator::Key;
331    ///
332    /// // Select predefined fields
333    /// let search = SearchPayload::default()
334    ///     .select([Key::Document, Key::Score]);
335    ///
336    /// // Select metadata fields
337    /// let search = SearchPayload::default()
338    ///     .select([Key::field("title"), Key::field("author")]);
339    ///
340    /// // Mix predefined and custom fields
341    /// let search = SearchPayload::default()
342    ///     .select([Key::Document, Key::Score, Key::field("title")]);
343    /// ```
344    pub fn select<I, T>(mut self, keys: I) -> Self
345    where
346        I: IntoIterator<Item = T>,
347        T: Into<Key>,
348    {
349        self.select.keys = keys.into_iter().map(Into::into).collect();
350        self
351    }
352
353    /// Sets the filter expression for narrowing results.
354    ///
355    /// # Arguments
356    ///
357    /// * `where` - A Where expression for filtering
358    ///
359    /// # Examples
360    ///
361    /// ## Simple equality filter
362    ///
363    /// ```
364    /// use chroma_types::plan::SearchPayload;
365    /// use chroma_types::operator::Key;
366    ///
367    /// let search = SearchPayload::default()
368    ///     .r#where(Key::field("status").eq("published"));
369    /// ```
370    ///
371    /// ## Numeric comparisons
372    ///
373    /// ```
374    /// use chroma_types::plan::SearchPayload;
375    /// use chroma_types::operator::Key;
376    ///
377    /// let search = SearchPayload::default()
378    ///     .r#where(Key::field("year").gte(2020));
379    /// ```
380    ///
381    /// ## Combining filters
382    ///
383    /// ```
384    /// use chroma_types::plan::SearchPayload;
385    /// use chroma_types::operator::Key;
386    ///
387    /// let search = SearchPayload::default()
388    ///     .r#where(
389    ///         Key::field("status").eq("published")
390    ///             & Key::field("year").gte(2020)
391    ///             & Key::field("category").is_in(vec!["tech", "science"])
392    ///     );
393    /// ```
394    ///
395    /// ## Document content filtering
396    ///
397    /// ```
398    /// use chroma_types::plan::SearchPayload;
399    /// use chroma_types::operator::Key;
400    ///
401    /// let search = SearchPayload::default()
402    ///     .r#where(Key::Document.contains("machine learning"));
403    /// ```
404    pub fn r#where(mut self, r#where: Where) -> Self {
405        self.filter.where_clause = Some(r#where);
406        self
407    }
408}
409
410#[cfg(feature = "utoipa")]
411impl PartialSchema for SearchPayload {
412    fn schema() -> RefOr<Schema> {
413        RefOr::T(Schema::Object(
414            ObjectBuilder::new()
415                .schema_type(SchemaType::Type(Type::Object))
416                .property(
417                    "filter",
418                    ObjectBuilder::new()
419                        .schema_type(SchemaType::Type(Type::Object))
420                        .property(
421                            "query_ids",
422                            ArrayBuilder::new()
423                                .items(Object::with_type(SchemaType::Type(Type::String))),
424                        )
425                        .property(
426                            "where_clause",
427                            Object::with_type(SchemaType::Type(Type::Object)),
428                        ),
429                )
430                .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
431                .property(
432                    "limit",
433                    ObjectBuilder::new()
434                        .schema_type(SchemaType::Type(Type::Object))
435                        .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
436                        .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
437                )
438                .property(
439                    "select",
440                    ObjectBuilder::new()
441                        .schema_type(SchemaType::Type(Type::Object))
442                        .property(
443                            "keys",
444                            ArrayBuilder::new()
445                                .items(Object::with_type(SchemaType::Type(Type::String))),
446                        ),
447                )
448                .build(),
449        ))
450    }
451}
452
453#[cfg(feature = "utoipa")]
454impl utoipa::ToSchema for SearchPayload {}
455
456impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
457    type Error = QueryConversionError;
458
459    fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
460        Ok(Self {
461            filter: value
462                .filter
463                .ok_or(QueryConversionError::field("filter"))?
464                .try_into()?,
465            rank: value
466                .rank
467                .ok_or(QueryConversionError::field("rank"))?
468                .try_into()?,
469            limit: value
470                .limit
471                .ok_or(QueryConversionError::field("limit"))?
472                .into(),
473            select: value
474                .select
475                .ok_or(QueryConversionError::field("select"))?
476                .try_into()?,
477        })
478    }
479}
480
481impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
482    type Error = QueryConversionError;
483
484    fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
485        Ok(Self {
486            filter: Some(value.filter.try_into()?),
487            rank: Some(value.rank.try_into()?),
488            limit: Some(value.limit.into()),
489            select: Some(value.select.try_into()?),
490        })
491    }
492}
493
494#[derive(Clone, Debug)]
495pub struct Search {
496    pub scan: Scan,
497    pub payloads: Vec<SearchPayload>,
498}
499
500impl TryFrom<chroma_proto::SearchPlan> for Search {
501    type Error = QueryConversionError;
502
503    fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
504        Ok(Self {
505            scan: value
506                .scan
507                .ok_or(QueryConversionError::field("scan"))?
508                .try_into()?,
509            payloads: value
510                .payloads
511                .into_iter()
512                .map(TryInto::try_into)
513                .collect::<Result<Vec<_>, _>>()?,
514        })
515    }
516}
517
518impl TryFrom<Search> for chroma_proto::SearchPlan {
519    type Error = QueryConversionError;
520
521    fn try_from(value: Search) -> Result<Self, Self::Error> {
522        Ok(Self {
523            scan: Some(value.scan.try_into()?),
524            payloads: value
525                .payloads
526                .into_iter()
527                .map(TryInto::try_into)
528                .collect::<Result<Vec<_>, _>>()?,
529        })
530    }
531}