Skip to main content

sql_json_path/
ast.rs

1// Copyright 2023 RisingWave Labs
2// Modifications Copyright (c) Citadel contributors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// This file has been modified by Citadel contributors.
17
18//! The AST of JSON Path.
19
20use std::fmt::Display;
21use std::fmt::Formatter;
22use std::ops::Deref;
23
24use serde_json::Number;
25
26/// A JSON Path value.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct JsonPath {
29    pub(crate) mode: Mode,
30    pub(crate) expr: ExprOrPredicate,
31}
32
33/// The mode of JSON Path.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum Mode {
36    /// Lax mode converts errors to empty SQL/JSON sequences.
37    Lax,
38    /// Strict mode raises an error if the data does not strictly adhere to the requirements of a path expression.
39    Strict,
40}
41
42/// An expression or predicate.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum ExprOrPredicate {
45    Expr(Expr),
46    Pred(Predicate),
47}
48
49/// An expression in JSON Path.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub enum Expr {
52    /// Path primary
53    PathPrimary(PathPrimary),
54    /// Accessor expression.
55    Accessor(Box<Expr>, AccessorOp),
56    /// Unary operation.
57    UnaryOp(UnaryOp, Box<Expr>),
58    /// Binary operation.
59    BinaryOp(BinaryOp, Box<Expr>, Box<Expr>),
60}
61
62/// A filter expression that evaluates to a truth value.
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum Predicate {
65    /// `==`, `!=`, `<`, `<=`, `>`, `>=` represents the comparison between two values.
66    Compare(CompareOp, Box<Expr>, Box<Expr>),
67    /// `exists` represents the value exists.
68    Exists(Box<Expr>),
69    /// `&&` represents logical AND.
70    And(Box<Predicate>, Box<Predicate>),
71    /// `||` represents logical OR.
72    Or(Box<Predicate>, Box<Predicate>),
73    /// `!` represents logical NOT.
74    Not(Box<Predicate>),
75    /// `is unknown` represents the value is unknown.
76    IsUnknown(Box<Predicate>),
77    /// `starts with` represents the value starts with the given value.
78    StartsWith(Box<Expr>, Value),
79    /// `like_regex` represents the value matches the given regular expression.
80    LikeRegex(Box<Expr>, Box<Regex>),
81}
82
83/// A primary expression.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum PathPrimary {
86    /// `$` represents the root node or element.
87    Root,
88    /// `@` represents the current node or element being processed in the filter expression.
89    Current,
90    /// `last` is the size of the array minus 1.
91    Last,
92    /// Literal value.
93    Value(Value),
94    /// `(expr)` represents an expression.
95    ExprOrPred(Box<ExprOrPredicate>),
96}
97
98/// An accessor operation.
99#[derive(Debug, Clone, PartialEq, Eq)]
100pub enum AccessorOp {
101    /// `.*` represents selecting all elements in an object.
102    MemberWildcard,
103    /// `.**` represents selecting all elements in an object and its sub-objects.
104    DescendantMemberWildcard(LevelRange),
105    /// `[*]` represents selecting all elements in an array.
106    ElementWildcard,
107    /// `.<name>` represents selecting element that matched the name in an object, like `$.event`.
108    /// The name can also be written as a string literal, allowing the name to contain special characters, like `$." $price"`.
109    Member(String),
110    /// `[<index1>,<index2>,..]` represents selecting elements specified by the indices in an Array.
111    Element(Vec<ArrayIndex>),
112    /// `?(<predicate>)` represents filtering elements using the predicate.
113    FilterExpr(Box<Predicate>),
114    /// `.method()` represents calling a method.
115    Method(Method),
116}
117
118/// A level range.
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub enum LevelRange {
121    /// none
122    All,
123    /// `{level}`
124    One(Level),
125    /// `{start to end}`
126    Range(Level, Level),
127}
128
129/// A level number.
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub enum Level {
132    N(u32),
133    Last,
134}
135
136/// An array index.
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum ArrayIndex {
139    /// The single number index.
140    Index(Expr),
141    /// `<start> to <end>` represents the slice of the array.
142    Slice(Expr, Expr),
143}
144
145/// Represents a scalar value.
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub enum Value {
148    /// Null value.
149    Null,
150    /// Boolean value.
151    Boolean(bool),
152    /// Number value.
153    Number(Number),
154    /// UTF-8 string.
155    String(String),
156    /// Variable
157    Variable(String),
158}
159
160/// A binary operator.
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162pub enum CompareOp {
163    /// `==` represents left is equal to right.
164    Eq,
165    /// `!=` and `<>` represents left is not equal to right.
166    Ne,
167    /// `<` represents left is less than right.
168    Lt,
169    /// `<=` represents left is less or equal to right.
170    Le,
171    /// `>` represents left is greater than right.
172    Gt,
173    /// `>=` represents left is greater than or equal to right.
174    Ge,
175}
176
177/// A unary operator.
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub enum UnaryOp {
180    /// `+` represents plus.
181    Plus,
182    /// `-` represents minus.
183    Minus,
184}
185
186/// A binary operator.
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum BinaryOp {
189    /// `+` represents left plus right.
190    Add,
191    /// `-` represents left minus right.
192    Sub,
193    /// `*` represents left multiply right.
194    Mul,
195    /// `/` represents left divide right.
196    Div,
197    /// `%` represents left modulo right.
198    Rem,
199}
200
201/// A item method.
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub enum Method {
204    /// `.type()` returns a character string that names the type of the SQL/JSON item.
205    Type,
206    /// `.size()` returns the size of an SQL/JSON item.
207    Size,
208    /// `.double()` converts a string or numeric to an approximate numeric value.
209    Double,
210    /// `.ceiling()` returns the smallest integer that is greater than or equal to the argument.
211    Ceiling,
212    /// `.floor()` returns the largest integer that is less than or equal to the argument.
213    Floor,
214    /// `.abs()` returns the absolute value of the argument.
215    Abs,
216    /// `.keyvalue()` returns the key-value pairs of an object.
217    ///
218    /// For example, suppose:
219    /// ```json
220    /// { who: "Fred", what: 64 }
221    /// ```
222    /// Then:
223    /// ```json
224    /// $.keyvalue() =
225    /// ( { name: "who",  value: "Fred", id: 9045 },
226    ///   { name: "what", value: 64,     id: 9045 }
227    /// )
228    /// ```
229    Keyvalue,
230    Datetime {
231        template: Option<String>,
232    },
233}
234
235impl PathPrimary {
236    /// If this is a nested path primary, unnest it.
237    /// `(primary) => primary`
238    pub(crate) fn unnest(self) -> Self {
239        match self {
240            Self::ExprOrPred(expr) => match *expr {
241                ExprOrPredicate::Expr(Expr::PathPrimary(inner)) => inner,
242                other => Self::ExprOrPred(Box::new(other)),
243            },
244            _ => self,
245        }
246    }
247}
248
249impl LevelRange {
250    /// Returns the upper bound of the range.
251    /// If no upper bound, returns `u32::MAX`.
252    pub(crate) fn end(&self) -> u32 {
253        match self {
254            Self::One(Level::N(n)) => *n,
255            Self::Range(_, Level::N(end)) => *end,
256            _ => u32::MAX,
257        }
258    }
259
260    /// Resolve the range with the given `last`.
261    ///
262    /// # Examples
263    ///
264    /// ```text
265    /// last = 3
266    /// .**             => 0..4
267    /// .**{1}          => 1..2
268    /// .**{1 to 4}     => 1..3
269    /// .**{1 to last}  => 1..4
270    /// .**{last to 2}  => 3..3
271    /// ```
272    pub(crate) fn to_range(&self, last: usize) -> std::ops::Range<usize> {
273        match self {
274            Self::All => 0..last + 1,
275            Self::One(level) => {
276                level.to_usize(last).min(last + 1)..level.to_usize(last).min(last) + 1
277            }
278            Self::Range(start, end) => {
279                start.to_usize(last).min(last + 1)..end.to_usize(last).min(last) + 1
280            }
281        }
282    }
283}
284
285impl Level {
286    fn to_usize(&self, last: usize) -> usize {
287        match self {
288            Self::N(n) => *n as usize,
289            Self::Last => last,
290        }
291    }
292}
293
294impl Display for JsonPath {
295    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
296        if self.mode == Mode::Strict {
297            write!(f, "strict ")?;
298        }
299        write!(f, "{}", self.expr)
300    }
301}
302
303impl Display for Mode {
304    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
305        match self {
306            Self::Lax => write!(f, "lax"),
307            Self::Strict => write!(f, "strict"),
308        }
309    }
310}
311
312impl Display for ExprOrPredicate {
313    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
314        match self {
315            Self::Expr(expr) => match expr {
316                Expr::BinaryOp(_, _, _) => write!(f, "({})", expr),
317                _ => write!(f, "{}", expr),
318            },
319            Self::Pred(pred) => match pred {
320                Predicate::Compare(_, _, _) | Predicate::And(_, _) | Predicate::Or(_, _) => {
321                    write!(f, "({})", pred)
322                }
323                _ => write!(f, "{}", pred),
324            },
325        }
326    }
327}
328
329impl Display for Predicate {
330    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
331        match self {
332            Self::Compare(op, left, right) => write!(f, "{left} {op} {right}"),
333            Self::Exists(expr) => write!(f, "exists ({expr})"),
334            Self::And(left, right) => {
335                match left.as_ref() {
336                    Self::Or(_, _) => write!(f, "({left})")?,
337                    _ => write!(f, "{left}")?,
338                }
339                write!(f, " && ")?;
340                match right.as_ref() {
341                    Self::Or(_, _) => write!(f, "({right})"),
342                    _ => write!(f, "{right}"),
343                }
344            }
345            Self::Or(left, right) => write!(f, "{left} || {right}"),
346            Self::Not(expr) => write!(f, "!({expr})"),
347            Self::IsUnknown(expr) => write!(f, "({expr}) is unknown"),
348            Self::StartsWith(expr, v) => write!(f, "{expr} starts with {v}"),
349            Self::LikeRegex(expr, regex) => {
350                write!(f, "{expr} like_regex \"{}\"", regex.pattern())?;
351                if let Some(flags) = regex.flags() {
352                    write!(f, " flag \"{flags}\"")?;
353                }
354                Ok(())
355            }
356        }
357    }
358}
359
360impl Display for Expr {
361    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
362        match self {
363            Expr::PathPrimary(primary) => write!(f, "{primary}"),
364            Expr::Accessor(base, op) => {
365                match base.as_ref() {
366                    Expr::PathPrimary(PathPrimary::Value(Value::Number(_))) => {
367                        write!(f, "({base})")?
368                    }
369                    Expr::PathPrimary(PathPrimary::ExprOrPred(expr)) => match expr.as_ref() {
370                        ExprOrPredicate::Expr(Expr::UnaryOp(_, _)) => write!(f, "({base})")?,
371                        _ => write!(f, "{base}")?,
372                    },
373                    _ => write!(f, "{base}")?,
374                }
375                write!(f, "{op}")?;
376                Ok(())
377            }
378            Expr::UnaryOp(op, expr) => match expr.as_ref() {
379                Expr::PathPrimary(_) | Expr::Accessor(_, _) => write!(f, "{op}{expr}"),
380                _ => write!(f, "{op}({expr})"),
381            },
382            Expr::BinaryOp(op, left, right) => write!(f, "{left} {op} {right}"),
383        }
384    }
385}
386
387impl Display for ArrayIndex {
388    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
389        match self {
390            Self::Index(idx) => write!(f, "{idx}"),
391            Self::Slice(start, end) => write!(f, "{start} to {end}"),
392        }
393    }
394}
395
396impl Display for PathPrimary {
397    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
398        match self {
399            Self::Root => write!(f, "$"),
400            Self::Current => write!(f, "@"),
401            Self::Value(v) => write!(f, "{v}"),
402            Self::Last => write!(f, "last"),
403            Self::ExprOrPred(expr) => write!(f, "{expr}"),
404        }
405    }
406}
407
408impl Display for AccessorOp {
409    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
410        match self {
411            Self::MemberWildcard => write!(f, ".*"),
412            Self::DescendantMemberWildcard(level) => write!(f, ".**{level}"),
413            Self::ElementWildcard => write!(f, "[*]"),
414            Self::Member(field) => write!(f, ".\"{field}\""),
415            Self::Element(indices) => {
416                write!(f, "[")?;
417                for (i, idx) in indices.iter().enumerate() {
418                    if i > 0 {
419                        write!(f, ",")?;
420                    }
421                    write!(f, "{idx}")?;
422                }
423                write!(f, "]")
424            }
425            Self::FilterExpr(expr) => write!(f, "?({expr})"),
426            Self::Method(method) => match method {
427                Method::Datetime { template: Some(t) } => write!(f, ".datetime(\"{t}\")"),
428                _ => write!(f, ".{method}()"),
429            },
430        }
431    }
432}
433
434impl Display for LevelRange {
435    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
436        match self {
437            Self::All => Ok(()),
438            Self::One(level) => write!(f, "{{{level}}}"),
439            Self::Range(start, end) => write!(f, "{{{start} to {end}}}"),
440        }
441    }
442}
443
444impl Display for Level {
445    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446        match self {
447            Self::N(n) => write!(f, "{n}"),
448            Self::Last => write!(f, "last"),
449        }
450    }
451}
452
453impl Display for Value {
454    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
455        match self {
456            Self::Null => write!(f, "null"),
457            Self::Boolean(v) => write!(f, "{v}"),
458            Self::Number(v) => write!(f, "{v}"),
459            Self::String(v) => write!(f, "\"{v}\""),
460            Self::Variable(v) => write!(f, "$\"{v}\""),
461        }
462    }
463}
464
465impl Display for UnaryOp {
466    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
467        match self {
468            Self::Plus => write!(f, "+"),
469            Self::Minus => write!(f, "-"),
470        }
471    }
472}
473
474impl Display for CompareOp {
475    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
476        match self {
477            Self::Eq => write!(f, "=="),
478            Self::Ne => write!(f, "!="),
479            Self::Lt => write!(f, "<"),
480            Self::Le => write!(f, "<="),
481            Self::Gt => write!(f, ">"),
482            Self::Ge => write!(f, ">="),
483        }
484    }
485}
486
487impl Display for BinaryOp {
488    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
489        match self {
490            Self::Add => write!(f, "+"),
491            Self::Sub => write!(f, "-"),
492            Self::Mul => write!(f, "*"),
493            Self::Div => write!(f, "/"),
494            Self::Rem => write!(f, "%"),
495        }
496    }
497}
498
499impl Display for Method {
500    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
501        match self {
502            Self::Type => write!(f, "type"),
503            Self::Size => write!(f, "size"),
504            Self::Double => write!(f, "double"),
505            Self::Ceiling => write!(f, "ceiling"),
506            Self::Floor => write!(f, "floor"),
507            Self::Abs => write!(f, "abs"),
508            Self::Keyvalue => write!(f, "keyvalue"),
509            Self::Datetime { .. } => write!(f, "datetime"),
510        }
511    }
512}
513
514/// A wrapper of `regex::Regex` to combine the pattern and flags.
515#[derive(Debug, Clone)]
516pub struct Regex {
517    regex: regex::Regex,
518    flags: String,
519}
520
521impl Regex {
522    pub(crate) fn with_flags(pattern: &str, flags: Option<String>) -> Result<Self, regex::Error> {
523        let translated;
524        let mut builder = match flags.as_deref() {
525            Some(flags) if flags.contains('q') => regex::RegexBuilder::new(&regex::escape(pattern)),
526            _ => {
527                translated = translate_pg_regex(pattern);
528                regex::RegexBuilder::new(&translated)
529            }
530        };
531        let mut out_flags = String::new();
532        if let Some(flags) = flags.as_deref() {
533            for c in flags.chars() {
534                match c {
535                    'q' => {}
536                    'i' => {
537                        builder.case_insensitive(true);
538                    }
539                    'm' => {
540                        builder.multi_line(true);
541                    }
542                    's' => {
543                        builder.dot_matches_new_line(true);
544                    }
545                    'x' => {
546                        return Err(regex::Error::Syntax(
547                            "XQuery \"x\" flag (expanded regular expressions) is not implemented"
548                                .to_string(),
549                        ))
550                    }
551                    _ => {
552                        return Err(regex::Error::Syntax(format!(
553                            "Unrecognized flag character \"{c}\" in LIKE_REGEX predicate."
554                        )))
555                    }
556                };
557                // Remove duplicated flags.
558                if !out_flags.contains(c) {
559                    out_flags.push(c);
560                }
561            }
562        }
563        let regex = builder.build()?;
564        Ok(Self {
565            regex,
566            flags: out_flags,
567        })
568    }
569
570    pub fn pattern(&self) -> &str {
571        self.regex.as_str()
572    }
573
574    pub fn flags(&self) -> Option<&str> {
575        if self.flags.is_empty() {
576            None
577        } else {
578            Some(&self.flags)
579        }
580    }
581}
582
583impl Deref for Regex {
584    type Target = regex::Regex;
585
586    fn deref(&self) -> &Self::Target {
587        &self.regex
588    }
589}
590
591fn translate_pg_regex(pat: &str) -> String {
592    let mut out = String::with_capacity(pat.len() + 4);
593    let mut chars = pat.chars().peekable();
594    while let Some(c) = chars.next() {
595        if c == '\\' {
596            match chars.next() {
597                Some('\\') => out.push_str("\\\\"),
598                Some('b') => out.push_str("\\x08"),
599                Some(next) => {
600                    out.push('\\');
601                    out.push(next);
602                }
603                None => out.push('\\'),
604            }
605        } else {
606            out.push(c);
607        }
608    }
609    out
610}
611
612impl PartialEq for Regex {
613    fn eq(&self, other: &Self) -> bool {
614        self.pattern() == other.pattern() && self.flags() == other.flags()
615    }
616}
617
618impl Eq for Regex {}