1use std::{error::Error, fmt::Display};
4
5use sqlparser::ast;
6
7use crate::{
8 identifier::{ColumnRef, IdentifierError},
9 value::{Value, ValueError},
10 BoundedString,
11};
12
13pub mod eval;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum Expr {
18 Value(Value),
19 ColumnRef(ColumnRef),
20 Wildcard,
21 Binary {
22 left: Box<Expr>,
23 op: BinOp,
24 right: Box<Expr>,
25 },
26 Unary {
27 op: UnOp,
28 operand: Box<Expr>,
29 },
30 Function {
31 name: BoundedString,
32 args: Vec<Expr>,
33 },
34}
35
36impl Display for Expr {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Value(v) => write!(f, "{}", v),
40 Self::ColumnRef(c) => write!(f, "column '{}'", c),
41 Self::Wildcard => write!(f, "*"),
42 Self::Binary { left, op, right } => write!(f, "({} {} {})", left, op, right),
43 Self::Unary { op, operand } => write!(f, "{}{}", op, operand),
44 Self::Function { name, args } => write!(
45 f,
46 "{}({})",
47 name,
48 args.iter()
49 .map(|a| a.to_string())
50 .collect::<Vec<String>>()
51 .join(", ")
52 ),
53 }
54 }
55}
56
57#[derive(Debug, Copy, Clone, PartialEq)]
59pub enum BinOp {
60 Plus,
61 Minus,
62 Multiply,
63 Divide,
64 Modulo,
65 Equal,
66 NotEqual,
67 LessThan,
68 LessThanOrEqual,
69 GreaterThan,
70 GreaterThanOrEqual,
71 Like,
72 ILike,
73 And,
74 Or,
75}
76
77impl Display for BinOp {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 write!(
80 f,
81 "{}",
82 match self {
83 BinOp::Plus => "+",
84 BinOp::Minus => "-",
85 BinOp::Multiply => "*",
86 BinOp::Divide => "/",
87 BinOp::Modulo => "%",
88 BinOp::Equal => "=",
89 BinOp::NotEqual => "!=",
90 BinOp::LessThan => "<",
91 BinOp::LessThanOrEqual => "<=",
92 BinOp::GreaterThan => ">",
93 BinOp::GreaterThanOrEqual => ">=",
94 BinOp::Like => "LIKE",
95 BinOp::ILike => "ILIKE",
96 BinOp::And => "AND",
97 BinOp::Or => "OR",
98 }
99 )
100 }
101}
102
103#[derive(Debug, Copy, Clone, PartialEq)]
105pub enum UnOp {
106 Plus,
107 Minus,
108 Not,
109 IsFalse,
110 IsTrue,
111 IsNull,
112 IsNotNull,
113}
114
115impl Display for UnOp {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(
118 f,
119 "{}",
120 match self {
121 UnOp::Plus => "+",
122 UnOp::Minus => "-",
123 UnOp::Not => "NOT",
124 UnOp::IsFalse => "IS FALSE",
125 UnOp::IsTrue => "IS TRUE",
126 UnOp::IsNull => "IS NULL",
127 UnOp::IsNotNull => "IS NOT NULL",
128 }
129 )
130 }
131}
132
133impl TryFrom<ast::Expr> for Expr {
134 type Error = ExprError;
135 fn try_from(expr_ast: ast::Expr) -> Result<Self, Self::Error> {
136 match expr_ast {
137 ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)),
138 ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)),
139 ast::Expr::IsFalse(e) => Ok(Expr::Unary {
140 op: UnOp::IsFalse,
141 operand: Box::new((*e).try_into()?),
142 }),
143 ast::Expr::IsTrue(e) => Ok(Expr::Unary {
144 op: UnOp::IsTrue,
145 operand: Box::new((*e).try_into()?),
146 }),
147 ast::Expr::IsNull(e) => Ok(Expr::Unary {
148 op: UnOp::IsNull,
149 operand: Box::new((*e).try_into()?),
150 }),
151 ast::Expr::IsNotNull(e) => Ok(Expr::Unary {
152 op: UnOp::IsNotNull,
153 operand: Box::new((*e).try_into()?),
154 }),
155 ast::Expr::Between {
156 expr,
157 negated,
158 low,
159 high,
160 } => {
161 let expr: Box<Expr> = Box::new((*expr).try_into()?);
162 let left = Box::new((*low).try_into()?);
163 let right = Box::new((*high).try_into()?);
164 let between = Expr::Binary {
165 left: Box::new(Expr::Binary {
166 left,
167 op: BinOp::LessThanOrEqual,
168 right: expr.clone(),
169 }),
170 op: BinOp::And,
171 right: Box::new(Expr::Binary {
172 left: expr,
173 op: BinOp::LessThanOrEqual,
174 right,
175 }),
176 };
177 if negated {
178 Ok(Expr::Unary {
179 op: UnOp::Not,
180 operand: Box::new(between),
181 })
182 } else {
183 Ok(between)
184 }
185 }
186 ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary {
187 left: Box::new((*left).try_into()?),
188 op: op.try_into()?,
189 right: Box::new((*right).try_into()?),
190 }),
191 ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary {
192 op: op.try_into()?,
193 operand: Box::new((*expr).try_into()?),
194 }),
195 ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)),
196 ast::Expr::Function(ref f) => Ok(Expr::Function {
197 name: f.name.to_string().as_str().into(),
198 args: f
199 .args
200 .iter()
201 .map(|arg| match arg {
202 ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
203 ast::FunctionArgExpr::Expr(e) => Ok(e.clone().try_into()?),
204 ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard),
205 ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr {
206 reason: "Qualified wildcards are not supported yet",
207 expr: expr_ast.clone(),
208 }),
209 },
210 ast::FunctionArg::Named { .. } => Err(ExprError::Expr {
211 reason: "Named function arguments are not supported",
212 expr: expr_ast.clone(),
213 }),
214 })
215 .collect::<Result<Vec<_>, _>>()?,
216 }),
217 _ => Err(ExprError::Expr {
218 reason: "Unsupported expression",
219 expr: expr_ast,
220 }),
221 }
222 }
223}
224
225impl TryFrom<ast::BinaryOperator> for BinOp {
226 type Error = ExprError;
227 fn try_from(op: ast::BinaryOperator) -> Result<Self, Self::Error> {
228 match op {
229 ast::BinaryOperator::Plus => Ok(BinOp::Plus),
230 ast::BinaryOperator::Minus => Ok(BinOp::Minus),
231 ast::BinaryOperator::Multiply => Ok(BinOp::Multiply),
232 ast::BinaryOperator::Divide => Ok(BinOp::Divide),
233 ast::BinaryOperator::Modulo => Ok(BinOp::Modulo),
234 ast::BinaryOperator::Eq => Ok(BinOp::Equal),
235 ast::BinaryOperator::NotEq => Ok(BinOp::NotEqual),
236 ast::BinaryOperator::Lt => Ok(BinOp::LessThan),
237 ast::BinaryOperator::LtEq => Ok(BinOp::LessThanOrEqual),
238 ast::BinaryOperator::Gt => Ok(BinOp::GreaterThan),
239 ast::BinaryOperator::GtEq => Ok(BinOp::GreaterThanOrEqual),
240 ast::BinaryOperator::Like => Ok(BinOp::Like),
241 ast::BinaryOperator::ILike => Ok(BinOp::ILike),
242 ast::BinaryOperator::And => Ok(BinOp::And),
243 ast::BinaryOperator::Or => Ok(BinOp::Or),
244 _ => Err(ExprError::Binary {
246 reason: "Unknown binary operator",
247 op,
248 }),
249 }
250 }
251}
252
253impl TryFrom<ast::UnaryOperator> for UnOp {
254 type Error = ExprError;
255 fn try_from(op: ast::UnaryOperator) -> Result<Self, Self::Error> {
256 match op {
257 ast::UnaryOperator::Plus => Ok(UnOp::Plus),
258 ast::UnaryOperator::Minus => Ok(UnOp::Minus),
259 ast::UnaryOperator::Not => Ok(UnOp::Not),
260 _ => Err(ExprError::Unary {
263 reason: "Unkown unary operator",
264 op,
265 }),
266 }
267 }
268}
269
270#[derive(Debug, PartialEq)]
272pub enum ExprError {
273 Expr {
274 reason: &'static str,
275 expr: ast::Expr,
276 },
277 Binary {
278 reason: &'static str,
279 op: ast::BinaryOperator,
280 },
281 Unary {
282 reason: &'static str,
283 op: ast::UnaryOperator,
284 },
285 Value(ValueError),
286 Identifier(IdentifierError),
287}
288
289impl Display for ExprError {
290 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
291 match self {
292 ExprError::Expr { reason, expr } => {
293 write!(f, "ExprError: {}: {}", reason, expr)
294 }
295 ExprError::Binary { reason, op } => {
296 write!(f, "ExprError: {}: {}", reason, op)
297 }
298 ExprError::Unary { reason, op } => {
299 write!(f, "ExprError: {}: {}", reason, op)
300 }
301 ExprError::Value(v) => write!(f, "{}", v),
302 ExprError::Identifier(v) => write!(f, "{}", v),
303 }
304 }
305}
306
307impl From<ValueError> for ExprError {
308 fn from(v: ValueError) -> Self {
309 Self::Value(v)
310 }
311}
312
313impl From<IdentifierError> for ExprError {
314 fn from(i: IdentifierError) -> Self {
315 Self::Identifier(i)
316 }
317}
318
319impl Error for ExprError {}
320
321#[cfg(test)]
322mod tests {
323 use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer};
324
325 use crate::{
326 expr::{BinOp, Expr, UnOp},
327 identifier::ColumnRef,
328 value::Value,
329 };
330
331 #[test]
332 fn conversion_from_ast() {
333 fn parse_expr(s: &str) -> ast::Expr {
334 let dialect = GenericDialect {};
335 let mut tokenizer = Tokenizer::new(&dialect, s);
336 let tokens = tokenizer.tokenize().unwrap();
337 let mut parser = Parser::new(tokens, &dialect);
338 parser.parse_expr().unwrap()
339 }
340
341 assert_eq!(
342 parse_expr("abc").try_into(),
343 Ok(Expr::ColumnRef(ColumnRef {
344 schema_name: None,
345 table_name: None,
346 col_name: "abc".into()
347 }))
348 );
349
350 assert_ne!(
351 parse_expr("abc").try_into(),
352 Ok(Expr::ColumnRef(ColumnRef {
353 schema_name: None,
354 table_name: None,
355 col_name: "cab".into()
356 }))
357 );
358
359 assert_eq!(
360 parse_expr("table1.col1").try_into(),
361 Ok(Expr::ColumnRef(ColumnRef {
362 schema_name: None,
363 table_name: Some("table1".into()),
364 col_name: "col1".into()
365 }))
366 );
367
368 assert_eq!(
369 parse_expr("schema1.table1.col1").try_into(),
370 Ok(Expr::ColumnRef(ColumnRef {
371 schema_name: Some("schema1".into()),
372 table_name: Some("table1".into()),
373 col_name: "col1".into()
374 }))
375 );
376
377 assert_eq!(
378 parse_expr("5 IS NULL").try_into(),
379 Ok(Expr::Unary {
380 op: UnOp::IsNull,
381 operand: Box::new(Expr::Value(Value::Int64(5)))
382 })
383 );
384
385 assert_eq!(
386 parse_expr("1 IS TRUE").try_into(),
387 Ok(Expr::Unary {
388 op: UnOp::IsTrue,
389 operand: Box::new(Expr::Value(Value::Int64(1)))
390 })
391 );
392
393 assert_eq!(
394 parse_expr("4 BETWEEN 3 AND 5").try_into(),
395 Ok(Expr::Binary {
396 left: Box::new(Expr::Binary {
397 left: Box::new(Expr::Value(Value::Int64(3))),
398 op: BinOp::LessThanOrEqual,
399 right: Box::new(Expr::Value(Value::Int64(4)))
400 }),
401 op: BinOp::And,
402 right: Box::new(Expr::Binary {
403 left: Box::new(Expr::Value(Value::Int64(4))),
404 op: BinOp::LessThanOrEqual,
405 right: Box::new(Expr::Value(Value::Int64(5)))
406 })
407 })
408 );
409
410 assert_eq!(
411 parse_expr("4 NOT BETWEEN 3 AND 5").try_into(),
412 Ok(Expr::Unary {
413 op: UnOp::Not,
414 operand: Box::new(Expr::Binary {
415 left: Box::new(Expr::Binary {
416 left: Box::new(Expr::Value(Value::Int64(3))),
417 op: BinOp::LessThanOrEqual,
418 right: Box::new(Expr::Value(Value::Int64(4)))
419 }),
420 op: BinOp::And,
421 right: Box::new(Expr::Binary {
422 left: Box::new(Expr::Value(Value::Int64(4))),
423 op: BinOp::LessThanOrEqual,
424 right: Box::new(Expr::Value(Value::Int64(5)))
425 })
426 })
427 })
428 );
429
430 assert_eq!(
431 parse_expr("MAX(col1)").try_into(),
432 Ok(Expr::Function {
433 name: "MAX".into(),
434 args: vec![Expr::ColumnRef(ColumnRef {
435 schema_name: None,
436 table_name: None,
437 col_name: "col1".into()
438 })]
439 })
440 );
441
442 assert_eq!(
443 parse_expr("some_func(col1, 1, 'abc')").try_into(),
444 Ok(Expr::Function {
445 name: "some_func".into(),
446 args: vec![
447 Expr::ColumnRef(ColumnRef {
448 schema_name: None,
449 table_name: None,
450 col_name: "col1".into()
451 }),
452 Expr::Value(Value::Int64(1)),
453 Expr::Value(Value::String("abc".to_owned()))
454 ]
455 })
456 );
457
458 assert_eq!(
459 parse_expr("COUNT(*)").try_into(),
460 Ok(Expr::Function {
461 name: "COUNT".into(),
462 args: vec![Expr::Wildcard]
463 })
464 );
465 }
466}