chroma_types/execution/plan.rs
1use super::{
2 error::QueryConversionError,
3 operator::{
4 Filter, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError, Select,
5 },
6};
7use crate::{
8 chroma_proto,
9 operator::{Key, RankExpr},
10 validators::validate_rank,
11 Where,
12};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15#[cfg(feature = "utoipa")]
16use utoipa::{
17 openapi::{
18 schema::{Schema, SchemaType},
19 ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
20 },
21 PartialSchema,
22};
23use validator::Validate;
24
25#[derive(Error, Debug)]
26pub enum PlanToProtoError {
27 #[error("Failed to convert scan to proto: {0}")]
28 Scan(#[from] ScanToProtoError),
29}
30
31/// The `Count` plan shoud ouutput the total number of records in the collection
32#[derive(Clone)]
33pub struct Count {
34 pub scan: Scan,
35}
36
37impl TryFrom<chroma_proto::CountPlan> for Count {
38 type Error = QueryConversionError;
39
40 fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
41 Ok(Self {
42 scan: value
43 .scan
44 .ok_or(QueryConversionError::field("scan"))?
45 .try_into()?,
46 })
47 }
48}
49
50impl TryFrom<Count> for chroma_proto::CountPlan {
51 type Error = PlanToProtoError;
52
53 fn try_from(value: Count) -> Result<Self, Self::Error> {
54 Ok(Self {
55 scan: Some(value.scan.try_into()?),
56 })
57 }
58}
59
60/// The `Get` plan should output records matching the specified filter and limit in the collection
61#[derive(Clone, Debug)]
62pub struct Get {
63 pub scan: Scan,
64 pub filter: Filter,
65 pub limit: Limit,
66 pub proj: Projection,
67}
68
69impl TryFrom<chroma_proto::GetPlan> for Get {
70 type Error = QueryConversionError;
71
72 fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
73 Ok(Self {
74 scan: value
75 .scan
76 .ok_or(QueryConversionError::field("scan"))?
77 .try_into()?,
78 filter: value
79 .filter
80 .ok_or(QueryConversionError::field("filter"))?
81 .try_into()?,
82 limit: value
83 .limit
84 .ok_or(QueryConversionError::field("limit"))?
85 .into(),
86 proj: value
87 .projection
88 .ok_or(QueryConversionError::field("projection"))?
89 .into(),
90 })
91 }
92}
93
94impl TryFrom<Get> for chroma_proto::GetPlan {
95 type Error = QueryConversionError;
96
97 fn try_from(value: Get) -> Result<Self, Self::Error> {
98 Ok(Self {
99 scan: Some(value.scan.try_into()?),
100 filter: Some(value.filter.try_into()?),
101 limit: Some(value.limit.into()),
102 projection: Some(value.proj.into()),
103 })
104 }
105}
106
107/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
108#[derive(Clone, Debug)]
109pub struct Knn {
110 pub scan: Scan,
111 pub filter: Filter,
112 pub knn: KnnBatch,
113 pub proj: KnnProjection,
114}
115
116impl TryFrom<chroma_proto::KnnPlan> for Knn {
117 type Error = QueryConversionError;
118
119 fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
120 Ok(Self {
121 scan: value
122 .scan
123 .ok_or(QueryConversionError::field("scan"))?
124 .try_into()?,
125 filter: value
126 .filter
127 .ok_or(QueryConversionError::field("filter"))?
128 .try_into()?,
129 knn: value
130 .knn
131 .ok_or(QueryConversionError::field("knn"))?
132 .try_into()?,
133 proj: value
134 .projection
135 .ok_or(QueryConversionError::field("projection"))?
136 .try_into()?,
137 })
138 }
139}
140
141impl TryFrom<Knn> for chroma_proto::KnnPlan {
142 type Error = QueryConversionError;
143
144 fn try_from(value: Knn) -> Result<Self, Self::Error> {
145 Ok(Self {
146 scan: Some(value.scan.try_into()?),
147 filter: Some(value.filter.try_into()?),
148 knn: Some(value.knn.try_into()?),
149 projection: Some(value.proj.into()),
150 })
151 }
152}
153
154/// A search payload for the hybrid search API.
155///
156/// Combines filtering, ranking, pagination, and field selection into a single query.
157/// Use the builder methods to construct complex searches with a fluent interface.
158///
159/// # Examples
160///
161/// ## Basic vector search
162///
163/// ```
164/// use chroma_types::plan::SearchPayload;
165/// use chroma_types::operator::{RankExpr, QueryVector, Key};
166///
167/// let search = SearchPayload::default()
168/// .rank(RankExpr::Knn {
169/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
170/// key: Key::Embedding,
171/// limit: 100,
172/// default: None,
173/// return_rank: false,
174/// })
175/// .limit(Some(10), 0)
176/// .select([Key::Document, Key::Score]);
177/// ```
178///
179/// ## Filtered search
180///
181/// ```
182/// use chroma_types::plan::SearchPayload;
183/// use chroma_types::operator::{RankExpr, QueryVector, Key};
184///
185/// let search = SearchPayload::default()
186/// .r#where(
187/// Key::field("status").eq("published")
188/// & Key::field("year").gte(2020)
189/// )
190/// .rank(RankExpr::Knn {
191/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
192/// key: Key::Embedding,
193/// limit: 200,
194/// default: None,
195/// return_rank: false,
196/// })
197/// .limit(Some(5), 0)
198/// .select([Key::Document, Key::Score, Key::field("title")]);
199/// ```
200///
201/// ## Hybrid search with custom ranking
202///
203/// ```
204/// use chroma_types::plan::SearchPayload;
205/// use chroma_types::operator::{RankExpr, QueryVector, Key};
206///
207/// let dense = RankExpr::Knn {
208/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
209/// key: Key::Embedding,
210/// limit: 200,
211/// default: None,
212/// return_rank: false,
213/// };
214///
215/// let sparse = RankExpr::Knn {
216/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
217/// key: Key::field("sparse_embedding"),
218/// limit: 200,
219/// default: None,
220/// return_rank: false,
221/// };
222///
223/// let search = SearchPayload::default()
224/// .rank(dense * 0.7 + sparse * 0.3)
225/// .limit(Some(10), 0)
226/// .select([Key::Document, Key::Score]);
227/// ```
228#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
229pub struct SearchPayload {
230 #[serde(default)]
231 pub filter: Filter,
232 #[serde(default)]
233 #[validate(custom(function = "validate_rank"))]
234 pub rank: Rank,
235 #[serde(default)]
236 pub limit: Limit,
237 #[serde(default)]
238 pub select: Select,
239}
240
241impl SearchPayload {
242 /// Sets pagination parameters.
243 ///
244 /// # Arguments
245 ///
246 /// * `limit` - Maximum number of results to return (None = no limit)
247 /// * `offset` - Number of results to skip
248 ///
249 /// # Examples
250 ///
251 /// ```
252 /// use chroma_types::plan::SearchPayload;
253 ///
254 /// // First page: results 0-9
255 /// let search = SearchPayload::default().limit(Some(10), 0);
256 ///
257 /// // Second page: results 10-19
258 /// let search = SearchPayload::default().limit(Some(10), 10);
259 /// ```
260 pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
261 self.limit.limit = limit;
262 self.limit.offset = offset;
263 self
264 }
265
266 /// Sets the ranking expression for scoring and ordering results.
267 ///
268 /// # Arguments
269 ///
270 /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
271 ///
272 /// # Examples
273 ///
274 /// ## Simple KNN ranking
275 ///
276 /// ```
277 /// use chroma_types::plan::SearchPayload;
278 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
279 ///
280 /// let search = SearchPayload::default()
281 /// .rank(RankExpr::Knn {
282 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
283 /// key: Key::Embedding,
284 /// limit: 100,
285 /// default: None,
286 /// return_rank: false,
287 /// });
288 /// ```
289 ///
290 /// ## Weighted combination
291 ///
292 /// ```
293 /// use chroma_types::plan::SearchPayload;
294 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
295 ///
296 /// let knn1 = RankExpr::Knn {
297 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
298 /// key: Key::Embedding,
299 /// limit: 100,
300 /// default: None,
301 /// return_rank: false,
302 /// };
303 ///
304 /// let knn2 = RankExpr::Knn {
305 /// query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
306 /// key: Key::field("other_embedding"),
307 /// limit: 100,
308 /// default: None,
309 /// return_rank: false,
310 /// };
311 ///
312 /// let search = SearchPayload::default()
313 /// .rank(knn1 * 0.8 + knn2 * 0.2);
314 /// ```
315 pub fn rank(mut self, expr: RankExpr) -> Self {
316 self.rank.expr = Some(expr);
317 self
318 }
319
320 /// Selects which fields to include in the results.
321 ///
322 /// # Arguments
323 ///
324 /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
325 ///
326 /// # Examples
327 ///
328 /// ```
329 /// use chroma_types::plan::SearchPayload;
330 /// use chroma_types::operator::Key;
331 ///
332 /// // Select predefined fields
333 /// let search = SearchPayload::default()
334 /// .select([Key::Document, Key::Score]);
335 ///
336 /// // Select metadata fields
337 /// let search = SearchPayload::default()
338 /// .select([Key::field("title"), Key::field("author")]);
339 ///
340 /// // Mix predefined and custom fields
341 /// let search = SearchPayload::default()
342 /// .select([Key::Document, Key::Score, Key::field("title")]);
343 /// ```
344 pub fn select<I, T>(mut self, keys: I) -> Self
345 where
346 I: IntoIterator<Item = T>,
347 T: Into<Key>,
348 {
349 self.select.keys = keys.into_iter().map(Into::into).collect();
350 self
351 }
352
353 /// Sets the filter expression for narrowing results.
354 ///
355 /// # Arguments
356 ///
357 /// * `where` - A Where expression for filtering
358 ///
359 /// # Examples
360 ///
361 /// ## Simple equality filter
362 ///
363 /// ```
364 /// use chroma_types::plan::SearchPayload;
365 /// use chroma_types::operator::Key;
366 ///
367 /// let search = SearchPayload::default()
368 /// .r#where(Key::field("status").eq("published"));
369 /// ```
370 ///
371 /// ## Numeric comparisons
372 ///
373 /// ```
374 /// use chroma_types::plan::SearchPayload;
375 /// use chroma_types::operator::Key;
376 ///
377 /// let search = SearchPayload::default()
378 /// .r#where(Key::field("year").gte(2020));
379 /// ```
380 ///
381 /// ## Combining filters
382 ///
383 /// ```
384 /// use chroma_types::plan::SearchPayload;
385 /// use chroma_types::operator::Key;
386 ///
387 /// let search = SearchPayload::default()
388 /// .r#where(
389 /// Key::field("status").eq("published")
390 /// & Key::field("year").gte(2020)
391 /// & Key::field("category").is_in(vec!["tech", "science"])
392 /// );
393 /// ```
394 ///
395 /// ## Document content filtering
396 ///
397 /// ```
398 /// use chroma_types::plan::SearchPayload;
399 /// use chroma_types::operator::Key;
400 ///
401 /// let search = SearchPayload::default()
402 /// .r#where(Key::Document.contains("machine learning"));
403 /// ```
404 pub fn r#where(mut self, r#where: Where) -> Self {
405 self.filter.where_clause = Some(r#where);
406 self
407 }
408}
409
410#[cfg(feature = "utoipa")]
411impl PartialSchema for SearchPayload {
412 fn schema() -> RefOr<Schema> {
413 RefOr::T(Schema::Object(
414 ObjectBuilder::new()
415 .schema_type(SchemaType::Type(Type::Object))
416 .property(
417 "filter",
418 ObjectBuilder::new()
419 .schema_type(SchemaType::Type(Type::Object))
420 .property(
421 "query_ids",
422 ArrayBuilder::new()
423 .items(Object::with_type(SchemaType::Type(Type::String))),
424 )
425 .property(
426 "where_clause",
427 Object::with_type(SchemaType::Type(Type::Object)),
428 ),
429 )
430 .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
431 .property(
432 "limit",
433 ObjectBuilder::new()
434 .schema_type(SchemaType::Type(Type::Object))
435 .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
436 .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
437 )
438 .property(
439 "select",
440 ObjectBuilder::new()
441 .schema_type(SchemaType::Type(Type::Object))
442 .property(
443 "keys",
444 ArrayBuilder::new()
445 .items(Object::with_type(SchemaType::Type(Type::String))),
446 ),
447 )
448 .build(),
449 ))
450 }
451}
452
453#[cfg(feature = "utoipa")]
454impl utoipa::ToSchema for SearchPayload {}
455
456impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
457 type Error = QueryConversionError;
458
459 fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
460 Ok(Self {
461 filter: value
462 .filter
463 .ok_or(QueryConversionError::field("filter"))?
464 .try_into()?,
465 rank: value
466 .rank
467 .ok_or(QueryConversionError::field("rank"))?
468 .try_into()?,
469 limit: value
470 .limit
471 .ok_or(QueryConversionError::field("limit"))?
472 .into(),
473 select: value
474 .select
475 .ok_or(QueryConversionError::field("select"))?
476 .try_into()?,
477 })
478 }
479}
480
481impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
482 type Error = QueryConversionError;
483
484 fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
485 Ok(Self {
486 filter: Some(value.filter.try_into()?),
487 rank: Some(value.rank.try_into()?),
488 limit: Some(value.limit.into()),
489 select: Some(value.select.try_into()?),
490 })
491 }
492}
493
494#[derive(Clone, Debug)]
495pub struct Search {
496 pub scan: Scan,
497 pub payloads: Vec<SearchPayload>,
498}
499
500impl TryFrom<chroma_proto::SearchPlan> for Search {
501 type Error = QueryConversionError;
502
503 fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
504 Ok(Self {
505 scan: value
506 .scan
507 .ok_or(QueryConversionError::field("scan"))?
508 .try_into()?,
509 payloads: value
510 .payloads
511 .into_iter()
512 .map(TryInto::try_into)
513 .collect::<Result<Vec<_>, _>>()?,
514 })
515 }
516}
517
518impl TryFrom<Search> for chroma_proto::SearchPlan {
519 type Error = QueryConversionError;
520
521 fn try_from(value: Search) -> Result<Self, Self::Error> {
522 Ok(Self {
523 scan: Some(value.scan.try_into()?),
524 payloads: value
525 .payloads
526 .into_iter()
527 .map(TryInto::try_into)
528 .collect::<Result<Vec<_>, _>>()?,
529 })
530 }
531}