chroma_types/execution/
plan.rs

1use super::{
2    error::QueryConversionError,
3    operator::{
4        Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5    },
6};
7use crate::{chroma_proto, validators::validate_rank};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10#[cfg(feature = "utoipa")]
11use utoipa::{
12    openapi::{
13        schema::{Schema, SchemaType},
14        ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
15    },
16    PartialSchema,
17};
18use validator::Validate;
19
20#[derive(Error, Debug)]
21pub enum PlanToProtoError {
22    #[error("Failed to convert scan to proto: {0}")]
23    Scan(#[from] ScanToProtoError),
24}
25
26/// The `Count` plan shoud ouutput the total number of records in the collection
27#[derive(Clone)]
28pub struct Count {
29    pub scan: Scan,
30}
31
32impl TryFrom<chroma_proto::CountPlan> for Count {
33    type Error = QueryConversionError;
34
35    fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
36        Ok(Self {
37            scan: value
38                .scan
39                .ok_or(QueryConversionError::field("scan"))?
40                .try_into()?,
41        })
42    }
43}
44
45impl TryFrom<Count> for chroma_proto::CountPlan {
46    type Error = PlanToProtoError;
47
48    fn try_from(value: Count) -> Result<Self, Self::Error> {
49        Ok(Self {
50            scan: Some(value.scan.try_into()?),
51        })
52    }
53}
54
55/// The `Get` plan should output records matching the specified filter and limit in the collection
56#[derive(Clone, Debug)]
57pub struct Get {
58    pub scan: Scan,
59    pub filter: Filter,
60    pub limit: Limit,
61    pub proj: Projection,
62}
63
64impl TryFrom<chroma_proto::GetPlan> for Get {
65    type Error = QueryConversionError;
66
67    fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
68        Ok(Self {
69            scan: value
70                .scan
71                .ok_or(QueryConversionError::field("scan"))?
72                .try_into()?,
73            filter: value
74                .filter
75                .ok_or(QueryConversionError::field("filter"))?
76                .try_into()?,
77            limit: value
78                .limit
79                .ok_or(QueryConversionError::field("limit"))?
80                .into(),
81            proj: value
82                .projection
83                .ok_or(QueryConversionError::field("projection"))?
84                .into(),
85        })
86    }
87}
88
89impl TryFrom<Get> for chroma_proto::GetPlan {
90    type Error = QueryConversionError;
91
92    fn try_from(value: Get) -> Result<Self, Self::Error> {
93        Ok(Self {
94            scan: Some(value.scan.try_into()?),
95            filter: Some(value.filter.try_into()?),
96            limit: Some(value.limit.into()),
97            projection: Some(value.proj.into()),
98        })
99    }
100}
101
102/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
103#[derive(Clone, Debug)]
104pub struct Knn {
105    pub scan: Scan,
106    pub filter: Filter,
107    pub knn: KnnBatch,
108    pub proj: KnnProjection,
109}
110
111impl TryFrom<chroma_proto::KnnPlan> for Knn {
112    type Error = QueryConversionError;
113
114    fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
115        Ok(Self {
116            scan: value
117                .scan
118                .ok_or(QueryConversionError::field("scan"))?
119                .try_into()?,
120            filter: value
121                .filter
122                .ok_or(QueryConversionError::field("filter"))?
123                .try_into()?,
124            knn: value
125                .knn
126                .ok_or(QueryConversionError::field("knn"))?
127                .try_into()?,
128            proj: value
129                .projection
130                .ok_or(QueryConversionError::field("projection"))?
131                .try_into()?,
132        })
133    }
134}
135
136impl TryFrom<Knn> for chroma_proto::KnnPlan {
137    type Error = QueryConversionError;
138
139    fn try_from(value: Knn) -> Result<Self, Self::Error> {
140        Ok(Self {
141            scan: Some(value.scan.try_into()?),
142            filter: Some(value.filter.try_into()?),
143            knn: Some(value.knn.try_into()?),
144            projection: Some(value.proj.into()),
145        })
146    }
147}
148
149#[derive(Clone, Debug, Deserialize, Serialize, Validate)]
150pub struct SearchPayload {
151    #[serde(default)]
152    pub filter: Filter,
153    #[serde(default)]
154    #[validate(custom(function = "validate_rank"))]
155    pub rank: Rank,
156    #[serde(default)]
157    pub limit: Limit,
158    #[serde(default)]
159    pub select: Select,
160}
161
162#[cfg(feature = "utoipa")]
163impl PartialSchema for SearchPayload {
164    fn schema() -> RefOr<Schema> {
165        RefOr::T(Schema::Object(
166            ObjectBuilder::new()
167                .schema_type(SchemaType::Type(Type::Object))
168                .property(
169                    "filter",
170                    ObjectBuilder::new()
171                        .schema_type(SchemaType::Type(Type::Object))
172                        .property(
173                            "query_ids",
174                            ArrayBuilder::new()
175                                .items(Object::with_type(SchemaType::Type(Type::String))),
176                        )
177                        .property(
178                            "where_clause",
179                            Object::with_type(SchemaType::Type(Type::Object)),
180                        ),
181                )
182                .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
183                .property(
184                    "limit",
185                    ObjectBuilder::new()
186                        .schema_type(SchemaType::Type(Type::Object))
187                        .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
188                        .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
189                )
190                .property(
191                    "select",
192                    ObjectBuilder::new()
193                        .schema_type(SchemaType::Type(Type::Object))
194                        .property(
195                            "keys",
196                            ArrayBuilder::new()
197                                .items(Object::with_type(SchemaType::Type(Type::String))),
198                        ),
199                )
200                .build(),
201        ))
202    }
203}
204
205#[cfg(feature = "utoipa")]
206impl utoipa::ToSchema for SearchPayload {}
207
208impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
209    type Error = QueryConversionError;
210
211    fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
212        Ok(Self {
213            filter: value
214                .filter
215                .ok_or(QueryConversionError::field("filter"))?
216                .try_into()?,
217            rank: value
218                .rank
219                .ok_or(QueryConversionError::field("rank"))?
220                .try_into()?,
221            limit: value
222                .limit
223                .ok_or(QueryConversionError::field("limit"))?
224                .into(),
225            select: value
226                .select
227                .ok_or(QueryConversionError::field("select"))?
228                .try_into()?,
229        })
230    }
231}
232
233impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
234    type Error = QueryConversionError;
235
236    fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
237        Ok(Self {
238            filter: Some(value.filter.try_into()?),
239            rank: Some(value.rank.try_into()?),
240            limit: Some(value.limit.into()),
241            select: Some(value.select.try_into()?),
242        })
243    }
244}
245
246#[derive(Clone, Debug)]
247pub struct Search {
248    pub scan: Scan,
249    pub payloads: Vec<SearchPayload>,
250}
251
252impl TryFrom<chroma_proto::SearchPlan> for Search {
253    type Error = QueryConversionError;
254
255    fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
256        Ok(Self {
257            scan: value
258                .scan
259                .ok_or(QueryConversionError::field("scan"))?
260                .try_into()?,
261            payloads: value
262                .payloads
263                .into_iter()
264                .map(TryInto::try_into)
265                .collect::<Result<Vec<_>, _>>()?,
266        })
267    }
268}
269
270impl TryFrom<Search> for chroma_proto::SearchPlan {
271    type Error = QueryConversionError;
272
273    fn try_from(value: Search) -> Result<Self, Self::Error> {
274        Ok(Self {
275            scan: Some(value.scan.try_into()?),
276            payloads: value
277                .payloads
278                .into_iter()
279                .map(TryInto::try_into)
280                .collect::<Result<Vec<_>, _>>()?,
281        })
282    }
283}