Skip to main content

oxigdal_query/parser/
sql.rs

1//! SQL parser implementation.
2
3use crate::error::{QueryError, Result};
4use crate::parser::ast::*;
5use sqlparser::ast as sql_ast;
6use sqlparser::dialect::GenericDialect;
7use sqlparser::parser::Parser as SqlParser;
8
9/// Parse SQL string into AST.
10pub fn parse_sql(sql: &str) -> Result<Statement> {
11    let dialect = GenericDialect {};
12    let statements = SqlParser::parse_sql(&dialect, sql)?;
13
14    if statements.is_empty() {
15        return Err(QueryError::semantic("Empty SQL statement"));
16    }
17
18    if statements.len() > 1 {
19        return Err(QueryError::semantic("Multiple statements not supported"));
20    }
21
22    convert_statement(&statements[0])
23}
24
25fn convert_statement(stmt: &sql_ast::Statement) -> Result<Statement> {
26    match stmt {
27        sql_ast::Statement::Query(query) => {
28            let select = convert_query(query)?;
29            Ok(Statement::Select(select))
30        }
31        _ => Err(QueryError::unsupported("Only SELECT statements supported")),
32    }
33}
34
35fn convert_query(query: &sql_ast::Query) -> Result<SelectStatement> {
36    if let sql_ast::SetExpr::Select(select) = &*query.body {
37        let mut stmt = SelectStatement {
38            projection: Vec::new(),
39            from: None,
40            selection: None,
41            group_by: Vec::new(),
42            having: None,
43            order_by: Vec::new(),
44            limit: None,
45            offset: None,
46        };
47
48        // Convert projection
49        for item in &select.projection {
50            stmt.projection.push(convert_select_item(item)?);
51        }
52
53        // Convert FROM clause
54        if !select.from.is_empty() {
55            stmt.from = Some(convert_table_reference(&select.from[0])?);
56        }
57
58        // Convert WHERE clause
59        if let Some(selection) = &select.selection {
60            stmt.selection = Some(convert_expr(selection)?);
61        }
62
63        // Convert GROUP BY clause
64        match &select.group_by {
65            sql_ast::GroupByExpr::Expressions(exprs, _) => {
66                for expr in exprs {
67                    stmt.group_by.push(convert_expr(expr)?);
68                }
69            }
70            sql_ast::GroupByExpr::All(_) => {
71                return Err(QueryError::unsupported("GROUP BY ALL not supported"));
72            }
73        }
74
75        // Convert HAVING clause
76        if let Some(having) = &select.having {
77            stmt.having = Some(convert_expr(having)?);
78        }
79
80        // Convert ORDER BY clause
81        if let Some(order_by) = &query.order_by {
82            if let sql_ast::OrderByKind::Expressions(exprs) = &order_by.kind {
83                for order_expr in exprs {
84                    stmt.order_by.push(convert_order_by_expr(order_expr)?);
85                }
86            }
87        }
88
89        // Convert LIMIT and OFFSET clauses
90        if let Some(limit_clause) = &query.limit_clause {
91            match limit_clause {
92                sql_ast::LimitClause::LimitOffset { limit, offset, .. } => {
93                    if let Some(limit_expr) = limit {
94                        stmt.limit = Some(convert_limit(limit_expr)?);
95                    }
96                    if let Some(offset_val) = offset {
97                        stmt.offset = Some(convert_offset(offset_val)?);
98                    }
99                }
100                sql_ast::LimitClause::OffsetCommaLimit { limit, offset } => {
101                    stmt.limit = Some(convert_limit(limit)?);
102                    stmt.offset = Some(convert_limit(offset)?);
103                }
104            }
105        }
106
107        Ok(stmt)
108    } else {
109        Err(QueryError::unsupported(
110            "Only simple SELECT queries supported",
111        ))
112    }
113}
114
115fn convert_select_item(item: &sql_ast::SelectItem) -> Result<SelectItem> {
116    match item {
117        sql_ast::SelectItem::UnnamedExpr(expr) => Ok(SelectItem::Expr {
118            expr: convert_expr(expr)?,
119            alias: None,
120        }),
121        sql_ast::SelectItem::ExprWithAlias { expr, alias } => Ok(SelectItem::Expr {
122            expr: convert_expr(expr)?,
123            alias: Some(alias.value.clone()),
124        }),
125        sql_ast::SelectItem::Wildcard(_) => Ok(SelectItem::Wildcard),
126        sql_ast::SelectItem::QualifiedWildcard(obj_name, _) => {
127            Ok(SelectItem::QualifiedWildcard(obj_name.to_string()))
128        }
129    }
130}
131
132fn convert_table_reference(table: &sql_ast::TableWithJoins) -> Result<TableReference> {
133    let mut result = convert_table_factor(&table.relation)?;
134
135    for join in &table.joins {
136        let right = convert_table_factor(&join.relation)?;
137        let join_type = convert_join_type(&join.join_operator)?;
138        let on = match &join.join_operator {
139            sql_ast::JoinOperator::Inner(constraint)
140            | sql_ast::JoinOperator::LeftOuter(constraint)
141            | sql_ast::JoinOperator::RightOuter(constraint)
142            | sql_ast::JoinOperator::FullOuter(constraint) => convert_join_constraint(constraint)?,
143            sql_ast::JoinOperator::CrossJoin(_) => None,
144            _ => return Err(QueryError::unsupported("Unsupported join type")),
145        };
146
147        result = TableReference::Join {
148            left: Box::new(result),
149            right: Box::new(right),
150            join_type,
151            on,
152        };
153    }
154
155    Ok(result)
156}
157
158fn convert_table_factor(factor: &sql_ast::TableFactor) -> Result<TableReference> {
159    match factor {
160        sql_ast::TableFactor::Table {
161            name, alias, args, ..
162        } => {
163            if args.is_some() {
164                return Err(QueryError::unsupported("Table functions not supported"));
165            }
166            Ok(TableReference::Table {
167                name: name.to_string(),
168                alias: alias.as_ref().map(|a| a.name.value.clone()),
169            })
170        }
171        sql_ast::TableFactor::Derived {
172            subquery, alias, ..
173        } => {
174            let query = convert_query(subquery)?;
175            let alias_name = alias
176                .as_ref()
177                .map(|a| a.name.value.clone())
178                .ok_or_else(|| QueryError::semantic("Subquery must have an alias"))?;
179            Ok(TableReference::Subquery {
180                query: Box::new(query),
181                alias: alias_name,
182            })
183        }
184        _ => Err(QueryError::unsupported("Unsupported table reference")),
185    }
186}
187
188fn convert_join_type(op: &sql_ast::JoinOperator) -> Result<JoinType> {
189    match op {
190        sql_ast::JoinOperator::Inner(_) => Ok(JoinType::Inner),
191        sql_ast::JoinOperator::LeftOuter(_) => Ok(JoinType::Left),
192        sql_ast::JoinOperator::RightOuter(_) => Ok(JoinType::Right),
193        sql_ast::JoinOperator::FullOuter(_) => Ok(JoinType::Full),
194        sql_ast::JoinOperator::CrossJoin(_) => Ok(JoinType::Cross),
195        _ => Err(QueryError::unsupported("Unsupported join type")),
196    }
197}
198
199fn convert_join_constraint(constraint: &sql_ast::JoinConstraint) -> Result<Option<Expr>> {
200    match constraint {
201        sql_ast::JoinConstraint::On(expr) => Ok(Some(convert_expr(expr)?)),
202        sql_ast::JoinConstraint::Using(_) => {
203            Err(QueryError::unsupported("USING clause not supported"))
204        }
205        sql_ast::JoinConstraint::Natural => {
206            Err(QueryError::unsupported("NATURAL join not supported"))
207        }
208        sql_ast::JoinConstraint::None => Ok(None),
209    }
210}
211
212fn convert_expr(expr: &sql_ast::Expr) -> Result<Expr> {
213    match expr {
214        sql_ast::Expr::Identifier(ident) => Ok(Expr::Column {
215            table: None,
216            name: ident.value.clone(),
217        }),
218        sql_ast::Expr::CompoundIdentifier(parts) => {
219            if parts.len() == 2 {
220                Ok(Expr::Column {
221                    table: Some(parts[0].value.clone()),
222                    name: parts[1].value.clone(),
223                })
224            } else {
225                Err(QueryError::semantic("Invalid column reference"))
226            }
227        }
228        sql_ast::Expr::Value(value_with_span) => {
229            Ok(Expr::Literal(convert_value(&value_with_span.value)?))
230        }
231        sql_ast::Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
232            left: Box::new(convert_expr(left)?),
233            op: convert_binary_op(op)?,
234            right: Box::new(convert_expr(right)?),
235        }),
236        sql_ast::Expr::UnaryOp { op, expr } => Ok(Expr::UnaryOp {
237            op: convert_unary_op(op)?,
238            expr: Box::new(convert_expr(expr)?),
239        }),
240        sql_ast::Expr::Function(func) => {
241            let name = func.name.to_string();
242            let mut args = Vec::new();
243
244            // Handle FunctionArguments enum
245            match &func.args {
246                sql_ast::FunctionArguments::None => {
247                    // No arguments
248                }
249                sql_ast::FunctionArguments::Subquery(_) => {
250                    return Err(QueryError::unsupported(
251                        "Subquery in function arguments not supported",
252                    ));
253                }
254                sql_ast::FunctionArguments::List(arg_list) => {
255                    for arg in &arg_list.args {
256                        match arg {
257                            sql_ast::FunctionArg::Unnamed(sql_ast::FunctionArgExpr::Expr(e)) => {
258                                args.push(convert_expr(e)?);
259                            }
260                            sql_ast::FunctionArg::Unnamed(sql_ast::FunctionArgExpr::Wildcard) => {
261                                // Handle COUNT(*) and similar
262                                args.push(Expr::Wildcard);
263                            }
264                            sql_ast::FunctionArg::Named {
265                                name: _,
266                                arg: sql_ast::FunctionArgExpr::Expr(e),
267                                ..
268                            } => {
269                                args.push(convert_expr(e)?);
270                            }
271                            sql_ast::FunctionArg::Named {
272                                name: _,
273                                arg: sql_ast::FunctionArgExpr::Wildcard,
274                                ..
275                            } => {
276                                args.push(Expr::Wildcard);
277                            }
278                            _ => {
279                                return Err(QueryError::unsupported(
280                                    "Unsupported function argument",
281                                ));
282                            }
283                        }
284                    }
285                }
286            }
287            Ok(Expr::Function { name, args })
288        }
289        sql_ast::Expr::Case {
290            operand,
291            conditions,
292            else_result,
293            ..
294        } => {
295            let operand = operand
296                .as_ref()
297                .map(|e| convert_expr(e))
298                .transpose()?
299                .map(Box::new);
300            let mut when_then = Vec::new();
301            for case_when in conditions.iter() {
302                when_then.push((
303                    convert_expr(&case_when.condition)?,
304                    convert_expr(&case_when.result)?,
305                ));
306            }
307            let else_result = else_result
308                .as_ref()
309                .map(|e| convert_expr(e))
310                .transpose()?
311                .map(Box::new);
312            Ok(Expr::Case {
313                operand,
314                when_then,
315                else_result,
316            })
317        }
318        sql_ast::Expr::Cast {
319            expr, data_type, ..
320        } => Ok(Expr::Cast {
321            expr: Box::new(convert_expr(expr)?),
322            data_type: convert_data_type(data_type)?,
323        }),
324        sql_ast::Expr::IsNull(expr) => Ok(Expr::IsNull(Box::new(convert_expr(expr)?))),
325        sql_ast::Expr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(convert_expr(expr)?))),
326        sql_ast::Expr::InList {
327            expr,
328            list,
329            negated,
330        } => {
331            let expr = Box::new(convert_expr(expr)?);
332            let list = list.iter().map(convert_expr).collect::<Result<Vec<_>>>()?;
333            Ok(Expr::InList {
334                expr,
335                list,
336                negated: *negated,
337            })
338        }
339        sql_ast::Expr::Between {
340            expr,
341            low,
342            high,
343            negated,
344        } => Ok(Expr::Between {
345            expr: Box::new(convert_expr(expr)?),
346            low: Box::new(convert_expr(low)?),
347            high: Box::new(convert_expr(high)?),
348            negated: *negated,
349        }),
350        sql_ast::Expr::Like {
351            negated,
352            expr,
353            pattern,
354            ..
355        } => Ok(Expr::BinaryOp {
356            left: Box::new(convert_expr(expr)?),
357            op: if *negated {
358                BinaryOperator::NotLike
359            } else {
360                BinaryOperator::Like
361            },
362            right: Box::new(convert_expr(pattern)?),
363        }),
364        sql_ast::Expr::ILike {
365            negated,
366            expr,
367            pattern,
368            ..
369        } => {
370            // ILIKE (case-insensitive) is treated as regular LIKE for now
371            Ok(Expr::BinaryOp {
372                left: Box::new(convert_expr(expr)?),
373                op: if *negated {
374                    BinaryOperator::NotLike
375                } else {
376                    BinaryOperator::Like
377                },
378                right: Box::new(convert_expr(pattern)?),
379            })
380        }
381        sql_ast::Expr::Subquery(query) => Ok(Expr::Subquery(Box::new(convert_query(query)?))),
382        sql_ast::Expr::Nested(expr) => convert_expr(expr),
383        sql_ast::Expr::Wildcard(_) => Ok(Expr::Wildcard),
384        _ => Err(QueryError::unsupported(format!(
385            "Unsupported expression: {:?}",
386            expr
387        ))),
388    }
389}
390
391fn convert_value(value: &sql_ast::Value) -> Result<Literal> {
392    match value {
393        sql_ast::Value::Null => Ok(Literal::Null),
394        sql_ast::Value::Boolean(b) => Ok(Literal::Boolean(*b)),
395        sql_ast::Value::Number(n, _) => {
396            if let Ok(i) = n.parse::<i64>() {
397                Ok(Literal::Integer(i))
398            } else if let Ok(f) = n.parse::<f64>() {
399                Ok(Literal::Float(f))
400            } else {
401                Err(QueryError::parse_error("Invalid number", 0, 0))
402            }
403        }
404        sql_ast::Value::SingleQuotedString(s) | sql_ast::Value::DoubleQuotedString(s) => {
405            Ok(Literal::String(s.clone()))
406        }
407        _ => Err(QueryError::unsupported("Unsupported literal value")),
408    }
409}
410
411fn convert_binary_op(op: &sql_ast::BinaryOperator) -> Result<BinaryOperator> {
412    match op {
413        sql_ast::BinaryOperator::Plus => Ok(BinaryOperator::Plus),
414        sql_ast::BinaryOperator::Minus => Ok(BinaryOperator::Minus),
415        sql_ast::BinaryOperator::Multiply => Ok(BinaryOperator::Multiply),
416        sql_ast::BinaryOperator::Divide => Ok(BinaryOperator::Divide),
417        sql_ast::BinaryOperator::Modulo => Ok(BinaryOperator::Modulo),
418        sql_ast::BinaryOperator::Eq => Ok(BinaryOperator::Eq),
419        sql_ast::BinaryOperator::NotEq => Ok(BinaryOperator::NotEq),
420        sql_ast::BinaryOperator::Lt => Ok(BinaryOperator::Lt),
421        sql_ast::BinaryOperator::LtEq => Ok(BinaryOperator::LtEq),
422        sql_ast::BinaryOperator::Gt => Ok(BinaryOperator::Gt),
423        sql_ast::BinaryOperator::GtEq => Ok(BinaryOperator::GtEq),
424        sql_ast::BinaryOperator::And => Ok(BinaryOperator::And),
425        sql_ast::BinaryOperator::Or => Ok(BinaryOperator::Or),
426        sql_ast::BinaryOperator::StringConcat => Ok(BinaryOperator::Concat),
427        // Note: LIKE and NOT LIKE are handled as separate expression types in sqlparser 0.52+
428        _ => Err(QueryError::unsupported("Unsupported binary operator")),
429    }
430}
431
432fn convert_unary_op(op: &sql_ast::UnaryOperator) -> Result<UnaryOperator> {
433    match op {
434        sql_ast::UnaryOperator::Minus => Ok(UnaryOperator::Minus),
435        sql_ast::UnaryOperator::Not => Ok(UnaryOperator::Not),
436        _ => Err(QueryError::unsupported("Unsupported unary operator")),
437    }
438}
439
440fn convert_order_by_expr(order: &sql_ast::OrderByExpr) -> Result<OrderByExpr> {
441    Ok(OrderByExpr {
442        expr: convert_expr(&order.expr)?,
443        asc: order.options.asc.unwrap_or(true),
444        nulls_first: order.options.nulls_first.unwrap_or(false),
445    })
446}
447
448fn convert_limit(limit: &sql_ast::Expr) -> Result<usize> {
449    match limit {
450        sql_ast::Expr::Value(value_with_span) => match &value_with_span.value {
451            sql_ast::Value::Number(n, _) => n
452                .parse::<usize>()
453                .map_err(|_| QueryError::semantic("Invalid LIMIT value")),
454            _ => Err(QueryError::semantic("LIMIT must be a number")),
455        },
456        _ => Err(QueryError::semantic("LIMIT must be a number")),
457    }
458}
459
460fn convert_offset(offset: &sql_ast::Offset) -> Result<usize> {
461    match &offset.value {
462        sql_ast::Expr::Value(value_with_span) => match &value_with_span.value {
463            sql_ast::Value::Number(n, _) => n
464                .parse::<usize>()
465                .map_err(|_| QueryError::semantic("Invalid OFFSET value")),
466            _ => Err(QueryError::semantic("OFFSET must be a number")),
467        },
468        _ => Err(QueryError::semantic("OFFSET must be a number")),
469    }
470}
471
472fn convert_data_type(data_type: &sql_ast::DataType) -> Result<DataType> {
473    match data_type {
474        sql_ast::DataType::Boolean => Ok(DataType::Boolean),
475        sql_ast::DataType::TinyInt(_) => Ok(DataType::Int8),
476        sql_ast::DataType::SmallInt(_) => Ok(DataType::Int16),
477        sql_ast::DataType::Int(_) | sql_ast::DataType::Integer(_) => Ok(DataType::Int32),
478        sql_ast::DataType::BigInt(_) => Ok(DataType::Int64),
479        sql_ast::DataType::TinyIntUnsigned(_) => Ok(DataType::UInt8),
480        sql_ast::DataType::SmallIntUnsigned(_) => Ok(DataType::UInt16),
481        sql_ast::DataType::IntUnsigned(_)
482        | sql_ast::DataType::IntegerUnsigned(_)
483        | sql_ast::DataType::UnsignedInteger => Ok(DataType::UInt32),
484        sql_ast::DataType::BigIntUnsigned(_) => Ok(DataType::UInt64),
485        sql_ast::DataType::Float(_) | sql_ast::DataType::Real => Ok(DataType::Float32),
486        sql_ast::DataType::Double(_) | sql_ast::DataType::DoublePrecision => Ok(DataType::Float64),
487        sql_ast::DataType::Varchar(_)
488        | sql_ast::DataType::Char(_)
489        | sql_ast::DataType::Text
490        | sql_ast::DataType::String(_) => Ok(DataType::String),
491        sql_ast::DataType::Binary(_) | sql_ast::DataType::Varbinary(_) => Ok(DataType::Binary),
492        sql_ast::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
493        sql_ast::DataType::Date => Ok(DataType::Date),
494        sql_ast::DataType::Custom(name, _) if name.to_string().to_uppercase() == "GEOMETRY" => {
495            Ok(DataType::Geometry)
496        }
497        _ => Err(QueryError::unsupported(format!(
498            "Unsupported data type: {:?}",
499            data_type
500        ))),
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_parse_simple_select() {
510        let sql = "SELECT id, name FROM users";
511        let result = parse_sql(sql);
512        assert!(result.is_ok());
513    }
514
515    #[test]
516    fn test_parse_select_with_where() {
517        let sql = "SELECT * FROM users WHERE age > 18";
518        let result = parse_sql(sql);
519        assert!(result.is_ok());
520    }
521
522    #[test]
523    fn test_parse_select_with_join() {
524        let sql = "SELECT u.name, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id";
525        let result = parse_sql(sql);
526        assert!(result.is_ok());
527    }
528
529    #[test]
530    fn test_parse_spatial_function() {
531        let sql = "SELECT ST_Area(geom) FROM buildings";
532        let result = parse_sql(sql);
533        assert!(result.is_ok());
534    }
535}