1use super::{
2 error::QueryConversionError,
3 operator::{
4 Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5 },
6};
7use crate::{
8 chroma_proto,
9 operator::{Key, RankExpr},
10 validators::validate_rank,
11 Where,
12};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15#[cfg(feature = "utoipa")]
16use utoipa::{
17 openapi::{
18 schema::{Schema, SchemaType},
19 ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
20 },
21 PartialSchema,
22};
23use validator::Validate;
24
25#[derive(Error, Debug)]
26pub enum PlanToProtoError {
27 #[error("Failed to convert scan to proto: {0}")]
28 Scan(#[from] ScanToProtoError),
29}
30
31#[derive(Clone)]
33pub struct Count {
34 pub scan: Scan,
35}
36
37impl TryFrom<chroma_proto::CountPlan> for Count {
38 type Error = QueryConversionError;
39
40 fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
41 Ok(Self {
42 scan: value
43 .scan
44 .ok_or(QueryConversionError::field("scan"))?
45 .try_into()?,
46 })
47 }
48}
49
50impl TryFrom<Count> for chroma_proto::CountPlan {
51 type Error = PlanToProtoError;
52
53 fn try_from(value: Count) -> Result<Self, Self::Error> {
54 Ok(Self {
55 scan: Some(value.scan.try_into()?),
56 })
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct Get {
63 pub scan: Scan,
64 pub filter: Filter,
65 pub limit: Limit,
66 pub proj: Projection,
67}
68
69impl TryFrom<chroma_proto::GetPlan> for Get {
70 type Error = QueryConversionError;
71
72 fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
73 Ok(Self {
74 scan: value
75 .scan
76 .ok_or(QueryConversionError::field("scan"))?
77 .try_into()?,
78 filter: value
79 .filter
80 .ok_or(QueryConversionError::field("filter"))?
81 .try_into()?,
82 limit: value
83 .limit
84 .ok_or(QueryConversionError::field("limit"))?
85 .into(),
86 proj: value
87 .projection
88 .ok_or(QueryConversionError::field("projection"))?
89 .into(),
90 })
91 }
92}
93
94impl TryFrom<Get> for chroma_proto::GetPlan {
95 type Error = QueryConversionError;
96
97 fn try_from(value: Get) -> Result<Self, Self::Error> {
98 Ok(Self {
99 scan: Some(value.scan.try_into()?),
100 filter: Some(value.filter.try_into()?),
101 limit: Some(value.limit.into()),
102 projection: Some(value.proj.into()),
103 })
104 }
105}
106
107#[derive(Clone, Debug)]
109pub struct Knn {
110 pub scan: Scan,
111 pub filter: Filter,
112 pub knn: KnnBatch,
113 pub proj: KnnProjection,
114}
115
116impl TryFrom<chroma_proto::KnnPlan> for Knn {
117 type Error = QueryConversionError;
118
119 fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
120 Ok(Self {
121 scan: value
122 .scan
123 .ok_or(QueryConversionError::field("scan"))?
124 .try_into()?,
125 filter: value
126 .filter
127 .ok_or(QueryConversionError::field("filter"))?
128 .try_into()?,
129 knn: value
130 .knn
131 .ok_or(QueryConversionError::field("knn"))?
132 .try_into()?,
133 proj: value
134 .projection
135 .ok_or(QueryConversionError::field("projection"))?
136 .try_into()?,
137 })
138 }
139}
140
141impl TryFrom<Knn> for chroma_proto::KnnPlan {
142 type Error = QueryConversionError;
143
144 fn try_from(value: Knn) -> Result<Self, Self::Error> {
145 Ok(Self {
146 scan: Some(value.scan.try_into()?),
147 filter: Some(value.filter.try_into()?),
148 knn: Some(value.knn.try_into()?),
149 projection: Some(value.proj.into()),
150 })
151 }
152}
153
154#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
155pub struct SearchPayload {
156 #[serde(default)]
157 pub filter: Filter,
158 #[serde(default)]
159 #[validate(custom(function = "validate_rank"))]
160 pub rank: Rank,
161 #[serde(default)]
162 pub limit: Limit,
163 #[serde(default)]
164 pub select: Select,
165}
166
167impl SearchPayload {
168 pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
169 self.limit.limit = limit;
170 self.limit.offset = offset;
171 self
172 }
173 pub fn rank(mut self, expr: RankExpr) -> Self {
174 self.rank.expr = Some(expr);
175 self
176 }
177 pub fn select<I, T>(mut self, keys: I) -> Self
178 where
179 I: IntoIterator<Item = T>,
180 T: Into<Key>,
181 {
182 self.select.keys = keys.into_iter().map(Into::into).collect();
183 self
184 }
185 pub fn r#where(mut self, r#where: Where) -> Self {
186 self.filter.where_clause = Some(r#where);
187 self
188 }
189}
190
191#[cfg(feature = "utoipa")]
192impl PartialSchema for SearchPayload {
193 fn schema() -> RefOr<Schema> {
194 RefOr::T(Schema::Object(
195 ObjectBuilder::new()
196 .schema_type(SchemaType::Type(Type::Object))
197 .property(
198 "filter",
199 ObjectBuilder::new()
200 .schema_type(SchemaType::Type(Type::Object))
201 .property(
202 "query_ids",
203 ArrayBuilder::new()
204 .items(Object::with_type(SchemaType::Type(Type::String))),
205 )
206 .property(
207 "where_clause",
208 Object::with_type(SchemaType::Type(Type::Object)),
209 ),
210 )
211 .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
212 .property(
213 "limit",
214 ObjectBuilder::new()
215 .schema_type(SchemaType::Type(Type::Object))
216 .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
217 .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
218 )
219 .property(
220 "select",
221 ObjectBuilder::new()
222 .schema_type(SchemaType::Type(Type::Object))
223 .property(
224 "keys",
225 ArrayBuilder::new()
226 .items(Object::with_type(SchemaType::Type(Type::String))),
227 ),
228 )
229 .build(),
230 ))
231 }
232}
233
234#[cfg(feature = "utoipa")]
235impl utoipa::ToSchema for SearchPayload {}
236
237impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
238 type Error = QueryConversionError;
239
240 fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
241 Ok(Self {
242 filter: value
243 .filter
244 .ok_or(QueryConversionError::field("filter"))?
245 .try_into()?,
246 rank: value
247 .rank
248 .ok_or(QueryConversionError::field("rank"))?
249 .try_into()?,
250 limit: value
251 .limit
252 .ok_or(QueryConversionError::field("limit"))?
253 .into(),
254 select: value
255 .select
256 .ok_or(QueryConversionError::field("select"))?
257 .try_into()?,
258 })
259 }
260}
261
262impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
263 type Error = QueryConversionError;
264
265 fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
266 Ok(Self {
267 filter: Some(value.filter.try_into()?),
268 rank: Some(value.rank.try_into()?),
269 limit: Some(value.limit.into()),
270 select: Some(value.select.try_into()?),
271 })
272 }
273}
274
275#[derive(Clone, Debug)]
276pub struct Search {
277 pub scan: Scan,
278 pub payloads: Vec<SearchPayload>,
279}
280
281impl TryFrom<chroma_proto::SearchPlan> for Search {
282 type Error = QueryConversionError;
283
284 fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
285 Ok(Self {
286 scan: value
287 .scan
288 .ok_or(QueryConversionError::field("scan"))?
289 .try_into()?,
290 payloads: value
291 .payloads
292 .into_iter()
293 .map(TryInto::try_into)
294 .collect::<Result<Vec<_>, _>>()?,
295 })
296 }
297}
298
299impl TryFrom<Search> for chroma_proto::SearchPlan {
300 type Error = QueryConversionError;
301
302 fn try_from(value: Search) -> Result<Self, Self::Error> {
303 Ok(Self {
304 scan: Some(value.scan.try_into()?),
305 payloads: value
306 .payloads
307 .into_iter()
308 .map(TryInto::try_into)
309 .collect::<Result<Vec<_>, _>>()?,
310 })
311 }
312}