1use std::fmt;
10
11use crate::error::{Result, SchemaError};
12use crate::span::Span;
13use crate::token::{Token, TokenKind};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum BinOp {
22 Add,
24 Sub,
26 Mul,
28 Div,
30 Mod,
32 Concat,
34 Lt,
36 Gt,
38}
39
40impl fmt::Display for BinOp {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.write_str(match self {
43 BinOp::Add => "+",
44 BinOp::Sub => "-",
45 BinOp::Mul => "*",
46 BinOp::Div => "/",
47 BinOp::Mod => "%",
48 BinOp::Concat => "||",
49 BinOp::Lt => "<",
50 BinOp::Gt => ">",
51 })
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum UnaryOp {
58 Neg,
60 Pos,
62}
63
64impl fmt::Display for UnaryOp {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.write_str(match self {
67 UnaryOp::Neg => "-",
68 UnaryOp::Pos => "+",
69 })
70 }
71}
72
73#[derive(Debug, Clone, PartialEq)]
75pub enum SqlExpr {
76 Ident(String),
78 Number(String),
80 StringLit(String),
82 Bool(bool),
84 BinaryOp {
86 left: Box<SqlExpr>,
88 op: BinOp,
90 right: Box<SqlExpr>,
92 },
93 UnaryOp {
95 op: UnaryOp,
97 operand: Box<SqlExpr>,
99 },
100 FnCall {
102 name: String,
104 args: Vec<SqlExpr>,
106 },
107 Paren(Box<SqlExpr>),
109}
110
111impl fmt::Display for SqlExpr {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 match self {
118 SqlExpr::Ident(name) => write!(f, "{}", name),
119 SqlExpr::Number(n) => write!(f, "{}", n),
120 SqlExpr::StringLit(s) => write!(f, "\"{}\"", s),
121 SqlExpr::Bool(b) => write!(f, "{}", b),
122 SqlExpr::BinaryOp { left, op, right } => {
123 write!(f, "{} {} {}", left, op, right)
124 }
125 SqlExpr::UnaryOp { op, operand } => write!(f, "{}{}", op, operand),
126 SqlExpr::FnCall { name, args } => {
127 write!(f, "{}(", name)?;
128 for (i, arg) in args.iter().enumerate() {
129 if i > 0 {
130 write!(f, ", ")?;
131 }
132 write!(f, "{}", arg)?;
133 }
134 write!(f, ")")
135 }
136 SqlExpr::Paren(inner) => write!(f, "({})", inner),
137 }
138 }
139}
140
141struct SqlExprParser<'a> {
148 tokens: &'a [Token],
149 pos: usize,
150 fallback_span: Span,
152}
153
154impl<'a> SqlExprParser<'a> {
155 fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
156 Self {
157 tokens,
158 pos: 0,
159 fallback_span,
160 }
161 }
162
163 fn peek(&self) -> Option<&TokenKind> {
164 self.tokens.get(self.pos).map(|t| &t.kind)
165 }
166
167 fn span(&self) -> Span {
168 self.tokens
169 .get(self.pos)
170 .map(|t| t.span)
171 .unwrap_or(self.fallback_span)
172 }
173
174 fn advance(&mut self) -> &Token {
175 let tok = &self.tokens[self.pos];
176 self.pos += 1;
177 tok
178 }
179
180 fn at_end(&self) -> bool {
181 self.pos >= self.tokens.len()
182 }
183
184 fn precedence(op: &BinOp) -> u8 {
186 match op {
187 BinOp::Concat => 1,
188 BinOp::Lt | BinOp::Gt => 2,
189 BinOp::Add | BinOp::Sub => 3,
190 BinOp::Mul | BinOp::Div | BinOp::Mod => 4,
191 }
192 }
193
194 fn token_to_binop(kind: &TokenKind) -> Option<BinOp> {
195 match kind {
196 TokenKind::Plus => Some(BinOp::Add),
197 TokenKind::Minus => Some(BinOp::Sub),
198 TokenKind::Star => Some(BinOp::Mul),
199 TokenKind::Slash => Some(BinOp::Div),
200 TokenKind::Percent => Some(BinOp::Mod),
201 TokenKind::DoublePipe => Some(BinOp::Concat),
202 TokenKind::LAngle => Some(BinOp::Lt),
203 TokenKind::RAngle => Some(BinOp::Gt),
204 _ => None,
205 }
206 }
207
208 fn parse_expr(&mut self) -> Result<SqlExpr> {
209 self.parse_binary(0)
210 }
211
212 fn parse_binary(&mut self, min_prec: u8) -> Result<SqlExpr> {
213 let mut left = self.parse_unary()?;
214
215 while let Some(kind) = self.peek().cloned() {
216 let Some(op) = Self::token_to_binop(&kind) else {
217 break;
218 };
219 let prec = Self::precedence(&op);
220 if prec < min_prec {
221 break;
222 }
223 self.advance();
224 let right = self.parse_binary(prec + 1)?;
225 left = SqlExpr::BinaryOp {
226 left: Box::new(left),
227 op,
228 right: Box::new(right),
229 };
230 }
231
232 Ok(left)
233 }
234
235 fn parse_unary(&mut self) -> Result<SqlExpr> {
236 match self.peek() {
237 Some(TokenKind::Minus) => {
238 self.advance();
239 let operand = self.parse_unary()?;
240 Ok(SqlExpr::UnaryOp {
241 op: UnaryOp::Neg,
242 operand: Box::new(operand),
243 })
244 }
245 Some(TokenKind::Plus) => {
246 self.advance();
247 let operand = self.parse_unary()?;
248 Ok(SqlExpr::UnaryOp {
249 op: UnaryOp::Pos,
250 operand: Box::new(operand),
251 })
252 }
253 _ => self.parse_primary(),
254 }
255 }
256
257 fn parse_primary(&mut self) -> Result<SqlExpr> {
258 if self.at_end() {
259 return Err(SchemaError::Parse(
260 "Unexpected end of SQL expression".to_string(),
261 self.span(),
262 ));
263 }
264
265 match self.peek().cloned() {
266 Some(TokenKind::Number(n)) => {
267 self.advance();
268 Ok(SqlExpr::Number(n))
269 }
270 Some(TokenKind::String(s)) => {
271 self.advance();
272 Ok(SqlExpr::StringLit(s))
273 }
274 Some(TokenKind::True) => {
275 self.advance();
276 Ok(SqlExpr::Bool(true))
277 }
278 Some(TokenKind::False) => {
279 self.advance();
280 Ok(SqlExpr::Bool(false))
281 }
282 Some(TokenKind::Ident(_)) => self.parse_ident_or_call(),
283 Some(k) if k.is_keyword() => self.parse_ident_or_call(),
285 Some(TokenKind::LParen) => {
286 self.advance();
287 let inner = self.parse_expr()?;
288 match self.peek() {
289 Some(TokenKind::RParen) => {
290 self.advance();
291 Ok(SqlExpr::Paren(Box::new(inner)))
292 }
293 _ => Err(SchemaError::Parse(
294 "Expected ')' after parenthesised expression".to_string(),
295 self.span(),
296 )),
297 }
298 }
299 Some(other) => Err(SchemaError::Parse(
300 format!("Unexpected token '{}' in SQL expression", other),
301 self.span(),
302 )),
303 None => Err(SchemaError::Parse(
304 "Unexpected end of SQL expression".to_string(),
305 self.span(),
306 )),
307 }
308 }
309
310 fn parse_ident_or_call(&mut self) -> Result<SqlExpr> {
311 let tok = self.advance();
312 let name = match &tok.kind {
313 TokenKind::Ident(s) => s.clone(),
314 other => other.to_string(),
316 };
317
318 if self.peek() == Some(&TokenKind::LParen) {
319 self.advance();
320 let mut args = Vec::new();
321 if self.peek() != Some(&TokenKind::RParen) {
322 args.push(self.parse_expr()?);
323 while self.peek() == Some(&TokenKind::Comma) {
324 self.advance();
325 args.push(self.parse_expr()?);
326 }
327 }
328 match self.peek() {
329 Some(TokenKind::RParen) => {
330 self.advance();
331 Ok(SqlExpr::FnCall { name, args })
332 }
333 _ => Err(SchemaError::Parse(
334 format!("Expected ')' after arguments of function '{}'", name),
335 self.span(),
336 )),
337 }
338 } else {
339 Ok(SqlExpr::Ident(name))
340 }
341 }
342}
343
344pub fn parse_sql_expr(tokens: &[Token], fallback_span: Span) -> Result<SqlExpr> {
355 if tokens.is_empty() {
356 return Err(SchemaError::Parse(
357 "@computed expression is empty".to_string(),
358 fallback_span,
359 ));
360 }
361
362 let mut parser = SqlExprParser::new(tokens, fallback_span);
363 let expr = parser.parse_expr()?;
364
365 if !parser.at_end() {
366 return Err(SchemaError::Parse(
367 format!(
368 "Unexpected token '{}' after SQL expression",
369 parser.tokens[parser.pos].kind
370 ),
371 parser.span(),
372 ));
373 }
374
375 Ok(expr)
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::lexer::Lexer;
386
387 fn tokenize(src: &str) -> Vec<Token> {
389 let mut lexer = Lexer::new(src);
390 let mut tokens = Vec::new();
391 loop {
392 let tok = lexer.next_token().expect("lex error");
393 match tok.kind {
394 TokenKind::Eof => break,
395 TokenKind::Newline => continue,
396 _ => tokens.push(tok),
397 }
398 }
399 tokens
400 }
401
402 fn parse(src: &str) -> SqlExpr {
403 let tokens = tokenize(src);
404 parse_sql_expr(&tokens, Span::new(0, 0)).expect("parse error")
405 }
406
407 fn parse_err(src: &str) -> String {
408 let tokens = tokenize(src);
409 match parse_sql_expr(&tokens, Span::new(0, 0)) {
410 Err(e) => format!("{}", e),
411 Ok(expr) => panic!("Expected error, got: {:?}", expr),
412 }
413 }
414
415 #[test]
418 fn simple_ident() {
419 assert_eq!(parse("price").to_string(), "price");
420 }
421
422 #[test]
423 fn binary_mul() {
424 let expr = parse("price * quantity");
425 assert_eq!(expr.to_string(), "price * quantity");
426 }
427
428 #[test]
429 fn precedence_add_mul() {
430 let expr = parse("a + b * c");
432 assert!(matches!(expr, SqlExpr::BinaryOp { op: BinOp::Add, .. }));
433 }
434
435 #[test]
436 fn concat_operator() {
437 let expr = parse("first_name || \" \" || last_name");
438 assert_eq!(expr.to_string(), "first_name || \" \" || last_name");
439 }
440
441 #[test]
442 fn function_call() {
443 let expr = parse("COALESCE(a, b)");
444 assert!(matches!(expr, SqlExpr::FnCall { .. }));
445 assert_eq!(expr.to_string(), "COALESCE(a, b)");
446 }
447
448 #[test]
449 fn nested_function() {
450 let expr = parse("UPPER(TRIM(name))");
451 assert_eq!(expr.to_string(), "UPPER(TRIM(name))");
452 }
453
454 #[test]
455 fn paren_expr() {
456 let expr = parse("(a + b) * c");
457 assert_eq!(expr.to_string(), "(a + b) * c");
458 }
459
460 #[test]
461 fn unary_neg() {
462 let expr = parse("-amount");
463 assert_eq!(expr.to_string(), "-amount");
464 }
465
466 #[test]
467 fn number_literal() {
468 let expr = parse("score * 10");
469 assert_eq!(expr.to_string(), "score * 10");
470 }
471
472 #[test]
473 fn boolean_literal() {
474 let expr = parse("true");
475 assert_eq!(expr.to_string(), "true");
476 }
477
478 #[test]
479 fn complex_expr() {
480 let expr = parse("(price * quantity) - COALESCE(discount, 0)");
481 assert_eq!(
482 expr.to_string(),
483 "(price * quantity) - COALESCE(discount, 0)"
484 );
485 }
486
487 #[test]
490 fn empty_is_error() {
491 let tokens: Vec<Token> = vec![];
492 assert!(parse_sql_expr(&tokens, Span::new(0, 0)).is_err());
493 }
494
495 #[test]
496 fn only_operators_is_error() {
497 let err = parse_err("* * *");
498 assert!(err.contains("Unexpected token"));
499 }
500
501 #[test]
502 fn trailing_operator_is_error() {
503 let err = parse_err("a +");
504 assert!(err.contains("Unexpected end"));
505 }
506
507 #[test]
508 fn unclosed_paren_is_error() {
509 let err = parse_err("(a + b");
510 assert!(err.contains("Expected ')'"));
511 }
512
513 #[test]
514 fn double_operator_is_error() {
515 let err = parse_err("a + * b");
516 assert!(err.contains("Unexpected token"));
517 }
518}