Skip to main content

oak_sql/ast/
expression_nodes.rs

1use crate::ast::statements::query::SelectStatement;
2use core::range::Range;
3use oak_core::source::{SourceBuffer, ToSource};
4use std::sync::Arc;
5
6/// Represents an SQL expression.
7#[derive(Debug, Clone)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9pub enum Expression {
10    /// An identifier expression.
11    Identifier(Identifier),
12    /// A literal value expression.
13    Literal(Literal),
14    /// A binary operation expression.
15    Binary {
16        /// The left-hand side of the binary operation.
17        left: Box<Expression>,
18        /// The binary operator.
19        op: BinaryOperator,
20        /// The right-hand side of the binary operation.
21        right: Box<Expression>,
22        /// The span of the binary operation.
23        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
24        span: Range<usize>,
25    },
26    /// A unary operation expression.
27    Unary {
28        /// The unary operator.
29        op: UnaryOperator,
30        /// The expression being operated on.
31        expr: Box<Expression>,
32        /// The span of the unary operation.
33        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
34        span: Range<usize>,
35    },
36    /// A function call expression.
37    FunctionCall {
38        /// The name of the function being called.
39        name: Identifier,
40        /// The arguments passed to the function.
41        args: Vec<Expression>,
42        /// The span of the function call.
43        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
44        span: Range<usize>,
45    },
46    /// Vector/Array literal.
47    Vector {
48        /// The vector elements.
49        elements: Vec<Expression>,
50        /// The span of the vector.
51        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
52        span: Range<usize>,
53    },
54    /// An IN expression.
55    InList {
56        /// The expression being checked.
57        expr: Box<Expression>,
58        /// The list of values to check against.
59        list: Vec<Expression>,
60        /// Whether the IN condition is negated (NOT IN).
61        negated: bool,
62        /// The span of the IN expression.
63        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
64        span: Range<usize>,
65    },
66    /// A BETWEEN expression.
67    Between {
68        /// The expression being checked.
69        expr: Box<Expression>,
70        /// The lower bound of the range.
71        low: Box<Expression>,
72        /// The upper bound of the range.
73        high: Box<Expression>,
74        /// Whether the BETWEEN condition is negated (NOT BETWEEN).
75        negated: bool,
76        /// The span of the BETWEEN expression.
77        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
78        span: Range<usize>,
79    },
80    /// A subquery expression.
81    Subquery {
82        /// The subquery SELECT statement.
83        query: Box<SelectStatement>,
84        /// The span of the subquery.
85        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
86        span: Range<usize>,
87    },
88    /// An IN expression with a subquery.
89    InSubquery {
90        /// The expression being checked.
91        expr: Box<Expression>,
92        /// The subquery to check against.
93        query: Box<SelectStatement>,
94        /// Whether the IN condition is negated (NOT IN).
95        negated: bool,
96        /// The span of the IN expression.
97        #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
98        span: Range<usize>,
99    },
100    /// An error occurred during expression parsing or building.
101    Error {
102        /// The error message.
103        message: Arc<str>,
104        /// The span where the error occurred.
105        #[serde(with = "oak_core::serde_range")]
106        span: Range<usize>,
107    },
108}
109
110impl Expression {
111    /// Returns the source span of this expression.
112    pub fn span(&self) -> Range<usize> {
113        match self {
114            Expression::Identifier(id) => id.span.clone(),
115            Expression::Literal(lit) => lit.span().clone(),
116            Expression::Binary { span, .. } => span.clone(),
117            Expression::Unary { span, .. } => span.clone(),
118            Expression::FunctionCall { span, .. } => span.clone(),
119            Expression::InList { span, .. } => span.clone(),
120            Expression::Between { span, .. } => span.clone(),
121            Expression::Subquery { span, .. } => span.clone(),
122            Expression::InSubquery { span, .. } => span.clone(),
123            Expression::Vector { span, .. } => span.clone(),
124            Expression::Error { span, .. } => span.clone(),
125        }
126    }
127}
128
129impl ToSource for Expression {
130    fn to_source(&self, buffer: &mut SourceBuffer) {
131        match self {
132            Expression::Identifier(id) => id.to_source(buffer),
133            Expression::Literal(lit) => lit.to_source(buffer),
134            Expression::Binary { left, op, right, .. } => {
135                left.to_source(buffer);
136                op.to_source(buffer);
137                right.to_source(buffer);
138            }
139            Expression::Unary { op, expr, .. } => {
140                op.to_source(buffer);
141                expr.to_source(buffer);
142            }
143            Expression::FunctionCall { name, args, .. } => {
144                name.to_source(buffer);
145                buffer.push("(");
146                for (i, arg) in args.iter().enumerate() {
147                    if i > 0 {
148                        buffer.push(",");
149                    }
150                    arg.to_source(buffer);
151                }
152                buffer.push(")");
153            }
154            Expression::Vector { elements, .. } => {
155                buffer.push("[");
156                for (i, elem) in elements.iter().enumerate() {
157                    if i > 0 {
158                        buffer.push(", ");
159                    }
160                    elem.to_source(buffer);
161                }
162                buffer.push("]");
163            }
164            Expression::InList { expr, list, negated, .. } => {
165                expr.to_source(buffer);
166                if *negated {
167                    buffer.push("NOT");
168                }
169                buffer.push("IN");
170                buffer.push("(");
171                for (i, item) in list.iter().enumerate() {
172                    if i > 0 {
173                        buffer.push(",");
174                    }
175                    item.to_source(buffer);
176                }
177                buffer.push(")");
178            }
179            Expression::Between { expr, low, high, negated, .. } => {
180                expr.to_source(buffer);
181                if *negated {
182                    buffer.push("NOT");
183                }
184                buffer.push("BETWEEN");
185                low.to_source(buffer);
186                buffer.push("AND");
187                high.to_source(buffer);
188            }
189            Expression::Subquery { query, .. } => {
190                buffer.push("(");
191                query.to_source(buffer);
192                buffer.push(")");
193            }
194            Expression::InSubquery { expr, query, negated, .. } => {
195                expr.to_source(buffer);
196                if *negated {
197                    buffer.push("NOT");
198                }
199                buffer.push("IN");
200                buffer.push("(");
201                query.to_source(buffer);
202                buffer.push(")");
203            }
204            Expression::Error { message, .. } => {
205                buffer.push("/* EXPR ERROR: ");
206                buffer.push(message);
207                buffer.push(" */");
208            }
209        }
210    }
211}
212
213/// Binary operators in SQL.
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
216pub enum BinaryOperator {
217    /// Addition (+).
218    Plus,
219    /// Subtraction (-).
220    Minus,
221    /// Multiplication (*).
222    Star,
223    /// Division (/).
224    Slash,
225    /// Modulo (%).
226    Percent,
227    /// Logical AND.
228    And,
229    /// Logical OR.
230    Or,
231    /// Equality (=).
232    Equal,
233    /// Inequality (<> or !=).
234    NotEqual,
235    /// Less than (<).
236    Less,
237    /// Greater than (>).
238    Greater,
239    /// Less than or equal to (<=).
240    LessEqual,
241    /// Greater than or equal to (>=).
242    GreaterEqual,
243    /// Pattern matching (LIKE).
244    Like,
245}
246
247impl ToSource for BinaryOperator {
248    fn to_source(&self, buffer: &mut SourceBuffer) {
249        match self {
250            BinaryOperator::Plus => buffer.push("+"),
251            BinaryOperator::Minus => buffer.push("-"),
252            BinaryOperator::Star => buffer.push("*"),
253            BinaryOperator::Slash => buffer.push("/"),
254            BinaryOperator::Percent => buffer.push("%"),
255            BinaryOperator::And => buffer.push("AND"),
256            BinaryOperator::Or => buffer.push("OR"),
257            BinaryOperator::Equal => buffer.push("="),
258            BinaryOperator::NotEqual => buffer.push("<>"),
259            BinaryOperator::Less => buffer.push("<"),
260            BinaryOperator::Greater => buffer.push(">"),
261            BinaryOperator::LessEqual => buffer.push("<="),
262            BinaryOperator::GreaterEqual => buffer.push(">="),
263            BinaryOperator::Like => buffer.push("LIKE"),
264        }
265    }
266}
267
268/// Unary operators in SQL.
269#[derive(Debug, Clone, Copy, PartialEq, Eq)]
270#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
271pub enum UnaryOperator {
272    /// Unary plus (+).
273    Plus,
274    /// Unary minus (-).
275    Minus,
276    /// Logical NOT.
277    Not,
278}
279
280impl ToSource for UnaryOperator {
281    fn to_source(&self, buffer: &mut SourceBuffer) {
282        match self {
283            UnaryOperator::Plus => buffer.push("+"),
284            UnaryOperator::Minus => buffer.push("-"),
285            UnaryOperator::Not => buffer.push("NOT"),
286        }
287    }
288}
289
290/// SQL literals.
291#[derive(Debug, Clone)]
292#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
293pub enum Literal {
294    /// Numeric literal.
295    Number(Arc<str>, #[serde(with = "oak_core::serde_range")] Range<usize>),
296    /// String literal.
297    String(Arc<str>, #[serde(with = "oak_core::serde_range")] Range<usize>),
298    /// Boolean literal.
299    Boolean(bool, #[serde(with = "oak_core::serde_range")] Range<usize>),
300    /// NULL literal.
301    Null(#[serde(with = "oak_core::serde_range")] Range<usize>),
302}
303
304impl Literal {
305    /// Returns the source span of this literal.
306    pub fn span(&self) -> Range<usize> {
307        match self {
308            Literal::Number(_, span) => span.clone(),
309            Literal::String(_, span) => span.clone(),
310            Literal::Boolean(_, span) => span.clone(),
311            Literal::Null(span) => span.clone(),
312        }
313    }
314}
315
316impl ToSource for Literal {
317    fn to_source(&self, buffer: &mut SourceBuffer) {
318        match self {
319            Literal::Number(n, _) => buffer.push(n),
320            Literal::String(s, _) => {
321                buffer.push("'");
322                buffer.push(s);
323                buffer.push("'");
324            }
325            Literal::Boolean(b, _) => buffer.push(if *b { "TRUE" } else { "FALSE" }),
326            Literal::Null(_) => buffer.push("NULL"),
327        }
328    }
329}
330
331/// SQL identifier.
332#[derive(Debug, Clone)]
333#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
334pub struct Identifier {
335    /// The name of the identifier.
336    pub name: Arc<str>,
337    /// The span of the identifier.
338    #[serde(with = "oak_core::serde_range")]
339    pub span: Range<usize>,
340}
341
342impl ToSource for Identifier {
343    fn to_source(&self, buffer: &mut SourceBuffer) {
344        buffer.push(&self.name);
345    }
346}
347
348/// SQL table name.
349#[derive(Debug, Clone)]
350#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
351pub struct TableName {
352    /// The name of the table.
353    pub name: Identifier,
354    /// The span of the table name.
355    #[serde(with = "oak_core::serde_range")]
356    pub span: Range<usize>,
357}
358
359impl ToSource for TableName {
360    fn to_source(&self, buffer: &mut SourceBuffer) {
361        self.name.to_source(buffer);
362    }
363}
364
365/// Column name.
366#[derive(Debug, Clone)]
367#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
368pub struct ColumnName {
369    /// The name of the column.
370    pub name: Identifier,
371    /// The span of the column name in the source.
372    #[serde(with = "oak_core::serde_range")]
373    pub span: Range<usize>,
374}
375
376impl ToSource for ColumnName {
377    fn to_source(&self, buffer: &mut SourceBuffer) {
378        self.name.to_source(buffer);
379    }
380}