chroma_types/execution/
plan.rs

1use super::{
2    error::QueryConversionError,
3    operator::{
4        Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5    },
6};
7use crate::{
8    chroma_proto,
9    operator::{Key, RankExpr},
10    validators::validate_rank,
11    Where,
12};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15#[cfg(feature = "utoipa")]
16use utoipa::{
17    openapi::{
18        schema::{Schema, SchemaType},
19        ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
20    },
21    PartialSchema,
22};
23use validator::Validate;
24
25#[derive(Error, Debug)]
26pub enum PlanToProtoError {
27    #[error("Failed to convert scan to proto: {0}")]
28    Scan(#[from] ScanToProtoError),
29}
30
31/// The `Count` plan shoud ouutput the total number of records in the collection
32#[derive(Clone)]
33pub struct Count {
34    pub scan: Scan,
35}
36
37impl TryFrom<chroma_proto::CountPlan> for Count {
38    type Error = QueryConversionError;
39
40    fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
41        Ok(Self {
42            scan: value
43                .scan
44                .ok_or(QueryConversionError::field("scan"))?
45                .try_into()?,
46        })
47    }
48}
49
50impl TryFrom<Count> for chroma_proto::CountPlan {
51    type Error = PlanToProtoError;
52
53    fn try_from(value: Count) -> Result<Self, Self::Error> {
54        Ok(Self {
55            scan: Some(value.scan.try_into()?),
56        })
57    }
58}
59
60/// The `Get` plan should output records matching the specified filter and limit in the collection
61#[derive(Clone, Debug)]
62pub struct Get {
63    pub scan: Scan,
64    pub filter: Filter,
65    pub limit: Limit,
66    pub proj: Projection,
67}
68
69impl TryFrom<chroma_proto::GetPlan> for Get {
70    type Error = QueryConversionError;
71
72    fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
73        Ok(Self {
74            scan: value
75                .scan
76                .ok_or(QueryConversionError::field("scan"))?
77                .try_into()?,
78            filter: value
79                .filter
80                .ok_or(QueryConversionError::field("filter"))?
81                .try_into()?,
82            limit: value
83                .limit
84                .ok_or(QueryConversionError::field("limit"))?
85                .into(),
86            proj: value
87                .projection
88                .ok_or(QueryConversionError::field("projection"))?
89                .into(),
90        })
91    }
92}
93
94impl TryFrom<Get> for chroma_proto::GetPlan {
95    type Error = QueryConversionError;
96
97    fn try_from(value: Get) -> Result<Self, Self::Error> {
98        Ok(Self {
99            scan: Some(value.scan.try_into()?),
100            filter: Some(value.filter.try_into()?),
101            limit: Some(value.limit.into()),
102            projection: Some(value.proj.into()),
103        })
104    }
105}
106
107/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
108#[derive(Clone, Debug)]
109pub struct Knn {
110    pub scan: Scan,
111    pub filter: Filter,
112    pub knn: KnnBatch,
113    pub proj: KnnProjection,
114}
115
116impl TryFrom<chroma_proto::KnnPlan> for Knn {
117    type Error = QueryConversionError;
118
119    fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
120        Ok(Self {
121            scan: value
122                .scan
123                .ok_or(QueryConversionError::field("scan"))?
124                .try_into()?,
125            filter: value
126                .filter
127                .ok_or(QueryConversionError::field("filter"))?
128                .try_into()?,
129            knn: value
130                .knn
131                .ok_or(QueryConversionError::field("knn"))?
132                .try_into()?,
133            proj: value
134                .projection
135                .ok_or(QueryConversionError::field("projection"))?
136                .try_into()?,
137        })
138    }
139}
140
141impl TryFrom<Knn> for chroma_proto::KnnPlan {
142    type Error = QueryConversionError;
143
144    fn try_from(value: Knn) -> Result<Self, Self::Error> {
145        Ok(Self {
146            scan: Some(value.scan.try_into()?),
147            filter: Some(value.filter.try_into()?),
148            knn: Some(value.knn.try_into()?),
149            projection: Some(value.proj.into()),
150        })
151    }
152}
153
154#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
155pub struct SearchPayload {
156    #[serde(default)]
157    pub filter: Filter,
158    #[serde(default)]
159    #[validate(custom(function = "validate_rank"))]
160    pub rank: Rank,
161    #[serde(default)]
162    pub limit: Limit,
163    #[serde(default)]
164    pub select: Select,
165}
166
167impl SearchPayload {
168    pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
169        self.limit.limit = limit;
170        self.limit.offset = offset;
171        self
172    }
173    pub fn rank(mut self, expr: RankExpr) -> Self {
174        self.rank.expr = Some(expr);
175        self
176    }
177    pub fn select<I, T>(mut self, keys: I) -> Self
178    where
179        I: IntoIterator<Item = T>,
180        T: Into<Key>,
181    {
182        self.select.keys = keys.into_iter().map(Into::into).collect();
183        self
184    }
185    pub fn r#where(mut self, r#where: Where) -> Self {
186        self.filter.where_clause = Some(r#where);
187        self
188    }
189}
190
191#[cfg(feature = "utoipa")]
192impl PartialSchema for SearchPayload {
193    fn schema() -> RefOr<Schema> {
194        RefOr::T(Schema::Object(
195            ObjectBuilder::new()
196                .schema_type(SchemaType::Type(Type::Object))
197                .property(
198                    "filter",
199                    ObjectBuilder::new()
200                        .schema_type(SchemaType::Type(Type::Object))
201                        .property(
202                            "query_ids",
203                            ArrayBuilder::new()
204                                .items(Object::with_type(SchemaType::Type(Type::String))),
205                        )
206                        .property(
207                            "where_clause",
208                            Object::with_type(SchemaType::Type(Type::Object)),
209                        ),
210                )
211                .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
212                .property(
213                    "limit",
214                    ObjectBuilder::new()
215                        .schema_type(SchemaType::Type(Type::Object))
216                        .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
217                        .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
218                )
219                .property(
220                    "select",
221                    ObjectBuilder::new()
222                        .schema_type(SchemaType::Type(Type::Object))
223                        .property(
224                            "keys",
225                            ArrayBuilder::new()
226                                .items(Object::with_type(SchemaType::Type(Type::String))),
227                        ),
228                )
229                .build(),
230        ))
231    }
232}
233
234#[cfg(feature = "utoipa")]
235impl utoipa::ToSchema for SearchPayload {}
236
237impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
238    type Error = QueryConversionError;
239
240    fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
241        Ok(Self {
242            filter: value
243                .filter
244                .ok_or(QueryConversionError::field("filter"))?
245                .try_into()?,
246            rank: value
247                .rank
248                .ok_or(QueryConversionError::field("rank"))?
249                .try_into()?,
250            limit: value
251                .limit
252                .ok_or(QueryConversionError::field("limit"))?
253                .into(),
254            select: value
255                .select
256                .ok_or(QueryConversionError::field("select"))?
257                .try_into()?,
258        })
259    }
260}
261
262impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
263    type Error = QueryConversionError;
264
265    fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
266        Ok(Self {
267            filter: Some(value.filter.try_into()?),
268            rank: Some(value.rank.try_into()?),
269            limit: Some(value.limit.into()),
270            select: Some(value.select.try_into()?),
271        })
272    }
273}
274
275#[derive(Clone, Debug)]
276pub struct Search {
277    pub scan: Scan,
278    pub payloads: Vec<SearchPayload>,
279}
280
281impl TryFrom<chroma_proto::SearchPlan> for Search {
282    type Error = QueryConversionError;
283
284    fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
285        Ok(Self {
286            scan: value
287                .scan
288                .ok_or(QueryConversionError::field("scan"))?
289                .try_into()?,
290            payloads: value
291                .payloads
292                .into_iter()
293                .map(TryInto::try_into)
294                .collect::<Result<Vec<_>, _>>()?,
295        })
296    }
297}
298
299impl TryFrom<Search> for chroma_proto::SearchPlan {
300    type Error = QueryConversionError;
301
302    fn try_from(value: Search) -> Result<Self, Self::Error> {
303        Ok(Self {
304            scan: Some(value.scan.try_into()?),
305            payloads: value
306                .payloads
307                .into_iter()
308                .map(TryInto::try_into)
309                .collect::<Result<Vec<_>, _>>()?,
310        })
311    }
312}