chroma_types/execution/plan.rs
1use super::{
2 error::QueryConversionError,
3 operator::{
4 Filter, GroupBy, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError,
5 Select,
6 },
7};
8use crate::{
9 chroma_proto,
10 operator::{Key, RankExpr},
11 validators::{validate_group_by, validate_rank, validate_search_payload},
12 Where,
13};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16#[cfg(feature = "utoipa")]
17use utoipa::{
18 openapi::{
19 schema::{Schema, SchemaType},
20 ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
21 },
22 PartialSchema,
23};
24use validator::Validate;
25
26#[derive(Error, Debug)]
27pub enum PlanToProtoError {
28 #[error("Failed to convert scan to proto: {0}")]
29 Scan(#[from] ScanToProtoError),
30}
31
32/// The `Count` plan shoud ouutput the total number of records in the collection
33#[derive(Clone)]
34pub struct Count {
35 pub scan: Scan,
36}
37
38impl TryFrom<chroma_proto::CountPlan> for Count {
39 type Error = QueryConversionError;
40
41 fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
42 Ok(Self {
43 scan: value
44 .scan
45 .ok_or(QueryConversionError::field("scan"))?
46 .try_into()?,
47 })
48 }
49}
50
51impl TryFrom<Count> for chroma_proto::CountPlan {
52 type Error = PlanToProtoError;
53
54 fn try_from(value: Count) -> Result<Self, Self::Error> {
55 Ok(Self {
56 scan: Some(value.scan.try_into()?),
57 })
58 }
59}
60
61/// The `Get` plan should output records matching the specified filter and limit in the collection
62#[derive(Clone, Debug)]
63pub struct Get {
64 pub scan: Scan,
65 pub filter: Filter,
66 pub limit: Limit,
67 pub proj: Projection,
68}
69
70impl TryFrom<chroma_proto::GetPlan> for Get {
71 type Error = QueryConversionError;
72
73 fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
74 Ok(Self {
75 scan: value
76 .scan
77 .ok_or(QueryConversionError::field("scan"))?
78 .try_into()?,
79 filter: value
80 .filter
81 .ok_or(QueryConversionError::field("filter"))?
82 .try_into()?,
83 limit: value
84 .limit
85 .ok_or(QueryConversionError::field("limit"))?
86 .into(),
87 proj: value
88 .projection
89 .ok_or(QueryConversionError::field("projection"))?
90 .into(),
91 })
92 }
93}
94
95impl TryFrom<Get> for chroma_proto::GetPlan {
96 type Error = QueryConversionError;
97
98 fn try_from(value: Get) -> Result<Self, Self::Error> {
99 Ok(Self {
100 scan: Some(value.scan.try_into()?),
101 filter: Some(value.filter.try_into()?),
102 limit: Some(value.limit.into()),
103 projection: Some(value.proj.into()),
104 })
105 }
106}
107
108/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
109#[derive(Clone, Debug)]
110pub struct Knn {
111 pub scan: Scan,
112 pub filter: Filter,
113 pub knn: KnnBatch,
114 pub proj: KnnProjection,
115}
116
117impl TryFrom<chroma_proto::KnnPlan> for Knn {
118 type Error = QueryConversionError;
119
120 fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
121 Ok(Self {
122 scan: value
123 .scan
124 .ok_or(QueryConversionError::field("scan"))?
125 .try_into()?,
126 filter: value
127 .filter
128 .ok_or(QueryConversionError::field("filter"))?
129 .try_into()?,
130 knn: value
131 .knn
132 .ok_or(QueryConversionError::field("knn"))?
133 .try_into()?,
134 proj: value
135 .projection
136 .ok_or(QueryConversionError::field("projection"))?
137 .try_into()?,
138 })
139 }
140}
141
142impl TryFrom<Knn> for chroma_proto::KnnPlan {
143 type Error = QueryConversionError;
144
145 fn try_from(value: Knn) -> Result<Self, Self::Error> {
146 Ok(Self {
147 scan: Some(value.scan.try_into()?),
148 filter: Some(value.filter.try_into()?),
149 knn: Some(value.knn.try_into()?),
150 projection: Some(value.proj.into()),
151 })
152 }
153}
154
155/// A search payload for the hybrid search API.
156///
157/// Combines filtering, ranking, pagination, and field selection into a single query.
158/// Use the builder methods to construct complex searches with a fluent interface.
159///
160/// # Examples
161///
162/// ## Basic vector search
163///
164/// ```
165/// use chroma_types::plan::SearchPayload;
166/// use chroma_types::operator::{RankExpr, QueryVector, Key};
167///
168/// let search = SearchPayload::default()
169/// .rank(RankExpr::Knn {
170/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
171/// key: Key::Embedding,
172/// limit: 100,
173/// default: None,
174/// return_rank: false,
175/// })
176/// .limit(Some(10), 0)
177/// .select([Key::Document, Key::Score]);
178/// ```
179///
180/// ## Filtered search
181///
182/// ```
183/// use chroma_types::plan::SearchPayload;
184/// use chroma_types::operator::{RankExpr, QueryVector, Key};
185///
186/// let search = SearchPayload::default()
187/// .r#where(
188/// Key::field("status").eq("published")
189/// & Key::field("year").gte(2020)
190/// )
191/// .rank(RankExpr::Knn {
192/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
193/// key: Key::Embedding,
194/// limit: 200,
195/// default: None,
196/// return_rank: false,
197/// })
198/// .limit(Some(5), 0)
199/// .select([Key::Document, Key::Score, Key::field("title")]);
200/// ```
201///
202/// ## Hybrid search with custom ranking
203///
204/// ```
205/// use chroma_types::plan::SearchPayload;
206/// use chroma_types::operator::{RankExpr, QueryVector, Key};
207///
208/// let dense = RankExpr::Knn {
209/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
210/// key: Key::Embedding,
211/// limit: 200,
212/// default: None,
213/// return_rank: false,
214/// };
215///
216/// let sparse = RankExpr::Knn {
217/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
218/// key: Key::field("sparse_embedding"),
219/// limit: 200,
220/// default: None,
221/// return_rank: false,
222/// };
223///
224/// let search = SearchPayload::default()
225/// .rank(dense * 0.7 + sparse * 0.3)
226/// .limit(Some(10), 0)
227/// .select([Key::Document, Key::Score]);
228/// ```
229#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
230#[validate(schema(function = "validate_search_payload"))]
231pub struct SearchPayload {
232 #[serde(default)]
233 pub filter: Filter,
234 #[serde(default)]
235 #[validate(custom(function = "validate_rank"))]
236 pub rank: Rank,
237 #[serde(default)]
238 #[validate(custom(function = "validate_group_by"))]
239 pub group_by: GroupBy,
240 #[serde(default)]
241 pub limit: Limit,
242 #[serde(default)]
243 pub select: Select,
244}
245
246impl SearchPayload {
247 /// Sets pagination parameters.
248 ///
249 /// # Arguments
250 ///
251 /// * `limit` - Maximum number of results to return (None = no limit)
252 /// * `offset` - Number of results to skip
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// use chroma_types::plan::SearchPayload;
258 ///
259 /// // First page: results 0-9
260 /// let search = SearchPayload::default().limit(Some(10), 0);
261 ///
262 /// // Second page: results 10-19
263 /// let search = SearchPayload::default().limit(Some(10), 10);
264 /// ```
265 pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
266 self.limit.limit = limit;
267 self.limit.offset = offset;
268 self
269 }
270
271 /// Sets the ranking expression for scoring and ordering results.
272 ///
273 /// # Arguments
274 ///
275 /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
276 ///
277 /// # Examples
278 ///
279 /// ## Simple KNN ranking
280 ///
281 /// ```
282 /// use chroma_types::plan::SearchPayload;
283 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
284 ///
285 /// let search = SearchPayload::default()
286 /// .rank(RankExpr::Knn {
287 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
288 /// key: Key::Embedding,
289 /// limit: 100,
290 /// default: None,
291 /// return_rank: false,
292 /// });
293 /// ```
294 ///
295 /// ## Weighted combination
296 ///
297 /// ```
298 /// use chroma_types::plan::SearchPayload;
299 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
300 ///
301 /// let knn1 = RankExpr::Knn {
302 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
303 /// key: Key::Embedding,
304 /// limit: 100,
305 /// default: None,
306 /// return_rank: false,
307 /// };
308 ///
309 /// let knn2 = RankExpr::Knn {
310 /// query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
311 /// key: Key::field("other_embedding"),
312 /// limit: 100,
313 /// default: None,
314 /// return_rank: false,
315 /// };
316 ///
317 /// let search = SearchPayload::default()
318 /// .rank(knn1 * 0.8 + knn2 * 0.2);
319 /// ```
320 pub fn rank(mut self, expr: RankExpr) -> Self {
321 self.rank.expr = Some(expr);
322 self
323 }
324
325 /// Selects which fields to include in the results.
326 ///
327 /// # Arguments
328 ///
329 /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
330 ///
331 /// # Examples
332 ///
333 /// ```
334 /// use chroma_types::plan::SearchPayload;
335 /// use chroma_types::operator::Key;
336 ///
337 /// // Select predefined fields
338 /// let search = SearchPayload::default()
339 /// .select([Key::Document, Key::Score]);
340 ///
341 /// // Select metadata fields
342 /// let search = SearchPayload::default()
343 /// .select([Key::field("title"), Key::field("author")]);
344 ///
345 /// // Mix predefined and custom fields
346 /// let search = SearchPayload::default()
347 /// .select([Key::Document, Key::Score, Key::field("title")]);
348 /// ```
349 pub fn select<I, T>(mut self, keys: I) -> Self
350 where
351 I: IntoIterator<Item = T>,
352 T: Into<Key>,
353 {
354 self.select.keys = keys.into_iter().map(Into::into).collect();
355 self
356 }
357
358 /// Sets the filter expression for narrowing results.
359 ///
360 /// # Arguments
361 ///
362 /// * `where` - A Where expression for filtering
363 ///
364 /// # Examples
365 ///
366 /// ## Simple equality filter
367 ///
368 /// ```
369 /// use chroma_types::plan::SearchPayload;
370 /// use chroma_types::operator::Key;
371 ///
372 /// let search = SearchPayload::default()
373 /// .r#where(Key::field("status").eq("published"));
374 /// ```
375 ///
376 /// ## Numeric comparisons
377 ///
378 /// ```
379 /// use chroma_types::plan::SearchPayload;
380 /// use chroma_types::operator::Key;
381 ///
382 /// let search = SearchPayload::default()
383 /// .r#where(Key::field("year").gte(2020));
384 /// ```
385 ///
386 /// ## Combining filters
387 ///
388 /// ```
389 /// use chroma_types::plan::SearchPayload;
390 /// use chroma_types::operator::Key;
391 ///
392 /// let search = SearchPayload::default()
393 /// .r#where(
394 /// Key::field("status").eq("published")
395 /// & Key::field("year").gte(2020)
396 /// & Key::field("category").is_in(vec!["tech", "science"])
397 /// );
398 /// ```
399 ///
400 /// ## Document content filtering
401 ///
402 /// ```
403 /// use chroma_types::plan::SearchPayload;
404 /// use chroma_types::operator::Key;
405 ///
406 /// let search = SearchPayload::default()
407 /// .r#where(Key::Document.contains("machine learning"));
408 /// ```
409 pub fn r#where(mut self, r#where: Where) -> Self {
410 self.filter.where_clause = Some(r#where);
411 self
412 }
413
414 /// Groups results by metadata keys and aggregates within each group.
415 ///
416 /// # Arguments
417 ///
418 /// * `group_by` - GroupBy configuration with keys and aggregation
419 ///
420 /// # Examples
421 ///
422 /// ```
423 /// use chroma_types::plan::SearchPayload;
424 /// use chroma_types::operator::{GroupBy, Aggregate, Key};
425 ///
426 /// // Top 3 best documents per category
427 /// let search = SearchPayload::default()
428 /// .group_by(GroupBy {
429 /// keys: vec![Key::field("category")],
430 /// aggregate: Some(Aggregate::MinK {
431 /// keys: vec![Key::Score],
432 /// k: 3,
433 /// }),
434 /// });
435 /// ```
436 pub fn group_by(mut self, group_by: GroupBy) -> Self {
437 self.group_by = group_by;
438 self
439 }
440}
441
442#[cfg(feature = "utoipa")]
443impl PartialSchema for SearchPayload {
444 fn schema() -> RefOr<Schema> {
445 RefOr::T(Schema::Object(
446 ObjectBuilder::new()
447 .schema_type(SchemaType::Type(Type::Object))
448 .property(
449 "filter",
450 ObjectBuilder::new()
451 .schema_type(SchemaType::Type(Type::Object))
452 .property(
453 "query_ids",
454 ArrayBuilder::new()
455 .items(Object::with_type(SchemaType::Type(Type::String))),
456 )
457 .property(
458 "where_clause",
459 Object::with_type(SchemaType::Type(Type::Object)),
460 ),
461 )
462 .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
463 .property(
464 "group_by",
465 ObjectBuilder::new()
466 .schema_type(SchemaType::Type(Type::Object))
467 .property(
468 "keys",
469 ArrayBuilder::new()
470 .items(Object::with_type(SchemaType::Type(Type::String))),
471 )
472 .property(
473 "aggregate",
474 Object::with_type(SchemaType::Type(Type::Object)),
475 ),
476 )
477 .property(
478 "limit",
479 ObjectBuilder::new()
480 .schema_type(SchemaType::Type(Type::Object))
481 .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
482 .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
483 )
484 .property(
485 "select",
486 ObjectBuilder::new()
487 .schema_type(SchemaType::Type(Type::Object))
488 .property(
489 "keys",
490 ArrayBuilder::new()
491 .items(Object::with_type(SchemaType::Type(Type::String))),
492 ),
493 )
494 .build(),
495 ))
496 }
497}
498
499#[cfg(feature = "utoipa")]
500impl utoipa::ToSchema for SearchPayload {}
501
502impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
503 type Error = QueryConversionError;
504
505 fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
506 Ok(Self {
507 filter: value
508 .filter
509 .ok_or(QueryConversionError::field("filter"))?
510 .try_into()?,
511 rank: value
512 .rank
513 .ok_or(QueryConversionError::field("rank"))?
514 .try_into()?,
515 group_by: value
516 .group_by
517 .map(TryInto::try_into)
518 .transpose()?
519 .unwrap_or_default(),
520 limit: value
521 .limit
522 .ok_or(QueryConversionError::field("limit"))?
523 .into(),
524 select: value
525 .select
526 .ok_or(QueryConversionError::field("select"))?
527 .try_into()?,
528 })
529 }
530}
531
532impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
533 type Error = QueryConversionError;
534
535 fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
536 Ok(Self {
537 filter: Some(value.filter.try_into()?),
538 rank: Some(value.rank.try_into()?),
539 group_by: Some(value.group_by.try_into()?),
540 limit: Some(value.limit.into()),
541 select: Some(value.select.try_into()?),
542 })
543 }
544}
545
546#[derive(Clone, Debug)]
547pub struct Search {
548 pub scan: Scan,
549 pub payloads: Vec<SearchPayload>,
550}
551
552impl TryFrom<chroma_proto::SearchPlan> for Search {
553 type Error = QueryConversionError;
554
555 fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
556 Ok(Self {
557 scan: value
558 .scan
559 .ok_or(QueryConversionError::field("scan"))?
560 .try_into()?,
561 payloads: value
562 .payloads
563 .into_iter()
564 .map(TryInto::try_into)
565 .collect::<Result<Vec<_>, _>>()?,
566 })
567 }
568}
569
570impl TryFrom<Search> for chroma_proto::SearchPlan {
571 type Error = QueryConversionError;
572
573 fn try_from(value: Search) -> Result<Self, Self::Error> {
574 Ok(Self {
575 scan: Some(value.scan.try_into()?),
576 payloads: value
577 .payloads
578 .into_iter()
579 .map(TryInto::try_into)
580 .collect::<Result<Vec<_>, _>>()?,
581 })
582 }
583}