chroma_types/execution/plan.rs
1use super::{
2 error::QueryConversionError,
3 operator::{
4 Filter, GroupBy, KnnBatch, KnnProjection, Limit, Projection, Rank, Scan, ScanToProtoError,
5 Select,
6 },
7};
8use crate::{
9 chroma_proto,
10 operator::{Key, RankExpr},
11 validators::{validate_group_by, validate_rank, validate_search_payload},
12 Where,
13};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16#[cfg(feature = "utoipa")]
17use utoipa::{
18 openapi::{
19 schema::{Schema, SchemaType},
20 ArrayBuilder, Object, ObjectBuilder, RefOr, Type,
21 },
22 PartialSchema,
23};
24use validator::Validate;
25
26#[derive(Error, Debug)]
27pub enum PlanToProtoError {
28 #[error("Failed to convert scan to proto: {0}")]
29 Scan(#[from] ScanToProtoError),
30}
31
32/// The `Count` plan shoud ouutput the total number of records in the collection
33#[derive(Clone)]
34pub struct Count {
35 pub scan: Scan,
36 pub read_level: ReadLevel,
37}
38
39impl TryFrom<chroma_proto::CountPlan> for Count {
40 type Error = QueryConversionError;
41
42 fn try_from(value: chroma_proto::CountPlan) -> Result<Self, Self::Error> {
43 let read_level = value.read_level().into();
44 Ok(Self {
45 scan: value
46 .scan
47 .ok_or(QueryConversionError::field("scan"))?
48 .try_into()?,
49 read_level,
50 })
51 }
52}
53
54impl TryFrom<Count> for chroma_proto::CountPlan {
55 type Error = PlanToProtoError;
56
57 fn try_from(value: Count) -> Result<Self, Self::Error> {
58 Ok(Self {
59 scan: Some(value.scan.try_into()?),
60 read_level: chroma_proto::ReadLevel::from(value.read_level).into(),
61 })
62 }
63}
64
65/// The `Get` plan should output records matching the specified filter and limit in the collection
66#[derive(Clone, Debug)]
67pub struct Get {
68 pub scan: Scan,
69 pub filter: Filter,
70 pub limit: Limit,
71 pub proj: Projection,
72}
73
74impl TryFrom<chroma_proto::GetPlan> for Get {
75 type Error = QueryConversionError;
76
77 fn try_from(value: chroma_proto::GetPlan) -> Result<Self, Self::Error> {
78 Ok(Self {
79 scan: value
80 .scan
81 .ok_or(QueryConversionError::field("scan"))?
82 .try_into()?,
83 filter: value
84 .filter
85 .ok_or(QueryConversionError::field("filter"))?
86 .try_into()?,
87 limit: value
88 .limit
89 .ok_or(QueryConversionError::field("limit"))?
90 .into(),
91 proj: value
92 .projection
93 .ok_or(QueryConversionError::field("projection"))?
94 .into(),
95 })
96 }
97}
98
99impl TryFrom<Get> for chroma_proto::GetPlan {
100 type Error = QueryConversionError;
101
102 fn try_from(value: Get) -> Result<Self, Self::Error> {
103 Ok(Self {
104 scan: Some(value.scan.try_into()?),
105 filter: Some(value.filter.try_into()?),
106 limit: Some(value.limit.into()),
107 projection: Some(value.proj.into()),
108 })
109 }
110}
111
112/// The `Knn` plan should output records nearest to the target embeddings that matches the specified filter
113#[derive(Clone, Debug)]
114pub struct Knn {
115 pub scan: Scan,
116 pub filter: Filter,
117 pub knn: KnnBatch,
118 pub proj: KnnProjection,
119}
120
121impl TryFrom<chroma_proto::KnnPlan> for Knn {
122 type Error = QueryConversionError;
123
124 fn try_from(value: chroma_proto::KnnPlan) -> Result<Self, Self::Error> {
125 Ok(Self {
126 scan: value
127 .scan
128 .ok_or(QueryConversionError::field("scan"))?
129 .try_into()?,
130 filter: value
131 .filter
132 .ok_or(QueryConversionError::field("filter"))?
133 .try_into()?,
134 knn: value
135 .knn
136 .ok_or(QueryConversionError::field("knn"))?
137 .try_into()?,
138 proj: value
139 .projection
140 .ok_or(QueryConversionError::field("projection"))?
141 .try_into()?,
142 })
143 }
144}
145
146impl TryFrom<Knn> for chroma_proto::KnnPlan {
147 type Error = QueryConversionError;
148
149 fn try_from(value: Knn) -> Result<Self, Self::Error> {
150 Ok(Self {
151 scan: Some(value.scan.try_into()?),
152 filter: Some(value.filter.try_into()?),
153 knn: Some(value.knn.try_into()?),
154 projection: Some(value.proj.into()),
155 })
156 }
157}
158
159/// A search payload for the hybrid search API.
160///
161/// Combines filtering, ranking, pagination, and field selection into a single query.
162/// Use the builder methods to construct complex searches with a fluent interface.
163///
164/// # Examples
165///
166/// ## Basic vector search
167///
168/// ```
169/// use chroma_types::plan::SearchPayload;
170/// use chroma_types::operator::{RankExpr, QueryVector, Key};
171///
172/// let search = SearchPayload::default()
173/// .rank(RankExpr::Knn {
174/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
175/// key: Key::Embedding,
176/// limit: 100,
177/// default: None,
178/// return_rank: false,
179/// })
180/// .limit(Some(10), 0)
181/// .select([Key::Document, Key::Score]);
182/// ```
183///
184/// ## Filtered search
185///
186/// ```
187/// use chroma_types::plan::SearchPayload;
188/// use chroma_types::operator::{RankExpr, QueryVector, Key};
189///
190/// let search = SearchPayload::default()
191/// .r#where(
192/// Key::field("status").eq("published")
193/// & Key::field("year").gte(2020)
194/// )
195/// .rank(RankExpr::Knn {
196/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
197/// key: Key::Embedding,
198/// limit: 200,
199/// default: None,
200/// return_rank: false,
201/// })
202/// .limit(Some(5), 0)
203/// .select([Key::Document, Key::Score, Key::field("title")]);
204/// ```
205///
206/// ## Hybrid search with custom ranking
207///
208/// ```
209/// use chroma_types::plan::SearchPayload;
210/// use chroma_types::operator::{RankExpr, QueryVector, Key};
211///
212/// let dense = RankExpr::Knn {
213/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
214/// key: Key::Embedding,
215/// limit: 200,
216/// default: None,
217/// return_rank: false,
218/// };
219///
220/// let sparse = RankExpr::Knn {
221/// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
222/// key: Key::field("sparse_embedding"),
223/// limit: 200,
224/// default: None,
225/// return_rank: false,
226/// };
227///
228/// let search = SearchPayload::default()
229/// .rank(dense * 0.7 + sparse * 0.3)
230/// .limit(Some(10), 0)
231/// .select([Key::Document, Key::Score]);
232/// ```
233#[derive(Clone, Debug, Default, Deserialize, Serialize, Validate)]
234#[validate(schema(function = "validate_search_payload"))]
235pub struct SearchPayload {
236 #[serde(default)]
237 pub filter: Filter,
238 #[serde(default)]
239 #[validate(custom(function = "validate_rank"))]
240 pub rank: Rank,
241 #[serde(default)]
242 #[validate(custom(function = "validate_group_by"))]
243 pub group_by: GroupBy,
244 #[serde(default)]
245 pub limit: Limit,
246 #[serde(default)]
247 pub select: Select,
248}
249
250impl SearchPayload {
251 /// Sets pagination parameters.
252 ///
253 /// # Arguments
254 ///
255 /// * `limit` - Maximum number of results to return (None = no limit)
256 /// * `offset` - Number of results to skip
257 ///
258 /// # Examples
259 ///
260 /// ```
261 /// use chroma_types::plan::SearchPayload;
262 ///
263 /// // First page: results 0-9
264 /// let search = SearchPayload::default().limit(Some(10), 0);
265 ///
266 /// // Second page: results 10-19
267 /// let search = SearchPayload::default().limit(Some(10), 10);
268 /// ```
269 pub fn limit(mut self, limit: Option<u32>, offset: u32) -> Self {
270 self.limit.limit = limit;
271 self.limit.offset = offset;
272 self
273 }
274
275 /// Sets the ranking expression for scoring and ordering results.
276 ///
277 /// # Arguments
278 ///
279 /// * `expr` - A ranking expression (typically Knn or a combination of expressions)
280 ///
281 /// # Examples
282 ///
283 /// ## Simple KNN ranking
284 ///
285 /// ```
286 /// use chroma_types::plan::SearchPayload;
287 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
288 ///
289 /// let search = SearchPayload::default()
290 /// .rank(RankExpr::Knn {
291 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
292 /// key: Key::Embedding,
293 /// limit: 100,
294 /// default: None,
295 /// return_rank: false,
296 /// });
297 /// ```
298 ///
299 /// ## Weighted combination
300 ///
301 /// ```
302 /// use chroma_types::plan::SearchPayload;
303 /// use chroma_types::operator::{RankExpr, QueryVector, Key};
304 ///
305 /// let knn1 = RankExpr::Knn {
306 /// query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
307 /// key: Key::Embedding,
308 /// limit: 100,
309 /// default: None,
310 /// return_rank: false,
311 /// };
312 ///
313 /// let knn2 = RankExpr::Knn {
314 /// query: QueryVector::Dense(vec![0.2, 0.3, 0.4]),
315 /// key: Key::field("other_embedding"),
316 /// limit: 100,
317 /// default: None,
318 /// return_rank: false,
319 /// };
320 ///
321 /// let search = SearchPayload::default()
322 /// .rank(knn1 * 0.8 + knn2 * 0.2);
323 /// ```
324 pub fn rank(mut self, expr: RankExpr) -> Self {
325 self.rank.expr = Some(expr);
326 self
327 }
328
329 /// Selects which fields to include in the results.
330 ///
331 /// # Arguments
332 ///
333 /// * `keys` - Fields to include (e.g., Document, Score, Metadata, or custom fields)
334 ///
335 /// # Examples
336 ///
337 /// ```
338 /// use chroma_types::plan::SearchPayload;
339 /// use chroma_types::operator::Key;
340 ///
341 /// // Select predefined fields
342 /// let search = SearchPayload::default()
343 /// .select([Key::Document, Key::Score]);
344 ///
345 /// // Select metadata fields
346 /// let search = SearchPayload::default()
347 /// .select([Key::field("title"), Key::field("author")]);
348 ///
349 /// // Mix predefined and custom fields
350 /// let search = SearchPayload::default()
351 /// .select([Key::Document, Key::Score, Key::field("title")]);
352 /// ```
353 pub fn select<I, T>(mut self, keys: I) -> Self
354 where
355 I: IntoIterator<Item = T>,
356 T: Into<Key>,
357 {
358 self.select.keys = keys.into_iter().map(Into::into).collect();
359 self
360 }
361
362 /// Sets the filter expression for narrowing results.
363 ///
364 /// # Arguments
365 ///
366 /// * `where` - A Where expression for filtering
367 ///
368 /// # Examples
369 ///
370 /// ## Simple equality filter
371 ///
372 /// ```
373 /// use chroma_types::plan::SearchPayload;
374 /// use chroma_types::operator::Key;
375 ///
376 /// let search = SearchPayload::default()
377 /// .r#where(Key::field("status").eq("published"));
378 /// ```
379 ///
380 /// ## Numeric comparisons
381 ///
382 /// ```
383 /// use chroma_types::plan::SearchPayload;
384 /// use chroma_types::operator::Key;
385 ///
386 /// let search = SearchPayload::default()
387 /// .r#where(Key::field("year").gte(2020));
388 /// ```
389 ///
390 /// ## Combining filters
391 ///
392 /// ```
393 /// use chroma_types::plan::SearchPayload;
394 /// use chroma_types::operator::Key;
395 ///
396 /// let search = SearchPayload::default()
397 /// .r#where(
398 /// Key::field("status").eq("published")
399 /// & Key::field("year").gte(2020)
400 /// & Key::field("category").is_in(vec!["tech", "science"])
401 /// );
402 /// ```
403 ///
404 /// ## Document content filtering
405 ///
406 /// ```
407 /// use chroma_types::plan::SearchPayload;
408 /// use chroma_types::operator::Key;
409 ///
410 /// let search = SearchPayload::default()
411 /// .r#where(Key::Document.contains("machine learning"));
412 /// ```
413 pub fn r#where(mut self, r#where: Where) -> Self {
414 self.filter.where_clause = Some(r#where);
415 self
416 }
417
418 /// Groups results by metadata keys and aggregates within each group.
419 ///
420 /// # Arguments
421 ///
422 /// * `group_by` - GroupBy configuration with keys and aggregation
423 ///
424 /// # Examples
425 ///
426 /// ```
427 /// use chroma_types::plan::SearchPayload;
428 /// use chroma_types::operator::{GroupBy, Aggregate, Key};
429 ///
430 /// // Top 3 best documents per category
431 /// let search = SearchPayload::default()
432 /// .group_by(GroupBy {
433 /// keys: vec![Key::field("category")],
434 /// aggregate: Some(Aggregate::MinK {
435 /// keys: vec![Key::Score],
436 /// k: 3,
437 /// }),
438 /// });
439 /// ```
440 pub fn group_by(mut self, group_by: GroupBy) -> Self {
441 self.group_by = group_by;
442 self
443 }
444}
445
446#[cfg(feature = "utoipa")]
447impl PartialSchema for SearchPayload {
448 fn schema() -> RefOr<Schema> {
449 RefOr::T(Schema::Object(
450 ObjectBuilder::new()
451 .schema_type(SchemaType::Type(Type::Object))
452 .property(
453 "filter",
454 ObjectBuilder::new()
455 .schema_type(SchemaType::Type(Type::Object))
456 .property(
457 "query_ids",
458 ArrayBuilder::new()
459 .items(Object::with_type(SchemaType::Type(Type::String))),
460 )
461 .property(
462 "where_clause",
463 Object::with_type(SchemaType::Type(Type::Object)),
464 ),
465 )
466 .property("rank", Object::with_type(SchemaType::Type(Type::Object)))
467 .property(
468 "group_by",
469 ObjectBuilder::new()
470 .schema_type(SchemaType::Type(Type::Object))
471 .property(
472 "keys",
473 ArrayBuilder::new()
474 .items(Object::with_type(SchemaType::Type(Type::String))),
475 )
476 .property(
477 "aggregate",
478 Object::with_type(SchemaType::Type(Type::Object)),
479 ),
480 )
481 .property(
482 "limit",
483 ObjectBuilder::new()
484 .schema_type(SchemaType::Type(Type::Object))
485 .property("offset", Object::with_type(SchemaType::Type(Type::Integer)))
486 .property("limit", Object::with_type(SchemaType::Type(Type::Integer))),
487 )
488 .property(
489 "select",
490 ObjectBuilder::new()
491 .schema_type(SchemaType::Type(Type::Object))
492 .property(
493 "keys",
494 ArrayBuilder::new()
495 .items(Object::with_type(SchemaType::Type(Type::String))),
496 ),
497 )
498 .build(),
499 ))
500 }
501}
502
503#[cfg(feature = "utoipa")]
504impl utoipa::ToSchema for SearchPayload {}
505
506impl TryFrom<chroma_proto::SearchPayload> for SearchPayload {
507 type Error = QueryConversionError;
508
509 fn try_from(value: chroma_proto::SearchPayload) -> Result<Self, Self::Error> {
510 Ok(Self {
511 filter: value
512 .filter
513 .ok_or(QueryConversionError::field("filter"))?
514 .try_into()?,
515 rank: value
516 .rank
517 .ok_or(QueryConversionError::field("rank"))?
518 .try_into()?,
519 group_by: value
520 .group_by
521 .map(TryInto::try_into)
522 .transpose()?
523 .unwrap_or_default(),
524 limit: value
525 .limit
526 .ok_or(QueryConversionError::field("limit"))?
527 .into(),
528 select: value
529 .select
530 .ok_or(QueryConversionError::field("select"))?
531 .try_into()?,
532 })
533 }
534}
535
536impl TryFrom<SearchPayload> for chroma_proto::SearchPayload {
537 type Error = QueryConversionError;
538
539 fn try_from(value: SearchPayload) -> Result<Self, Self::Error> {
540 Ok(Self {
541 filter: Some(value.filter.try_into()?),
542 rank: Some(value.rank.try_into()?),
543 group_by: Some(value.group_by.try_into()?),
544 limit: Some(value.limit.into()),
545 select: Some(value.select.try_into()?),
546 })
547 }
548}
549
550#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
551#[serde(rename_all = "snake_case")]
552#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
553pub enum ReadLevel {
554 /// Read from both the index and the write-ahead log (default).
555 /// Provides full consistency with all committed writes visible.
556 #[default]
557 IndexAndWal,
558 /// Read only from the index, skipping the write-ahead log.
559 /// Provides eventual consistency - recent uncommitted writes may not be visible.
560 IndexOnly,
561}
562
563impl From<chroma_proto::ReadLevel> for ReadLevel {
564 fn from(value: chroma_proto::ReadLevel) -> Self {
565 match value {
566 chroma_proto::ReadLevel::IndexAndWal => ReadLevel::IndexAndWal,
567 chroma_proto::ReadLevel::IndexOnly => ReadLevel::IndexOnly,
568 }
569 }
570}
571
572impl From<ReadLevel> for chroma_proto::ReadLevel {
573 fn from(value: ReadLevel) -> Self {
574 match value {
575 ReadLevel::IndexAndWal => chroma_proto::ReadLevel::IndexAndWal,
576 ReadLevel::IndexOnly => chroma_proto::ReadLevel::IndexOnly,
577 }
578 }
579}
580
581#[derive(Clone, Debug)]
582pub struct Search {
583 pub scan: Scan,
584 pub payloads: Vec<SearchPayload>,
585 pub read_level: ReadLevel,
586}
587
588impl TryFrom<chroma_proto::SearchPlan> for Search {
589 type Error = QueryConversionError;
590
591 fn try_from(value: chroma_proto::SearchPlan) -> Result<Self, Self::Error> {
592 let read_level = value.read_level().into();
593 Ok(Self {
594 scan: value
595 .scan
596 .ok_or(QueryConversionError::field("scan"))?
597 .try_into()?,
598 payloads: value
599 .payloads
600 .into_iter()
601 .map(TryInto::try_into)
602 .collect::<Result<Vec<_>, _>>()?,
603 read_level,
604 })
605 }
606}
607
608impl TryFrom<Search> for chroma_proto::SearchPlan {
609 type Error = QueryConversionError;
610
611 fn try_from(value: Search) -> Result<Self, Self::Error> {
612 Ok(Self {
613 scan: Some(value.scan.try_into()?),
614 payloads: value
615 .payloads
616 .into_iter()
617 .map(TryInto::try_into)
618 .collect::<Result<Vec<_>, _>>()?,
619 read_level: chroma_proto::ReadLevel::from(value.read_level).into(),
620 })
621 }
622}