crdb_core/
query.rs

1use rust_decimal::Decimal;
2
3use crate::fts;
4
5#[derive(Debug, serde::Deserialize, serde::Serialize)]
6#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
7pub enum JsonPathItem {
8    Key(String),
9
10    /// Negative values count from the end
11    // PostgreSQL throws an error if trying to use -> with a value beyond i32 range
12    Id(i32),
13}
14
15#[derive(Debug, serde::Deserialize, serde::Serialize)]
16pub enum Query {
17    // Logic operators
18    All(Vec<Query>),
19    Any(Vec<Query>),
20    Not(Box<Query>),
21
22    // TODO(misc-low): this could be useful?
23    // Any/all the values in the array at JsonPathItem must match Query
24    // AnyIn(Vec<JsonPathItem>, Box<Query>),
25    // AllIn(Vec<JsonPathItem>, Box<Query>),
26
27    // JSON tests
28    // TODO(api-high): allow comparing JsonPath with JsonPath
29    Eq(Vec<JsonPathItem>, serde_json::Value),
30
31    // Integers
32    Le(Vec<JsonPathItem>, Decimal),
33    Lt(Vec<JsonPathItem>, Decimal),
34    Ge(Vec<JsonPathItem>, Decimal),
35    Gt(Vec<JsonPathItem>, Decimal),
36
37    // Arrays and object containment
38    Contains(Vec<JsonPathItem>, serde_json::Value),
39
40    // Full text search
41    ContainsStr(Vec<JsonPathItem>, String),
42}
43
44#[cfg(feature = "arbitrary")]
45impl<'a> arbitrary::Arbitrary<'a> for Query {
46    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Query> {
47        arbitrary_impl(u, 0)
48    }
49}
50
51#[cfg(feature = "arbitrary")]
52fn arbitrary_impl(u: &mut arbitrary::Unstructured<'_>, depth: usize) -> arbitrary::Result<Query> {
53    if u.is_empty() || depth > 50 {
54        // avoid stack overflow in arbitrary
55        return Ok(Query::Eq(Vec::new(), serde_json::Value::Null));
56    }
57    let res = match u.arbitrary::<u8>()? % 10 {
58        0 => Query::All({
59            let mut v = Vec::new();
60            u.arbitrary_loop(None, Some(50), |u| {
61                v.push(arbitrary_impl(u, depth + 1)?);
62                Ok(std::ops::ControlFlow::Continue(()))
63            })?;
64            v
65        }),
66        1 => Query::Any({
67            let mut v = Vec::new();
68            u.arbitrary_loop(None, Some(50), |u| {
69                v.push(arbitrary_impl(u, depth + 1)?);
70                Ok(std::ops::ControlFlow::Continue(()))
71            })?;
72            v
73        }),
74        2 => Query::Not(Box::new(arbitrary_impl(u, depth + 1)?)),
75        3 => Query::Eq(
76            u.arbitrary()?,
77            u.arbitrary::<arbitrary_json::ArbitraryValue>()?.into(),
78        ),
79        4 => Query::Le(u.arbitrary()?, u.arbitrary()?),
80        5 => Query::Lt(u.arbitrary()?, u.arbitrary()?),
81        6 => Query::Ge(u.arbitrary()?, u.arbitrary()?),
82        7 => Query::Gt(u.arbitrary()?, u.arbitrary()?),
83        8 => Query::Contains(
84            u.arbitrary()?,
85            u.arbitrary::<arbitrary_json::ArbitraryValue>()?.into(),
86        ),
87        9 => Query::ContainsStr(u.arbitrary()?, u.arbitrary()?),
88        _ => unimplemented!(),
89    };
90    Ok(res)
91}
92
93impl Query {
94    pub fn check(&self) -> crate::Result<()> {
95        match self {
96            Query::All(v) => {
97                for v in v {
98                    v.check()?;
99                }
100            }
101            Query::Any(v) => {
102                for v in v {
103                    v.check()?;
104                }
105            }
106            Query::Not(v) => v.check()?,
107            Query::Eq(p, v) => {
108                Self::check_path(p)?;
109                Self::check_value(v)?;
110            }
111            Query::Le(p, _) => Self::check_path(p)?,
112            Query::Lt(p, _) => Self::check_path(p)?,
113            Query::Ge(p, _) => Self::check_path(p)?,
114            Query::Gt(p, _) => Self::check_path(p)?,
115            Query::Contains(p, v) => {
116                Self::check_path(p)?;
117                Self::check_value(v)?;
118            }
119            Query::ContainsStr(p, s) => {
120                Self::check_path(p)?;
121                crate::check_string(s)?;
122            }
123        }
124
125        Ok(())
126    }
127
128    fn check_value(v: &serde_json::Value) -> crate::Result<()> {
129        match v {
130            serde_json::Value::Null => (),
131            serde_json::Value::Bool(_) => (),
132            serde_json::Value::Number(_) => (),
133            serde_json::Value::String(s) => crate::check_string(s)?,
134            serde_json::Value::Array(v) => {
135                for v in v.iter() {
136                    Self::check_value(v)?;
137                }
138            }
139            serde_json::Value::Object(m) => {
140                for (k, v) in m.iter() {
141                    crate::check_string(k)?;
142                    Self::check_value(v)?;
143                }
144            }
145        }
146        Ok(())
147    }
148
149    fn check_path(p: &[JsonPathItem]) -> crate::Result<()> {
150        p.iter().try_for_each(|i| match i {
151            JsonPathItem::Id(_) => Ok(()),
152            JsonPathItem::Key(k) => crate::check_string(k),
153        })
154    }
155
156    pub fn matches<T: serde::Serialize>(&self, v: T) -> serde_json::Result<bool> {
157        let json = serde_json::to_value(v)?;
158        Ok(self.matches_json(&json))
159    }
160
161    pub fn matches_json(&self, v: &serde_json::Value) -> bool {
162        match self {
163            Query::All(q) => q.iter().all(|q| q.matches_json(v)),
164            Query::Any(q) => q.iter().any(|q| q.matches_json(v)),
165            Query::Not(q) => !q.matches_json(v),
166            Query::Eq(p, to) => Self::deref(v, p)
167                .map(|v| Self::compare_with_nums(v, to))
168                .unwrap_or(false),
169            Query::Le(p, to) => Self::deref_num(v, p).map(|n| n <= *to).unwrap_or(false),
170            Query::Lt(p, to) => Self::deref_num(v, p).map(|n| n < *to).unwrap_or(false),
171            Query::Ge(p, to) => Self::deref_num(v, p).map(|n| n >= *to).unwrap_or(false),
172            Query::Gt(p, to) => Self::deref_num(v, p).map(|n| n > *to).unwrap_or(false),
173            Query::Contains(p, pat) => {
174                let Some(v) = Self::deref(v, p) else {
175                    return false;
176                };
177                Self::contains(v, pat)
178            }
179            Query::ContainsStr(p, pat) => Self::deref(v, p)
180                .and_then(|v| v.as_object())
181                .and_then(|v| v.get("_crdb-normalized"))
182                .and_then(|s| s.as_str())
183                .map(|s| fts::matches(s, &fts::normalize(pat)))
184                .unwrap_or(false),
185        }
186    }
187
188    fn compare_with_nums(l: &serde_json::Value, r: &serde_json::Value) -> bool {
189        use serde_json::Value::*;
190        match (l, r) {
191            (Null, Null) => true,
192            (Bool(l), Bool(r)) => l == r,
193            (l @ Number(_), r @ Number(_)) => {
194                let normalized_l = serde_json::from_value::<Decimal>(l.clone());
195                normalized_l.is_ok()
196                    && normalized_l.ok() == serde_json::from_value::<Decimal>(r.clone()).ok()
197            }
198            (String(l), String(r)) => l == r,
199            (Array(l), Array(r)) => {
200                l.len() == r.len()
201                    && l.iter()
202                        .zip(r.iter())
203                        .all(|(l, r)| Self::compare_with_nums(l, r))
204            }
205            (Object(l), Object(r)) => {
206                l.len() == r.len()
207                    && l.iter()
208                        .zip(r.iter())
209                        .all(|((lk, lv), (rk, rv))| lk == rk && Self::compare_with_nums(lv, rv))
210            }
211            _ => false,
212        }
213    }
214
215    fn contains(v: &serde_json::Value, pat: &serde_json::Value) -> bool {
216        use serde_json::Value::*;
217        match (v, pat) {
218            (Null, Null) => true,
219            (Bool(l), Bool(r)) => l == r,
220            (l @ Number(_), r @ Number(_)) => Self::compare_with_nums(l, r),
221            (String(l), String(r)) => l == r,
222            (Object(v), Object(pat)) => {
223                for (key, pat) in pat.iter() {
224                    if !v.get(key).map(|v| Self::contains(v, pat)).unwrap_or(false) {
225                        return false;
226                    }
227                }
228                true
229            }
230            (Array(v), Array(pat)) => {
231                for pat in pat.iter() {
232                    if !v.iter().any(|v| Self::contains(v, pat)) {
233                        return false;
234                    }
235                }
236                true
237            }
238            (Array(_), Object(_)) => false, // primitive containment doesn't work on objects
239            (Array(v), pat) => v.iter().any(|v| Self::compare_with_nums(v, pat)), // but does work on primitives
240            _ => false,
241        }
242    }
243
244    fn deref_num(v: &serde_json::Value, path: &[JsonPathItem]) -> Option<Decimal> {
245        use serde_json::Value;
246        match Self::deref(v, path)? {
247            Value::Number(n) => serde_json::from_value(Value::Number(n.clone())).ok(),
248            _ => None,
249        }
250    }
251
252    fn deref<'a>(v: &'a serde_json::Value, path: &[JsonPathItem]) -> Option<&'a serde_json::Value> {
253        match path.first() {
254            None => Some(v),
255            Some(JsonPathItem::Key(k)) => match v.as_object() {
256                None => None,
257                Some(v) => v.get(k).and_then(|v| Self::deref(v, &path[1..])),
258            },
259            Some(JsonPathItem::Id(k)) if *k >= 0 => match v.as_array() {
260                None => None,
261                Some(v) => v.get(*k as usize).and_then(|v| Self::deref(v, &path[1..])),
262            },
263            Some(JsonPathItem::Id(k)) /* if *k < 0 */ => match v.as_array() {
264                None => None,
265                Some(v) => v
266                    .len()
267                    .checked_add_signed(isize::try_from(*k).unwrap())
268                    .and_then(|i| v.get(i))
269                    .and_then(|v| Self::deref(v, &path[1..])),
270            },
271        }
272    }
273}