Skip to main content

chroma_types/
where_parsing.rs

1use crate::regex::{ChromaRegex, ChromaRegexError};
2use crate::{
3    CompositeExpression, ContainsOperator, DocumentOperator, MetadataExpression, PrimitiveOperator,
4    Where,
5};
6use chroma_error::{ChromaError, ErrorCodes};
7use serde::Deserialize;
8use serde::Serialize;
9use serde_json::Value;
10use thiserror::Error;
11
12#[derive(Default, Deserialize, Debug, Clone, Serialize)]
13#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
14pub struct RawWhereFields {
15    #[serde(default)]
16    pub r#where: Value,
17    #[serde(default)]
18    pub where_document: Value,
19}
20
21impl RawWhereFields {
22    pub fn new(r#where: Value, where_document: Value) -> Self {
23        Self {
24            r#where,
25            where_document,
26        }
27    }
28
29    pub fn from_json_str(
30        r#where: Option<&str>,
31        where_document: Option<&str>,
32    ) -> Result<Self, WhereValidationError> {
33        let r#where = r#where
34            .map(|r#where| {
35                serde_json::from_str(r#where).map_err(|_| WhereValidationError::WhereClause)
36            })
37            .transpose()?
38            .unwrap_or(Value::Null);
39
40        let where_document = where_document
41            .map(|where_document| {
42                serde_json::from_str(where_document)
43                    .map_err(|_| WhereValidationError::WhereDocumentClause)
44            })
45            .transpose()?
46            .unwrap_or(Value::Null);
47
48        Ok(Self {
49            r#where,
50            where_document,
51        })
52    }
53}
54
55#[derive(Error, Debug)]
56pub enum WhereValidationError {
57    #[error(transparent)]
58    Regex(#[from] ChromaRegexError),
59    #[error("Invalid where clause")]
60    WhereClause,
61    #[error("Invalid where document clause")]
62    WhereDocumentClause,
63}
64
65impl ChromaError for WhereValidationError {
66    fn code(&self) -> chroma_error::ErrorCodes {
67        ErrorCodes::InvalidArgument
68    }
69}
70
71impl RawWhereFields {
72    pub fn parse(self) -> Result<Option<Where>, WhereValidationError> {
73        let mut where_clause = None;
74        if !self.r#where.is_null() {
75            let where_payload = &self.r#where;
76            where_clause = Some(parse_where(where_payload)?);
77        }
78        let mut where_document_clause = None;
79        if !self.where_document.is_null() {
80            let where_document_payload = &self.where_document;
81            where_document_clause = Some(parse_where_document(where_document_payload)?);
82        }
83        let combined_where = match where_clause {
84            Some(where_clause) => match where_document_clause {
85                Some(where_document_clause) => Some(Where::Composite(CompositeExpression {
86                    operator: crate::BooleanOperator::And,
87                    children: vec![where_clause, where_document_clause],
88                })),
89                None => Some(where_clause),
90            },
91            None => where_document_clause,
92        };
93
94        Ok(combined_where)
95    }
96}
97
98pub fn parse_where_document(json_payload: &Value) -> Result<Where, WhereValidationError> {
99    let where_doc_payload = json_payload
100        .as_object()
101        .ok_or(WhereValidationError::WhereDocumentClause)?;
102    if where_doc_payload.len() != 1 {
103        return Err(WhereValidationError::WhereDocumentClause);
104    }
105    let (key, value) = where_doc_payload.iter().next().unwrap();
106    // Check if it is a composite expression.
107    if key == "$and" {
108        let logical_operator = crate::BooleanOperator::And;
109        // Check that the value is list type.
110        let children = value
111            .as_array()
112            .ok_or(WhereValidationError::WhereDocumentClause)?;
113        let mut predicate_list = vec![];
114        // Recursively parse the children.
115        for child in children {
116            predicate_list.push(parse_where_document(child)?);
117        }
118        return Ok(Where::Composite(CompositeExpression {
119            operator: logical_operator,
120            children: predicate_list,
121        }));
122    }
123    if key == "$or" {
124        let logical_operator = crate::BooleanOperator::Or;
125        // Check that the value is list type.
126        let children = value
127            .as_array()
128            .ok_or(WhereValidationError::WhereDocumentClause)?;
129        let mut predicate_list = vec![];
130        // Recursively parse the children.
131        for child in children {
132            predicate_list.push(parse_where_document(child)?);
133        }
134        return Ok(Where::Composite(CompositeExpression {
135            operator: logical_operator,
136            children: predicate_list,
137        }));
138    }
139    if !value.is_string() {
140        return Err(WhereValidationError::WhereDocumentClause);
141    }
142    let value_str = value.as_str().unwrap();
143    let operator_type = match key.as_str() {
144        "$contains" => DocumentOperator::Contains,
145        "$not_contains" => DocumentOperator::NotContains,
146        "$regex" => DocumentOperator::Regex,
147        "$not_regex" => DocumentOperator::NotRegex,
148        _ => return Err(WhereValidationError::WhereDocumentClause),
149    };
150    if matches!(
151        operator_type,
152        DocumentOperator::Regex | DocumentOperator::NotRegex
153    ) {
154        ChromaRegex::try_from(value_str.to_string())?;
155    }
156    Ok(Where::Document(crate::DocumentExpression {
157        operator: operator_type,
158        pattern: value_str.to_string(),
159    }))
160}
161
162/// Returns the [`ContainsOperator`] for `$contains` / `$not_contains`,
163/// or `None` for any other operator string.
164fn parse_contains_operator(operator: &str) -> Option<ContainsOperator> {
165    match operator {
166        "$contains" => Some(ContainsOperator::Contains),
167        "$not_contains" => Some(ContainsOperator::NotContains),
168        _ => None,
169    }
170}
171
172pub fn parse_where(json_payload: &Value) -> Result<Where, WhereValidationError> {
173    let where_payload = json_payload
174        .as_object()
175        .ok_or(WhereValidationError::WhereClause)?;
176    if where_payload.len() != 1 {
177        return Err(WhereValidationError::WhereClause);
178    }
179    let (key, value) = where_payload.iter().next().unwrap();
180    // Check if it is a composite expression.
181    if key == "$and" {
182        let logical_operator = crate::BooleanOperator::And;
183        // Check that the value is list type.
184        let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
185        let mut predicate_list = vec![];
186        // Recursively parse the children.
187        for child in children {
188            predicate_list.push(parse_where(child)?);
189        }
190        return Ok(Where::Composite(CompositeExpression {
191            operator: logical_operator,
192            children: predicate_list,
193        }));
194    }
195    if key == "$or" {
196        let logical_operator = crate::BooleanOperator::Or;
197        // Check that the value is list type.
198        let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
199        let mut predicate_list = vec![];
200        // Recursively parse the children.
201        for child in children {
202            predicate_list.push(parse_where(child)?);
203        }
204        return Ok(Where::Composite(CompositeExpression {
205            operator: logical_operator,
206            children: predicate_list,
207        }));
208    }
209    // Any other $-prefixed key is an operator, not a metadata field name.
210    // Operators like $contains, $not_contains, $gt, etc. are only valid
211    // inside a field expression (e.g. {"field": {"$contains": val}}).
212    if key.starts_with('$') {
213        return Err(WhereValidationError::WhereClause);
214    }
215    // At this point we know we're at a direct comparison. It can either
216    // be of the form {"key": "value"} or {"key": {"$operator": "value"}}.
217    if value.is_string() {
218        return Ok(Where::Metadata(MetadataExpression {
219            key: key.clone(),
220            comparison: crate::MetadataComparison::Primitive(
221                crate::PrimitiveOperator::Equal,
222                crate::MetadataValue::Str(value.as_str().unwrap().to_string()),
223            ),
224        }));
225    }
226    if value.is_boolean() {
227        return Ok(Where::Metadata(MetadataExpression {
228            key: key.clone(),
229            comparison: crate::MetadataComparison::Primitive(
230                crate::PrimitiveOperator::Equal,
231                crate::MetadataValue::Bool(value.as_bool().unwrap()),
232            ),
233        }));
234    }
235    if value.is_f64() {
236        return Ok(Where::Metadata(MetadataExpression {
237            key: key.clone(),
238            comparison: crate::MetadataComparison::Primitive(
239                crate::PrimitiveOperator::Equal,
240                crate::MetadataValue::Float(value.as_f64().unwrap()),
241            ),
242        }));
243    }
244    if value.is_i64() {
245        return Ok(Where::Metadata(MetadataExpression {
246            key: key.clone(),
247            comparison: crate::MetadataComparison::Primitive(
248                crate::PrimitiveOperator::Equal,
249                crate::MetadataValue::Int(value.as_i64().unwrap()),
250            ),
251        }));
252    }
253    if value.is_object() {
254        let value_obj = value.as_object().unwrap();
255        // value_obj should have exactly one key.
256        if value_obj.len() != 1 {
257            return Err(WhereValidationError::WhereClause);
258        }
259        let (operator, operand) = value_obj.iter().next().unwrap();
260        if operand.is_array() {
261            let set_operator;
262            if operator == "$in" {
263                set_operator = crate::SetOperator::In;
264            } else if operator == "$nin" {
265                set_operator = crate::SetOperator::NotIn;
266            } else {
267                return Err(WhereValidationError::WhereClause);
268            }
269            let operand = operand.as_array().unwrap();
270            if operand.is_empty() {
271                return Err(WhereValidationError::WhereClause);
272            }
273            if operand[0].is_string() {
274                let operand_str = operand
275                    .iter()
276                    .map(|val| {
277                        val.as_str()
278                            .ok_or(WhereValidationError::WhereClause)
279                            .map(|s| s.to_string())
280                    })
281                    .collect::<Result<Vec<String>, _>>()?;
282                return Ok(Where::Metadata(MetadataExpression {
283                    key: key.clone(),
284                    comparison: crate::MetadataComparison::Set(
285                        set_operator,
286                        crate::MetadataSetValue::Str(operand_str),
287                    ),
288                }));
289            }
290            if operand[0].is_boolean() {
291                let operand_bool = operand
292                    .iter()
293                    .map(|val| val.as_bool().ok_or(WhereValidationError::WhereClause))
294                    .collect::<Result<Vec<bool>, _>>()?;
295                return Ok(Where::Metadata(MetadataExpression {
296                    key: key.clone(),
297                    comparison: crate::MetadataComparison::Set(
298                        set_operator,
299                        crate::MetadataSetValue::Bool(operand_bool),
300                    ),
301                }));
302            }
303            if operand[0].is_f64() {
304                let operand_f64 = operand
305                    .iter()
306                    .map(|val| val.as_f64().ok_or(WhereValidationError::WhereClause))
307                    .collect::<Result<Vec<f64>, _>>()?;
308                return Ok(Where::Metadata(MetadataExpression {
309                    key: key.clone(),
310                    comparison: crate::MetadataComparison::Set(
311                        set_operator,
312                        crate::MetadataSetValue::Float(operand_f64),
313                    ),
314                }));
315            }
316            if operand[0].is_i64() {
317                let operand_i64 = operand
318                    .iter()
319                    .map(|val| val.as_i64().ok_or(WhereValidationError::WhereClause))
320                    .collect::<Result<Vec<i64>, _>>()?;
321                return Ok(Where::Metadata(MetadataExpression {
322                    key: key.clone(),
323                    comparison: crate::MetadataComparison::Set(
324                        set_operator,
325                        crate::MetadataSetValue::Int(operand_i64),
326                    ),
327                }));
328            }
329            return Err(WhereValidationError::WhereClause);
330        }
331        if operand.is_string() {
332            let operand_str = operand.as_str().unwrap();
333            // $contains/$not_contains on the "#document" key are document
334            // search operators. On any other key they are metadata array
335            // contains operators.
336            if operator == "$contains" || operator == "$not_contains" {
337                if key == "#document" {
338                    let doc_op = if operator == "$contains" {
339                        DocumentOperator::Contains
340                    } else {
341                        DocumentOperator::NotContains
342                    };
343                    return Ok(Where::Document(crate::DocumentExpression {
344                        operator: doc_op,
345                        pattern: operand_str.to_string(),
346                    }));
347                }
348                let contains_op = if operator == "$contains" {
349                    ContainsOperator::Contains
350                } else {
351                    ContainsOperator::NotContains
352                };
353                return Ok(Where::Metadata(MetadataExpression {
354                    key: key.clone(),
355                    comparison: crate::MetadataComparison::ArrayContains(
356                        contains_op,
357                        crate::MetadataValue::Str(operand_str.to_string()),
358                    ),
359                }));
360            }
361            if operator == "$regex" || operator == "$not_regex" {
362                // Regex operators are only valid on document content.
363                if key != "#document" {
364                    return Err(WhereValidationError::WhereClause);
365                }
366                ChromaRegex::try_from(operand_str.to_string())?;
367                let doc_op = if operator == "$regex" {
368                    DocumentOperator::Regex
369                } else {
370                    DocumentOperator::NotRegex
371                };
372                return Ok(Where::Document(crate::DocumentExpression {
373                    operator: doc_op,
374                    pattern: operand_str.to_string(),
375                }));
376            }
377            let operator_type;
378            if operator == "$eq" {
379                operator_type = PrimitiveOperator::Equal;
380            } else if operator == "$ne" {
381                operator_type = PrimitiveOperator::NotEqual;
382            } else {
383                return Err(WhereValidationError::WhereClause);
384            }
385            return Ok(Where::Metadata(MetadataExpression {
386                key: key.clone(),
387                comparison: crate::MetadataComparison::Primitive(
388                    operator_type,
389                    crate::MetadataValue::Str(operand_str.to_string()),
390                ),
391            }));
392        }
393        if operand.is_boolean() {
394            let operand_bool = operand.as_bool().unwrap();
395            if let Some(contains_op) = parse_contains_operator(operator) {
396                // $contains/$not_contains on "#document" requires a string operand.
397                if key == "#document" {
398                    return Err(WhereValidationError::WhereClause);
399                }
400                return Ok(Where::Metadata(MetadataExpression {
401                    key: key.clone(),
402                    comparison: crate::MetadataComparison::ArrayContains(
403                        contains_op,
404                        crate::MetadataValue::Bool(operand_bool),
405                    ),
406                }));
407            }
408            let operator_type;
409            if operator == "$eq" {
410                operator_type = PrimitiveOperator::Equal;
411            } else if operator == "$ne" {
412                operator_type = PrimitiveOperator::NotEqual;
413            } else {
414                return Err(WhereValidationError::WhereClause);
415            }
416            return Ok(Where::Metadata(MetadataExpression {
417                key: key.clone(),
418                comparison: crate::MetadataComparison::Primitive(
419                    operator_type,
420                    crate::MetadataValue::Bool(operand_bool),
421                ),
422            }));
423        }
424        if operand.is_f64() {
425            let operand_f64 = operand.as_f64().unwrap();
426            if let Some(contains_op) = parse_contains_operator(operator) {
427                // $contains/$not_contains on "#document" requires a string operand.
428                if key == "#document" {
429                    return Err(WhereValidationError::WhereClause);
430                }
431                return Ok(Where::Metadata(MetadataExpression {
432                    key: key.clone(),
433                    comparison: crate::MetadataComparison::ArrayContains(
434                        contains_op,
435                        crate::MetadataValue::Float(operand_f64),
436                    ),
437                }));
438            }
439            let operator_type;
440            if operator == "$eq" {
441                operator_type = PrimitiveOperator::Equal;
442            } else if operator == "$ne" {
443                operator_type = PrimitiveOperator::NotEqual;
444            } else if operator == "$lt" {
445                operator_type = PrimitiveOperator::LessThan;
446            } else if operator == "$lte" {
447                operator_type = PrimitiveOperator::LessThanOrEqual;
448            } else if operator == "$gt" {
449                operator_type = PrimitiveOperator::GreaterThan;
450            } else if operator == "$gte" {
451                operator_type = PrimitiveOperator::GreaterThanOrEqual;
452            } else {
453                return Err(WhereValidationError::WhereClause);
454            }
455            return Ok(Where::Metadata(MetadataExpression {
456                key: key.clone(),
457                comparison: crate::MetadataComparison::Primitive(
458                    operator_type,
459                    crate::MetadataValue::Float(operand_f64),
460                ),
461            }));
462        }
463        if operand.is_i64() {
464            let operand_i64 = operand.as_i64().unwrap();
465            if let Some(contains_op) = parse_contains_operator(operator) {
466                // $contains/$not_contains on "#document" requires a string operand.
467                if key == "#document" {
468                    return Err(WhereValidationError::WhereClause);
469                }
470                return Ok(Where::Metadata(MetadataExpression {
471                    key: key.clone(),
472                    comparison: crate::MetadataComparison::ArrayContains(
473                        contains_op,
474                        crate::MetadataValue::Int(operand_i64),
475                    ),
476                }));
477            }
478            let operator_type;
479            if operator == "$eq" {
480                operator_type = PrimitiveOperator::Equal;
481            } else if operator == "$ne" {
482                operator_type = PrimitiveOperator::NotEqual;
483            } else if operator == "$lt" {
484                operator_type = PrimitiveOperator::LessThan;
485            } else if operator == "$lte" {
486                operator_type = PrimitiveOperator::LessThanOrEqual;
487            } else if operator == "$gt" {
488                operator_type = PrimitiveOperator::GreaterThan;
489            } else if operator == "$gte" {
490                operator_type = PrimitiveOperator::GreaterThanOrEqual;
491            } else {
492                return Err(WhereValidationError::WhereClause);
493            }
494            return Ok(Where::Metadata(MetadataExpression {
495                key: key.clone(),
496                comparison: crate::MetadataComparison::Primitive(
497                    operator_type,
498                    crate::MetadataValue::Int(operand_i64),
499                ),
500            }));
501        }
502        return Err(WhereValidationError::WhereClause);
503    }
504    Err(WhereValidationError::WhereClause)
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use serde_json::json;
511
512    #[test]
513    fn test_parse_where_direct_eq() {
514        let payload = json!({
515          "key1": "value1"
516        });
517        let expected_result = Where::Metadata(MetadataExpression {
518            key: "key1".to_string(),
519            comparison: crate::MetadataComparison::Primitive(
520                PrimitiveOperator::Equal,
521                crate::MetadataValue::Str("value1".to_string()),
522            ),
523        });
524
525        let result = parse_where(&payload).expect("This clause to parse successfully");
526        assert_eq!(result, expected_result);
527    }
528
529    // TODO: add a proptest when there's an Arbitrary impl for Where and WhereDocument
530    #[test]
531    fn test_parse_where_document() {
532        let payloads = [
533            // $contains
534            json!({
535              "$and": [
536                  {"$contains": "value1"},
537                  {"$or": [
538                      {"$contains": "value2"},
539                      {"$contains": "value3"}
540                  ]}
541              ]
542            }),
543            // $not_contains
544            json!({
545              "$not_contains": "value1",
546            }),
547        ];
548
549        let expected_results = [
550            // $contains
551            Where::Composite(CompositeExpression {
552                operator: crate::BooleanOperator::And,
553                children: vec![
554                    Where::Document(crate::DocumentExpression {
555                        operator: DocumentOperator::Contains,
556                        pattern: "value1".to_string(),
557                    }),
558                    Where::Composite(CompositeExpression {
559                        operator: crate::BooleanOperator::Or,
560                        children: vec![
561                            Where::Document(crate::DocumentExpression {
562                                operator: DocumentOperator::Contains,
563                                pattern: "value2".to_string(),
564                            }),
565                            Where::Document(crate::DocumentExpression {
566                                operator: DocumentOperator::Contains,
567                                pattern: "value3".to_string(),
568                            }),
569                        ],
570                    }),
571                ],
572            }),
573            // $not_contains
574            Where::Document(crate::DocumentExpression {
575                operator: DocumentOperator::NotContains,
576                pattern: "value1".to_string(),
577            }),
578        ];
579
580        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
581            let result = parse_where_document(payload);
582            assert!(
583                result.is_ok(),
584                "Parsing failed for payload: {}: {:?}",
585                serde_json::to_string_pretty(payload).unwrap(),
586                result
587            );
588            assert_eq!(
589                result.unwrap(),
590                *expected_result,
591                "Parsed result did not match expected result: {}",
592                serde_json::to_string_pretty(payload).unwrap(),
593            );
594        }
595    }
596
597    #[test]
598    fn test_parse_where() {
599        let payloads = [
600            // $in
601            json!({
602              "key1": {"$in": ["value1", "value2", "value3"]}
603            }),
604            // $nin
605            json!({
606              "key1": {"$nin": ["value1", "value2", "value3"]}
607            }),
608            // $eq
609            json!({
610              "key1": {"$eq": "value1"}
611            }),
612            // $ne
613            json!({
614              "key1": {"$ne": "value1"}
615            }),
616        ];
617
618        let expected_results = [
619            // $in
620            Where::Metadata(MetadataExpression {
621                key: "key1".to_string(),
622                comparison: crate::MetadataComparison::Set(
623                    crate::SetOperator::In,
624                    crate::MetadataSetValue::Str(vec![
625                        "value1".to_string(),
626                        "value2".to_string(),
627                        "value3".to_string(),
628                    ]),
629                ),
630            }),
631            // $nin
632            Where::Metadata(MetadataExpression {
633                key: "key1".to_string(),
634                comparison: crate::MetadataComparison::Set(
635                    crate::SetOperator::NotIn,
636                    crate::MetadataSetValue::Str(vec![
637                        "value1".to_string(),
638                        "value2".to_string(),
639                        "value3".to_string(),
640                    ]),
641                ),
642            }),
643            // $eq
644            Where::Metadata(MetadataExpression {
645                key: "key1".to_string(),
646                comparison: crate::MetadataComparison::Primitive(
647                    PrimitiveOperator::Equal,
648                    crate::MetadataValue::Str("value1".to_string()),
649                ),
650            }),
651            // $ne
652            Where::Metadata(MetadataExpression {
653                key: "key1".to_string(),
654                comparison: crate::MetadataComparison::Primitive(
655                    PrimitiveOperator::NotEqual,
656                    crate::MetadataValue::Str("value1".to_string()),
657                ),
658            }),
659        ];
660
661        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
662            let result = parse_where(payload);
663            assert!(
664                result.is_ok(),
665                "Parsing failed for payload: {}: {:?}",
666                serde_json::to_string_pretty(payload).unwrap(),
667                result
668            );
669            assert_eq!(
670                result.unwrap(),
671                *expected_result,
672                "Parsed result did not match expected result: {}",
673                serde_json::to_string_pretty(payload).unwrap(),
674            );
675        }
676    }
677
678    #[test]
679    fn test_parse_where_contains_metadata() {
680        // $contains on a metadata key should produce MetadataComparison::Contains,
681        // NOT a DocumentExpression.
682        let payloads = [
683            // string contains
684            json!({"tags": {"$contains": "action"}}),
685            // string not_contains
686            json!({"tags": {"$not_contains": "comedy"}}),
687            // int contains
688            json!({"scores": {"$contains": 42}}),
689            // float contains
690            json!({"ratings": {"$contains": 4.5}}),
691            // bool contains
692            json!({"flags": {"$contains": true}}),
693        ];
694
695        let expected_results = [
696            Where::Metadata(MetadataExpression {
697                key: "tags".to_string(),
698                comparison: crate::MetadataComparison::ArrayContains(
699                    ContainsOperator::Contains,
700                    crate::MetadataValue::Str("action".to_string()),
701                ),
702            }),
703            Where::Metadata(MetadataExpression {
704                key: "tags".to_string(),
705                comparison: crate::MetadataComparison::ArrayContains(
706                    ContainsOperator::NotContains,
707                    crate::MetadataValue::Str("comedy".to_string()),
708                ),
709            }),
710            Where::Metadata(MetadataExpression {
711                key: "scores".to_string(),
712                comparison: crate::MetadataComparison::ArrayContains(
713                    ContainsOperator::Contains,
714                    crate::MetadataValue::Int(42),
715                ),
716            }),
717            Where::Metadata(MetadataExpression {
718                key: "ratings".to_string(),
719                comparison: crate::MetadataComparison::ArrayContains(
720                    ContainsOperator::Contains,
721                    crate::MetadataValue::Float(4.5),
722                ),
723            }),
724            Where::Metadata(MetadataExpression {
725                key: "flags".to_string(),
726                comparison: crate::MetadataComparison::ArrayContains(
727                    ContainsOperator::Contains,
728                    crate::MetadataValue::Bool(true),
729                ),
730            }),
731        ];
732
733        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
734            let result = parse_where(payload);
735            assert!(
736                result.is_ok(),
737                "Parsing failed for payload: {}: {:?}",
738                serde_json::to_string_pretty(payload).unwrap(),
739                result
740            );
741            assert_eq!(
742                result.unwrap(),
743                *expected_result,
744                "Parsed result did not match expected result: {}",
745                serde_json::to_string_pretty(payload).unwrap(),
746            );
747        }
748    }
749
750    #[test]
751    fn test_parse_where_document_contains_in_where() {
752        // $contains on the "#document" key within a where clause should still
753        // produce a DocumentExpression for backwards compatibility.
754        let payload = json!({"#document": {"$contains": "search term"}});
755        let result = parse_where(&payload).expect("Should parse successfully");
756        assert_eq!(
757            result,
758            Where::Document(crate::DocumentExpression {
759                operator: DocumentOperator::Contains,
760                pattern: "search term".to_string(),
761            })
762        );
763    }
764
765    #[test]
766    fn test_parse_where_regex_only_on_document() {
767        // $regex / $not_regex are only valid on the "#document" key.
768        let payload = json!({"#document": {"$regex": "act.*"}});
769        let result = parse_where(&payload).expect("Should parse successfully");
770        assert_eq!(
771            result,
772            Where::Document(crate::DocumentExpression {
773                operator: DocumentOperator::Regex,
774                pattern: "act.*".to_string(),
775            })
776        );
777
778        let payload = json!({"#document": {"$not_regex": "draft.*"}});
779        let result = parse_where(&payload).expect("Should parse successfully");
780        assert_eq!(
781            result,
782            Where::Document(crate::DocumentExpression {
783                operator: DocumentOperator::NotRegex,
784                pattern: "draft.*".to_string(),
785            })
786        );
787
788        // $regex on a metadata key should be rejected.
789        let payload = json!({"tags": {"$regex": "act.*"}});
790        assert!(parse_where(&payload).is_err());
791
792        let payload = json!({"tags": {"$not_regex": "draft.*"}});
793        assert!(parse_where(&payload).is_err());
794    }
795
796    #[test]
797    fn test_where_contains_round_trip() {
798        // Verify that serializing a Contains expression and parsing it back
799        // produces the same result.
800        let original = Where::Metadata(MetadataExpression {
801            key: "tags".to_string(),
802            comparison: crate::MetadataComparison::ArrayContains(
803                ContainsOperator::Contains,
804                crate::MetadataValue::Str("action".to_string()),
805            ),
806        });
807        let json_str = serde_json::to_string(&original).unwrap();
808        let json_value: Value = serde_json::from_str(&json_str).unwrap();
809        let parsed = parse_where(&json_value).expect("Round-trip parsing should succeed");
810        assert_eq!(original, parsed);
811    }
812
813    #[test]
814    fn test_document_contains_rejects_non_string_operand() {
815        // $contains / $not_contains on "#document" must have a string operand.
816        // Non-string values should be rejected, not silently treated as metadata.
817        let payloads = [
818            json!({"#document": {"$contains": 42}}),
819            json!({"#document": {"$contains": 2.72}}),
820            json!({"#document": {"$contains": true}}),
821            json!({"#document": {"$not_contains": 42}}),
822            json!({"#document": {"$not_contains": false}}),
823        ];
824        for payload in &payloads {
825            let result = parse_where(payload);
826            assert!(
827                result.is_err(),
828                "Expected error for non-string #document contains, but got Ok for: {}",
829                serde_json::to_string_pretty(payload).unwrap(),
830            );
831        }
832    }
833
834    #[test]
835    fn test_parse_where_in_nin_typed_arrays() {
836        // $in / $nin with integer, boolean, and float arrays.
837        let payloads = [
838            // int $in
839            json!({"scores": {"$in": [1, 2, 3]}}),
840            // int $nin
841            json!({"scores": {"$nin": [10, 20]}}),
842            // bool $in
843            json!({"flags": {"$in": [true, false]}}),
844            // float $in
845            json!({"ratings": {"$in": [1.5, 2.5, 3.5]}}),
846        ];
847
848        let expected_results = [
849            Where::Metadata(MetadataExpression {
850                key: "scores".to_string(),
851                comparison: crate::MetadataComparison::Set(
852                    crate::SetOperator::In,
853                    crate::MetadataSetValue::Int(vec![1, 2, 3]),
854                ),
855            }),
856            Where::Metadata(MetadataExpression {
857                key: "scores".to_string(),
858                comparison: crate::MetadataComparison::Set(
859                    crate::SetOperator::NotIn,
860                    crate::MetadataSetValue::Int(vec![10, 20]),
861                ),
862            }),
863            Where::Metadata(MetadataExpression {
864                key: "flags".to_string(),
865                comparison: crate::MetadataComparison::Set(
866                    crate::SetOperator::In,
867                    crate::MetadataSetValue::Bool(vec![true, false]),
868                ),
869            }),
870            Where::Metadata(MetadataExpression {
871                key: "ratings".to_string(),
872                comparison: crate::MetadataComparison::Set(
873                    crate::SetOperator::In,
874                    crate::MetadataSetValue::Float(vec![1.5, 2.5, 3.5]),
875                ),
876            }),
877        ];
878
879        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
880            let result = parse_where(payload);
881            assert!(
882                result.is_ok(),
883                "Parsing failed for payload: {}: {:?}",
884                serde_json::to_string_pretty(payload).unwrap(),
885                result
886            );
887            assert_eq!(
888                result.unwrap(),
889                *expected_result,
890                "Parsed result did not match expected result: {}",
891                serde_json::to_string_pretty(payload).unwrap(),
892            );
893        }
894    }
895
896    #[test]
897    fn test_parse_where_in_mixed_types_rejected() {
898        // $in / $nin arrays with mixed types should be rejected because the
899        // parser requires all elements to match the type of the first element.
900        let payloads = [
901            json!({"key": {"$in": ["a", 1]}}),
902            json!({"key": {"$in": [1, "b"]}}),
903            json!({"key": {"$nin": [true, 1]}}),
904        ];
905        for payload in &payloads {
906            let result = parse_where(payload);
907            assert!(
908                result.is_err(),
909                "Expected error for mixed-type array, but got Ok for: {}",
910                serde_json::to_string_pretty(payload).unwrap(),
911            );
912        }
913    }
914
915    #[test]
916    fn test_parse_where_in_empty_array_rejected() {
917        // $in / $nin with an empty array should be rejected.
918        let payloads = [json!({"key": {"$in": []}}), json!({"key": {"$nin": []}})];
919        for payload in &payloads {
920            let result = parse_where(payload);
921            assert!(
922                result.is_err(),
923                "Expected error for empty array, but got Ok for: {}",
924                serde_json::to_string_pretty(payload).unwrap(),
925            );
926        }
927    }
928
929    #[test]
930    fn test_parse_where_contains_not_valid_with_array_operand() {
931        // $contains / $not_contains expect a scalar operand, not an array.
932        let payloads = [
933            json!({"tags": {"$contains": ["a", "b"]}}),
934            json!({"tags": {"$not_contains": [1, 2]}}),
935        ];
936        for payload in &payloads {
937            let result = parse_where(payload);
938            assert!(
939                result.is_err(),
940                "Expected error for array operand in $contains, but got Ok for: {}",
941                serde_json::to_string_pretty(payload).unwrap(),
942            );
943        }
944    }
945}