chroma_types/
where_parsing.rs

1use crate::regex::{ChromaRegex, ChromaRegexError};
2use crate::{CompositeExpression, DocumentOperator, MetadataExpression, PrimitiveOperator, Where};
3use chroma_error::{ChromaError, ErrorCodes};
4use serde::Deserialize;
5use serde::Serialize;
6use serde_json::Value;
7use thiserror::Error;
8
9#[derive(Deserialize, Debug, Clone, Serialize)]
10#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
11pub struct RawWhereFields {
12    #[serde(default)]
13    r#where: Value,
14    #[serde(default)]
15    where_document: Value,
16}
17
18impl RawWhereFields {
19    pub fn new(r#where: Value, where_document: Value) -> Self {
20        Self {
21            r#where,
22            where_document,
23        }
24    }
25
26    pub fn from_json_str(
27        r#where: Option<&str>,
28        where_document: Option<&str>,
29    ) -> Result<Self, WhereValidationError> {
30        let r#where = r#where
31            .map(|r#where| {
32                serde_json::from_str(r#where).map_err(|_| WhereValidationError::WhereClause)
33            })
34            .transpose()?
35            .unwrap_or(Value::Null);
36
37        let where_document = where_document
38            .map(|where_document| {
39                serde_json::from_str(where_document)
40                    .map_err(|_| WhereValidationError::WhereDocumentClause)
41            })
42            .transpose()?
43            .unwrap_or(Value::Null);
44
45        Ok(Self {
46            r#where,
47            where_document,
48        })
49    }
50}
51
52#[derive(Error, Debug)]
53pub enum WhereValidationError {
54    #[error(transparent)]
55    Regex(#[from] ChromaRegexError),
56    #[error("Invalid where clause")]
57    WhereClause,
58    #[error("Invalid where document clause")]
59    WhereDocumentClause,
60}
61
62impl ChromaError for WhereValidationError {
63    fn code(&self) -> chroma_error::ErrorCodes {
64        ErrorCodes::InvalidArgument
65    }
66}
67
68impl RawWhereFields {
69    pub fn parse(self) -> Result<Option<Where>, WhereValidationError> {
70        let mut where_clause = None;
71        if !self.r#where.is_null() {
72            let where_payload = &self.r#where;
73            where_clause = Some(parse_where(where_payload)?);
74        }
75        let mut where_document_clause = None;
76        if !self.where_document.is_null() {
77            let where_document_payload = &self.where_document;
78            where_document_clause = Some(parse_where_document(where_document_payload)?);
79        }
80        let combined_where = match where_clause {
81            Some(where_clause) => match where_document_clause {
82                Some(where_document_clause) => Some(Where::Composite(CompositeExpression {
83                    operator: crate::BooleanOperator::And,
84                    children: vec![where_clause, where_document_clause],
85                })),
86                None => Some(where_clause),
87            },
88            None => where_document_clause,
89        };
90
91        Ok(combined_where)
92    }
93}
94
95pub fn parse_where_document(json_payload: &Value) -> Result<Where, WhereValidationError> {
96    let where_doc_payload = json_payload
97        .as_object()
98        .ok_or(WhereValidationError::WhereDocumentClause)?;
99    if where_doc_payload.len() != 1 {
100        return Err(WhereValidationError::WhereDocumentClause);
101    }
102    let (key, value) = where_doc_payload.iter().next().unwrap();
103    // Check if it is a composite expression.
104    if key == "$and" {
105        let logical_operator = crate::BooleanOperator::And;
106        // Check that the value is list type.
107        let children = value
108            .as_array()
109            .ok_or(WhereValidationError::WhereDocumentClause)?;
110        let mut predicate_list = vec![];
111        // Recursively parse the children.
112        for child in children {
113            predicate_list.push(parse_where_document(child)?);
114        }
115        return Ok(Where::Composite(CompositeExpression {
116            operator: logical_operator,
117            children: predicate_list,
118        }));
119    }
120    if key == "$or" {
121        let logical_operator = crate::BooleanOperator::Or;
122        // Check that the value is list type.
123        let children = value
124            .as_array()
125            .ok_or(WhereValidationError::WhereDocumentClause)?;
126        let mut predicate_list = vec![];
127        // Recursively parse the children.
128        for child in children {
129            predicate_list.push(parse_where_document(child)?);
130        }
131        return Ok(Where::Composite(CompositeExpression {
132            operator: logical_operator,
133            children: predicate_list,
134        }));
135    }
136    if !value.is_string() {
137        return Err(WhereValidationError::WhereDocumentClause);
138    }
139    let value_str = value.as_str().unwrap();
140    let operator_type = match key.as_str() {
141        "$contains" => DocumentOperator::Contains,
142        "$not_contains" => DocumentOperator::NotContains,
143        "$regex" => DocumentOperator::Regex,
144        "$not_regex" => DocumentOperator::NotRegex,
145        _ => return Err(WhereValidationError::WhereDocumentClause),
146    };
147    if matches!(
148        operator_type,
149        DocumentOperator::Regex | DocumentOperator::NotRegex
150    ) {
151        ChromaRegex::try_from(value_str.to_string())?;
152    }
153    Ok(Where::Document(crate::DocumentExpression {
154        operator: operator_type,
155        pattern: value_str.to_string(),
156    }))
157}
158
159pub fn parse_where(json_payload: &Value) -> Result<Where, WhereValidationError> {
160    let where_payload = json_payload
161        .as_object()
162        .ok_or(WhereValidationError::WhereClause)?;
163    if where_payload.len() != 1 {
164        return Err(WhereValidationError::WhereClause);
165    }
166    let (key, value) = where_payload.iter().next().unwrap();
167    // Check if it is a composite expression.
168    if key == "$and" {
169        let logical_operator = crate::BooleanOperator::And;
170        // Check that the value is list type.
171        let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
172        let mut predicate_list = vec![];
173        // Recursively parse the children.
174        for child in children {
175            predicate_list.push(parse_where(child)?);
176        }
177        return Ok(Where::Composite(CompositeExpression {
178            operator: logical_operator,
179            children: predicate_list,
180        }));
181    }
182    if key == "$or" {
183        let logical_operator = crate::BooleanOperator::Or;
184        // Check that the value is list type.
185        let children = value.as_array().ok_or(WhereValidationError::WhereClause)?;
186        let mut predicate_list = vec![];
187        // Recursively parse the children.
188        for child in children {
189            predicate_list.push(parse_where(child)?);
190        }
191        return Ok(Where::Composite(CompositeExpression {
192            operator: logical_operator,
193            children: predicate_list,
194        }));
195    }
196    // At this point we know we're at a direct comparison. It can either
197    // be of the form {"key": "value"} or {"key": {"$operator": "value"}}.
198    if value.is_string() {
199        return Ok(Where::Metadata(MetadataExpression {
200            key: key.clone(),
201            comparison: crate::MetadataComparison::Primitive(
202                crate::PrimitiveOperator::Equal,
203                crate::MetadataValue::Str(value.as_str().unwrap().to_string()),
204            ),
205        }));
206    }
207    if value.is_boolean() {
208        return Ok(Where::Metadata(MetadataExpression {
209            key: key.clone(),
210            comparison: crate::MetadataComparison::Primitive(
211                crate::PrimitiveOperator::Equal,
212                crate::MetadataValue::Bool(value.as_bool().unwrap()),
213            ),
214        }));
215    }
216    if value.is_f64() {
217        return Ok(Where::Metadata(MetadataExpression {
218            key: key.clone(),
219            comparison: crate::MetadataComparison::Primitive(
220                crate::PrimitiveOperator::Equal,
221                crate::MetadataValue::Float(value.as_f64().unwrap()),
222            ),
223        }));
224    }
225    if value.is_i64() {
226        return Ok(Where::Metadata(MetadataExpression {
227            key: key.clone(),
228            comparison: crate::MetadataComparison::Primitive(
229                crate::PrimitiveOperator::Equal,
230                crate::MetadataValue::Int(value.as_i64().unwrap()),
231            ),
232        }));
233    }
234    if value.is_object() {
235        let value_obj = value.as_object().unwrap();
236        // value_obj should have exactly one key.
237        if value_obj.len() != 1 {
238            return Err(WhereValidationError::WhereClause);
239        }
240        let (operator, operand) = value_obj.iter().next().unwrap();
241        if operand.is_array() {
242            let set_operator;
243            if operator == "$in" {
244                set_operator = crate::SetOperator::In;
245            } else if operator == "$nin" {
246                set_operator = crate::SetOperator::NotIn;
247            } else {
248                return Err(WhereValidationError::WhereClause);
249            }
250            let operand = operand.as_array().unwrap();
251            if operand.is_empty() {
252                return Err(WhereValidationError::WhereClause);
253            }
254            if operand[0].is_string() {
255                let operand_str = operand
256                    .iter()
257                    .map(|val| {
258                        val.as_str()
259                            .ok_or(WhereValidationError::WhereClause)
260                            .map(|s| s.to_string())
261                    })
262                    .collect::<Result<Vec<String>, _>>()?;
263                return Ok(Where::Metadata(MetadataExpression {
264                    key: key.clone(),
265                    comparison: crate::MetadataComparison::Set(
266                        set_operator,
267                        crate::MetadataSetValue::Str(operand_str),
268                    ),
269                }));
270            }
271            if operand[0].is_boolean() {
272                let operand_bool = operand
273                    .iter()
274                    .map(|val| val.as_bool().ok_or(WhereValidationError::WhereClause))
275                    .collect::<Result<Vec<bool>, _>>()?;
276                return Ok(Where::Metadata(MetadataExpression {
277                    key: key.clone(),
278                    comparison: crate::MetadataComparison::Set(
279                        set_operator,
280                        crate::MetadataSetValue::Bool(operand_bool),
281                    ),
282                }));
283            }
284            if operand[0].is_f64() {
285                let operand_f64 = operand
286                    .iter()
287                    .map(|val| val.as_f64().ok_or(WhereValidationError::WhereClause))
288                    .collect::<Result<Vec<f64>, _>>()?;
289                return Ok(Where::Metadata(MetadataExpression {
290                    key: key.clone(),
291                    comparison: crate::MetadataComparison::Set(
292                        set_operator,
293                        crate::MetadataSetValue::Float(operand_f64),
294                    ),
295                }));
296            }
297            if operand[0].is_i64() {
298                let operand_i64 = operand
299                    .iter()
300                    .map(|val| val.as_i64().ok_or(WhereValidationError::WhereClause))
301                    .collect::<Result<Vec<i64>, _>>()?;
302                return Ok(Where::Metadata(MetadataExpression {
303                    key: key.clone(),
304                    comparison: crate::MetadataComparison::Set(
305                        set_operator,
306                        crate::MetadataSetValue::Int(operand_i64),
307                    ),
308                }));
309            }
310            return Err(WhereValidationError::WhereClause);
311        }
312        if operand.is_string() {
313            let operand_str = operand.as_str().unwrap();
314            let document_operator_type = match operator.as_str() {
315                "$contains" => Some(DocumentOperator::Contains),
316                "$not_contains" => Some(DocumentOperator::NotContains),
317                "$regex" => Some(DocumentOperator::Regex),
318                "$not_regex" => Some(DocumentOperator::NotRegex),
319                _ => None,
320            };
321            if let Some(doc_op) = document_operator_type {
322                if matches!(doc_op, DocumentOperator::Regex | DocumentOperator::NotRegex) {
323                    ChromaRegex::try_from(operand_str.to_string())?;
324                }
325                return Ok(Where::Document(crate::DocumentExpression {
326                    operator: doc_op,
327                    pattern: operand_str.to_string(),
328                }));
329            }
330            let operator_type;
331            if operator == "$eq" {
332                operator_type = PrimitiveOperator::Equal;
333            } else if operator == "$ne" {
334                operator_type = PrimitiveOperator::NotEqual;
335            } else {
336                return Err(WhereValidationError::WhereClause);
337            }
338            return Ok(Where::Metadata(MetadataExpression {
339                key: key.clone(),
340                comparison: crate::MetadataComparison::Primitive(
341                    operator_type,
342                    crate::MetadataValue::Str(operand_str.to_string()),
343                ),
344            }));
345        }
346        if operand.is_boolean() {
347            let operand_bool = operand.as_bool().unwrap();
348            let operator_type;
349            if operator == "$eq" {
350                operator_type = PrimitiveOperator::Equal;
351            } else if operator == "$ne" {
352                operator_type = PrimitiveOperator::NotEqual;
353            } else {
354                return Err(WhereValidationError::WhereClause);
355            }
356            return Ok(Where::Metadata(MetadataExpression {
357                key: key.clone(),
358                comparison: crate::MetadataComparison::Primitive(
359                    operator_type,
360                    crate::MetadataValue::Bool(operand_bool),
361                ),
362            }));
363        }
364        if operand.is_f64() {
365            let operand_f64 = operand.as_f64().unwrap();
366            let operator_type;
367            if operator == "$eq" {
368                operator_type = PrimitiveOperator::Equal;
369            } else if operator == "$ne" {
370                operator_type = PrimitiveOperator::NotEqual;
371            } else if operator == "$lt" {
372                operator_type = PrimitiveOperator::LessThan;
373            } else if operator == "$lte" {
374                operator_type = PrimitiveOperator::LessThanOrEqual;
375            } else if operator == "$gt" {
376                operator_type = PrimitiveOperator::GreaterThan;
377            } else if operator == "$gte" {
378                operator_type = PrimitiveOperator::GreaterThanOrEqual;
379            } else {
380                return Err(WhereValidationError::WhereClause);
381            }
382            return Ok(Where::Metadata(MetadataExpression {
383                key: key.clone(),
384                comparison: crate::MetadataComparison::Primitive(
385                    operator_type,
386                    crate::MetadataValue::Float(operand_f64),
387                ),
388            }));
389        }
390        if operand.is_i64() {
391            let operand_i64 = operand.as_i64().unwrap();
392            let operator_type;
393            if operator == "$eq" {
394                operator_type = PrimitiveOperator::Equal;
395            } else if operator == "$ne" {
396                operator_type = PrimitiveOperator::NotEqual;
397            } else if operator == "$lt" {
398                operator_type = PrimitiveOperator::LessThan;
399            } else if operator == "$lte" {
400                operator_type = PrimitiveOperator::LessThanOrEqual;
401            } else if operator == "$gt" {
402                operator_type = PrimitiveOperator::GreaterThan;
403            } else if operator == "$gte" {
404                operator_type = PrimitiveOperator::GreaterThanOrEqual;
405            } else {
406                return Err(WhereValidationError::WhereClause);
407            }
408            return Ok(Where::Metadata(MetadataExpression {
409                key: key.clone(),
410                comparison: crate::MetadataComparison::Primitive(
411                    operator_type,
412                    crate::MetadataValue::Int(operand_i64),
413                ),
414            }));
415        }
416        return Err(WhereValidationError::WhereClause);
417    }
418    Err(WhereValidationError::WhereClause)
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use serde_json::json;
425
426    #[test]
427    fn test_parse_where_direct_eq() {
428        let payload = json!({
429          "key1": "value1"
430        });
431        let expected_result = Where::Metadata(MetadataExpression {
432            key: "key1".to_string(),
433            comparison: crate::MetadataComparison::Primitive(
434                PrimitiveOperator::Equal,
435                crate::MetadataValue::Str("value1".to_string()),
436            ),
437        });
438
439        let result = parse_where(&payload).expect("This clause to parse successfully");
440        assert_eq!(result, expected_result);
441    }
442
443    // TODO: add a proptest when there's an Arbitrary impl for Where and WhereDocument
444    #[test]
445    fn test_parse_where_document() {
446        let payloads = [
447            // $contains
448            json!({
449              "$and": [
450                  {"$contains": "value1"},
451                  {"$or": [
452                      {"$contains": "value2"},
453                      {"$contains": "value3"}
454                  ]}
455              ]
456            }),
457            // $not_contains
458            json!({
459              "$not_contains": "value1",
460            }),
461        ];
462
463        let expected_results = [
464            // $contains
465            Where::Composite(CompositeExpression {
466                operator: crate::BooleanOperator::And,
467                children: vec![
468                    Where::Document(crate::DocumentExpression {
469                        operator: DocumentOperator::Contains,
470                        pattern: "value1".to_string(),
471                    }),
472                    Where::Composite(CompositeExpression {
473                        operator: crate::BooleanOperator::Or,
474                        children: vec![
475                            Where::Document(crate::DocumentExpression {
476                                operator: DocumentOperator::Contains,
477                                pattern: "value2".to_string(),
478                            }),
479                            Where::Document(crate::DocumentExpression {
480                                operator: DocumentOperator::Contains,
481                                pattern: "value3".to_string(),
482                            }),
483                        ],
484                    }),
485                ],
486            }),
487            // $not_contains
488            Where::Document(crate::DocumentExpression {
489                operator: DocumentOperator::NotContains,
490                pattern: "value1".to_string(),
491            }),
492        ];
493
494        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
495            let result = parse_where_document(payload);
496            assert!(
497                result.is_ok(),
498                "Parsing failed for payload: {}: {:?}",
499                serde_json::to_string_pretty(payload).unwrap(),
500                result
501            );
502            assert_eq!(
503                result.unwrap(),
504                *expected_result,
505                "Parsed result did not match expected result: {}",
506                serde_json::to_string_pretty(payload).unwrap(),
507            );
508        }
509    }
510
511    #[test]
512    fn test_parse_where() {
513        let payloads = [
514            // $in
515            json!({
516              "key1": {"$in": ["value1", "value2", "value3"]}
517            }),
518            // $nin
519            json!({
520              "key1": {"$nin": ["value1", "value2", "value3"]}
521            }),
522            // $eq
523            json!({
524              "key1": {"$eq": "value1"}
525            }),
526            // $ne
527            json!({
528              "key1": {"$ne": "value1"}
529            }),
530        ];
531
532        let expected_results = [
533            // $in
534            Where::Metadata(MetadataExpression {
535                key: "key1".to_string(),
536                comparison: crate::MetadataComparison::Set(
537                    crate::SetOperator::In,
538                    crate::MetadataSetValue::Str(vec![
539                        "value1".to_string(),
540                        "value2".to_string(),
541                        "value3".to_string(),
542                    ]),
543                ),
544            }),
545            // $nin
546            Where::Metadata(MetadataExpression {
547                key: "key1".to_string(),
548                comparison: crate::MetadataComparison::Set(
549                    crate::SetOperator::NotIn,
550                    crate::MetadataSetValue::Str(vec![
551                        "value1".to_string(),
552                        "value2".to_string(),
553                        "value3".to_string(),
554                    ]),
555                ),
556            }),
557            // $eq
558            Where::Metadata(MetadataExpression {
559                key: "key1".to_string(),
560                comparison: crate::MetadataComparison::Primitive(
561                    PrimitiveOperator::Equal,
562                    crate::MetadataValue::Str("value1".to_string()),
563                ),
564            }),
565            // $ne
566            Where::Metadata(MetadataExpression {
567                key: "key1".to_string(),
568                comparison: crate::MetadataComparison::Primitive(
569                    PrimitiveOperator::NotEqual,
570                    crate::MetadataValue::Str("value1".to_string()),
571                ),
572            }),
573        ];
574
575        for (payload, expected_result) in payloads.iter().zip(expected_results.iter()) {
576            let result = parse_where(payload);
577            assert!(
578                result.is_ok(),
579                "Parsing failed for payload: {}: {:?}",
580                serde_json::to_string_pretty(payload).unwrap(),
581                result
582            );
583            assert_eq!(
584                result.unwrap(),
585                *expected_result,
586                "Parsed result did not match expected result: {}",
587                serde_json::to_string_pretty(payload).unwrap(),
588            );
589        }
590    }
591}