manifoldb_query/parser/
sql.rs

1//! SQL parser implementation.
2//!
3//! This module provides the core SQL parsing functionality using `sqlparser-rs`
4//! as the foundation, with custom transformations to our AST types.
5
6use sqlparser::ast as sp;
7use sqlparser::dialect::GenericDialect;
8use sqlparser::parser::Parser;
9
10use crate::ast::{
11    Assignment, BinaryOp, CaseExpr, ColumnConstraint, ColumnDef, ConflictAction, ConflictTarget,
12    CreateIndexStatement, CreateTableStatement, DataType, DeleteStatement, DropIndexStatement,
13    DropTableStatement, Expr, FunctionCall, Identifier, IndexColumn, InsertSource, InsertStatement,
14    JoinClause, JoinCondition, JoinType, Literal, OnConflict, OrderByExpr, ParameterRef,
15    QualifiedName, SelectItem, SelectStatement, SetOperation, SetOperator, Statement, TableAlias,
16    TableConstraint, TableRef, UnaryOp, UpdateStatement, WindowFrame, WindowFrameBound,
17    WindowFrameUnits, WindowSpec, WithClause,
18};
19use crate::error::{ParseError, ParseResult};
20
21/// Parses a SQL string into a list of statements.
22///
23/// # Errors
24///
25/// Returns an error if the SQL is syntactically invalid.
26pub fn parse_sql(sql: &str) -> ParseResult<Vec<Statement>> {
27    if sql.trim().is_empty() {
28        return Err(ParseError::EmptyQuery);
29    }
30
31    let dialect = GenericDialect {};
32    let statements = Parser::parse_sql(&dialect, sql)?;
33
34    statements.into_iter().map(convert_statement).collect()
35}
36
37/// Parses a single SQL statement.
38///
39/// # Errors
40///
41/// Returns an error if the SQL is invalid or contains multiple statements.
42pub fn parse_single_statement(sql: &str) -> ParseResult<Statement> {
43    let mut stmts = parse_sql(sql)?;
44    if stmts.len() != 1 {
45        return Err(ParseError::SqlSyntax(format!("expected 1 statement, found {}", stmts.len())));
46    }
47    // SAFETY: We just verified there's exactly one statement
48    Ok(stmts.remove(0))
49}
50
51/// Converts a sqlparser Statement to our Statement.
52fn convert_statement(stmt: sp::Statement) -> ParseResult<Statement> {
53    match stmt {
54        sp::Statement::Query(query) => {
55            let select = convert_query(*query)?;
56            Ok(Statement::Select(Box::new(select)))
57        }
58        sp::Statement::Insert(insert) => {
59            let insert_stmt = convert_insert(insert)?;
60            Ok(Statement::Insert(Box::new(insert_stmt)))
61        }
62        sp::Statement::Update { table, assignments, from, selection, returning } => {
63            let from_vec = from.map(|t| vec![t]);
64            let update_stmt = convert_update(table, assignments, from_vec, selection, returning)?;
65            Ok(Statement::Update(Box::new(update_stmt)))
66        }
67        sp::Statement::Delete(delete) => {
68            let delete_stmt = convert_delete(delete)?;
69            Ok(Statement::Delete(Box::new(delete_stmt)))
70        }
71        sp::Statement::CreateTable(create) => {
72            let create_stmt = convert_create_table(create)?;
73            Ok(Statement::CreateTable(create_stmt))
74        }
75        sp::Statement::CreateIndex(create) => {
76            let create_stmt = convert_create_index(create)?;
77            Ok(Statement::CreateIndex(Box::new(create_stmt)))
78        }
79        sp::Statement::Drop { object_type, if_exists, names, cascade, .. } => match object_type {
80            sp::ObjectType::Table => {
81                let drop_stmt = DropTableStatement {
82                    if_exists,
83                    names: names.into_iter().map(convert_object_name).collect(),
84                    cascade,
85                };
86                Ok(Statement::DropTable(drop_stmt))
87            }
88            sp::ObjectType::Index => {
89                let drop_stmt = DropIndexStatement {
90                    if_exists,
91                    names: names.into_iter().map(convert_object_name).collect(),
92                    cascade,
93                };
94                Ok(Statement::DropIndex(drop_stmt))
95            }
96            _ => Err(ParseError::Unsupported(format!("DROP {object_type:?}"))),
97        },
98        sp::Statement::Explain { statement, .. } => {
99            let inner = convert_statement(*statement)?;
100            Ok(Statement::Explain(Box::new(inner)))
101        }
102        _ => Err(ParseError::Unsupported(format!("statement type: {stmt:?}"))),
103    }
104}
105
106/// Converts a sqlparser Query to our `SelectStatement`.
107fn convert_query(query: sp::Query) -> ParseResult<SelectStatement> {
108    // Handle WITH clause if present
109    let with_clauses =
110        if let Some(with) = query.with { convert_with_clause(with)? } else { vec![] };
111
112    let body = match *query.body {
113        sp::SetExpr::Select(select) => convert_select(*select)?,
114        sp::SetExpr::Query(subquery) => convert_query(*subquery)?,
115        sp::SetExpr::SetOperation { op, set_quantifier, left, right } => {
116            let mut base = match *left {
117                sp::SetExpr::Select(select) => convert_select(*select)?,
118                sp::SetExpr::Query(q) => convert_query(*q)?,
119                _ => return Err(ParseError::Unsupported("nested set operation".to_string())),
120            };
121            let right_stmt = match *right {
122                sp::SetExpr::Select(select) => convert_select(*select)?,
123                sp::SetExpr::Query(q) => convert_query(*q)?,
124                _ => return Err(ParseError::Unsupported("nested set operation".to_string())),
125            };
126            let set_op = SetOperation {
127                op: match op {
128                    sp::SetOperator::Union => SetOperator::Union,
129                    sp::SetOperator::Intersect => SetOperator::Intersect,
130                    sp::SetOperator::Except => SetOperator::Except,
131                },
132                all: matches!(set_quantifier, sp::SetQuantifier::All),
133                right: right_stmt,
134            };
135            base.set_op = Some(Box::new(set_op));
136            base
137        }
138        sp::SetExpr::Values(values) => {
139            // VALUES as a standalone select
140            let rows: Vec<Vec<Expr>> = values
141                .rows
142                .into_iter()
143                .map(|row| row.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>())
144                .collect::<ParseResult<Vec<_>>>()?;
145
146            if rows.is_empty() {
147                return Err(ParseError::SqlSyntax("empty VALUES".to_string()));
148            }
149
150            // Create column aliases (column1, column2, etc.)
151            let num_cols = rows.first().map_or(0, Vec::len);
152            let projection: Vec<SelectItem> = (1..=num_cols)
153                .map(|i| SelectItem::Expr {
154                    expr: Expr::Column(QualifiedName::simple(format!("column{i}"))),
155                    alias: None,
156                })
157                .collect();
158
159            SelectStatement::new(projection)
160        }
161        _ => return Err(ParseError::Unsupported("set expression type".to_string())),
162    };
163
164    // Apply ORDER BY, LIMIT, OFFSET from the outer query
165    let mut result = body;
166
167    if let Some(order_by) = query.order_by {
168        result.order_by = order_by
169            .exprs
170            .into_iter()
171            .map(convert_order_by_expr)
172            .collect::<ParseResult<Vec<_>>>()?;
173    }
174
175    if let Some(limit_expr) = query.limit {
176        result.limit = Some(convert_expr(limit_expr)?);
177    }
178
179    if let Some(offset_expr) = query.offset {
180        result.offset = Some(convert_expr(offset_expr.value)?);
181    }
182
183    // Add WITH clauses to the result
184    result.with_clauses = with_clauses;
185
186    Ok(result)
187}
188
189/// Converts a sqlparser WITH clause to our `WithClause` list.
190fn convert_with_clause(with: sp::With) -> ParseResult<Vec<WithClause>> {
191    // Recursive CTEs are not supported yet
192    if with.recursive {
193        return Err(ParseError::Unsupported("WITH RECURSIVE".to_string()));
194    }
195
196    with.cte_tables
197        .into_iter()
198        .map(|cte| {
199            let name = convert_ident(cte.alias.name);
200            let columns: Vec<Identifier> =
201                cte.alias.columns.into_iter().map(convert_ident).collect();
202            let query = convert_query(*cte.query)?;
203
204            Ok(WithClause { name, columns, query: Box::new(query) })
205        })
206        .collect()
207}
208
209/// Converts a sqlparser Select to our `SelectStatement`.
210fn convert_select(select: sp::Select) -> ParseResult<SelectStatement> {
211    let distinct = match select.distinct {
212        Some(sp::Distinct::Distinct) => true,
213        Some(sp::Distinct::On(_)) => {
214            return Err(ParseError::Unsupported("DISTINCT ON".to_string()))
215        }
216        None => false,
217    };
218
219    let projection =
220        select.projection.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>()?;
221
222    let from =
223        select.from.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>()?;
224
225    let where_clause = select.selection.map(convert_expr).transpose()?;
226
227    let group_by = match select.group_by {
228        sp::GroupByExpr::Expressions(exprs, _) => {
229            exprs.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?
230        }
231        sp::GroupByExpr::All(_) => return Err(ParseError::Unsupported("GROUP BY ALL".to_string())),
232    };
233
234    let having = select.having.map(convert_expr).transpose()?;
235
236    Ok(SelectStatement {
237        with_clauses: vec![], // CTEs are handled at the Query level, not Select level
238        distinct,
239        projection,
240        from,
241        match_clause: None,             // Handled separately by extension parser
242        optional_match_clauses: vec![], // Handled separately by extension parser
243        where_clause,
244        group_by,
245        having,
246        order_by: vec![],
247        limit: None,
248        offset: None,
249        set_op: None,
250    })
251}
252
253/// Converts a sqlparser `SelectItem`.
254fn convert_select_item(item: sp::SelectItem) -> ParseResult<SelectItem> {
255    match item {
256        sp::SelectItem::UnnamedExpr(expr) => {
257            Ok(SelectItem::Expr { expr: convert_expr(expr)?, alias: None })
258        }
259        sp::SelectItem::ExprWithAlias { expr, alias } => {
260            Ok(SelectItem::Expr { expr: convert_expr(expr)?, alias: Some(convert_ident(alias)) })
261        }
262        sp::SelectItem::Wildcard(_) => Ok(SelectItem::Wildcard),
263        sp::SelectItem::QualifiedWildcard(name, _) => {
264            Ok(SelectItem::QualifiedWildcard(convert_object_name(name)))
265        }
266    }
267}
268
269/// Converts a table with joins.
270fn convert_table_with_joins(twj: sp::TableWithJoins) -> ParseResult<TableRef> {
271    let mut result = convert_table_factor(twj.relation)?;
272
273    for join in twj.joins {
274        let right = convert_table_factor(join.relation)?;
275        let join_type = match join.join_operator {
276            sp::JoinOperator::Inner(_) => JoinType::Inner,
277            sp::JoinOperator::LeftOuter(_) => JoinType::LeftOuter,
278            sp::JoinOperator::RightOuter(_) => JoinType::RightOuter,
279            sp::JoinOperator::FullOuter(_) => JoinType::FullOuter,
280            sp::JoinOperator::CrossJoin => JoinType::Cross,
281            sp::JoinOperator::LeftSemi(_) | sp::JoinOperator::RightSemi(_) => {
282                return Err(ParseError::Unsupported("SEMI JOIN".to_string()));
283            }
284            sp::JoinOperator::LeftAnti(_) | sp::JoinOperator::RightAnti(_) => {
285                return Err(ParseError::Unsupported("ANTI JOIN".to_string()));
286            }
287            sp::JoinOperator::AsOf { .. } => {
288                return Err(ParseError::Unsupported("AS OF JOIN".to_string()));
289            }
290            sp::JoinOperator::CrossApply | sp::JoinOperator::OuterApply => {
291                return Err(ParseError::Unsupported("APPLY".to_string()));
292            }
293        };
294
295        let condition = match join.join_operator {
296            sp::JoinOperator::Inner(constraint)
297            | sp::JoinOperator::LeftOuter(constraint)
298            | sp::JoinOperator::RightOuter(constraint)
299            | sp::JoinOperator::FullOuter(constraint) => convert_join_constraint(constraint)?,
300            // All other join types have no condition
301            _ => JoinCondition::None,
302        };
303
304        result = TableRef::Join(Box::new(JoinClause { left: result, right, join_type, condition }));
305    }
306
307    Ok(result)
308}
309
310/// Converts a join constraint.
311fn convert_join_constraint(constraint: sp::JoinConstraint) -> ParseResult<JoinCondition> {
312    match constraint {
313        sp::JoinConstraint::On(expr) => Ok(JoinCondition::On(convert_expr(expr)?)),
314        sp::JoinConstraint::Using(idents) => {
315            Ok(JoinCondition::Using(idents.into_iter().map(convert_ident).collect()))
316        }
317        sp::JoinConstraint::Natural => Ok(JoinCondition::Natural),
318        sp::JoinConstraint::None => Ok(JoinCondition::None),
319    }
320}
321
322/// Converts a table factor.
323fn convert_table_factor(factor: sp::TableFactor) -> ParseResult<TableRef> {
324    match factor {
325        sp::TableFactor::Table { name, alias, .. } => Ok(TableRef::Table {
326            name: convert_object_name(name),
327            alias: alias.map(convert_table_alias),
328        }),
329        sp::TableFactor::Derived { subquery, alias, .. } => {
330            let alias =
331                alias.ok_or_else(|| ParseError::MissingClause("alias for subquery".to_string()))?;
332            Ok(TableRef::Subquery {
333                query: Box::new(convert_query(*subquery)?),
334                alias: convert_table_alias(alias),
335            })
336        }
337        sp::TableFactor::TableFunction { expr, alias } => {
338            // Extract function name and args from the expression
339            if let sp::Expr::Function(func) = expr {
340                Ok(TableRef::TableFunction {
341                    name: convert_object_name(func.name),
342                    args: convert_function_args(func.args)?,
343                    alias: alias.map(convert_table_alias),
344                })
345            } else {
346                Err(ParseError::Unsupported("non-function table function".to_string()))
347            }
348        }
349        sp::TableFactor::NestedJoin { table_with_joins, alias } => {
350            let mut result = convert_table_with_joins(*table_with_joins)?;
351            if let Some(alias) = alias {
352                // Wrap in another table ref with alias if needed
353                match &mut result {
354                    TableRef::Table { alias: ref mut a, .. } => {
355                        *a = Some(convert_table_alias(alias))
356                    }
357                    TableRef::Subquery { alias: ref mut a, .. } => *a = convert_table_alias(alias),
358                    _ => {}
359                }
360            }
361            Ok(result)
362        }
363        _ => Err(ParseError::Unsupported("table factor type".to_string())),
364    }
365}
366
367/// Converts function arguments.
368fn convert_function_args(args: sp::FunctionArguments) -> ParseResult<Vec<Expr>> {
369    match args {
370        sp::FunctionArguments::None => Ok(vec![]),
371        sp::FunctionArguments::Subquery(_) => {
372            Err(ParseError::Unsupported("subquery function argument".to_string()))
373        }
374        sp::FunctionArguments::List(arg_list) => arg_list
375            .args
376            .into_iter()
377            .map(|arg| match arg {
378                sp::FunctionArg::Unnamed(expr) => expr,
379                sp::FunctionArg::Named { arg, .. } => arg,
380            })
381            .map(|arg_expr| match arg_expr {
382                sp::FunctionArgExpr::Expr(e) => convert_expr(e),
383                sp::FunctionArgExpr::QualifiedWildcard(name) => {
384                    Ok(Expr::QualifiedWildcard(convert_object_name(name)))
385                }
386                sp::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard),
387            })
388            .collect::<ParseResult<Vec<_>>>(),
389    }
390}
391
392/// Converts a table alias.
393fn convert_table_alias(alias: sp::TableAlias) -> TableAlias {
394    TableAlias {
395        name: convert_ident(alias.name),
396        columns: alias.columns.into_iter().map(convert_ident).collect(),
397    }
398}
399
400/// Converts an expression.
401#[allow(clippy::too_many_lines)]
402fn convert_expr(expr: sp::Expr) -> ParseResult<Expr> {
403    match expr {
404        sp::Expr::Identifier(ident) => {
405            Ok(Expr::Column(QualifiedName::simple(convert_ident(ident))))
406        }
407        sp::Expr::CompoundIdentifier(idents) => {
408            Ok(Expr::Column(QualifiedName::new(idents.into_iter().map(convert_ident).collect())))
409        }
410        sp::Expr::Value(value) => convert_value(value),
411        sp::Expr::BinaryOp { left, op, right } => {
412            let left = convert_expr(*left)?;
413            let right = convert_expr(*right)?;
414            let op = convert_binary_op(&op)?;
415            Ok(Expr::BinaryOp { left: Box::new(left), op, right: Box::new(right) })
416        }
417        sp::Expr::UnaryOp { op, expr } => {
418            let operand = convert_expr(*expr)?;
419            let op = convert_unary_op(op)?;
420            Ok(Expr::UnaryOp { op, operand: Box::new(operand) })
421        }
422        sp::Expr::Nested(inner) => convert_expr(*inner),
423        sp::Expr::Function(func) => convert_function(func),
424        sp::Expr::Cast { expr, data_type, .. } => Ok(Expr::Cast {
425            expr: Box::new(convert_expr(*expr)?),
426            data_type: format_data_type(&data_type),
427        }),
428        sp::Expr::Case { operand, conditions, results, else_result } => {
429            let when_clauses: Vec<(Expr, Expr)> = conditions
430                .into_iter()
431                .zip(results)
432                .map(|(cond, result)| Ok((convert_expr(cond)?, convert_expr(result)?)))
433                .collect::<ParseResult<Vec<_>>>()?;
434
435            Ok(Expr::Case(CaseExpr {
436                operand: operand.map(|e| convert_expr(*e)).transpose()?.map(Box::new),
437                when_clauses,
438                else_result: else_result.map(|e| convert_expr(*e)).transpose()?.map(Box::new),
439            }))
440        }
441        sp::Expr::Subquery(query) => Ok(Expr::Subquery(crate::ast::expr::Subquery {
442            query: Box::new(convert_query(*query)?),
443        })),
444        sp::Expr::Exists { subquery, .. } => Ok(Expr::Exists(crate::ast::expr::Subquery {
445            query: Box::new(convert_query(*subquery)?),
446        })),
447        sp::Expr::InList { expr, list, negated } => Ok(Expr::InList {
448            expr: Box::new(convert_expr(*expr)?),
449            list: list.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?,
450            negated,
451        }),
452        sp::Expr::InSubquery { expr, subquery, negated } => Ok(Expr::InSubquery {
453            expr: Box::new(convert_expr(*expr)?),
454            subquery: crate::ast::expr::Subquery { query: Box::new(convert_query(*subquery)?) },
455            negated,
456        }),
457        sp::Expr::Between { expr, low, high, negated } => Ok(Expr::Between {
458            expr: Box::new(convert_expr(*expr)?),
459            low: Box::new(convert_expr(*low)?),
460            high: Box::new(convert_expr(*high)?),
461            negated,
462        }),
463        sp::Expr::IsNull(expr) => {
464            Ok(Expr::UnaryOp { op: UnaryOp::IsNull, operand: Box::new(convert_expr(*expr)?) })
465        }
466        sp::Expr::IsNotNull(expr) => {
467            Ok(Expr::UnaryOp { op: UnaryOp::IsNotNull, operand: Box::new(convert_expr(*expr)?) })
468        }
469        sp::Expr::Tuple(exprs) => {
470            Ok(Expr::Tuple(exprs.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?))
471        }
472        sp::Expr::Array(arr) => {
473            let sp::Array { elem, .. } = arr;
474            // Try to convert to a vector or multi-vector literal
475            convert_array_expr(elem)
476        }
477        sp::Expr::Subscript { expr, subscript } => match *subscript {
478            sp::Subscript::Index { index } => Ok(Expr::ArrayIndex {
479                array: Box::new(convert_expr(*expr)?),
480                index: Box::new(convert_expr(index)?),
481            }),
482            sp::Subscript::Slice { .. } => {
483                Err(ParseError::Unsupported("subscript slice".to_string()))
484            }
485        },
486        sp::Expr::Like { negated, expr, pattern, escape_char: _, any: _ } => Ok(Expr::BinaryOp {
487            left: Box::new(convert_expr(*expr)?),
488            op: if negated { BinaryOp::NotLike } else { BinaryOp::Like },
489            right: Box::new(convert_expr(*pattern)?),
490        }),
491        sp::Expr::ILike { negated, expr, pattern, escape_char: _, any: _ } => Ok(Expr::BinaryOp {
492            left: Box::new(convert_expr(*expr)?),
493            op: if negated { BinaryOp::NotILike } else { BinaryOp::ILike },
494            right: Box::new(convert_expr(*pattern)?),
495        }),
496        sp::Expr::Named { name, .. } => {
497            // Named parameter like $name
498            Ok(Expr::Parameter(ParameterRef::Named(name.value)))
499        }
500        // Handle placeholder for positional parameters
501        _ => Err(ParseError::Unsupported(format!("expression type: {expr:?}"))),
502    }
503}
504
505/// Converts an array expression, detecting vector and multi-vector literals.
506///
507/// This function analyzes the array elements to determine if they form:
508/// - A vector literal: `[0.1, 0.2, 0.3]` -> `Literal::Vector`
509/// - A multi-vector literal: `[[0.1, 0.2], [0.3, 0.4]]` -> `Literal::MultiVector`
510/// - A general tuple/array for other cases
511fn convert_array_expr(elements: Vec<sp::Expr>) -> ParseResult<Expr> {
512    // Check if all elements are numeric literals (for vector)
513    let all_numeric =
514        elements.iter().all(|e| matches!(e, sp::Expr::Value(v) if is_numeric_value(v)));
515
516    if all_numeric && !elements.is_empty() {
517        // Convert to a vector literal
518        let values: Vec<f32> = elements
519            .iter()
520            .map(|e| {
521                if let sp::Expr::Value(v) = e {
522                    value_to_f32(v)
523                } else {
524                    Err(ParseError::InvalidLiteral("expected numeric value".to_string()))
525                }
526            })
527            .collect::<ParseResult<Vec<_>>>()?;
528        return Ok(Expr::Literal(Literal::Vector(values)));
529    }
530
531    // Check if all elements are arrays of numeric literals (for multi-vector)
532    let all_arrays = elements.iter().all(|e| {
533        matches!(e, sp::Expr::Array(arr) if arr.elem.iter().all(|inner| matches!(inner, sp::Expr::Value(v) if is_numeric_value(v))))
534    });
535
536    if all_arrays && !elements.is_empty() {
537        // Convert to a multi-vector literal
538        let vectors: Vec<Vec<f32>> = elements
539            .iter()
540            .map(|e| {
541                if let sp::Expr::Array(arr) = e {
542                    arr.elem
543                        .iter()
544                        .map(|inner| {
545                            if let sp::Expr::Value(v) = inner {
546                                value_to_f32(v)
547                            } else {
548                                Err(ParseError::InvalidLiteral(
549                                    "expected numeric value in nested array".to_string(),
550                                ))
551                            }
552                        })
553                        .collect::<ParseResult<Vec<_>>>()
554                } else {
555                    Err(ParseError::InvalidLiteral("expected array in multi-vector".to_string()))
556                }
557            })
558            .collect::<ParseResult<Vec<_>>>()?;
559        return Ok(Expr::Literal(Literal::MultiVector(vectors)));
560    }
561
562    // Fall back to Tuple for other cases
563    let converted = elements.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?;
564    Ok(Expr::Tuple(converted))
565}
566
567/// Checks if a sqlparser Value is a numeric literal.
568fn is_numeric_value(value: &sp::Value) -> bool {
569    matches!(value, sp::Value::Number(_, _))
570}
571
572/// Converts a sqlparser Value to f32.
573fn value_to_f32(value: &sp::Value) -> ParseResult<f32> {
574    match value {
575        sp::Value::Number(n, _) => {
576            n.parse::<f32>().map_err(|_| ParseError::InvalidLiteral(format!("invalid f32: {n}")))
577        }
578        _ => Err(ParseError::InvalidLiteral("expected numeric value".to_string())),
579    }
580}
581
582/// Converts a sqlparser Value to our Expr.
583fn convert_value(value: sp::Value) -> ParseResult<Expr> {
584    match value {
585        sp::Value::Null => Ok(Expr::Literal(Literal::Null)),
586        sp::Value::Boolean(b) => Ok(Expr::Literal(Literal::Boolean(b))),
587        sp::Value::Number(n, _) => {
588            // Try to parse as integer first, then float
589            if let Ok(i) = n.parse::<i64>() {
590                Ok(Expr::Literal(Literal::Integer(i)))
591            } else if let Ok(f) = n.parse::<f64>() {
592                Ok(Expr::Literal(Literal::Float(f)))
593            } else {
594                Err(ParseError::InvalidLiteral(format!("invalid number: {n}")))
595            }
596        }
597        sp::Value::SingleQuotedString(s) | sp::Value::DoubleQuotedString(s) => {
598            Ok(Expr::Literal(Literal::String(s)))
599        }
600        sp::Value::Placeholder(p) => {
601            if p == "?" {
602                Ok(Expr::Parameter(ParameterRef::Anonymous))
603            } else if let Some(n) = p.strip_prefix('$') {
604                if let Ok(pos) = n.parse::<u32>() {
605                    Ok(Expr::Parameter(ParameterRef::Positional(pos)))
606                } else {
607                    Ok(Expr::Parameter(ParameterRef::Named(n.to_string())))
608                }
609            } else {
610                Err(ParseError::InvalidLiteral(format!("unknown placeholder: {p}")))
611            }
612        }
613        _ => Err(ParseError::Unsupported(format!("value type: {value:?}"))),
614    }
615}
616
617/// Converts a binary operator.
618fn convert_binary_op(op: &sp::BinaryOperator) -> ParseResult<BinaryOp> {
619    match op {
620        sp::BinaryOperator::Plus => Ok(BinaryOp::Add),
621        sp::BinaryOperator::Minus => Ok(BinaryOp::Sub),
622        sp::BinaryOperator::Multiply => Ok(BinaryOp::Mul),
623        sp::BinaryOperator::Divide => Ok(BinaryOp::Div),
624        sp::BinaryOperator::Modulo => Ok(BinaryOp::Mod),
625        sp::BinaryOperator::Eq => Ok(BinaryOp::Eq),
626        sp::BinaryOperator::NotEq => Ok(BinaryOp::NotEq),
627        sp::BinaryOperator::Lt => Ok(BinaryOp::Lt),
628        sp::BinaryOperator::LtEq => Ok(BinaryOp::LtEq),
629        sp::BinaryOperator::Gt => Ok(BinaryOp::Gt),
630        sp::BinaryOperator::GtEq => Ok(BinaryOp::GtEq),
631        sp::BinaryOperator::And => Ok(BinaryOp::And),
632        sp::BinaryOperator::Or => Ok(BinaryOp::Or),
633        // Custom operators for vector operations (will be handled by extension parser)
634        sp::BinaryOperator::Arrow => Err(ParseError::Unsupported("-> operator".to_string())),
635        sp::BinaryOperator::LongArrow => Err(ParseError::Unsupported("->> operator".to_string())),
636        sp::BinaryOperator::HashArrow => Err(ParseError::Unsupported("#> operator".to_string())),
637        sp::BinaryOperator::HashLongArrow => {
638            Err(ParseError::Unsupported("#>> operator".to_string()))
639        }
640        _ => Err(ParseError::Unsupported(format!("binary operator: {op:?}"))),
641    }
642}
643
644/// Converts a unary operator.
645fn convert_unary_op(op: sp::UnaryOperator) -> ParseResult<UnaryOp> {
646    match op {
647        sp::UnaryOperator::Not => Ok(UnaryOp::Not),
648        // Unary plus is treated as a no-op (identity), but we convert it to Neg
649        // with a special case since there's no identity op - the caller should
650        // handle this by not wrapping in UnaryOp at all for plus
651        sp::UnaryOperator::Plus | sp::UnaryOperator::Minus => Ok(UnaryOp::Neg),
652        _ => Err(ParseError::Unsupported(format!("unary operator: {op:?}"))),
653    }
654}
655
656/// Converts a function call.
657fn convert_function(func: sp::Function) -> ParseResult<Expr> {
658    let name = convert_object_name(func.name);
659    let args = convert_function_args(func.args)?;
660
661    let filter = func.filter.map(|f| convert_expr(*f)).transpose()?.map(Box::new);
662
663    let over = func.over.map(convert_window_spec).transpose()?;
664
665    Ok(Expr::Function(FunctionCall {
666        name,
667        args,
668        distinct: false, // sqlparser 0.52 handles this differently
669        filter,
670        over,
671    }))
672}
673
674/// Converts a window specification.
675fn convert_window_spec(spec: sp::WindowType) -> ParseResult<WindowSpec> {
676    match spec {
677        sp::WindowType::WindowSpec(spec) => {
678            let partition_by =
679                spec.partition_by.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?;
680
681            let order_by = spec
682                .order_by
683                .into_iter()
684                .map(convert_order_by_expr)
685                .collect::<ParseResult<Vec<_>>>()?;
686
687            let frame = spec.window_frame.map(convert_window_frame).transpose()?;
688
689            Ok(WindowSpec { partition_by, order_by, frame })
690        }
691        sp::WindowType::NamedWindow(_) => {
692            Err(ParseError::Unsupported("named window reference".to_string()))
693        }
694    }
695}
696
697/// Converts a window frame.
698fn convert_window_frame(frame: sp::WindowFrame) -> ParseResult<WindowFrame> {
699    let units = match frame.units {
700        sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
701        sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
702        sp::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
703    };
704
705    let start = convert_window_frame_bound(frame.start_bound)?;
706    let end = frame.end_bound.map(convert_window_frame_bound).transpose()?;
707
708    Ok(WindowFrame { units, start, end })
709}
710
711/// Converts a window frame bound.
712fn convert_window_frame_bound(bound: sp::WindowFrameBound) -> ParseResult<WindowFrameBound> {
713    match bound {
714        sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
715        sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
716        sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
717        sp::WindowFrameBound::Preceding(Some(expr)) => {
718            Ok(WindowFrameBound::Preceding(Box::new(convert_expr(*expr)?)))
719        }
720        sp::WindowFrameBound::Following(Some(expr)) => {
721            Ok(WindowFrameBound::Following(Box::new(convert_expr(*expr)?)))
722        }
723    }
724}
725
726/// Converts an ORDER BY expression.
727fn convert_order_by_expr(expr: sp::OrderByExpr) -> ParseResult<OrderByExpr> {
728    let asc = expr.asc.unwrap_or(true); // Default to ASC
729
730    Ok(OrderByExpr { expr: Box::new(convert_expr(expr.expr)?), asc, nulls_first: expr.nulls_first })
731}
732
733/// Converts an INSERT statement.
734fn convert_insert(insert: sp::Insert) -> ParseResult<InsertStatement> {
735    // Extract table name
736    let table = convert_object_name(insert.table_name);
737
738    let columns: Vec<Identifier> = insert.columns.into_iter().map(convert_ident).collect();
739
740    let source = match insert.source {
741        Some(source) => match *source.body {
742            sp::SetExpr::Values(values) => {
743                let rows: Vec<Vec<Expr>> = values
744                    .rows
745                    .into_iter()
746                    .map(|row| row.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>())
747                    .collect::<ParseResult<Vec<_>>>()?;
748                InsertSource::Values(rows)
749            }
750            sp::SetExpr::Select(select) => {
751                let query = convert_select(*select)?;
752                InsertSource::Query(Box::new(query))
753            }
754            _ => return Err(ParseError::Unsupported("INSERT source type".to_string())),
755        },
756        None => InsertSource::DefaultValues,
757    };
758
759    let on_conflict = insert.on.map(convert_on_conflict).transpose()?;
760
761    let returning = insert
762        .returning
763        .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
764        .transpose()?
765        .unwrap_or_default();
766
767    Ok(InsertStatement { table, columns, source, on_conflict, returning })
768}
769
770/// Converts ON CONFLICT clause.
771fn convert_on_conflict(on: sp::OnInsert) -> ParseResult<OnConflict> {
772    match on {
773        sp::OnInsert::DuplicateKeyUpdate(assignments) => {
774            Ok(OnConflict {
775                target: ConflictTarget::Columns(vec![]), // MySQL doesn't specify columns
776                action: ConflictAction::DoUpdate {
777                    assignments: assignments
778                        .into_iter()
779                        .map(convert_assignment)
780                        .collect::<ParseResult<Vec<_>>>()?,
781                    where_clause: None,
782                },
783            })
784        }
785        sp::OnInsert::OnConflict(conflict) => {
786            let target = match conflict.conflict_target {
787                Some(sp::ConflictTarget::Columns(cols)) => {
788                    ConflictTarget::Columns(cols.into_iter().map(convert_ident).collect())
789                }
790                Some(sp::ConflictTarget::OnConstraint(name)) => {
791                    let converted = convert_object_name(name);
792                    let ident = converted.parts.into_iter().next().ok_or_else(|| {
793                        ParseError::MissingClause("constraint name in ON CONFLICT".to_string())
794                    })?;
795                    ConflictTarget::Constraint(ident)
796                }
797                None => ConflictTarget::Columns(vec![]),
798            };
799
800            let action = match conflict.action {
801                sp::OnConflictAction::DoNothing => ConflictAction::DoNothing,
802                sp::OnConflictAction::DoUpdate(update) => ConflictAction::DoUpdate {
803                    assignments: update
804                        .assignments
805                        .into_iter()
806                        .map(convert_assignment)
807                        .collect::<ParseResult<Vec<_>>>()?,
808                    where_clause: update.selection.map(convert_expr).transpose()?,
809                },
810            };
811
812            Ok(OnConflict { target, action })
813        }
814        _ => Err(ParseError::Unsupported("ON INSERT type".to_string())),
815    }
816}
817
818/// Converts an assignment (for UPDATE or ON CONFLICT).
819fn convert_assignment(assign: sp::Assignment) -> ParseResult<Assignment> {
820    // Convert assignment target to column name
821    let column = match assign.target {
822        sp::AssignmentTarget::ColumnName(names) => names
823            .0
824            .into_iter()
825            .next()
826            .map(convert_ident)
827            .ok_or_else(|| ParseError::MissingClause("assignment target".to_string()))?,
828        sp::AssignmentTarget::Tuple(_) => {
829            return Err(ParseError::Unsupported("tuple assignment target".to_string()));
830        }
831    };
832
833    Ok(Assignment { column, value: convert_expr(assign.value)? })
834}
835
836/// Converts an UPDATE statement.
837fn convert_update(
838    table: sp::TableWithJoins,
839    assignments: Vec<sp::Assignment>,
840    from: Option<Vec<sp::TableWithJoins>>,
841    selection: Option<sp::Expr>,
842    returning: Option<Vec<sp::SelectItem>>,
843) -> ParseResult<UpdateStatement> {
844    let table_ref = convert_table_with_joins(table)?;
845    let TableRef::Table { name: table_name, alias } = table_ref else {
846        return Err(ParseError::Unsupported("complex UPDATE target".to_string()));
847    };
848
849    let assignments =
850        assignments.into_iter().map(convert_assignment).collect::<ParseResult<Vec<_>>>()?;
851
852    let from_clause = from
853        .map(|f| f.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>())
854        .transpose()?
855        .unwrap_or_default();
856
857    let where_clause = selection.map(convert_expr).transpose()?;
858
859    let returning = returning
860        .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
861        .transpose()?
862        .unwrap_or_default();
863
864    Ok(UpdateStatement {
865        table: table_name,
866        alias,
867        assignments,
868        from: from_clause,
869        match_clause: None,
870        where_clause,
871        returning,
872    })
873}
874
875/// Converts a DELETE statement.
876fn convert_delete(delete: sp::Delete) -> ParseResult<DeleteStatement> {
877    let from_table = match delete.from {
878        sp::FromTable::WithFromKeyword(tables) => tables
879            .into_iter()
880            .next()
881            .ok_or_else(|| ParseError::MissingClause("FROM".to_string()))?,
882        sp::FromTable::WithoutKeyword(tables) => tables
883            .into_iter()
884            .next()
885            .ok_or_else(|| ParseError::MissingClause("table".to_string()))?,
886    };
887
888    let table_ref = convert_table_with_joins(from_table)?;
889    let TableRef::Table { name: table_name, alias } = table_ref else {
890        return Err(ParseError::Unsupported("complex DELETE target".to_string()));
891    };
892
893    let using = delete
894        .using
895        .map(|u| u.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>())
896        .transpose()?
897        .unwrap_or_default();
898
899    let where_clause = delete.selection.map(convert_expr).transpose()?;
900
901    let returning = delete
902        .returning
903        .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
904        .transpose()?
905        .unwrap_or_default();
906
907    Ok(DeleteStatement {
908        table: table_name,
909        alias,
910        using,
911        match_clause: None,
912        where_clause,
913        returning,
914    })
915}
916
917/// Converts a CREATE TABLE statement.
918fn convert_create_table(create: sp::CreateTable) -> ParseResult<CreateTableStatement> {
919    let columns =
920        create.columns.into_iter().map(convert_column_def).collect::<ParseResult<Vec<_>>>()?;
921
922    let constraints = create
923        .constraints
924        .into_iter()
925        .map(convert_table_constraint)
926        .collect::<ParseResult<Vec<_>>>()?;
927
928    Ok(CreateTableStatement {
929        if_not_exists: create.if_not_exists,
930        name: convert_object_name(create.name),
931        columns,
932        constraints,
933    })
934}
935
936/// Converts a column definition.
937fn convert_column_def(col: sp::ColumnDef) -> ParseResult<ColumnDef> {
938    let constraints =
939        col.options.into_iter().filter_map(|opt| convert_column_option(opt.option).ok()).collect();
940
941    Ok(ColumnDef {
942        name: convert_ident(col.name),
943        data_type: convert_data_type(col.data_type)?,
944        constraints,
945    })
946}
947
948/// Converts a column option to a constraint.
949fn convert_column_option(opt: sp::ColumnOption) -> ParseResult<ColumnConstraint> {
950    match opt {
951        sp::ColumnOption::Null => Ok(ColumnConstraint::Null),
952        sp::ColumnOption::NotNull => Ok(ColumnConstraint::NotNull),
953        sp::ColumnOption::Unique { is_primary, .. } => {
954            if is_primary {
955                Ok(ColumnConstraint::PrimaryKey)
956            } else {
957                Ok(ColumnConstraint::Unique)
958            }
959        }
960        sp::ColumnOption::ForeignKey { foreign_table, referred_columns, .. } => {
961            Ok(ColumnConstraint::References {
962                table: convert_object_name(foreign_table),
963                column: referred_columns.into_iter().next().map(convert_ident),
964            })
965        }
966        sp::ColumnOption::Check(expr) => Ok(ColumnConstraint::Check(convert_expr(expr)?)),
967        sp::ColumnOption::Default(expr) => Ok(ColumnConstraint::Default(convert_expr(expr)?)),
968        _ => Err(ParseError::Unsupported("column option".to_string())),
969    }
970}
971
972/// Converts a table constraint.
973fn convert_table_constraint(constraint: sp::TableConstraint) -> ParseResult<TableConstraint> {
974    match constraint {
975        sp::TableConstraint::PrimaryKey { columns, name, .. } => Ok(TableConstraint::PrimaryKey {
976            name: name.map(convert_ident),
977            columns: columns.into_iter().map(convert_ident).collect(),
978        }),
979        sp::TableConstraint::Unique { columns, name, .. } => Ok(TableConstraint::Unique {
980            name: name.map(convert_ident),
981            columns: columns.into_iter().map(convert_ident).collect(),
982        }),
983        sp::TableConstraint::ForeignKey {
984            columns, foreign_table, referred_columns, name, ..
985        } => Ok(TableConstraint::ForeignKey {
986            name: name.map(convert_ident),
987            columns: columns.into_iter().map(convert_ident).collect(),
988            references_table: convert_object_name(foreign_table),
989            references_columns: referred_columns.into_iter().map(convert_ident).collect(),
990        }),
991        sp::TableConstraint::Check { name, expr } => {
992            Ok(TableConstraint::Check { name: name.map(convert_ident), expr: convert_expr(*expr)? })
993        }
994        _ => Err(ParseError::Unsupported("table constraint".to_string())),
995    }
996}
997
998/// Converts a CREATE INDEX statement.
999fn convert_create_index(create: sp::CreateIndex) -> ParseResult<CreateIndexStatement> {
1000    let name = create
1001        .name
1002        .map(convert_object_name)
1003        .and_then(|n| n.parts.into_iter().next())
1004        .ok_or_else(|| ParseError::MissingClause("index name".to_string()))?;
1005
1006    let table = convert_object_name(create.table_name);
1007
1008    let columns = create
1009        .columns
1010        .into_iter()
1011        .map(|col| {
1012            Ok(IndexColumn {
1013                expr: convert_expr(col.expr)?,
1014                asc: col.asc,
1015                nulls_first: col.nulls_first,
1016                opclass: None,
1017            })
1018        })
1019        .collect::<ParseResult<Vec<_>>>()?;
1020
1021    Ok(CreateIndexStatement {
1022        unique: create.unique,
1023        if_not_exists: create.if_not_exists,
1024        name,
1025        table,
1026        columns,
1027        using: create.using.map(convert_ident).map(|i| i.name),
1028        with: vec![],
1029        where_clause: create.predicate.map(convert_expr).transpose()?,
1030    })
1031}
1032
1033/// Converts a data type.
1034#[allow(clippy::cast_possible_truncation)]
1035fn convert_data_type(dt: sp::DataType) -> ParseResult<DataType> {
1036    match dt {
1037        sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1038        sp::DataType::SmallInt(_) | sp::DataType::Int2(_) => Ok(DataType::SmallInt),
1039        sp::DataType::Int(_) | sp::DataType::Integer(_) | sp::DataType::Int4(_) => {
1040            Ok(DataType::Integer)
1041        }
1042        sp::DataType::BigInt(_) | sp::DataType::Int8(_) => Ok(DataType::BigInt),
1043        sp::DataType::Real | sp::DataType::Float4 => Ok(DataType::Real),
1044        sp::DataType::DoublePrecision | sp::DataType::Double | sp::DataType::Float8 => {
1045            Ok(DataType::DoublePrecision)
1046        }
1047        sp::DataType::Numeric(info) | sp::DataType::Decimal(info) => {
1048            let (precision, scale) = match info {
1049                sp::ExactNumberInfo::None => (None, None),
1050                sp::ExactNumberInfo::Precision(p) => (Some(p as u32), None),
1051                sp::ExactNumberInfo::PrecisionAndScale(p, s) => (Some(p as u32), Some(s as u32)),
1052            };
1053            Ok(DataType::Numeric { precision, scale })
1054        }
1055        sp::DataType::Varchar(len) | sp::DataType::CharVarying(len) => {
1056            let len_val = len.and_then(|l| match l {
1057                sp::CharacterLength::IntegerLength { length, .. } => Some(length as u32),
1058                sp::CharacterLength::Max => None,
1059            });
1060            Ok(DataType::Varchar(len_val))
1061        }
1062        sp::DataType::Text => Ok(DataType::Text),
1063        sp::DataType::Bytea => Ok(DataType::Bytea),
1064        sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
1065        sp::DataType::Date => Ok(DataType::Date),
1066        sp::DataType::Time(_, _) => Ok(DataType::Time),
1067        sp::DataType::Interval => Ok(DataType::Interval),
1068        sp::DataType::JSON => Ok(DataType::Json),
1069        sp::DataType::Uuid => Ok(DataType::Uuid),
1070        sp::DataType::Array(elem) => match elem {
1071            sp::ArrayElemTypeDef::AngleBracket(inner)
1072            | sp::ArrayElemTypeDef::SquareBracket(inner, _) => {
1073                Ok(DataType::Array(Box::new(convert_data_type(*inner)?)))
1074            }
1075            sp::ArrayElemTypeDef::None => Err(ParseError::Unsupported("untyped array".to_string())),
1076            sp::ArrayElemTypeDef::Parenthesis(_) => {
1077                Err(ParseError::Unsupported("parenthesized array type".to_string()))
1078            }
1079        },
1080        sp::DataType::Custom(name, _) => {
1081            let name_str = name.0.iter().map(|p| p.value.clone()).collect::<Vec<_>>().join(".");
1082
1083            // Check for VECTOR type
1084            if name_str.eq_ignore_ascii_case("vector") {
1085                Ok(DataType::Vector(None))
1086            } else {
1087                Ok(DataType::Custom(name_str))
1088            }
1089        }
1090        _ => Err(ParseError::Unsupported(format!("data type: {dt:?}"))),
1091    }
1092}
1093
1094/// Formats a data type as a string.
1095fn format_data_type(dt: &sp::DataType) -> String {
1096    format!("{dt}")
1097}
1098
1099/// Converts an object name.
1100fn convert_object_name(name: sp::ObjectName) -> QualifiedName {
1101    QualifiedName::new(name.0.into_iter().map(convert_ident).collect())
1102}
1103
1104/// Converts an identifier.
1105fn convert_ident(ident: sp::Ident) -> Identifier {
1106    Identifier { name: ident.value, quote_style: ident.quote_style }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111    use super::*;
1112
1113    #[test]
1114    fn parse_simple_select() {
1115        let stmt = parse_single_statement("SELECT * FROM users").unwrap();
1116        match stmt {
1117            Statement::Select(select) => {
1118                assert_eq!(select.projection.len(), 1);
1119                assert!(matches!(select.projection[0], SelectItem::Wildcard));
1120            }
1121            _ => panic!("expected SELECT"),
1122        }
1123    }
1124
1125    #[test]
1126    fn parse_select_with_where() {
1127        let stmt = parse_single_statement("SELECT id, name FROM users WHERE id = 1").unwrap();
1128        match stmt {
1129            Statement::Select(select) => {
1130                assert_eq!(select.projection.len(), 2);
1131                assert!(select.where_clause.is_some());
1132            }
1133            _ => panic!("expected SELECT"),
1134        }
1135    }
1136
1137    #[test]
1138    fn parse_insert() {
1139        let stmt =
1140            parse_single_statement("INSERT INTO users (name, age) VALUES ('Alice', 30)").unwrap();
1141        match stmt {
1142            Statement::Insert(insert) => {
1143                assert_eq!(insert.columns.len(), 2);
1144                match &insert.source {
1145                    InsertSource::Values(rows) => {
1146                        assert_eq!(rows.len(), 1);
1147                        assert_eq!(rows[0].len(), 2);
1148                    }
1149                    _ => panic!("expected VALUES"),
1150                }
1151            }
1152            _ => panic!("expected INSERT"),
1153        }
1154    }
1155
1156    #[test]
1157    fn parse_update() {
1158        let stmt = parse_single_statement("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1159        match stmt {
1160            Statement::Update(update) => {
1161                assert_eq!(update.assignments.len(), 1);
1162                assert!(update.where_clause.is_some());
1163            }
1164            _ => panic!("expected UPDATE"),
1165        }
1166    }
1167
1168    #[test]
1169    fn parse_delete() {
1170        let stmt = parse_single_statement("DELETE FROM users WHERE id = 1").unwrap();
1171        match stmt {
1172            Statement::Delete(delete) => {
1173                assert!(delete.where_clause.is_some());
1174            }
1175            _ => panic!("expected DELETE"),
1176        }
1177    }
1178
1179    #[test]
1180    fn parse_create_table() {
1181        let stmt = parse_single_statement(
1182            "CREATE TABLE users (id BIGINT PRIMARY KEY, name VARCHAR(100) NOT NULL)",
1183        )
1184        .unwrap();
1185        match stmt {
1186            Statement::CreateTable(create) => {
1187                assert_eq!(create.columns.len(), 2);
1188            }
1189            _ => panic!("expected CREATE TABLE"),
1190        }
1191    }
1192
1193    #[test]
1194    fn parse_join() {
1195        let stmt = parse_single_statement(
1196            "SELECT u.name, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id",
1197        )
1198        .unwrap();
1199        match stmt {
1200            Statement::Select(select) => {
1201                assert_eq!(select.from.len(), 1);
1202                match &select.from[0] {
1203                    TableRef::Join(join) => {
1204                        assert_eq!(join.join_type, JoinType::Inner);
1205                    }
1206                    _ => panic!("expected JOIN"),
1207                }
1208            }
1209            _ => panic!("expected SELECT"),
1210        }
1211    }
1212
1213    #[test]
1214    fn parse_empty_query() {
1215        let result = parse_sql("");
1216        assert!(matches!(result, Err(ParseError::EmptyQuery)));
1217    }
1218
1219    #[test]
1220    fn parse_parameter() {
1221        let stmt = parse_single_statement("SELECT * FROM users WHERE id = $1").unwrap();
1222        match stmt {
1223            Statement::Select(select) => {
1224                if let Some(Expr::BinaryOp { right, .. }) = select.where_clause {
1225                    match *right {
1226                        Expr::Parameter(ParameterRef::Positional(1)) => {}
1227                        _ => panic!("expected positional parameter"),
1228                    }
1229                }
1230            }
1231            _ => panic!("expected SELECT"),
1232        }
1233    }
1234
1235    #[test]
1236    fn parse_vector_literal() {
1237        let stmt = parse_single_statement("SELECT [0.1, 0.2, 0.3]").unwrap();
1238        match stmt {
1239            Statement::Select(select) => {
1240                assert_eq!(select.projection.len(), 1);
1241                if let SelectItem::Expr { expr, .. } = &select.projection[0] {
1242                    match expr {
1243                        Expr::Literal(Literal::Vector(v)) => {
1244                            assert_eq!(v.len(), 3);
1245                            assert!((v[0] - 0.1).abs() < 0.001);
1246                            assert!((v[1] - 0.2).abs() < 0.001);
1247                            assert!((v[2] - 0.3).abs() < 0.001);
1248                        }
1249                        _ => panic!("expected Vector literal, got {:?}", expr),
1250                    }
1251                } else {
1252                    panic!("expected expression in projection");
1253                }
1254            }
1255            _ => panic!("expected SELECT"),
1256        }
1257    }
1258
1259    #[test]
1260    fn parse_multi_vector_literal() {
1261        let stmt = parse_single_statement("SELECT [[0.1, 0.2], [0.3, 0.4]]").unwrap();
1262        match stmt {
1263            Statement::Select(select) => {
1264                assert_eq!(select.projection.len(), 1);
1265                if let SelectItem::Expr { expr, .. } = &select.projection[0] {
1266                    match expr {
1267                        Expr::Literal(Literal::MultiVector(v)) => {
1268                            assert_eq!(v.len(), 2);
1269                            assert_eq!(v[0].len(), 2);
1270                            assert_eq!(v[1].len(), 2);
1271                            assert!((v[0][0] - 0.1).abs() < 0.001);
1272                            assert!((v[0][1] - 0.2).abs() < 0.001);
1273                            assert!((v[1][0] - 0.3).abs() < 0.001);
1274                            assert!((v[1][1] - 0.4).abs() < 0.001);
1275                        }
1276                        _ => panic!("expected MultiVector literal, got {:?}", expr),
1277                    }
1278                } else {
1279                    panic!("expected expression in projection");
1280                }
1281            }
1282            _ => panic!("expected SELECT"),
1283        }
1284    }
1285
1286    #[test]
1287    fn parse_multi_vector_in_order_by() {
1288        // This tests that multi-vector literals can appear in ORDER BY clauses
1289        // The actual operator <##> will be handled by the extensions parser
1290        let stmt = parse_single_statement(
1291            "SELECT * FROM docs ORDER BY embedding <-> [[0.1, 0.2], [0.3, 0.4]]",
1292        );
1293        // This will fail parsing due to <-> which needs the extensions parser
1294        // But we're just testing the multi-vector parsing capability here
1295        assert!(stmt.is_err()); // <-> is not a standard SQL operator
1296    }
1297
1298    #[test]
1299    fn parse_insert_with_multi_vector() {
1300        let stmt = parse_single_statement(
1301            "INSERT INTO docs (id, embedding) VALUES (1, [[0.1, 0.2], [0.3, 0.4]])",
1302        )
1303        .unwrap();
1304        match stmt {
1305            Statement::Insert(insert) => {
1306                assert_eq!(insert.columns.len(), 2);
1307                match &insert.source {
1308                    InsertSource::Values(rows) => {
1309                        assert_eq!(rows.len(), 1);
1310                        assert_eq!(rows[0].len(), 2);
1311                        match &rows[0][1] {
1312                            Expr::Literal(Literal::MultiVector(v)) => {
1313                                assert_eq!(v.len(), 2);
1314                                assert_eq!(v[0].len(), 2);
1315                            }
1316                            _ => panic!("expected MultiVector literal in insert"),
1317                        }
1318                    }
1319                    _ => panic!("expected VALUES"),
1320                }
1321            }
1322            _ => panic!("expected INSERT"),
1323        }
1324    }
1325}