1use std::fmt;
10
11use crate::error::{Result, SchemaError};
12use crate::span::Span;
13use crate::token::{Token, TokenKind};
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum BinOp {
17 Add,
19 Sub,
21 Mul,
23 Div,
25 Mod,
27 Concat,
29 Lt,
31 Gt,
33}
34
35impl fmt::Display for BinOp {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.write_str(match self {
38 BinOp::Add => "+",
39 BinOp::Sub => "-",
40 BinOp::Mul => "*",
41 BinOp::Div => "/",
42 BinOp::Mod => "%",
43 BinOp::Concat => "||",
44 BinOp::Lt => "<",
45 BinOp::Gt => ">",
46 })
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum UnaryOp {
53 Neg,
55 Pos,
57}
58
59impl fmt::Display for UnaryOp {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.write_str(match self {
62 UnaryOp::Neg => "-",
63 UnaryOp::Pos => "+",
64 })
65 }
66}
67
68#[derive(Debug, Clone, PartialEq)]
70pub enum SqlExpr {
71 Ident(String),
73 Number(String),
75 StringLit(String),
77 Bool(bool),
79 BinaryOp {
81 left: Box<SqlExpr>,
83 op: BinOp,
85 right: Box<SqlExpr>,
87 },
88 UnaryOp {
90 op: UnaryOp,
92 operand: Box<SqlExpr>,
94 },
95 FnCall {
97 name: String,
99 args: Vec<SqlExpr>,
101 },
102 Paren(Box<SqlExpr>),
104}
105impl fmt::Display for SqlExpr {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match self {
108 SqlExpr::Ident(name) => write!(f, "{}", name),
109 SqlExpr::Number(n) => write!(f, "{}", n),
110 SqlExpr::StringLit(s) => write!(f, "\"{}\"", s),
111 SqlExpr::Bool(b) => write!(f, "{}", b),
112 SqlExpr::BinaryOp { left, op, right } => {
113 write!(f, "{} {} {}", left, op, right)
114 }
115 SqlExpr::UnaryOp { op, operand } => write!(f, "{}{}", op, operand),
116 SqlExpr::FnCall { name, args } => {
117 write!(f, "{}(", name)?;
118 for (i, arg) in args.iter().enumerate() {
119 if i > 0 {
120 write!(f, ", ")?;
121 }
122 write!(f, "{}", arg)?;
123 }
124 write!(f, ")")
125 }
126 SqlExpr::Paren(inner) => write!(f, "({})", inner),
127 }
128 }
129}
130
131impl SqlExpr {
132 pub fn to_sql_mapped<F>(&self, map_field: &F) -> String
135 where
136 F: Fn(&str) -> String,
137 {
138 match self {
139 SqlExpr::Ident(name) => map_field(name),
140 SqlExpr::Number(n) => n.clone(),
141 SqlExpr::StringLit(s) => format!("\"{}\"", s),
142 SqlExpr::Bool(b) => b.to_string(),
143 SqlExpr::BinaryOp { left, op, right } => format!(
144 "{} {} {}",
145 left.to_sql_mapped(map_field),
146 op,
147 right.to_sql_mapped(map_field)
148 ),
149 SqlExpr::UnaryOp { op, operand } => {
150 format!("{}{}", op, operand.to_sql_mapped(map_field))
151 }
152 SqlExpr::FnCall { name, args } => {
153 let args_s: Vec<String> = args.iter().map(|a| a.to_sql_mapped(map_field)).collect();
154 format!("{}({})", name, args_s.join(", "))
155 }
156 SqlExpr::Paren(inner) => format!("({})", inner.to_sql_mapped(map_field)),
157 }
158 }
159}
160struct SqlExprParser<'a> {
163 tokens: &'a [Token],
164 pos: usize,
165 fallback_span: Span,
167}
168
169impl<'a> SqlExprParser<'a> {
170 fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
171 Self {
172 tokens,
173 pos: 0,
174 fallback_span,
175 }
176 }
177
178 fn peek(&self) -> Option<&TokenKind> {
179 self.tokens.get(self.pos).map(|t| &t.kind)
180 }
181
182 fn span(&self) -> Span {
183 self.tokens
184 .get(self.pos)
185 .map(|t| t.span)
186 .unwrap_or(self.fallback_span)
187 }
188
189 fn advance(&mut self) -> &Token {
190 let tok = &self.tokens[self.pos];
191 self.pos += 1;
192 tok
193 }
194
195 fn at_end(&self) -> bool {
196 self.pos >= self.tokens.len()
197 }
198
199 fn precedence(op: &BinOp) -> u8 {
201 match op {
202 BinOp::Concat => 1,
203 BinOp::Lt | BinOp::Gt => 2,
204 BinOp::Add | BinOp::Sub => 3,
205 BinOp::Mul | BinOp::Div | BinOp::Mod => 4,
206 }
207 }
208
209 fn token_to_binop(kind: &TokenKind) -> Option<BinOp> {
210 match kind {
211 TokenKind::Plus => Some(BinOp::Add),
212 TokenKind::Minus => Some(BinOp::Sub),
213 TokenKind::Star => Some(BinOp::Mul),
214 TokenKind::Slash => Some(BinOp::Div),
215 TokenKind::Percent => Some(BinOp::Mod),
216 TokenKind::DoublePipe => Some(BinOp::Concat),
217 TokenKind::LAngle => Some(BinOp::Lt),
218 TokenKind::RAngle => Some(BinOp::Gt),
219 _ => None,
220 }
221 }
222
223 fn parse_expr(&mut self) -> Result<SqlExpr> {
224 self.parse_binary(0)
225 }
226
227 fn parse_binary(&mut self, min_prec: u8) -> Result<SqlExpr> {
228 let mut left = self.parse_unary()?;
229
230 while let Some(kind) = self.peek().cloned() {
231 let Some(op) = Self::token_to_binop(&kind) else {
232 break;
233 };
234 let prec = Self::precedence(&op);
235 if prec < min_prec {
236 break;
237 }
238 self.advance();
239 let right = self.parse_binary(prec + 1)?;
240 left = SqlExpr::BinaryOp {
241 left: Box::new(left),
242 op,
243 right: Box::new(right),
244 };
245 }
246
247 Ok(left)
248 }
249
250 fn parse_unary(&mut self) -> Result<SqlExpr> {
251 match self.peek() {
252 Some(TokenKind::Minus) => {
253 self.advance();
254 let operand = self.parse_unary()?;
255 Ok(SqlExpr::UnaryOp {
256 op: UnaryOp::Neg,
257 operand: Box::new(operand),
258 })
259 }
260 Some(TokenKind::Plus) => {
261 self.advance();
262 let operand = self.parse_unary()?;
263 Ok(SqlExpr::UnaryOp {
264 op: UnaryOp::Pos,
265 operand: Box::new(operand),
266 })
267 }
268 _ => self.parse_primary(),
269 }
270 }
271
272 fn parse_primary(&mut self) -> Result<SqlExpr> {
273 if self.at_end() {
274 return Err(SchemaError::Parse(
275 "Unexpected end of SQL expression".to_string(),
276 self.span(),
277 ));
278 }
279
280 match self.peek().cloned() {
281 Some(TokenKind::Number(n)) => {
282 self.advance();
283 Ok(SqlExpr::Number(n))
284 }
285 Some(TokenKind::String(s)) => {
286 self.advance();
287 Ok(SqlExpr::StringLit(s))
288 }
289 Some(TokenKind::True) => {
290 self.advance();
291 Ok(SqlExpr::Bool(true))
292 }
293 Some(TokenKind::False) => {
294 self.advance();
295 Ok(SqlExpr::Bool(false))
296 }
297 Some(TokenKind::Ident(_)) => self.parse_ident_or_call(),
298 Some(k) if k.is_keyword() => self.parse_ident_or_call(),
300 Some(TokenKind::LParen) => {
301 self.advance();
302 let inner = self.parse_expr()?;
303 match self.peek() {
304 Some(TokenKind::RParen) => {
305 self.advance();
306 Ok(SqlExpr::Paren(Box::new(inner)))
307 }
308 _ => Err(SchemaError::Parse(
309 "Expected ')' after parenthesised expression".to_string(),
310 self.span(),
311 )),
312 }
313 }
314 Some(other) => Err(SchemaError::Parse(
315 format!("Unexpected token '{}' in SQL expression", other),
316 self.span(),
317 )),
318 None => Err(SchemaError::Parse(
319 "Unexpected end of SQL expression".to_string(),
320 self.span(),
321 )),
322 }
323 }
324
325 fn parse_ident_or_call(&mut self) -> Result<SqlExpr> {
326 let tok = self.advance();
327 let name = match &tok.kind {
328 TokenKind::Ident(s) => s.clone(),
329 other => other.to_string(),
331 };
332
333 if self.peek() == Some(&TokenKind::LParen) {
334 self.advance();
335 let mut args = Vec::new();
336 if self.peek() != Some(&TokenKind::RParen) {
337 args.push(self.parse_expr()?);
338 while self.peek() == Some(&TokenKind::Comma) {
339 self.advance();
340 args.push(self.parse_expr()?);
341 }
342 }
343 match self.peek() {
344 Some(TokenKind::RParen) => {
345 self.advance();
346 Ok(SqlExpr::FnCall { name, args })
347 }
348 _ => Err(SchemaError::Parse(
349 format!("Expected ')' after arguments of function '{}'", name),
350 self.span(),
351 )),
352 }
353 } else {
354 Ok(SqlExpr::Ident(name))
355 }
356 }
357}
358pub fn parse_sql_expr(tokens: &[Token], fallback_span: Span) -> Result<SqlExpr> {
365 if tokens.is_empty() {
366 return Err(SchemaError::Parse(
367 "@computed expression is empty".to_string(),
368 fallback_span,
369 ));
370 }
371
372 let mut parser = SqlExprParser::new(tokens, fallback_span);
373 let expr = parser.parse_expr()?;
374
375 if !parser.at_end() {
376 return Err(SchemaError::Parse(
377 format!(
378 "Unexpected token '{}' after SQL expression",
379 parser.tokens[parser.pos].kind
380 ),
381 parser.span(),
382 ));
383 }
384
385 Ok(expr)
386}
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::lexer::Lexer;
391
392 fn tokenize(src: &str) -> Vec<Token> {
394 let mut lexer = Lexer::new(src);
395 let mut tokens = Vec::new();
396 loop {
397 let tok = lexer.next_token().expect("lex error");
398 match tok.kind {
399 TokenKind::Eof => break,
400 TokenKind::Newline => continue,
401 _ => tokens.push(tok),
402 }
403 }
404 tokens
405 }
406
407 fn parse(src: &str) -> SqlExpr {
408 let tokens = tokenize(src);
409 parse_sql_expr(&tokens, Span::new(0, 0)).expect("parse error")
410 }
411
412 fn parse_err(src: &str) -> String {
413 let tokens = tokenize(src);
414 match parse_sql_expr(&tokens, Span::new(0, 0)) {
415 Err(e) => format!("{}", e),
416 Ok(expr) => panic!("Expected error, got: {:?}", expr),
417 }
418 }
419
420 #[test]
421 fn simple_ident() {
422 assert_eq!(parse("price").to_string(), "price");
423 }
424
425 #[test]
426 fn binary_mul() {
427 let expr = parse("price * quantity");
428 assert_eq!(expr.to_string(), "price * quantity");
429 }
430
431 #[test]
432 fn precedence_add_mul() {
433 let expr = parse("a + b * c");
434 assert!(matches!(expr, SqlExpr::BinaryOp { op: BinOp::Add, .. }));
435 }
436
437 #[test]
438 fn concat_operator() {
439 let expr = parse("first_name || \" \" || last_name");
440 assert_eq!(expr.to_string(), "first_name || \" \" || last_name");
441 }
442
443 #[test]
444 fn function_call() {
445 let expr = parse("COALESCE(a, b)");
446 assert!(matches!(expr, SqlExpr::FnCall { .. }));
447 assert_eq!(expr.to_string(), "COALESCE(a, b)");
448 }
449
450 #[test]
451 fn nested_function() {
452 let expr = parse("UPPER(TRIM(name))");
453 assert_eq!(expr.to_string(), "UPPER(TRIM(name))");
454 }
455
456 #[test]
457 fn paren_expr() {
458 let expr = parse("(a + b) * c");
459 assert_eq!(expr.to_string(), "(a + b) * c");
460 }
461
462 #[test]
463 fn unary_neg() {
464 let expr = parse("-amount");
465 assert_eq!(expr.to_string(), "-amount");
466 }
467
468 #[test]
469 fn number_literal() {
470 let expr = parse("score * 10");
471 assert_eq!(expr.to_string(), "score * 10");
472 }
473
474 #[test]
475 fn boolean_literal() {
476 let expr = parse("true");
477 assert_eq!(expr.to_string(), "true");
478 }
479
480 #[test]
481 fn complex_expr() {
482 let expr = parse("(price * quantity) - COALESCE(discount, 0)");
483 assert_eq!(
484 expr.to_string(),
485 "(price * quantity) - COALESCE(discount, 0)"
486 );
487 }
488 #[test]
489 fn empty_is_error() {
490 let tokens: Vec<Token> = vec![];
491 assert!(parse_sql_expr(&tokens, Span::new(0, 0)).is_err());
492 }
493
494 #[test]
495 fn only_operators_is_error() {
496 let err = parse_err("* * *");
497 assert!(err.contains("Unexpected token"));
498 }
499
500 #[test]
501 fn trailing_operator_is_error() {
502 let err = parse_err("a +");
503 assert!(err.contains("Unexpected end"));
504 }
505
506 #[test]
507 fn unclosed_paren_is_error() {
508 let err = parse_err("(a + b");
509 assert!(err.contains("Expected ')'"));
510 }
511
512 #[test]
513 fn double_operator_is_error() {
514 let err = parse_err("a + * b");
515 assert!(err.contains("Unexpected token"));
516 }
517}