Skip to main content

oak_sql/ast/
expression_nodes.rs

1use crate::ast::statements::query::SelectStatement;
2use core::range::Range;
3use oak_core::source::{SourceBuffer, ToSource};
4use std::sync::Arc;
5
6/// Represents an SQL expression.
7#[derive(Debug, Clone)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9pub enum Expression {
10    /// An identifier expression.
11    Identifier(Identifier),
12    /// A literal value expression.
13    Literal(Literal),
14    /// A binary operation expression.
15    Binary {
16        /// The left-hand side of the binary operation.
17        left: Box<Expression>,
18        /// The binary operator.
19        op: BinaryOperator,
20        /// The right-hand side of the binary operation.
21        right: Box<Expression>,
22        /// The span of the binary operation.
23        #[serde(with = "oak_core::serde_range")]
24        span: Range<usize>,
25    },
26    /// A unary operation expression.
27    Unary {
28        /// The unary operator.
29        op: UnaryOperator,
30        /// The expression being operated on.
31        expr: Box<Expression>,
32        /// The span of the unary operation.
33        #[serde(with = "oak_core::serde_range")]
34        span: Range<usize>,
35    },
36    /// A function call expression.
37    FunctionCall {
38        /// The name of the function being called.
39        name: Identifier,
40        /// The arguments passed to the function.
41        args: Vec<Expression>,
42        /// The span of the function call.
43        #[serde(with = "oak_core::serde_range")]
44        span: Range<usize>,
45    },
46    /// An IN expression.
47    InList {
48        /// The expression being checked.
49        expr: Box<Expression>,
50        /// The list of values to check against.
51        list: Vec<Expression>,
52        /// Whether the IN condition is negated (NOT IN).
53        negated: bool,
54        /// The span of the IN expression.
55        #[serde(with = "oak_core::serde_range")]
56        span: Range<usize>,
57    },
58    /// A BETWEEN expression.
59    Between {
60        /// The expression being checked.
61        expr: Box<Expression>,
62        /// The lower bound of the range.
63        low: Box<Expression>,
64        /// The upper bound of the range.
65        high: Box<Expression>,
66        /// Whether the BETWEEN condition is negated (NOT BETWEEN).
67        negated: bool,
68        /// The span of the BETWEEN expression.
69        #[serde(with = "oak_core::serde_range")]
70        span: Range<usize>,
71    },
72    /// A subquery expression.
73    Subquery {
74        /// The subquery SELECT statement.
75        query: Box<SelectStatement>,
76        /// The span of the subquery.
77        #[serde(with = "oak_core::serde_range")]
78        span: Range<usize>,
79    },
80    /// An IN expression with a subquery.
81    InSubquery {
82        /// The expression being checked.
83        expr: Box<Expression>,
84        /// The subquery to check against.
85        query: Box<SelectStatement>,
86        /// Whether the IN condition is negated (NOT IN).
87        negated: bool,
88        /// The span of the IN expression.
89        #[serde(with = "oak_core::serde_range")]
90        span: Range<usize>,
91    },
92    /// An error occurred during expression parsing or building.
93    Error {
94        /// The error message.
95        message: Arc<str>,
96        /// The span where the error occurred.
97        #[serde(with = "oak_core::serde_range")]
98        span: Range<usize>,
99    },
100}
101
102impl Expression {
103    pub fn span(&self) -> Range<usize> {
104        match self {
105            Expression::Identifier(id) => id.span.clone(),
106            Expression::Literal(lit) => lit.span().clone(),
107            Expression::Binary { span, .. } => span.clone(),
108            Expression::Unary { span, .. } => span.clone(),
109            Expression::FunctionCall { span, .. } => span.clone(),
110            Expression::InList { span, .. } => span.clone(),
111            Expression::Between { span, .. } => span.clone(),
112            Expression::Subquery { span, .. } => span.clone(),
113            Expression::InSubquery { span, .. } => span.clone(),
114            Expression::Error { span, .. } => span.clone(),
115        }
116    }
117}
118
119impl ToSource for Expression {
120    fn to_source(&self, buffer: &mut SourceBuffer) {
121        match self {
122            Expression::Identifier(id) => id.to_source(buffer),
123            Expression::Literal(lit) => lit.to_source(buffer),
124            Expression::Binary { left, op, right, .. } => {
125                left.to_source(buffer);
126                op.to_source(buffer);
127                right.to_source(buffer);
128            }
129            Expression::Unary { op, expr, .. } => {
130                op.to_source(buffer);
131                expr.to_source(buffer);
132            }
133            Expression::FunctionCall { name, args, .. } => {
134                name.to_source(buffer);
135                buffer.push("(");
136                for (i, arg) in args.iter().enumerate() {
137                    if i > 0 {
138                        buffer.push(",");
139                    }
140                    arg.to_source(buffer);
141                }
142                buffer.push(")");
143            }
144            Expression::InList { expr, list, negated, .. } => {
145                expr.to_source(buffer);
146                if *negated {
147                    buffer.push("NOT");
148                }
149                buffer.push("IN");
150                buffer.push("(");
151                for (i, item) in list.iter().enumerate() {
152                    if i > 0 {
153                        buffer.push(",");
154                    }
155                    item.to_source(buffer);
156                }
157                buffer.push(")");
158            }
159            Expression::Between { expr, low, high, negated, .. } => {
160                expr.to_source(buffer);
161                if *negated {
162                    buffer.push("NOT");
163                }
164                buffer.push("BETWEEN");
165                low.to_source(buffer);
166                buffer.push("AND");
167                high.to_source(buffer);
168            }
169            Expression::Subquery { query, .. } => {
170                buffer.push("(");
171                query.to_source(buffer);
172                buffer.push(")");
173            }
174            Expression::InSubquery { expr, query, negated, .. } => {
175                expr.to_source(buffer);
176                if *negated {
177                    buffer.push("NOT");
178                }
179                buffer.push("IN");
180                buffer.push("(");
181                query.to_source(buffer);
182                buffer.push(")");
183            }
184            Expression::Error { message, .. } => {
185                buffer.push("/* EXPR ERROR: ");
186                buffer.push(message);
187                buffer.push(" */");
188            }
189        }
190    }
191}
192
193/// Binary operators in SQL.
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
196pub enum BinaryOperator {
197    /// Addition (+).
198    Plus,
199    /// Subtraction (-).
200    Minus,
201    /// Multiplication (*).
202    Star,
203    /// Division (/).
204    Slash,
205    /// Modulo (%).
206    Percent,
207    /// Logical AND.
208    And,
209    /// Logical OR.
210    Or,
211    /// Equality (=).
212    Equal,
213    /// Inequality (<> or !=).
214    NotEqual,
215    /// Less than (<).
216    Less,
217    /// Greater than (>).
218    Greater,
219    /// Less than or equal to (<=).
220    LessEqual,
221    /// Greater than or equal to (>=).
222    GreaterEqual,
223    /// Pattern matching (LIKE).
224    Like,
225}
226
227impl ToSource for BinaryOperator {
228    fn to_source(&self, buffer: &mut SourceBuffer) {
229        match self {
230            BinaryOperator::Plus => buffer.push("+"),
231            BinaryOperator::Minus => buffer.push("-"),
232            BinaryOperator::Star => buffer.push("*"),
233            BinaryOperator::Slash => buffer.push("/"),
234            BinaryOperator::Percent => buffer.push("%"),
235            BinaryOperator::And => buffer.push("AND"),
236            BinaryOperator::Or => buffer.push("OR"),
237            BinaryOperator::Equal => buffer.push("="),
238            BinaryOperator::NotEqual => buffer.push("<>"),
239            BinaryOperator::Less => buffer.push("<"),
240            BinaryOperator::Greater => buffer.push(">"),
241            BinaryOperator::LessEqual => buffer.push("<="),
242            BinaryOperator::GreaterEqual => buffer.push(">="),
243            BinaryOperator::Like => buffer.push("LIKE"),
244        }
245    }
246}
247
248/// Unary operators in SQL.
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
250#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
251pub enum UnaryOperator {
252    /// Unary plus (+).
253    Plus,
254    /// Unary minus (-).
255    Minus,
256    /// Logical NOT.
257    Not,
258}
259
260impl ToSource for UnaryOperator {
261    fn to_source(&self, buffer: &mut SourceBuffer) {
262        match self {
263            UnaryOperator::Plus => buffer.push("+"),
264            UnaryOperator::Minus => buffer.push("-"),
265            UnaryOperator::Not => buffer.push("NOT"),
266        }
267    }
268}
269
270/// SQL literals.
271#[derive(Debug, Clone)]
272#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
273pub enum Literal {
274    /// Numeric literal.
275    Number(Arc<str>, #[serde(with = "oak_core::serde_range")] Range<usize>),
276    /// String literal.
277    String(Arc<str>, #[serde(with = "oak_core::serde_range")] Range<usize>),
278    /// Boolean literal.
279    Boolean(bool, #[serde(with = "oak_core::serde_range")] Range<usize>),
280    /// NULL literal.
281    Null(#[serde(with = "oak_core::serde_range")] Range<usize>),
282}
283
284impl Literal {
285    pub fn span(&self) -> Range<usize> {
286        match self {
287            Literal::Number(_, span) => span.clone(),
288            Literal::String(_, span) => span.clone(),
289            Literal::Boolean(_, span) => span.clone(),
290            Literal::Null(span) => span.clone(),
291        }
292    }
293}
294
295impl ToSource for Literal {
296    fn to_source(&self, buffer: &mut SourceBuffer) {
297        match self {
298            Literal::Number(n, _) => buffer.push(n),
299            Literal::String(s, _) => {
300                buffer.push("'");
301                buffer.push(s);
302                buffer.push("'");
303            }
304            Literal::Boolean(b, _) => buffer.push(if *b { "TRUE" } else { "FALSE" }),
305            Literal::Null(_) => buffer.push("NULL"),
306        }
307    }
308}
309
310/// SQL identifier.
311#[derive(Debug, Clone)]
312#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
313pub struct Identifier {
314    /// The name of the identifier.
315    pub name: Arc<str>,
316    /// The span of the identifier.
317    #[serde(with = "oak_core::serde_range")]
318    pub span: Range<usize>,
319}
320
321impl ToSource for Identifier {
322    fn to_source(&self, buffer: &mut SourceBuffer) {
323        buffer.push(&self.name);
324    }
325}
326
327/// SQL table name.
328#[derive(Debug, Clone)]
329#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
330pub struct TableName {
331    /// The name of the table.
332    pub name: Identifier,
333    /// The span of the table name.
334    #[serde(with = "oak_core::serde_range")]
335    pub span: Range<usize>,
336}
337
338impl ToSource for TableName {
339    fn to_source(&self, buffer: &mut SourceBuffer) {
340        self.name.to_source(buffer);
341    }
342}
343
344/// Column name.
345#[derive(Debug, Clone)]
346#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
347pub struct ColumnName {
348    /// The name of the column.
349    pub name: Identifier,
350    /// The span of the column name in the source.
351    #[serde(with = "oak_core::serde_range")]
352    pub span: Range<usize>,
353}
354
355impl ToSource for ColumnName {
356    fn to_source(&self, buffer: &mut SourceBuffer) {
357        self.name.to_source(buffer);
358    }
359}