1use super::{
2 error::QueryConversionError,
3 operator::{
4 Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5 },
6};
7use crate::{chroma_proto, validators::validate_rank};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10#[cfg(feature = "utoipa")]
11use utoipa::{
12 openapi::{
13 schema::{Schema, SchemaType},
14 ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
15 },
16 PartialSchema,
17};
18use validator::Validate;
19
20#[derive(Error, Debug)]
21pub enum PlanToProtoError {
22 #[error("Failed to convert scan to proto: {0}")]
23 Scan(#[from] ScanToProtoError),
24}
25
26#[derive(Clone)]
28pub struct Count {
29 pub scan: Scan,
30}
31
32impl TryFrom<chroma_proto::CountPlan> for Count {
33 type Error = QueryConversionError;
34
35 fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
36 Ok(Self {
37 scan: value
38 .scan
39 .ok_or(QueryConversionError::field("scan"))?
40 .try_into()?,
41 })
42 }
43}
44
45impl TryFrom<Count> for chroma_proto::CountPlan {
46 type Error = PlanToProtoError;
47
48 fn try_from(value: Count) -> Result<Self, Self::Error> {
49 Ok(Self {
50 scan: Some(value.scan.try_into()?),
51 })
52 }
53}
54
55#[derive(Clone, Debug)]
57pub struct Get {
58 pub scan: Scan,
59 pub filter: Filter,
60 pub limit: Limit,
61 pub proj: Projection,
62}
63
64impl TryFrom<chroma_proto::GetPlan> for Get {
65 type Error = QueryConversionError;
66
67 fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
68 Ok(Self {
69 scan: value
70 .scan
71 .ok_or(QueryConversionError::field("scan"))?
72 .try_into()?,
73 filter: value
74 .filter
75 .ok_or(QueryConversionError::field("filter"))?
76 .try_into()?,
77 limit: value
78 .limit
79 .ok_or(QueryConversionError::field("limit"))?
80 .into(),
81 proj: value
82 .projection
83 .ok_or(QueryConversionError::field("projection"))?
84 .into(),
85 })
86 }
87}
88
89impl TryFrom<Get> for chroma_proto::GetPlan {
90 type Error = QueryConversionError;
91
92 fn try_from(value: Get) -> Result<Self, Self::Error> {
93 Ok(Self {
94 scan: Some(value.scan.try_into()?),
95 filter: Some(value.filter.try_into()?),
96 limit: Some(value.limit.into()),
97 projection: Some(value.proj.into()),
98 })
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct Knn {
105 pub scan: Scan,
106 pub filter: Filter,
107 pub knn: KnnBatch,
108 pub proj: KnnProjection,
109}
110
111impl TryFrom<chroma_proto::KnnPlan> for Knn {
112 type Error = QueryConversionError;
113
114 fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
115 Ok(Self {
116 scan: value
117 .scan
118 .ok_or(QueryConversionError::field("scan"))?
119 .try_into()?,
120 filter: value
121 .filter
122 .ok_or(QueryConversionError::field("filter"))?
123 .try_into()?,
124 knn: value
125 .knn
126 .ok_or(QueryConversionError::field("knn"))?
127 .try_into()?,
128 proj: value
129 .projection
130 .ok_or(QueryConversionError::field("projection"))?
131 .try_into()?,
132 })
133 }
134}
135
136impl TryFrom<Knn> for chroma_proto::KnnPlan {
137 type Error = QueryConversionError;
138
139 fn try_from(value: Knn) -> Result<Self, Self::Error> {
140 Ok(Self {
141 scan: Some(value.scan.try_into()?),
142 filter: Some(value.filter.try_into()?),
143 knn: Some(value.knn.try_into()?),
144 projection: Some(value.proj.into()),
145 })
146 }
147}
148
149#[derive(Clone, Debug, Deserialize, Serialize, Validate)]
150pub struct SearchPayload {
151 #[serde(default)]
152 pub filter: Filter,
153 #[serde(default)]
154 #[validate(custom(function = "validate_rank"))]
155 pub rank: Rank,
156 #[serde(default)]
157 pub limit: Limit,
158 #[serde(default)]
159 pub select: Select,
160}
161
162#[cfg(feature = "utoipa")]
163impl PartialSchema for SearchPayload {
164 fn schema() -> RefOr<Schema> {
165 RefOr::T(Schema::Object(
166 ObjectBuilder::new()
167 .schema_type(SchemaType::Type(Type::Object))
168 .property(
169 "filter",
170 ObjectBuilder::new()
171 .schema_type(SchemaType::Type(Type::Object))
172 .property(
173 "query_ids",
174 ArrayBuilder::new()
175 .items(Object::with_type(SchemaType::Type(Type::String))),
176 )
177 .property(
178 "where_clause",
179 Object::with_type(SchemaType::Type(Type::Object)),
180 ),
181 )
182 .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
183 .property(
184 "limit",
185 ObjectBuilder::new()
186 .schema_type(SchemaType::Type(Type::Object))
187 .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
188 .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
189 )
190 .property(
191 "select",
192 ObjectBuilder::new()
193 .schema_type(SchemaType::Type(Type::Object))
194 .property(
195 "keys",
196 ArrayBuilder::new()
197 .items(Object::with_type(SchemaType::Type(Type::String))),
198 ),
199 )
200 .build(),
201 ))
202 }
203}
204
205#[cfg(feature = "utoipa")]
206impl utoipa::ToSchema for SearchPayload {}
207
208impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
209 type Error = QueryConversionError;
210
211 fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
212 Ok(Self {
213 filter: value
214 .filter
215 .ok_or(QueryConversionError::field("filter"))?
216 .try_into()?,
217 rank: value
218 .rank
219 .ok_or(QueryConversionError::field("rank"))?
220 .try_into()?,
221 limit: value
222 .limit
223 .ok_or(QueryConversionError::field("limit"))?
224 .into(),
225 select: value
226 .select
227 .ok_or(QueryConversionError::field("select"))?
228 .try_into()?,
229 })
230 }
231}
232
233impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
234 type Error = QueryConversionError;
235
236 fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
237 Ok(Self {
238 filter: Some(value.filter.try_into()?),
239 rank: Some(value.rank.try_into()?),
240 limit: Some(value.limit.into()),
241 select: Some(value.select.try_into()?),
242 })
243 }
244}
245
246#[derive(Clone, Debug)]
247pub struct Search {
248 pub scan: Scan,
249 pub payloads: Vec<SearchPayload>,
250}
251
252impl TryFrom<chroma_proto::SearchPlan> for Search {
253 type Error = QueryConversionError;
254
255 fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
256 Ok(Self {
257 scan: value
258 .scan
259 .ok_or(QueryConversionError::field("scan"))?
260 .try_into()?,
261 payloads: value
262 .payloads
263 .into_iter()
264 .map(TryInto::try_into)
265 .collect::<Result<Vec<_>, _>>()?,
266 })
267 }
268}
269
270impl TryFrom<Search> for chroma_proto::SearchPlan {
271 type Error = QueryConversionError;
272
273 fn try_from(value: Search) -> Result<Self, Self::Error> {
274 Ok(Self {
275 scan: Some(value.scan.try_into()?),
276 payloads: value
277 .payloads
278 .into_iter()
279 .map(TryInto::try_into)
280 .collect::<Result<Vec<_>, _>>()?,
281 })
282 }
283}