Skip to main content

oxide_sql_core/ast/
expression.rs

1//! Expression AST types.
2
3use core::fmt;
4
5use crate::lexer::Span;
6
7/// A literal value.
8#[derive(Debug, Clone, PartialEq)]
9pub enum Literal {
10    /// Integer literal.
11    Integer(i64),
12    /// Float literal.
13    Float(f64),
14    /// String literal.
15    String(String),
16    /// Blob literal.
17    Blob(Vec<u8>),
18    /// Boolean literal.
19    Boolean(bool),
20    /// NULL literal.
21    Null,
22}
23
24/// Binary operators.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BinaryOp {
27    // Arithmetic
28    Add,
29    Sub,
30    Mul,
31    Div,
32    Mod,
33
34    // Comparison
35    Eq,
36    NotEq,
37    Lt,
38    LtEq,
39    Gt,
40    GtEq,
41
42    // Logical
43    And,
44    Or,
45
46    // String
47    Concat,
48    Like,
49
50    // Bitwise
51    BitAnd,
52    BitOr,
53    LeftShift,
54    RightShift,
55}
56
57impl BinaryOp {
58    /// Returns the SQL representation of the operator.
59    #[must_use]
60    pub const fn as_str(&self) -> &'static str {
61        match self {
62            Self::Add => "+",
63            Self::Sub => "-",
64            Self::Mul => "*",
65            Self::Div => "/",
66            Self::Mod => "%",
67            Self::Eq => "=",
68            Self::NotEq => "!=",
69            Self::Lt => "<",
70            Self::LtEq => "<=",
71            Self::Gt => ">",
72            Self::GtEq => ">=",
73            Self::And => "AND",
74            Self::Or => "OR",
75            Self::Concat => "||",
76            Self::Like => "LIKE",
77            Self::BitAnd => "&",
78            Self::BitOr => "|",
79            Self::LeftShift => "<<",
80            Self::RightShift => ">>",
81        }
82    }
83
84    /// Returns the precedence of the operator (higher = binds tighter).
85    #[must_use]
86    pub const fn precedence(&self) -> u8 {
87        match self {
88            Self::Or => 1,
89            Self::And => 2,
90            Self::Eq | Self::NotEq | Self::Lt | Self::LtEq | Self::Gt | Self::GtEq => 3,
91            Self::Like => 4,
92            Self::BitOr => 5,
93            Self::BitAnd => 6,
94            Self::LeftShift | Self::RightShift => 7,
95            Self::Add | Self::Sub | Self::Concat => 8,
96            Self::Mul | Self::Div | Self::Mod => 9,
97        }
98    }
99}
100
101impl fmt::Display for Literal {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self {
104            Self::Integer(n) => write!(f, "{n}"),
105            Self::Float(v) => write!(f, "{v}"),
106            Self::String(s) => {
107                let escaped = s.replace('\'', "''");
108                write!(f, "'{escaped}'")
109            }
110            Self::Blob(bytes) => {
111                write!(f, "X'")?;
112                for b in bytes {
113                    write!(f, "{b:02X}")?;
114                }
115                write!(f, "'")
116            }
117            Self::Boolean(true) => write!(f, "TRUE"),
118            Self::Boolean(false) => write!(f, "FALSE"),
119            Self::Null => write!(f, "NULL"),
120        }
121    }
122}
123
124impl fmt::Display for BinaryOp {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.write_str(self.as_str())
127    }
128}
129
130/// Unary operators.
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum UnaryOp {
133    /// Negation (-)
134    Neg,
135    /// Logical NOT
136    Not,
137    /// Bitwise NOT (~)
138    BitNot,
139}
140
141impl UnaryOp {
142    /// Returns the SQL representation of the operator.
143    #[must_use]
144    pub const fn as_str(&self) -> &'static str {
145        match self {
146            Self::Neg => "-",
147            Self::Not => "NOT",
148            Self::BitNot => "~",
149        }
150    }
151}
152
153impl fmt::Display for UnaryOp {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        f.write_str(self.as_str())
156    }
157}
158
159/// A function call expression.
160#[derive(Debug, Clone, PartialEq)]
161pub struct FunctionCall {
162    /// The function name.
163    pub name: String,
164    /// The arguments.
165    pub args: Vec<Expr>,
166    /// Whether DISTINCT was specified.
167    pub distinct: bool,
168}
169
170/// An SQL expression.
171#[derive(Debug, Clone, PartialEq)]
172pub enum Expr {
173    /// A literal value.
174    Literal(Literal),
175
176    /// A column reference (optionally qualified with table name).
177    Column {
178        /// Table name or alias (optional).
179        table: Option<String>,
180        /// Column name.
181        name: String,
182        /// Source span.
183        span: Span,
184    },
185
186    /// A binary expression.
187    Binary {
188        /// Left operand.
189        left: Box<Expr>,
190        /// Operator.
191        op: BinaryOp,
192        /// Right operand.
193        right: Box<Expr>,
194    },
195
196    /// A unary expression.
197    Unary {
198        /// Operator.
199        op: UnaryOp,
200        /// Operand.
201        operand: Box<Expr>,
202    },
203
204    /// A function call.
205    Function(FunctionCall),
206
207    /// A subquery.
208    Subquery(Box<super::SelectStatement>),
209
210    /// IS NULL expression.
211    IsNull {
212        /// The expression to check.
213        expr: Box<Expr>,
214        /// Whether this is IS NOT NULL.
215        negated: bool,
216    },
217
218    /// IN expression.
219    In {
220        /// The expression to check.
221        expr: Box<Expr>,
222        /// The list of values or subquery.
223        list: Vec<Expr>,
224        /// Whether this is NOT IN.
225        negated: bool,
226    },
227
228    /// BETWEEN expression.
229    Between {
230        /// The expression to check.
231        expr: Box<Expr>,
232        /// Lower bound.
233        low: Box<Expr>,
234        /// Upper bound.
235        high: Box<Expr>,
236        /// Whether this is NOT BETWEEN.
237        negated: bool,
238    },
239
240    /// CASE expression.
241    Case {
242        /// The operand (if any).
243        operand: Option<Box<Expr>>,
244        /// WHEN/THEN clauses.
245        when_clauses: Vec<(Expr, Expr)>,
246        /// ELSE clause.
247        else_clause: Option<Box<Expr>>,
248    },
249
250    /// CAST expression.
251    Cast {
252        /// Expression to cast.
253        expr: Box<Expr>,
254        /// Target type.
255        data_type: super::DataType,
256    },
257
258    /// Parenthesized expression.
259    Paren(Box<Expr>),
260
261    /// A parameter placeholder (? or :name).
262    Parameter {
263        /// The parameter index or name.
264        name: Option<String>,
265        /// Position in the query (1-based for ? placeholders).
266        position: usize,
267    },
268
269    /// Wildcard (*) in SELECT.
270    Wildcard {
271        /// Table qualifier (optional).
272        table: Option<String>,
273    },
274}
275
276impl fmt::Display for FunctionCall {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        // EXISTS gets special handling: the subquery already
279        // contains its own parentheses in the rendered form,
280        // so we render `EXISTS(SELECT ...)` instead of
281        // `EXISTS((SELECT ...))`.
282        if self.name == "EXISTS" {
283            if let [Expr::Subquery(q)] = self.args.as_slice() {
284                return write!(f, "EXISTS({q})");
285            }
286        }
287        write!(f, "{}(", self.name)?;
288        if self.distinct {
289            write!(f, "DISTINCT ")?;
290        }
291        for (i, arg) in self.args.iter().enumerate() {
292            if i > 0 {
293                write!(f, ", ")?;
294            }
295            write!(f, "{arg}")?;
296        }
297        write!(f, ")")
298    }
299}
300
301impl fmt::Display for Expr {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            Self::Literal(lit) => write!(f, "{lit}"),
305            Self::Column { table, name, .. } => {
306                if let Some(t) = table {
307                    write!(f, "{t}.{name}")
308                } else {
309                    write!(f, "{name}")
310                }
311            }
312            Self::Binary { left, op, right } => {
313                write!(f, "{left} {op} {right}")
314            }
315            Self::Unary { op, operand } => match op {
316                UnaryOp::Not => write!(f, "NOT {operand}"),
317                UnaryOp::Neg => write!(f, "-{operand}"),
318                UnaryOp::BitNot => write!(f, "~{operand}"),
319            },
320            Self::Function(func) => write!(f, "{func}"),
321            Self::Subquery(q) => write!(f, "({q})"),
322            Self::IsNull { expr, negated } => {
323                if *negated {
324                    write!(f, "{expr} IS NOT NULL")
325                } else {
326                    write!(f, "{expr} IS NULL")
327                }
328            }
329            Self::In {
330                expr,
331                list,
332                negated,
333            } => {
334                write!(f, "{expr}")?;
335                if *negated {
336                    write!(f, " NOT IN (")?;
337                } else {
338                    write!(f, " IN (")?;
339                }
340                for (i, item) in list.iter().enumerate() {
341                    if i > 0 {
342                        write!(f, ", ")?;
343                    }
344                    write!(f, "{item}")?;
345                }
346                write!(f, ")")
347            }
348            Self::Between {
349                expr,
350                low,
351                high,
352                negated,
353            } => {
354                if *negated {
355                    write!(f, "{expr} NOT BETWEEN {low} AND {high}")
356                } else {
357                    write!(f, "{expr} BETWEEN {low} AND {high}")
358                }
359            }
360            Self::Case {
361                operand,
362                when_clauses,
363                else_clause,
364            } => {
365                write!(f, "CASE")?;
366                if let Some(op) = operand {
367                    write!(f, " {op}")?;
368                }
369                for (when, then) in when_clauses {
370                    write!(f, " WHEN {when} THEN {then}")?;
371                }
372                if let Some(el) = else_clause {
373                    write!(f, " ELSE {el}")?;
374                }
375                write!(f, " END")
376            }
377            Self::Cast { expr, data_type } => {
378                write!(f, "CAST({expr} AS {data_type})")
379            }
380            Self::Paren(inner) => write!(f, "({inner})"),
381            Self::Parameter { name, .. } => {
382                if let Some(n) = name {
383                    write!(f, ":{n}")
384                } else {
385                    write!(f, "?")
386                }
387            }
388            Self::Wildcard { table } => {
389                if let Some(t) = table {
390                    write!(f, "{t}.*")
391                } else {
392                    write!(f, "*")
393                }
394            }
395        }
396    }
397}
398
399impl Expr {
400    /// Creates a new column reference.
401    #[must_use]
402    pub fn column(name: impl Into<String>) -> Self {
403        Self::Column {
404            table: None,
405            name: name.into(),
406            span: Span::default(),
407        }
408    }
409
410    /// Creates a new qualified column reference.
411    #[must_use]
412    pub fn qualified_column(table: impl Into<String>, name: impl Into<String>) -> Self {
413        Self::Column {
414            table: Some(table.into()),
415            name: name.into(),
416            span: Span::default(),
417        }
418    }
419
420    /// Creates a new integer literal.
421    #[must_use]
422    pub const fn integer(value: i64) -> Self {
423        Self::Literal(Literal::Integer(value))
424    }
425
426    /// Creates a new float literal.
427    #[must_use]
428    pub const fn float(value: f64) -> Self {
429        Self::Literal(Literal::Float(value))
430    }
431
432    /// Creates a new string literal.
433    #[must_use]
434    pub fn string(value: impl Into<String>) -> Self {
435        Self::Literal(Literal::String(value.into()))
436    }
437
438    /// Creates a new boolean literal.
439    #[must_use]
440    pub const fn boolean(value: bool) -> Self {
441        Self::Literal(Literal::Boolean(value))
442    }
443
444    /// Creates a NULL literal.
445    #[must_use]
446    pub const fn null() -> Self {
447        Self::Literal(Literal::Null)
448    }
449
450    /// Creates a binary expression.
451    #[must_use]
452    pub fn binary(self, op: BinaryOp, right: Self) -> Self {
453        Self::Binary {
454            left: Box::new(self),
455            op,
456            right: Box::new(right),
457        }
458    }
459
460    /// Creates an equality expression.
461    #[must_use]
462    pub fn eq(self, right: Self) -> Self {
463        self.binary(BinaryOp::Eq, right)
464    }
465
466    /// Creates an inequality expression.
467    #[must_use]
468    pub fn not_eq(self, right: Self) -> Self {
469        self.binary(BinaryOp::NotEq, right)
470    }
471
472    /// Creates a less-than expression.
473    #[must_use]
474    pub fn lt(self, right: Self) -> Self {
475        self.binary(BinaryOp::Lt, right)
476    }
477
478    /// Creates a less-than-or-equal expression.
479    #[must_use]
480    pub fn lt_eq(self, right: Self) -> Self {
481        self.binary(BinaryOp::LtEq, right)
482    }
483
484    /// Creates a greater-than expression.
485    #[must_use]
486    pub fn gt(self, right: Self) -> Self {
487        self.binary(BinaryOp::Gt, right)
488    }
489
490    /// Creates a greater-than-or-equal expression.
491    #[must_use]
492    pub fn gt_eq(self, right: Self) -> Self {
493        self.binary(BinaryOp::GtEq, right)
494    }
495
496    /// Creates an AND expression.
497    #[must_use]
498    pub fn and(self, right: Self) -> Self {
499        self.binary(BinaryOp::And, right)
500    }
501
502    /// Creates an OR expression.
503    #[must_use]
504    pub fn or(self, right: Self) -> Self {
505        self.binary(BinaryOp::Or, right)
506    }
507
508    /// Creates an IS NULL expression.
509    #[must_use]
510    pub fn is_null(self) -> Self {
511        Self::IsNull {
512            expr: Box::new(self),
513            negated: false,
514        }
515    }
516
517    /// Creates an IS NOT NULL expression.
518    #[must_use]
519    pub fn is_not_null(self) -> Self {
520        Self::IsNull {
521            expr: Box::new(self),
522            negated: true,
523        }
524    }
525
526    /// Creates a BETWEEN expression.
527    #[must_use]
528    pub fn between(self, low: Self, high: Self) -> Self {
529        Self::Between {
530            expr: Box::new(self),
531            low: Box::new(low),
532            high: Box::new(high),
533            negated: false,
534        }
535    }
536
537    /// Creates a NOT BETWEEN expression.
538    #[must_use]
539    pub fn not_between(self, low: Self, high: Self) -> Self {
540        Self::Between {
541            expr: Box::new(self),
542            low: Box::new(low),
543            high: Box::new(high),
544            negated: true,
545        }
546    }
547
548    /// Creates an IN expression.
549    #[must_use]
550    pub fn in_list(self, list: Vec<Self>) -> Self {
551        Self::In {
552            expr: Box::new(self),
553            list,
554            negated: false,
555        }
556    }
557
558    /// Creates a NOT IN expression.
559    #[must_use]
560    pub fn not_in_list(self, list: Vec<Self>) -> Self {
561        Self::In {
562            expr: Box::new(self),
563            list,
564            negated: true,
565        }
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_binary_op_precedence() {
575        assert!(BinaryOp::Mul.precedence() > BinaryOp::Add.precedence());
576        assert!(BinaryOp::And.precedence() > BinaryOp::Or.precedence());
577        assert!(BinaryOp::Eq.precedence() > BinaryOp::And.precedence());
578    }
579
580    #[test]
581    fn test_expr_builders() {
582        let col = Expr::column("name");
583        assert!(matches!(col, Expr::Column { name, .. } if name == "name"));
584
585        let lit = Expr::integer(42);
586        assert!(matches!(lit, Expr::Literal(Literal::Integer(42))));
587    }
588
589    #[test]
590    fn test_expr_chaining() {
591        let expr = Expr::column("age")
592            .gt(Expr::integer(18))
593            .and(Expr::column("status").eq(Expr::string("active")));
594
595        assert!(matches!(
596            expr,
597            Expr::Binary {
598                op: BinaryOp::And,
599                ..
600            }
601        ));
602    }
603}