1use crate::regex::{ChromaRegex, ChromaRegexError};
2use crate::{
3 CompositeExpression, ContainsOperator, DocumentOperator, MetadataExpression, PrimitiveOperator,
4 Where,
5};
6use chroma_error::{ChromaError, ErrorCodes};
7use serde::Deserialize;
8use serde::Serialize;
9use serde_json::Value;
10use thiserror::Error;
11
12#[derive(Default, Deserialize, Debug, Clone, Serialize)]
13#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
14pub struct RawWhereFields {
15 #[serde(default)]
16 pub r#where: Value,
17 #[serde(default)]
18 pub where_document: Value,
19}
20
21impl RawWhereFields {
22 pub fn new(r#where: Value, where_document: Value) -> Self {
23 Self {
24 r#where,
25 where_document,
26 }
27 }
28
29 pub fn from_json_str(
30 r#where: Option<&str>,
31 where_document: Option<&str>,
32 ) -> Result<Self, WhereValidationError> {
33 let r#where = r#where
34 .map(|r#where| {
35 serde_json::from_str(r#where).map_err(|_| WhereValidationError::WhereClause)
36 })
37 .transpose()?
38 .unwrap_or(Value::Null);
39
40 let where_document = where_document
41 .map(|where_document| {
42 serde_json::from_str(where_document)
43 .map_err(|_| WhereValidationError::WhereDocumentClause)
44 })
45 .transpose()?
46 .unwrap_or(Value::Null);
47
48 Ok(Self {
49 r#where,
50 where_document,
51 })
52 }
53}
54
55#[derive(Error, Debug)]
56pub enum WhereValidationError {
57 #[error(transparent)]
58 Regex(#[from] ChromaRegexError),
59 #[error("Invalid where clause")]
60 WhereClause,
61 #[error("Invalid where document clause")]
62 WhereDocumentClause,
63}
64
65impl ChromaError for WhereValidationError {
66 fn code(&self) -> chroma_error::ErrorCodes {
67 ErrorCodes::InvalidArgument
68 }
69}
70
71impl RawWhereFields {
72 pub fn parse(self) -> Result<Option<Where>, WhereValidationError> {
73 let mut where_clause = None;
74 if !self.r#where.is_null() {
75 let where_payload = &self.r#where;
76 where_clause = Some(parse_where(where_payload)?);
77 }
78 let mut where_document_clause = None;
79 if !self.where_document.is_null() {
80 let where_document_payload = &self.where_document;
81 where_document_clause = Some(parse_where_document(where_document_payload)?);
82 }
83 let combined_where = match where_clause {
84 Some(where_clause) => match where_document_clause {
85 Some(where_document_clause) => Some(Where::Composite(CompositeExpression {
86 operator: crate::BooleanOperator::And,
87 children: vec![where_clause, where_document_clause],
88 })),
89 None => Some(where_clause),
90 },
91 None => where_document_clause,
92 };
93
94 Ok(combined_where)
95 }
96}
97
98pub fn parse_where_document(json_payload: &Value) -> Result<Where, WhereValidationError> {
99 let where_doc_payload = json_payload
100 .as_object()
101 .ok_or(WhereValidationError::WhereDocumentClause)?;
102 if where_doc_payload.len() != 1 {
103 return Err(WhereValidationError::WhereDocumentClause);
104 }
105 let (key, value) = where_doc_payload.iter().next().unwrap();
106 if key == "$and" {
108 let logical_operator = crate::BooleanOperator::And;
109 let children = value
111 .as_array()
112 .ok_or(WhereValidationError::WhereDocumentClause)?;
113 let mut predicate_list = vec![];
114 for child in children {
116 predicate_list.push(parse_where_document(child)?);
117 }
118 return Ok(Where::Composite(CompositeExpression {
119 operator: logical_operator,
120 children: predicate_list,
121 }));
122 }
123 if key == "$or" {
124 let logical_operator = crate::BooleanOperator::Or;
125 let children = value
127 .as_array()
128 .ok_or(WhereValidationError::WhereDocumentClause)?;
129 let mut predicate_list = vec![];
130 for child in children {
132 predicate_list.push(parse_where_document(child)?);
133 }
134 return Ok(Where::Composite(CompositeExpression {
135 operator: logical_operator,
136 children: predicate_list,
137 }));
138 }
139 if !value.is_string() {
140 return Err(WhereValidationError::WhereDocumentClause);
141 }
142 let value_str = value.as_str().unwrap();
143 let operator_type = match key.as_str() {
144 "$contains" => DocumentOperator::Contains,
145 "$not_contains" => DocumentOperator::NotContains,
146 "$regex" => DocumentOperator::Regex,
147 "$not_regex" => DocumentOperator::NotRegex,
148 _ => return Err(WhereValidationError::WhereDocumentClause),
149 };
150 if matches!(
151 operator_type,
152 DocumentOperator::Regex | DocumentOperator::NotRegex
153 ) {
154 ChromaRegex::try_from(value_str.to_string())?;
155 }
156 Ok(Where::Document(crate::DocumentExpression {
157 operator: operator_type,
158 pattern: value_str.to_string(),
159 }))
160}
161
162fn parse_contains_operator(operator: &str) -> Option<ContainsOperator> {
165 match operator {
166 "$contains" => Some(ContainsOperator::Contains),
167 "$not_contains" => Some(ContainsOperator::NotContains),
168 _ => None,
169 }
170}
171
172pub fn parse_where(json_payload: &Value) -> Result<Where, WhereValidationError> {
173 let where_payload = json_payload
174 .as_object()
175 .ok_or(WhereValidationError::WhereClause)?;
176 if where_payload.len() != 1 {
177 return Err(WhereValidationError::WhereClause);
178 }
179 let (key, value) = where_payload.iter().next().unwrap();
180 if key == "$and" {
182 let logical_operator = crate::BooleanOperator::And;
183 let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
185 let mut predicate_list = vec![];
186 for child in children {
188 predicate_list.push(parse_where(child)?);
189 }
190 return Ok(Where::Composite(CompositeExpression {
191 operator: logical_operator,
192 children: predicate_list,
193 }));
194 }
195 if key == "$or" {
196 let logical_operator = crate::BooleanOperator::Or;
197 let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
199 let mut predicate_list = vec![];
200 for child in children {
202 predicate_list.push(parse_where(child)?);
203 }
204 return Ok(Where::Composite(CompositeExpression {
205 operator: logical_operator,
206 children: predicate_list,
207 }));
208 }
209 if key.starts_with('$') {
213 return Err(WhereValidationError::WhereClause);
214 }
215 if value.is_string() {
218 return Ok(Where::Metadata(MetadataExpression {
219 key: key.clone(),
220 comparison: crate::MetadataComparison::Primitive(
221 crate::PrimitiveOperator::Equal,
222 crate::MetadataValue::Str(value.as_str().unwrap().to_string()),
223 ),
224 }));
225 }
226 if value.is_boolean() {
227 return Ok(Where::Metadata(MetadataExpression {
228 key: key.clone(),
229 comparison: crate::MetadataComparison::Primitive(
230 crate::PrimitiveOperator::Equal,
231 crate::MetadataValue::Bool(value.as_bool().unwrap()),
232 ),
233 }));
234 }
235 if value.is_f64() {
236 return Ok(Where::Metadata(MetadataExpression {
237 key: key.clone(),
238 comparison: crate::MetadataComparison::Primitive(
239 crate::PrimitiveOperator::Equal,
240 crate::MetadataValue::Float(value.as_f64().unwrap()),
241 ),
242 }));
243 }
244 if value.is_i64() {
245 return Ok(Where::Metadata(MetadataExpression {
246 key: key.clone(),
247 comparison: crate::MetadataComparison::Primitive(
248 crate::PrimitiveOperator::Equal,
249 crate::MetadataValue::Int(value.as_i64().unwrap()),
250 ),
251 }));
252 }
253 if value.is_object() {
254 let value_obj = value.as_object().unwrap();
255 if value_obj.len() != 1 {
257 return Err(WhereValidationError::WhereClause);
258 }
259 let (operator, operand) = value_obj.iter().next().unwrap();
260 if operand.is_array() {
261 let set_operator;
262 if operator == "$in" {
263 set_operator = crate::SetOperator::In;
264 } else if operator == "$nin" {
265 set_operator = crate::SetOperator::NotIn;
266 } else {
267 return Err(WhereValidationError::WhereClause);
268 }
269 let operand = operand.as_array().unwrap();
270 if operand.is_empty() {
271 return Err(WhereValidationError::WhereClause);
272 }
273 if operand[0].is_string() {
274 let operand_str = operand
275 .iter()
276 .map(|val| {
277 val.as_str()
278 .ok_or(WhereValidationError::WhereClause)
279 .map(|s| s.to_string())
280 })
281 .collect::<Result<Vec<String>, _>>()?;
282 return Ok(Where::Metadata(MetadataExpression {
283 key: key.clone(),
284 comparison: crate::MetadataComparison::Set(
285 set_operator,
286 crate::MetadataSetValue::Str(operand_str),
287 ),
288 }));
289 }
290 if operand[0].is_boolean() {
291 let operand_bool = operand
292 .iter()
293 .map(|val| val.as_bool().ok_or(WhereValidationError::WhereClause))
294 .collect::<Result<Vec<bool>, _>>()?;
295 return Ok(Where::Metadata(MetadataExpression {
296 key: key.clone(),
297 comparison: crate::MetadataComparison::Set(
298 set_operator,
299 crate::MetadataSetValue::Bool(operand_bool),
300 ),
301 }));
302 }
303 if operand[0].is_f64() {
304 let operand_f64 = operand
305 .iter()
306 .map(|val| val.as_f64().ok_or(WhereValidationError::WhereClause))
307 .collect::<Result<Vec<f64>, _>>()?;
308 return Ok(Where::Metadata(MetadataExpression {
309 key: key.clone(),
310 comparison: crate::MetadataComparison::Set(
311 set_operator,
312 crate::MetadataSetValue::Float(operand_f64),
313 ),
314 }));
315 }
316 if operand[0].is_i64() {
317 let operand_i64 = operand
318 .iter()
319 .map(|val| val.as_i64().ok_or(WhereValidationError::WhereClause))
320 .collect::<Result<Vec<i64>, _>>()?;
321 return Ok(Where::Metadata(MetadataExpression {
322 key: key.clone(),
323 comparison: crate::MetadataComparison::Set(
324 set_operator,
325 crate::MetadataSetValue::Int(operand_i64),
326 ),
327 }));
328 }
329 return Err(WhereValidationError::WhereClause);
330 }
331 if operand.is_string() {
332 let operand_str = operand.as_str().unwrap();
333 if operator == "$contains" || operator == "$not_contains" {
337 if key == "#document" {
338 let doc_op = if operator == "$contains" {
339 DocumentOperator::Contains
340 } else {
341 DocumentOperator::NotContains
342 };
343 return Ok(Where::Document(crate::DocumentExpression {
344 operator: doc_op,
345 pattern: operand_str.to_string(),
346 }));
347 }
348 let contains_op = if operator == "$contains" {
349 ContainsOperator::Contains
350 } else {
351 ContainsOperator::NotContains
352 };
353 return Ok(Where::Metadata(MetadataExpression {
354 key: key.clone(),
355 comparison: crate::MetadataComparison::ArrayContains(
356 contains_op,
357 crate::MetadataValue::Str(operand_str.to_string()),
358 ),
359 }));
360 }
361 if operator == "$regex" || operator == "$not_regex" {
362 if key != "#document" {
364 return Err(WhereValidationError::WhereClause);
365 }
366 ChromaRegex::try_from(operand_str.to_string())?;
367 let doc_op = if operator == "$regex" {
368 DocumentOperator::Regex
369 } else {
370 DocumentOperator::NotRegex
371 };
372 return Ok(Where::Document(crate::DocumentExpression {
373 operator: doc_op,
374 pattern: operand_str.to_string(),
375 }));
376 }
377 let operator_type;
378 if operator == "$eq" {
379 operator_type = PrimitiveOperator::Equal;
380 } else if operator == "$ne" {
381 operator_type = PrimitiveOperator::NotEqual;
382 } else {
383 return Err(WhereValidationError::WhereClause);
384 }
385 return Ok(Where::Metadata(MetadataExpression {
386 key: key.clone(),
387 comparison: crate::MetadataComparison::Primitive(
388 operator_type,
389 crate::MetadataValue::Str(operand_str.to_string()),
390 ),
391 }));
392 }
393 if operand.is_boolean() {
394 let operand_bool = operand.as_bool().unwrap();
395 if let Some(contains_op) = parse_contains_operator(operator) {
396 if key == "#document" {
398 return Err(WhereValidationError::WhereClause);
399 }
400 return Ok(Where::Metadata(MetadataExpression {
401 key: key.clone(),
402 comparison: crate::MetadataComparison::ArrayContains(
403 contains_op,
404 crate::MetadataValue::Bool(operand_bool),
405 ),
406 }));
407 }
408 let operator_type;
409 if operator == "$eq" {
410 operator_type = PrimitiveOperator::Equal;
411 } else if operator == "$ne" {
412 operator_type = PrimitiveOperator::NotEqual;
413 } else {
414 return Err(WhereValidationError::WhereClause);
415 }
416 return Ok(Where::Metadata(MetadataExpression {
417 key: key.clone(),
418 comparison: crate::MetadataComparison::Primitive(
419 operator_type,
420 crate::MetadataValue::Bool(operand_bool),
421 ),
422 }));
423 }
424 if operand.is_f64() {
425 let operand_f64 = operand.as_f64().unwrap();
426 if let Some(contains_op) = parse_contains_operator(operator) {
427 if key == "#document" {
429 return Err(WhereValidationError::WhereClause);
430 }
431 return Ok(Where::Metadata(MetadataExpression {
432 key: key.clone(),
433 comparison: crate::MetadataComparison::ArrayContains(
434 contains_op,
435 crate::MetadataValue::Float(operand_f64),
436 ),
437 }));
438 }
439 let operator_type;
440 if operator == "$eq" {
441 operator_type = PrimitiveOperator::Equal;
442 } else if operator == "$ne" {
443 operator_type = PrimitiveOperator::NotEqual;
444 } else if operator == "$lt" {
445 operator_type = PrimitiveOperator::LessThan;
446 } else if operator == "$lte" {
447 operator_type = PrimitiveOperator::LessThanOrEqual;
448 } else if operator == "$gt" {
449 operator_type = PrimitiveOperator::GreaterThan;
450 } else if operator == "$gte" {
451 operator_type = PrimitiveOperator::GreaterThanOrEqual;
452 } else {
453 return Err(WhereValidationError::WhereClause);
454 }
455 return Ok(Where::Metadata(MetadataExpression {
456 key: key.clone(),
457 comparison: crate::MetadataComparison::Primitive(
458 operator_type,
459 crate::MetadataValue::Float(operand_f64),
460 ),
461 }));
462 }
463 if operand.is_i64() {
464 let operand_i64 = operand.as_i64().unwrap();
465 if let Some(contains_op) = parse_contains_operator(operator) {
466 if key == "#document" {
468 return Err(WhereValidationError::WhereClause);
469 }
470 return Ok(Where::Metadata(MetadataExpression {
471 key: key.clone(),
472 comparison: crate::MetadataComparison::ArrayContains(
473 contains_op,
474 crate::MetadataValue::Int(operand_i64),
475 ),
476 }));
477 }
478 let operator_type;
479 if operator == "$eq" {
480 operator_type = PrimitiveOperator::Equal;
481 } else if operator == "$ne" {
482 operator_type = PrimitiveOperator::NotEqual;
483 } else if operator == "$lt" {
484 operator_type = PrimitiveOperator::LessThan;
485 } else if operator == "$lte" {
486 operator_type = PrimitiveOperator::LessThanOrEqual;
487 } else if operator == "$gt" {
488 operator_type = PrimitiveOperator::GreaterThan;
489 } else if operator == "$gte" {
490 operator_type = PrimitiveOperator::GreaterThanOrEqual;
491 } else {
492 return Err(WhereValidationError::WhereClause);
493 }
494 return Ok(Where::Metadata(MetadataExpression {
495 key: key.clone(),
496 comparison: crate::MetadataComparison::Primitive(
497 operator_type,
498 crate::MetadataValue::Int(operand_i64),
499 ),
500 }));
501 }
502 return Err(WhereValidationError::WhereClause);
503 }
504 Err(WhereValidationError::WhereClause)
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use serde_json::json;
511
512 #[test]
513 fn test_parse_where_direct_eq() {
514 let payload = json!({
515 "key1": "value1"
516 });
517 let expected_result = Where::Metadata(MetadataExpression {
518 key: "key1".to_string(),
519 comparison: crate::MetadataComparison::Primitive(
520 PrimitiveOperator::Equal,
521 crate::MetadataValue::Str("value1".to_string()),
522 ),
523 });
524
525 let result = parse_where(&payload).expect("This clause to parse successfully");
526 assert_eq!(result, expected_result);
527 }
528
529 #[test]
531 fn test_parse_where_document() {
532 let payloads = [
533 json!({
535 "$and": [
536 {"$contains": "value1"},
537 {"$or": [
538 {"$contains": "value2"},
539 {"$contains": "value3"}
540 ]}
541 ]
542 }),
543 json!({
545 "$not_contains": "value1",
546 }),
547 ];
548
549 let expected_results = [
550 Where::Composite(CompositeExpression {
552 operator: crate::BooleanOperator::And,
553 children: vec![
554 Where::Document(crate::DocumentExpression {
555 operator: DocumentOperator::Contains,
556 pattern: "value1".to_string(),
557 }),
558 Where::Composite(CompositeExpression {
559 operator: crate::BooleanOperator::Or,
560 children: vec![
561 Where::Document(crate::DocumentExpression {
562 operator: DocumentOperator::Contains,
563 pattern: "value2".to_string(),
564 }),
565 Where::Document(crate::DocumentExpression {
566 operator: DocumentOperator::Contains,
567 pattern: "value3".to_string(),
568 }),
569 ],
570 }),
571 ],
572 }),
573 Where::Document(crate::DocumentExpression {
575 operator: DocumentOperator::NotContains,
576 pattern: "value1".to_string(),
577 }),
578 ];
579
580 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
581 let result = parse_where_document(payload);
582 assert!(
583 result.is_ok(),
584 "Parsing failed for payload: {}: {:?}",
585 serde_json::to_string_pretty(payload).unwrap(),
586 result
587 );
588 assert_eq!(
589 result.unwrap(),
590 *expected_result,
591 "Parsed result did not match expected result: {}",
592 serde_json::to_string_pretty(payload).unwrap(),
593 );
594 }
595 }
596
597 #[test]
598 fn test_parse_where() {
599 let payloads = [
600 json!({
602 "key1": {"$in": ["value1", "value2", "value3"]}
603 }),
604 json!({
606 "key1": {"$nin": ["value1", "value2", "value3"]}
607 }),
608 json!({
610 "key1": {"$eq": "value1"}
611 }),
612 json!({
614 "key1": {"$ne": "value1"}
615 }),
616 ];
617
618 let expected_results = [
619 Where::Metadata(MetadataExpression {
621 key: "key1".to_string(),
622 comparison: crate::MetadataComparison::Set(
623 crate::SetOperator::In,
624 crate::MetadataSetValue::Str(vec![
625 "value1".to_string(),
626 "value2".to_string(),
627 "value3".to_string(),
628 ]),
629 ),
630 }),
631 Where::Metadata(MetadataExpression {
633 key: "key1".to_string(),
634 comparison: crate::MetadataComparison::Set(
635 crate::SetOperator::NotIn,
636 crate::MetadataSetValue::Str(vec![
637 "value1".to_string(),
638 "value2".to_string(),
639 "value3".to_string(),
640 ]),
641 ),
642 }),
643 Where::Metadata(MetadataExpression {
645 key: "key1".to_string(),
646 comparison: crate::MetadataComparison::Primitive(
647 PrimitiveOperator::Equal,
648 crate::MetadataValue::Str("value1".to_string()),
649 ),
650 }),
651 Where::Metadata(MetadataExpression {
653 key: "key1".to_string(),
654 comparison: crate::MetadataComparison::Primitive(
655 PrimitiveOperator::NotEqual,
656 crate::MetadataValue::Str("value1".to_string()),
657 ),
658 }),
659 ];
660
661 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
662 let result = parse_where(payload);
663 assert!(
664 result.is_ok(),
665 "Parsing failed for payload: {}: {:?}",
666 serde_json::to_string_pretty(payload).unwrap(),
667 result
668 );
669 assert_eq!(
670 result.unwrap(),
671 *expected_result,
672 "Parsed result did not match expected result: {}",
673 serde_json::to_string_pretty(payload).unwrap(),
674 );
675 }
676 }
677
678 #[test]
679 fn test_parse_where_contains_metadata() {
680 let payloads = [
683 json!({"tags": {"$contains": "action"}}),
685 json!({"tags": {"$not_contains": "comedy"}}),
687 json!({"scores": {"$contains": 42}}),
689 json!({"ratings": {"$contains": 4.5}}),
691 json!({"flags": {"$contains": true}}),
693 ];
694
695 let expected_results = [
696 Where::Metadata(MetadataExpression {
697 key: "tags".to_string(),
698 comparison: crate::MetadataComparison::ArrayContains(
699 ContainsOperator::Contains,
700 crate::MetadataValue::Str("action".to_string()),
701 ),
702 }),
703 Where::Metadata(MetadataExpression {
704 key: "tags".to_string(),
705 comparison: crate::MetadataComparison::ArrayContains(
706 ContainsOperator::NotContains,
707 crate::MetadataValue::Str("comedy".to_string()),
708 ),
709 }),
710 Where::Metadata(MetadataExpression {
711 key: "scores".to_string(),
712 comparison: crate::MetadataComparison::ArrayContains(
713 ContainsOperator::Contains,
714 crate::MetadataValue::Int(42),
715 ),
716 }),
717 Where::Metadata(MetadataExpression {
718 key: "ratings".to_string(),
719 comparison: crate::MetadataComparison::ArrayContains(
720 ContainsOperator::Contains,
721 crate::MetadataValue::Float(4.5),
722 ),
723 }),
724 Where::Metadata(MetadataExpression {
725 key: "flags".to_string(),
726 comparison: crate::MetadataComparison::ArrayContains(
727 ContainsOperator::Contains,
728 crate::MetadataValue::Bool(true),
729 ),
730 }),
731 ];
732
733 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
734 let result = parse_where(payload);
735 assert!(
736 result.is_ok(),
737 "Parsing failed for payload: {}: {:?}",
738 serde_json::to_string_pretty(payload).unwrap(),
739 result
740 );
741 assert_eq!(
742 result.unwrap(),
743 *expected_result,
744 "Parsed result did not match expected result: {}",
745 serde_json::to_string_pretty(payload).unwrap(),
746 );
747 }
748 }
749
750 #[test]
751 fn test_parse_where_document_contains_in_where() {
752 let payload = json!({"#document": {"$contains": "search term"}});
755 let result = parse_where(&payload).expect("Should parse successfully");
756 assert_eq!(
757 result,
758 Where::Document(crate::DocumentExpression {
759 operator: DocumentOperator::Contains,
760 pattern: "search term".to_string(),
761 })
762 );
763 }
764
765 #[test]
766 fn test_parse_where_regex_only_on_document() {
767 let payload = json!({"#document": {"$regex": "act.*"}});
769 let result = parse_where(&payload).expect("Should parse successfully");
770 assert_eq!(
771 result,
772 Where::Document(crate::DocumentExpression {
773 operator: DocumentOperator::Regex,
774 pattern: "act.*".to_string(),
775 })
776 );
777
778 let payload = json!({"#document": {"$not_regex": "draft.*"}});
779 let result = parse_where(&payload).expect("Should parse successfully");
780 assert_eq!(
781 result,
782 Where::Document(crate::DocumentExpression {
783 operator: DocumentOperator::NotRegex,
784 pattern: "draft.*".to_string(),
785 })
786 );
787
788 let payload = json!({"tags": {"$regex": "act.*"}});
790 assert!(parse_where(&payload).is_err());
791
792 let payload = json!({"tags": {"$not_regex": "draft.*"}});
793 assert!(parse_where(&payload).is_err());
794 }
795
796 #[test]
797 fn test_where_contains_round_trip() {
798 let original = Where::Metadata(MetadataExpression {
801 key: "tags".to_string(),
802 comparison: crate::MetadataComparison::ArrayContains(
803 ContainsOperator::Contains,
804 crate::MetadataValue::Str("action".to_string()),
805 ),
806 });
807 let json_str = serde_json::to_string(&original).unwrap();
808 let json_value: Value = serde_json::from_str(&json_str).unwrap();
809 let parsed = parse_where(&json_value).expect("Round-trip parsing should succeed");
810 assert_eq!(original, parsed);
811 }
812
813 #[test]
814 fn test_document_contains_rejects_non_string_operand() {
815 let payloads = [
818 json!({"#document": {"$contains": 42}}),
819 json!({"#document": {"$contains": 2.72}}),
820 json!({"#document": {"$contains": true}}),
821 json!({"#document": {"$not_contains": 42}}),
822 json!({"#document": {"$not_contains": false}}),
823 ];
824 for payload in &payloads {
825 let result = parse_where(payload);
826 assert!(
827 result.is_err(),
828 "Expected error for non-string #document contains, but got Ok for: {}",
829 serde_json::to_string_pretty(payload).unwrap(),
830 );
831 }
832 }
833
834 #[test]
835 fn test_parse_where_in_nin_typed_arrays() {
836 let payloads = [
838 json!({"scores": {"$in": [1, 2, 3]}}),
840 json!({"scores": {"$nin": [10, 20]}}),
842 json!({"flags": {"$in": [true, false]}}),
844 json!({"ratings": {"$in": [1.5, 2.5, 3.5]}}),
846 ];
847
848 let expected_results = [
849 Where::Metadata(MetadataExpression {
850 key: "scores".to_string(),
851 comparison: crate::MetadataComparison::Set(
852 crate::SetOperator::In,
853 crate::MetadataSetValue::Int(vec![1, 2, 3]),
854 ),
855 }),
856 Where::Metadata(MetadataExpression {
857 key: "scores".to_string(),
858 comparison: crate::MetadataComparison::Set(
859 crate::SetOperator::NotIn,
860 crate::MetadataSetValue::Int(vec![10, 20]),
861 ),
862 }),
863 Where::Metadata(MetadataExpression {
864 key: "flags".to_string(),
865 comparison: crate::MetadataComparison::Set(
866 crate::SetOperator::In,
867 crate::MetadataSetValue::Bool(vec![true, false]),
868 ),
869 }),
870 Where::Metadata(MetadataExpression {
871 key: "ratings".to_string(),
872 comparison: crate::MetadataComparison::Set(
873 crate::SetOperator::In,
874 crate::MetadataSetValue::Float(vec![1.5, 2.5, 3.5]),
875 ),
876 }),
877 ];
878
879 for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
880 let result = parse_where(payload);
881 assert!(
882 result.is_ok(),
883 "Parsing failed for payload: {}: {:?}",
884 serde_json::to_string_pretty(payload).unwrap(),
885 result
886 );
887 assert_eq!(
888 result.unwrap(),
889 *expected_result,
890 "Parsed result did not match expected result: {}",
891 serde_json::to_string_pretty(payload).unwrap(),
892 );
893 }
894 }
895
896 #[test]
897 fn test_parse_where_in_mixed_types_rejected() {
898 let payloads = [
901 json!({"key": {"$in": ["a", 1]}}),
902 json!({"key": {"$in": [1, "b"]}}),
903 json!({"key": {"$nin": [true, 1]}}),
904 ];
905 for payload in &payloads {
906 let result = parse_where(payload);
907 assert!(
908 result.is_err(),
909 "Expected error for mixed-type array, but got Ok for: {}",
910 serde_json::to_string_pretty(payload).unwrap(),
911 );
912 }
913 }
914
915 #[test]
916 fn test_parse_where_in_empty_array_rejected() {
917 let payloads = [json!({"key": {"$in": []}}), json!({"key": {"$nin": []}})];
919 for payload in &payloads {
920 let result = parse_where(payload);
921 assert!(
922 result.is_err(),
923 "Expected error for empty array, but got Ok for: {}",
924 serde_json::to_string_pretty(payload).unwrap(),
925 );
926 }
927 }
928
929 #[test]
930 fn test_parse_where_contains_not_valid_with_array_operand() {
931 let payloads = [
933 json!({"tags": {"$contains": ["a", "b"]}}),
934 json!({"tags": {"$not_contains": [1, 2]}}),
935 ];
936 for payload in &payloads {
937 let result = parse_where(payload);
938 assert!(
939 result.is_err(),
940 "Expected error for array operand in $contains, but got Ok for: {}",
941 serde_json::to_string_pretty(payload).unwrap(),
942 );
943 }
944 }
945}