1use crate::regex::{ChromaRegex, ChromaRegexError};
2use crate::{CompositeExpression, DocumentOperator, MetadataExpression, PrimitiveOperator, Where};
3use chroma_error::{ChromaError, ErrorCodes};
4use serde::Deserialize;
5use serde::Serialize;
6use serde_json::Value;
7use thiserror::Error;
8
9#[derive(Deserialize, Debug, Clone, Serialize)]
10#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
11pub struct RawWhereFields {
12 #[serde(default)]
13 r#where: Value,
14 #[serde(default)]
15 where_document: Value,
16}
17
18impl RawWhereFields {
19 pub fn new(r#where: Value, where_document: Value) -> Self {
20 Self {
21 r#where,
22 where_document,
23 }
24 }
25
26 pub fn from_json_str(
27 r#where: Option<&str>,
28 where_document: Option<&str>,
29 ) -> Result<Self, WhereValidationError> {
30 let r#where = r#where
31 .map(|r#where| {
32 serde_json::from_str(r#where).map_err(|_| WhereValidationError::WhereClause)
33 })
34 .transpose()?
35 .unwrap_or(Value::Null);
36
37 let where_document = where_document
38 .map(|where_document| {
39 serde_json::from_str(where_document)
40 .map_err(|_| WhereValidationError::WhereDocumentClause)
41 })
42 .transpose()?
43 .unwrap_or(Value::Null);
44
45 Ok(Self {
46 r#where,
47 where_document,
48 })
49 }
50}
51
52#[derive(Error, Debug)]
53pub enum WhereValidationError {
54 #[error(transparent)]
55 Regex(#[from] ChromaRegexError),
56 #[error("Invalid where clause")]
57 WhereClause,
58 #[error("Invalid where document clause")]
59 WhereDocumentClause,
60}
61
62impl ChromaError for WhereValidationError {
63 fn code(&self) -> chroma_error::ErrorCodes {
64 ErrorCodes::InvalidArgument
65 }
66}
67
68impl RawWhereFields {
69 pub fn parse(self) -> Result<Option<Where>, WhereValidationError> {
70 let mut where_clause = None;
71 if !self.r#where.is_null() {
72 let where_payload = &self.r#where;
73 where_clause = Some(parse_where(where_payload)?);
74 }
75 let mut where_document_clause = None;
76 if !self.where_document.is_null() {
77 let where_document_payload = &self.where_document;
78 where_document_clause = Some(parse_where_document(where_document_payload)?);
79 }
80 let combined_where = match where_clause {
81 Some(where_clause) => match where_document_clause {
82 Some(where_document_clause) => Some(Where::Composite(CompositeExpression {
83 operator: crate::BooleanOperator::And,
84 children: vec![where_clause, where_document_clause],
85 })),
86 None => Some(where_clause),
87 },
88 None => where_document_clause,
89 };
90
91 Ok(combined_where)
92 }
93}
94
95pub fn parse_where_document(json_payload: &Value) -> Result<Where, WhereValidationError> {
96 let where_doc_payload = json_payload
97 .as_object()
98 .ok_or(WhereValidationError::WhereDocumentClause)?;
99 if where_doc_payload.len() != 1 {
100 return Err(WhereValidationError::WhereDocumentClause);
101 }
102 let (key, value) = where_doc_payload.iter().next().unwrap();
103 if key == "$and" {
105 let logical_operator = crate::BooleanOperator::And;
106 let children = value
108 .as_array()
109 .ok_or(WhereValidationError::WhereDocumentClause)?;
110 let mut predicate_list = vec![];
111 for child in children {
113 predicate_list.push(parse_where_document(child)?);
114 }
115 return Ok(Where::Composite(CompositeExpression {
116 operator: logical_operator,
117 children: predicate_list,
118 }));
119 }
120 if key == "$or" {
121 let logical_operator = crate::BooleanOperator::Or;
122 let children = value
124 .as_array()
125 .ok_or(WhereValidationError::WhereDocumentClause)?;
126 let mut predicate_list = vec![];
127 for child in children {
129 predicate_list.push(parse_where_document(child)?);
130 }
131 return Ok(Where::Composite(CompositeExpression {
132 operator: logical_operator,
133 children: predicate_list,
134 }));
135 }
136 if !value.is_string() {
137 return Err(WhereValidationError::WhereDocumentClause);
138 }
139 let value_str = value.as_str().unwrap();
140 let operator_type = match key.as_str() {
141 "$contains" => DocumentOperator::Contains,
142 "$not_contains" => DocumentOperator::NotContains,
143 "$regex" => DocumentOperator::Regex,
144 "$not_regex" => DocumentOperator::NotRegex,
145 _ => return Err(WhereValidationError::WhereDocumentClause),
146 };
147 if matches!(
148 operator_type,
149 DocumentOperator::Regex | DocumentOperator::NotRegex
150 ) {
151 ChromaRegex::try_from(value_str.to_string())?;
152 }
153 Ok(Where::Document(crate::DocumentExpression {
154 operator: operator_type,
155 pattern: value_str.to_string(),
156 }))
157}
158
159pub fn parse_where(json_payload: &Value) -> Result<Where, WhereValidationError> {
160 let where_payload = json_payload
161 .as_object()
162 .ok_or(WhereValidationError::WhereClause)?;
163 if where_payload.len() != 1 {
164 return Err(WhereValidationError::WhereClause);
165 }
166 let (key, value) = where_payload.iter().next().unwrap();
167 if key == "$and" {
169 let logical_operator = crate::BooleanOperator::And;
170 let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
172 let mut predicate_list = vec![];
173 for child in children {
175 predicate_list.push(parse_where(child)?);
176 }
177 return Ok(Where::Composite(CompositeExpression {
178 operator: logical_operator,
179 children: predicate_list,
180 }));
181 }
182 if key == "$or" {
183 let logical_operator = crate::BooleanOperator::Or;
184 let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
186 let mut predicate_list = vec![];
187 for child in children {
189 predicate_list.push(parse_where(child)?);
190 }
191 return Ok(Where::Composite(CompositeExpression {
192 operator: logical_operator,
193 children: predicate_list,
194 }));
195 }
196 if value.is_string() {
199 return Ok(Where::Metadata(MetadataExpression {
200 key: key.clone(),
201 comparison: crate::MetadataComparison::Primitive(
202 crate::PrimitiveOperator::Equal,
203 crate::MetadataValue::Str(value.as_str().unwrap().to_string()),
204 ),
205 }));
206 }
207 if value.is_boolean() {
208 return Ok(Where::Metadata(MetadataExpression {
209 key: key.clone(),
210 comparison: crate::MetadataComparison::Primitive(
211 crate::PrimitiveOperator::Equal,
212 crate::MetadataValue::Bool(value.as_bool().unwrap()),
213 ),
214 }));
215 }
216 if value.is_f64() {
217 return Ok(Where::Metadata(MetadataExpression {
218 key: key.clone(),
219 comparison: crate::MetadataComparison::Primitive(
220 crate::PrimitiveOperator::Equal,
221 crate::MetadataValue::Float(value.as_f64().unwrap()),
222 ),
223 }));
224 }
225 if value.is_i64() {
226 return Ok(Where::Metadata(MetadataExpression {
227 key: key.clone(),
228 comparison: crate::MetadataComparison::Primitive(
229 crate::PrimitiveOperator::Equal,
230 crate::MetadataValue::Int(value.as_i64().unwrap()),
231 ),
232 }));
233 }
234 if value.is_object() {
235 let value_obj = value.as_object().unwrap();
236 if value_obj.len() != 1 {
238 return Err(WhereValidationError::WhereClause);
239 }
240 let (operator, operand) = value_obj.iter().next().unwrap();
241 if operand.is_array() {
242 let set_operator;
243 if operator == "$in" {
244 set_operator = crate::SetOperator::In;
245 } else if operator == "$nin" {
246 set_operator = crate::SetOperator::NotIn;
247 } else {
248 return Err(WhereValidationError::WhereClause);
249 }
250 let operand = operand.as_array().unwrap();
251 if operand.is_empty() {
252 return Err(WhereValidationError::WhereClause);
253 }
254 if operand[0].is_string() {
255 let operand_str = operand
256 .iter()
257 .map(|val| {
258 val.as_str()
259 .ok_or(WhereValidationError::WhereClause)
260 .map(|s| s.to_string())
261 })
262 .collect::<Result<Vec<String>, _>>()?;
263 return Ok(Where::Metadata(MetadataExpression {
264 key: key.clone(),
265 comparison: crate::MetadataComparison::Set(
266 set_operator,
267 crate::MetadataSetValue::Str(operand_str),
268 ),
269 }));
270 }
271 if operand[0].is_boolean() {
272 let operand_bool = operand
273 .iter()
274 .map(|val| val.as_bool().ok_or(WhereValidationError::WhereClause))
275 .collect::<Result<Vec<bool>, _>>()?;
276 return Ok(Where::Metadata(MetadataExpression {
277 key: key.clone(),
278 comparison: crate::MetadataComparison::Set(
279 set_operator,
280 crate::MetadataSetValue::Bool(operand_bool),
281 ),
282 }));
283 }
284 if operand[0].is_f64() {
285 let operand_f64 = operand
286 .iter()
287 .map(|val| val.as_f64().ok_or(WhereValidationError::WhereClause))
288 .collect::<Result<Vec<f64>, _>>()?;
289 return Ok(Where::Metadata(MetadataExpression {
290 key: key.clone(),
291 comparison: crate::MetadataComparison::Set(
292 set_operator,
293 crate::MetadataSetValue::Float(operand_f64),
294 ),
295 }));
296 }
297 if operand[0].is_i64() {
298 let operand_i64 = operand
299 .iter()
300 .map(|val| val.as_i64().ok_or(WhereValidationError::WhereClause))
301 .collect::<Result<Vec<i64>, _>>()?;
302 return Ok(Where::Metadata(MetadataExpression {
303 key: key.clone(),
304 comparison: crate::MetadataComparison::Set(
305 set_operator,
306 crate::MetadataSetValue::Int(operand_i64),
307 ),
308 }));
309 }
310 return Err(WhereValidationError::WhereClause);
311 }
312 if operand.is_string() {
313 let operand_str = operand.as_str().unwrap();
314 let document_operator_type = match operator.as_str() {
315 "$contains" => Some(DocumentOperator::Contains),
316 "$not_contains" => Some(DocumentOperator::NotContains),
317 "$regex" => Some(DocumentOperator::Regex),
318 "$not_regex" => Some(DocumentOperator::NotRegex),
319 _ => None,
320 };
321 if let Some(doc_op) = document_operator_type {
322 if matches!(doc_op, DocumentOperator::Regex | DocumentOperator::NotRegex) {
323 ChromaRegex::try_from(operand_str.to_string())?;
324 }
325 return Ok(Where::Document(crate::DocumentExpression {
326 operator: doc_op,
327 pattern: operand_str.to_string(),
328 }));
329 }
330 let operator_type;
331 if operator == "$eq" {
332 operator_type = PrimitiveOperator::Equal;
333 } else if operator == "$ne" {
334 operator_type = PrimitiveOperator::NotEqual;
335 } else {
336 return Err(WhereValidationError::WhereClause);
337 }
338 return Ok(Where::Metadata(MetadataExpression {
339 key: key.clone(),
340 comparison: crate::MetadataComparison::Primitive(
341 operator_type,
342 crate::MetadataValue::Str(operand_str.to_string()),
343 ),
344 }));
345 }
346 if operand.is_boolean() {
347 let operand_bool = operand.as_bool().unwrap();
348 let operator_type;
349 if operator == "$eq" {
350 operator_type = PrimitiveOperator::Equal;
351 } else if operator == "$ne" {
352 operator_type = PrimitiveOperator::NotEqual;
353 } else {
354 return Err(WhereValidationError::WhereClause);
355 }
356 return Ok(Where::Metadata(MetadataExpression {
357 key: key.clone(),
358 comparison: crate::MetadataComparison::Primitive(
359 operator_type,
360 crate::MetadataValue::Bool(operand_bool),
361 ),
362 }));
363 }
364 if operand.is_f64() {
365 let operand_f64 = operand.as_f64().unwrap();
366 let operator_type;
367 if operator == "$eq" {
368 operator_type = PrimitiveOperator::Equal;
369 } else if operator == "$ne" {
370 operator_type = PrimitiveOperator::NotEqual;
371 } else if operator == "$lt" {
372 operator_type = PrimitiveOperator::LessThan;
373 } else if operator == "$lte" {
374 operator_type = PrimitiveOperator::LessThanOrEqual;
375 } else if operator == "$gt" {
376 operator_type = PrimitiveOperator::GreaterThan;
377 } else if operator == "$gte" {
378 operator_type = PrimitiveOperator::GreaterThanOrEqual;
379 } else {
380 return Err(WhereValidationError::WhereClause);
381 }
382 return Ok(Where::Metadata(MetadataExpression {
383 key: key.clone(),
384 comparison: crate::MetadataComparison::Primitive(
385 operator_type,
386 crate::MetadataValue::Float(operand_f64),
387 ),
388 }));
389 }
390 if operand.is_i64() {
391 let operand_i64 = operand.as_i64().unwrap();
392 let operator_type;
393 if operator == "$eq" {
394 operator_type = PrimitiveOperator::Equal;
395 } else if operator == "$ne" {
396 operator_type = PrimitiveOperator::NotEqual;
397 } else if operator == "$lt" {
398 operator_type = PrimitiveOperator::LessThan;
399 } else if operator == "$lte" {
400 operator_type = PrimitiveOperator::LessThanOrEqual;
401 } else if operator == "$gt" {
402 operator_type = PrimitiveOperator::GreaterThan;
403 } else if operator == "$gte" {
404 operator_type = PrimitiveOperator::GreaterThanOrEqual;
405 } else {
406 return Err(WhereValidationError::WhereClause);
407 }
408 return Ok(Where::Metadata(MetadataExpression {
409 key: key.clone(),
410 comparison: crate::MetadataComparison::Primitive(
411 operator_type,
412 crate::MetadataValue::Int(operand_i64),
413 ),
414 }));
415 }
416 return Err(WhereValidationError::WhereClause);
417 }
418 Err(WhereValidationError::WhereClause)
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use serde_json::json;
425
426 #[test]
427 fn test_parse_where_direct_eq() {
428 let payload = json!({
429 "key1": "value1"
430 });
431 let expected_result = Where::Metadata(MetadataExpression {
432 key: "key1".to_string(),
433 comparison: crate::MetadataComparison::Primitive(
434 PrimitiveOperator::Equal,
435 crate::MetadataValue::Str("value1".to_string()),
436 ),
437 });
438
439 let result = parse_where(&payload).expect("This clause to parse successfully");
440 assert_eq!(result, expected_result);
441 }
442
443 #[test]
445 fn test_parse_where_document() {
446 let payloads = [
447 json!({
449 "$and": [
450 {"$contains": "value1"},
451 {"$or": [
452 {"$contains": "value2"},
453 {"$contains": "value3"}
454 ]}
455 ]
456 }),
457 json!({
459 "$not_contains": "value1",
460 }),
461 ];
462
463 let expected_results = [
464 Where::Composite(CompositeExpression {
466 operator: crate::BooleanOperator::And,
467 children: vec![
468 Where::Document(crate::DocumentExpression {
469 operator: DocumentOperator::Contains,
470 pattern: "value1".to_string(),
471 }),
472 Where::Composite(CompositeExpression {
473 operator: crate::BooleanOperator::Or,
474 children: vec![
475 Where::Document(crate::DocumentExpression {
476 operator: DocumentOperator::Contains,
477 pattern: "value2".to_string(),
478 }),
479 Where::Document(crate::DocumentExpression {
480 operator: DocumentOperator::Contains,
481 pattern: "value3".to_string(),
482 }),
483 ],
484 }),
485 ],
486 }),
487 Where::Document(crate::DocumentExpression {
489 operator: DocumentOperator::NotContains,
490 pattern: "value1".to_string(),
491 }),
492 ];
493
494 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
495 let result = parse_where_document(payload);
496 assert!(
497 result.is_ok(),
498 "Parsing failed for payload: {}: {:?}",
499 serde_json::to_string_pretty(payload).unwrap(),
500 result
501 );
502 assert_eq!(
503 result.unwrap(),
504 *expected_result,
505 "Parsed result did not match expected result: {}",
506 serde_json::to_string_pretty(payload).unwrap(),
507 );
508 }
509 }
510
511 #[test]
512 fn test_parse_where() {
513 let payloads = [
514 json!({
516 "key1": {"$in": ["value1", "value2", "value3"]}
517 }),
518 json!({
520 "key1": {"$nin": ["value1", "value2", "value3"]}
521 }),
522 json!({
524 "key1": {"$eq": "value1"}
525 }),
526 json!({
528 "key1": {"$ne": "value1"}
529 }),
530 ];
531
532 let expected_results = [
533 Where::Metadata(MetadataExpression {
535 key: "key1".to_string(),
536 comparison: crate::MetadataComparison::Set(
537 crate::SetOperator::In,
538 crate::MetadataSetValue::Str(vec![
539 "value1".to_string(),
540 "value2".to_string(),
541 "value3".to_string(),
542 ]),
543 ),
544 }),
545 Where::Metadata(MetadataExpression {
547 key: "key1".to_string(),
548 comparison: crate::MetadataComparison::Set(
549 crate::SetOperator::NotIn,
550 crate::MetadataSetValue::Str(vec![
551 "value1".to_string(),
552 "value2".to_string(),
553 "value3".to_string(),
554 ]),
555 ),
556 }),
557 Where::Metadata(MetadataExpression {
559 key: "key1".to_string(),
560 comparison: crate::MetadataComparison::Primitive(
561 PrimitiveOperator::Equal,
562 crate::MetadataValue::Str("value1".to_string()),
563 ),
564 }),
565 Where::Metadata(MetadataExpression {
567 key: "key1".to_string(),
568 comparison: crate::MetadataComparison::Primitive(
569 PrimitiveOperator::NotEqual,
570 crate::MetadataValue::Str("value1".to_string()),
571 ),
572 }),
573 ];
574
575 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
576 let result = parse_where(payload);
577 assert!(
578 result.is_ok(),
579 "Parsing failed for payload: {}: {:?}",
580 serde_json::to_string_pretty(payload).unwrap(),
581 result
582 );
583 assert_eq!(
584 result.unwrap(),
585 *expected_result,
586 "Parsed result did not match expected result: {}",
587 serde_json::to_string_pretty(payload).unwrap(),
588 );
589 }
590 }
591}