1use sqlparser::ast as sp;
7use sqlparser::dialect::GenericDialect;
8use sqlparser::parser::Parser;
9
10use crate::ast::{
11 Assignment, BinaryOp, CaseExpr, ColumnConstraint, ColumnDef, ConflictAction, ConflictTarget,
12 CreateIndexStatement, CreateTableStatement, DataType, DeleteStatement, DropIndexStatement,
13 DropTableStatement, Expr, FunctionCall, Identifier, IndexColumn, InsertSource, InsertStatement,
14 JoinClause, JoinCondition, JoinType, Literal, OnConflict, OrderByExpr, ParameterRef,
15 QualifiedName, SelectItem, SelectStatement, SetOperation, SetOperator, Statement, TableAlias,
16 TableConstraint, TableRef, UnaryOp, UpdateStatement, WindowFrame, WindowFrameBound,
17 WindowFrameUnits, WindowSpec, WithClause,
18};
19use crate::error::{ParseError, ParseResult};
20
21pub fn parse_sql(sql: &str) -> ParseResult<Vec<Statement>> {
27 if sql.trim().is_empty() {
28 return Err(ParseError::EmptyQuery);
29 }
30
31 let dialect = GenericDialect {};
32 let statements = Parser::parse_sql(&dialect, sql)?;
33
34 statements.into_iter().map(convert_statement).collect()
35}
36
37pub fn parse_single_statement(sql: &str) -> ParseResult<Statement> {
43 let mut stmts = parse_sql(sql)?;
44 if stmts.len() != 1 {
45 return Err(ParseError::SqlSyntax(format!("expected 1 statement, found {}", stmts.len())));
46 }
47 Ok(stmts.remove(0))
49}
50
51fn convert_statement(stmt: sp::Statement) -> ParseResult<Statement> {
53 match stmt {
54 sp::Statement::Query(query) => {
55 let select = convert_query(*query)?;
56 Ok(Statement::Select(Box::new(select)))
57 }
58 sp::Statement::Insert(insert) => {
59 let insert_stmt = convert_insert(insert)?;
60 Ok(Statement::Insert(Box::new(insert_stmt)))
61 }
62 sp::Statement::Update { table, assignments, from, selection, returning } => {
63 let from_vec = from.map(|t| vec![t]);
64 let update_stmt = convert_update(table, assignments, from_vec, selection, returning)?;
65 Ok(Statement::Update(Box::new(update_stmt)))
66 }
67 sp::Statement::Delete(delete) => {
68 let delete_stmt = convert_delete(delete)?;
69 Ok(Statement::Delete(Box::new(delete_stmt)))
70 }
71 sp::Statement::CreateTable(create) => {
72 let create_stmt = convert_create_table(create)?;
73 Ok(Statement::CreateTable(create_stmt))
74 }
75 sp::Statement::CreateIndex(create) => {
76 let create_stmt = convert_create_index(create)?;
77 Ok(Statement::CreateIndex(Box::new(create_stmt)))
78 }
79 sp::Statement::Drop { object_type, if_exists, names, cascade, .. } => match object_type {
80 sp::ObjectType::Table => {
81 let drop_stmt = DropTableStatement {
82 if_exists,
83 names: names.into_iter().map(convert_object_name).collect(),
84 cascade,
85 };
86 Ok(Statement::DropTable(drop_stmt))
87 }
88 sp::ObjectType::Index => {
89 let drop_stmt = DropIndexStatement {
90 if_exists,
91 names: names.into_iter().map(convert_object_name).collect(),
92 cascade,
93 };
94 Ok(Statement::DropIndex(drop_stmt))
95 }
96 _ => Err(ParseError::Unsupported(format!("DROP {object_type:?}"))),
97 },
98 sp::Statement::Explain { statement, .. } => {
99 let inner = convert_statement(*statement)?;
100 Ok(Statement::Explain(Box::new(inner)))
101 }
102 _ => Err(ParseError::Unsupported(format!("statement type: {stmt:?}"))),
103 }
104}
105
106fn convert_query(query: sp::Query) -> ParseResult<SelectStatement> {
108 let with_clauses =
110 if let Some(with) = query.with { convert_with_clause(with)? } else { vec![] };
111
112 let body = match *query.body {
113 sp::SetExpr::Select(select) => convert_select(*select)?,
114 sp::SetExpr::Query(subquery) => convert_query(*subquery)?,
115 sp::SetExpr::SetOperation { op, set_quantifier, left, right } => {
116 let mut base = match *left {
117 sp::SetExpr::Select(select) => convert_select(*select)?,
118 sp::SetExpr::Query(q) => convert_query(*q)?,
119 _ => return Err(ParseError::Unsupported("nested set operation".to_string())),
120 };
121 let right_stmt = match *right {
122 sp::SetExpr::Select(select) => convert_select(*select)?,
123 sp::SetExpr::Query(q) => convert_query(*q)?,
124 _ => return Err(ParseError::Unsupported("nested set operation".to_string())),
125 };
126 let set_op = SetOperation {
127 op: match op {
128 sp::SetOperator::Union => SetOperator::Union,
129 sp::SetOperator::Intersect => SetOperator::Intersect,
130 sp::SetOperator::Except => SetOperator::Except,
131 },
132 all: matches!(set_quantifier, sp::SetQuantifier::All),
133 right: right_stmt,
134 };
135 base.set_op = Some(Box::new(set_op));
136 base
137 }
138 sp::SetExpr::Values(values) => {
139 let rows: Vec<Vec<Expr>> = values
141 .rows
142 .into_iter()
143 .map(|row| row.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>())
144 .collect::<ParseResult<Vec<_>>>()?;
145
146 if rows.is_empty() {
147 return Err(ParseError::SqlSyntax("empty VALUES".to_string()));
148 }
149
150 let num_cols = rows.first().map_or(0, Vec::len);
152 let projection: Vec<SelectItem> = (1..=num_cols)
153 .map(|i| SelectItem::Expr {
154 expr: Expr::Column(QualifiedName::simple(format!("column{i}"))),
155 alias: None,
156 })
157 .collect();
158
159 SelectStatement::new(projection)
160 }
161 _ => return Err(ParseError::Unsupported("set expression type".to_string())),
162 };
163
164 let mut result = body;
166
167 if let Some(order_by) = query.order_by {
168 result.order_by = order_by
169 .exprs
170 .into_iter()
171 .map(convert_order_by_expr)
172 .collect::<ParseResult<Vec<_>>>()?;
173 }
174
175 if let Some(limit_expr) = query.limit {
176 result.limit = Some(convert_expr(limit_expr)?);
177 }
178
179 if let Some(offset_expr) = query.offset {
180 result.offset = Some(convert_expr(offset_expr.value)?);
181 }
182
183 result.with_clauses = with_clauses;
185
186 Ok(result)
187}
188
189fn convert_with_clause(with: sp::With) -> ParseResult<Vec<WithClause>> {
191 if with.recursive {
193 return Err(ParseError::Unsupported("WITH RECURSIVE".to_string()));
194 }
195
196 with.cte_tables
197 .into_iter()
198 .map(|cte| {
199 let name = convert_ident(cte.alias.name);
200 let columns: Vec<Identifier> =
201 cte.alias.columns.into_iter().map(convert_ident).collect();
202 let query = convert_query(*cte.query)?;
203
204 Ok(WithClause { name, columns, query: Box::new(query) })
205 })
206 .collect()
207}
208
209fn convert_select(select: sp::Select) -> ParseResult<SelectStatement> {
211 let distinct = match select.distinct {
212 Some(sp::Distinct::Distinct) => true,
213 Some(sp::Distinct::On(_)) => {
214 return Err(ParseError::Unsupported("DISTINCT ON".to_string()))
215 }
216 None => false,
217 };
218
219 let projection =
220 select.projection.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>()?;
221
222 let from =
223 select.from.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>()?;
224
225 let where_clause = select.selection.map(convert_expr).transpose()?;
226
227 let group_by = match select.group_by {
228 sp::GroupByExpr::Expressions(exprs, _) => {
229 exprs.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?
230 }
231 sp::GroupByExpr::All(_) => return Err(ParseError::Unsupported("GROUP BY ALL".to_string())),
232 };
233
234 let having = select.having.map(convert_expr).transpose()?;
235
236 Ok(SelectStatement {
237 with_clauses: vec![], distinct,
239 projection,
240 from,
241 match_clause: None, optional_match_clauses: vec![], where_clause,
244 group_by,
245 having,
246 order_by: vec![],
247 limit: None,
248 offset: None,
249 set_op: None,
250 })
251}
252
253fn convert_select_item(item: sp::SelectItem) -> ParseResult<SelectItem> {
255 match item {
256 sp::SelectItem::UnnamedExpr(expr) => {
257 Ok(SelectItem::Expr { expr: convert_expr(expr)?, alias: None })
258 }
259 sp::SelectItem::ExprWithAlias { expr, alias } => {
260 Ok(SelectItem::Expr { expr: convert_expr(expr)?, alias: Some(convert_ident(alias)) })
261 }
262 sp::SelectItem::Wildcard(_) => Ok(SelectItem::Wildcard),
263 sp::SelectItem::QualifiedWildcard(name, _) => {
264 Ok(SelectItem::QualifiedWildcard(convert_object_name(name)))
265 }
266 }
267}
268
269fn convert_table_with_joins(twj: sp::TableWithJoins) -> ParseResult<TableRef> {
271 let mut result = convert_table_factor(twj.relation)?;
272
273 for join in twj.joins {
274 let right = convert_table_factor(join.relation)?;
275 let join_type = match join.join_operator {
276 sp::JoinOperator::Inner(_) => JoinType::Inner,
277 sp::JoinOperator::LeftOuter(_) => JoinType::LeftOuter,
278 sp::JoinOperator::RightOuter(_) => JoinType::RightOuter,
279 sp::JoinOperator::FullOuter(_) => JoinType::FullOuter,
280 sp::JoinOperator::CrossJoin => JoinType::Cross,
281 sp::JoinOperator::LeftSemi(_) | sp::JoinOperator::RightSemi(_) => {
282 return Err(ParseError::Unsupported("SEMI JOIN".to_string()));
283 }
284 sp::JoinOperator::LeftAnti(_) | sp::JoinOperator::RightAnti(_) => {
285 return Err(ParseError::Unsupported("ANTI JOIN".to_string()));
286 }
287 sp::JoinOperator::AsOf { .. } => {
288 return Err(ParseError::Unsupported("AS OF JOIN".to_string()));
289 }
290 sp::JoinOperator::CrossApply | sp::JoinOperator::OuterApply => {
291 return Err(ParseError::Unsupported("APPLY".to_string()));
292 }
293 };
294
295 let condition = match join.join_operator {
296 sp::JoinOperator::Inner(constraint)
297 | sp::JoinOperator::LeftOuter(constraint)
298 | sp::JoinOperator::RightOuter(constraint)
299 | sp::JoinOperator::FullOuter(constraint) => convert_join_constraint(constraint)?,
300 _ => JoinCondition::None,
302 };
303
304 result = TableRef::Join(Box::new(JoinClause { left: result, right, join_type, condition }));
305 }
306
307 Ok(result)
308}
309
310fn convert_join_constraint(constraint: sp::JoinConstraint) -> ParseResult<JoinCondition> {
312 match constraint {
313 sp::JoinConstraint::On(expr) => Ok(JoinCondition::On(convert_expr(expr)?)),
314 sp::JoinConstraint::Using(idents) => {
315 Ok(JoinCondition::Using(idents.into_iter().map(convert_ident).collect()))
316 }
317 sp::JoinConstraint::Natural => Ok(JoinCondition::Natural),
318 sp::JoinConstraint::None => Ok(JoinCondition::None),
319 }
320}
321
322fn convert_table_factor(factor: sp::TableFactor) -> ParseResult<TableRef> {
324 match factor {
325 sp::TableFactor::Table { name, alias, .. } => Ok(TableRef::Table {
326 name: convert_object_name(name),
327 alias: alias.map(convert_table_alias),
328 }),
329 sp::TableFactor::Derived { subquery, alias, .. } => {
330 let alias =
331 alias.ok_or_else(|| ParseError::MissingClause("alias for subquery".to_string()))?;
332 Ok(TableRef::Subquery {
333 query: Box::new(convert_query(*subquery)?),
334 alias: convert_table_alias(alias),
335 })
336 }
337 sp::TableFactor::TableFunction { expr, alias } => {
338 if let sp::Expr::Function(func) = expr {
340 Ok(TableRef::TableFunction {
341 name: convert_object_name(func.name),
342 args: convert_function_args(func.args)?,
343 alias: alias.map(convert_table_alias),
344 })
345 } else {
346 Err(ParseError::Unsupported("non-function table function".to_string()))
347 }
348 }
349 sp::TableFactor::NestedJoin { table_with_joins, alias } => {
350 let mut result = convert_table_with_joins(*table_with_joins)?;
351 if let Some(alias) = alias {
352 match &mut result {
354 TableRef::Table { alias: ref mut a, .. } => {
355 *a = Some(convert_table_alias(alias))
356 }
357 TableRef::Subquery { alias: ref mut a, .. } => *a = convert_table_alias(alias),
358 _ => {}
359 }
360 }
361 Ok(result)
362 }
363 _ => Err(ParseError::Unsupported("table factor type".to_string())),
364 }
365}
366
367fn convert_function_args(args: sp::FunctionArguments) -> ParseResult<Vec<Expr>> {
369 match args {
370 sp::FunctionArguments::None => Ok(vec![]),
371 sp::FunctionArguments::Subquery(_) => {
372 Err(ParseError::Unsupported("subquery function argument".to_string()))
373 }
374 sp::FunctionArguments::List(arg_list) => arg_list
375 .args
376 .into_iter()
377 .map(|arg| match arg {
378 sp::FunctionArg::Unnamed(expr) => expr,
379 sp::FunctionArg::Named { arg, .. } => arg,
380 })
381 .map(|arg_expr| match arg_expr {
382 sp::FunctionArgExpr::Expr(e) => convert_expr(e),
383 sp::FunctionArgExpr::QualifiedWildcard(name) => {
384 Ok(Expr::QualifiedWildcard(convert_object_name(name)))
385 }
386 sp::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard),
387 })
388 .collect::<ParseResult<Vec<_>>>(),
389 }
390}
391
392fn convert_table_alias(alias: sp::TableAlias) -> TableAlias {
394 TableAlias {
395 name: convert_ident(alias.name),
396 columns: alias.columns.into_iter().map(convert_ident).collect(),
397 }
398}
399
400#[allow(clippy::too_many_lines)]
402fn convert_expr(expr: sp::Expr) -> ParseResult<Expr> {
403 match expr {
404 sp::Expr::Identifier(ident) => {
405 Ok(Expr::Column(QualifiedName::simple(convert_ident(ident))))
406 }
407 sp::Expr::CompoundIdentifier(idents) => {
408 Ok(Expr::Column(QualifiedName::new(idents.into_iter().map(convert_ident).collect())))
409 }
410 sp::Expr::Value(value) => convert_value(value),
411 sp::Expr::BinaryOp { left, op, right } => {
412 let left = convert_expr(*left)?;
413 let right = convert_expr(*right)?;
414 let op = convert_binary_op(&op)?;
415 Ok(Expr::BinaryOp { left: Box::new(left), op, right: Box::new(right) })
416 }
417 sp::Expr::UnaryOp { op, expr } => {
418 let operand = convert_expr(*expr)?;
419 let op = convert_unary_op(op)?;
420 Ok(Expr::UnaryOp { op, operand: Box::new(operand) })
421 }
422 sp::Expr::Nested(inner) => convert_expr(*inner),
423 sp::Expr::Function(func) => convert_function(func),
424 sp::Expr::Cast { expr, data_type, .. } => Ok(Expr::Cast {
425 expr: Box::new(convert_expr(*expr)?),
426 data_type: format_data_type(&data_type),
427 }),
428 sp::Expr::Case { operand, conditions, results, else_result } => {
429 let when_clauses: Vec<(Expr, Expr)> = conditions
430 .into_iter()
431 .zip(results)
432 .map(|(cond, result)| Ok((convert_expr(cond)?, convert_expr(result)?)))
433 .collect::<ParseResult<Vec<_>>>()?;
434
435 Ok(Expr::Case(CaseExpr {
436 operand: operand.map(|e| convert_expr(*e)).transpose()?.map(Box::new),
437 when_clauses,
438 else_result: else_result.map(|e| convert_expr(*e)).transpose()?.map(Box::new),
439 }))
440 }
441 sp::Expr::Subquery(query) => Ok(Expr::Subquery(crate::ast::expr::Subquery {
442 query: Box::new(convert_query(*query)?),
443 })),
444 sp::Expr::Exists { subquery, .. } => Ok(Expr::Exists(crate::ast::expr::Subquery {
445 query: Box::new(convert_query(*subquery)?),
446 })),
447 sp::Expr::InList { expr, list, negated } => Ok(Expr::InList {
448 expr: Box::new(convert_expr(*expr)?),
449 list: list.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?,
450 negated,
451 }),
452 sp::Expr::InSubquery { expr, subquery, negated } => Ok(Expr::InSubquery {
453 expr: Box::new(convert_expr(*expr)?),
454 subquery: crate::ast::expr::Subquery { query: Box::new(convert_query(*subquery)?) },
455 negated,
456 }),
457 sp::Expr::Between { expr, low, high, negated } => Ok(Expr::Between {
458 expr: Box::new(convert_expr(*expr)?),
459 low: Box::new(convert_expr(*low)?),
460 high: Box::new(convert_expr(*high)?),
461 negated,
462 }),
463 sp::Expr::IsNull(expr) => {
464 Ok(Expr::UnaryOp { op: UnaryOp::IsNull, operand: Box::new(convert_expr(*expr)?) })
465 }
466 sp::Expr::IsNotNull(expr) => {
467 Ok(Expr::UnaryOp { op: UnaryOp::IsNotNull, operand: Box::new(convert_expr(*expr)?) })
468 }
469 sp::Expr::Tuple(exprs) => {
470 Ok(Expr::Tuple(exprs.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?))
471 }
472 sp::Expr::Array(arr) => {
473 let sp::Array { elem, .. } = arr;
474 convert_array_expr(elem)
476 }
477 sp::Expr::Subscript { expr, subscript } => match *subscript {
478 sp::Subscript::Index { index } => Ok(Expr::ArrayIndex {
479 array: Box::new(convert_expr(*expr)?),
480 index: Box::new(convert_expr(index)?),
481 }),
482 sp::Subscript::Slice { .. } => {
483 Err(ParseError::Unsupported("subscript slice".to_string()))
484 }
485 },
486 sp::Expr::Like { negated, expr, pattern, escape_char: _, any: _ } => Ok(Expr::BinaryOp {
487 left: Box::new(convert_expr(*expr)?),
488 op: if negated { BinaryOp::NotLike } else { BinaryOp::Like },
489 right: Box::new(convert_expr(*pattern)?),
490 }),
491 sp::Expr::ILike { negated, expr, pattern, escape_char: _, any: _ } => Ok(Expr::BinaryOp {
492 left: Box::new(convert_expr(*expr)?),
493 op: if negated { BinaryOp::NotILike } else { BinaryOp::ILike },
494 right: Box::new(convert_expr(*pattern)?),
495 }),
496 sp::Expr::Named { name, .. } => {
497 Ok(Expr::Parameter(ParameterRef::Named(name.value)))
499 }
500 _ => Err(ParseError::Unsupported(format!("expression type: {expr:?}"))),
502 }
503}
504
505fn convert_array_expr(elements: Vec<sp::Expr>) -> ParseResult<Expr> {
512 let all_numeric =
514 elements.iter().all(|e| matches!(e, sp::Expr::Value(v) if is_numeric_value(v)));
515
516 if all_numeric && !elements.is_empty() {
517 let values: Vec<f32> = elements
519 .iter()
520 .map(|e| {
521 if let sp::Expr::Value(v) = e {
522 value_to_f32(v)
523 } else {
524 Err(ParseError::InvalidLiteral("expected numeric value".to_string()))
525 }
526 })
527 .collect::<ParseResult<Vec<_>>>()?;
528 return Ok(Expr::Literal(Literal::Vector(values)));
529 }
530
531 let all_arrays = elements.iter().all(|e| {
533 matches!(e, sp::Expr::Array(arr) if arr.elem.iter().all(|inner| matches!(inner, sp::Expr::Value(v) if is_numeric_value(v))))
534 });
535
536 if all_arrays && !elements.is_empty() {
537 let vectors: Vec<Vec<f32>> = elements
539 .iter()
540 .map(|e| {
541 if let sp::Expr::Array(arr) = e {
542 arr.elem
543 .iter()
544 .map(|inner| {
545 if let sp::Expr::Value(v) = inner {
546 value_to_f32(v)
547 } else {
548 Err(ParseError::InvalidLiteral(
549 "expected numeric value in nested array".to_string(),
550 ))
551 }
552 })
553 .collect::<ParseResult<Vec<_>>>()
554 } else {
555 Err(ParseError::InvalidLiteral("expected array in multi-vector".to_string()))
556 }
557 })
558 .collect::<ParseResult<Vec<_>>>()?;
559 return Ok(Expr::Literal(Literal::MultiVector(vectors)));
560 }
561
562 let converted = elements.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?;
564 Ok(Expr::Tuple(converted))
565}
566
567fn is_numeric_value(value: &sp::Value) -> bool {
569 matches!(value, sp::Value::Number(_, _))
570}
571
572fn value_to_f32(value: &sp::Value) -> ParseResult<f32> {
574 match value {
575 sp::Value::Number(n, _) => {
576 n.parse::<f32>().map_err(|_| ParseError::InvalidLiteral(format!("invalid f32: {n}")))
577 }
578 _ => Err(ParseError::InvalidLiteral("expected numeric value".to_string())),
579 }
580}
581
582fn convert_value(value: sp::Value) -> ParseResult<Expr> {
584 match value {
585 sp::Value::Null => Ok(Expr::Literal(Literal::Null)),
586 sp::Value::Boolean(b) => Ok(Expr::Literal(Literal::Boolean(b))),
587 sp::Value::Number(n, _) => {
588 if let Ok(i) = n.parse::<i64>() {
590 Ok(Expr::Literal(Literal::Integer(i)))
591 } else if let Ok(f) = n.parse::<f64>() {
592 Ok(Expr::Literal(Literal::Float(f)))
593 } else {
594 Err(ParseError::InvalidLiteral(format!("invalid number: {n}")))
595 }
596 }
597 sp::Value::SingleQuotedString(s) | sp::Value::DoubleQuotedString(s) => {
598 Ok(Expr::Literal(Literal::String(s)))
599 }
600 sp::Value::Placeholder(p) => {
601 if p == "?" {
602 Ok(Expr::Parameter(ParameterRef::Anonymous))
603 } else if let Some(n) = p.strip_prefix('$') {
604 if let Ok(pos) = n.parse::<u32>() {
605 Ok(Expr::Parameter(ParameterRef::Positional(pos)))
606 } else {
607 Ok(Expr::Parameter(ParameterRef::Named(n.to_string())))
608 }
609 } else {
610 Err(ParseError::InvalidLiteral(format!("unknown placeholder: {p}")))
611 }
612 }
613 _ => Err(ParseError::Unsupported(format!("value type: {value:?}"))),
614 }
615}
616
617fn convert_binary_op(op: &sp::BinaryOperator) -> ParseResult<BinaryOp> {
619 match op {
620 sp::BinaryOperator::Plus => Ok(BinaryOp::Add),
621 sp::BinaryOperator::Minus => Ok(BinaryOp::Sub),
622 sp::BinaryOperator::Multiply => Ok(BinaryOp::Mul),
623 sp::BinaryOperator::Divide => Ok(BinaryOp::Div),
624 sp::BinaryOperator::Modulo => Ok(BinaryOp::Mod),
625 sp::BinaryOperator::Eq => Ok(BinaryOp::Eq),
626 sp::BinaryOperator::NotEq => Ok(BinaryOp::NotEq),
627 sp::BinaryOperator::Lt => Ok(BinaryOp::Lt),
628 sp::BinaryOperator::LtEq => Ok(BinaryOp::LtEq),
629 sp::BinaryOperator::Gt => Ok(BinaryOp::Gt),
630 sp::BinaryOperator::GtEq => Ok(BinaryOp::GtEq),
631 sp::BinaryOperator::And => Ok(BinaryOp::And),
632 sp::BinaryOperator::Or => Ok(BinaryOp::Or),
633 sp::BinaryOperator::Arrow => Err(ParseError::Unsupported("-> operator".to_string())),
635 sp::BinaryOperator::LongArrow => Err(ParseError::Unsupported("->> operator".to_string())),
636 sp::BinaryOperator::HashArrow => Err(ParseError::Unsupported("#> operator".to_string())),
637 sp::BinaryOperator::HashLongArrow => {
638 Err(ParseError::Unsupported("#>> operator".to_string()))
639 }
640 _ => Err(ParseError::Unsupported(format!("binary operator: {op:?}"))),
641 }
642}
643
644fn convert_unary_op(op: sp::UnaryOperator) -> ParseResult<UnaryOp> {
646 match op {
647 sp::UnaryOperator::Not => Ok(UnaryOp::Not),
648 sp::UnaryOperator::Plus | sp::UnaryOperator::Minus => Ok(UnaryOp::Neg),
652 _ => Err(ParseError::Unsupported(format!("unary operator: {op:?}"))),
653 }
654}
655
656fn convert_function(func: sp::Function) -> ParseResult<Expr> {
658 let name = convert_object_name(func.name);
659 let args = convert_function_args(func.args)?;
660
661 let filter = func.filter.map(|f| convert_expr(*f)).transpose()?.map(Box::new);
662
663 let over = func.over.map(convert_window_spec).transpose()?;
664
665 Ok(Expr::Function(FunctionCall {
666 name,
667 args,
668 distinct: false, filter,
670 over,
671 }))
672}
673
674fn convert_window_spec(spec: sp::WindowType) -> ParseResult<WindowSpec> {
676 match spec {
677 sp::WindowType::WindowSpec(spec) => {
678 let partition_by =
679 spec.partition_by.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>()?;
680
681 let order_by = spec
682 .order_by
683 .into_iter()
684 .map(convert_order_by_expr)
685 .collect::<ParseResult<Vec<_>>>()?;
686
687 let frame = spec.window_frame.map(convert_window_frame).transpose()?;
688
689 Ok(WindowSpec { partition_by, order_by, frame })
690 }
691 sp::WindowType::NamedWindow(_) => {
692 Err(ParseError::Unsupported("named window reference".to_string()))
693 }
694 }
695}
696
697fn convert_window_frame(frame: sp::WindowFrame) -> ParseResult<WindowFrame> {
699 let units = match frame.units {
700 sp::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
701 sp::WindowFrameUnits::Range => WindowFrameUnits::Range,
702 sp::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
703 };
704
705 let start = convert_window_frame_bound(frame.start_bound)?;
706 let end = frame.end_bound.map(convert_window_frame_bound).transpose()?;
707
708 Ok(WindowFrame { units, start, end })
709}
710
711fn convert_window_frame_bound(bound: sp::WindowFrameBound) -> ParseResult<WindowFrameBound> {
713 match bound {
714 sp::WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
715 sp::WindowFrameBound::Preceding(None) => Ok(WindowFrameBound::UnboundedPreceding),
716 sp::WindowFrameBound::Following(None) => Ok(WindowFrameBound::UnboundedFollowing),
717 sp::WindowFrameBound::Preceding(Some(expr)) => {
718 Ok(WindowFrameBound::Preceding(Box::new(convert_expr(*expr)?)))
719 }
720 sp::WindowFrameBound::Following(Some(expr)) => {
721 Ok(WindowFrameBound::Following(Box::new(convert_expr(*expr)?)))
722 }
723 }
724}
725
726fn convert_order_by_expr(expr: sp::OrderByExpr) -> ParseResult<OrderByExpr> {
728 let asc = expr.asc.unwrap_or(true); Ok(OrderByExpr { expr: Box::new(convert_expr(expr.expr)?), asc, nulls_first: expr.nulls_first })
731}
732
733fn convert_insert(insert: sp::Insert) -> ParseResult<InsertStatement> {
735 let table = convert_object_name(insert.table_name);
737
738 let columns: Vec<Identifier> = insert.columns.into_iter().map(convert_ident).collect();
739
740 let source = match insert.source {
741 Some(source) => match *source.body {
742 sp::SetExpr::Values(values) => {
743 let rows: Vec<Vec<Expr>> = values
744 .rows
745 .into_iter()
746 .map(|row| row.into_iter().map(convert_expr).collect::<ParseResult<Vec<_>>>())
747 .collect::<ParseResult<Vec<_>>>()?;
748 InsertSource::Values(rows)
749 }
750 sp::SetExpr::Select(select) => {
751 let query = convert_select(*select)?;
752 InsertSource::Query(Box::new(query))
753 }
754 _ => return Err(ParseError::Unsupported("INSERT source type".to_string())),
755 },
756 None => InsertSource::DefaultValues,
757 };
758
759 let on_conflict = insert.on.map(convert_on_conflict).transpose()?;
760
761 let returning = insert
762 .returning
763 .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
764 .transpose()?
765 .unwrap_or_default();
766
767 Ok(InsertStatement { table, columns, source, on_conflict, returning })
768}
769
770fn convert_on_conflict(on: sp::OnInsert) -> ParseResult<OnConflict> {
772 match on {
773 sp::OnInsert::DuplicateKeyUpdate(assignments) => {
774 Ok(OnConflict {
775 target: ConflictTarget::Columns(vec![]), action: ConflictAction::DoUpdate {
777 assignments: assignments
778 .into_iter()
779 .map(convert_assignment)
780 .collect::<ParseResult<Vec<_>>>()?,
781 where_clause: None,
782 },
783 })
784 }
785 sp::OnInsert::OnConflict(conflict) => {
786 let target = match conflict.conflict_target {
787 Some(sp::ConflictTarget::Columns(cols)) => {
788 ConflictTarget::Columns(cols.into_iter().map(convert_ident).collect())
789 }
790 Some(sp::ConflictTarget::OnConstraint(name)) => {
791 let converted = convert_object_name(name);
792 let ident = converted.parts.into_iter().next().ok_or_else(|| {
793 ParseError::MissingClause("constraint name in ON CONFLICT".to_string())
794 })?;
795 ConflictTarget::Constraint(ident)
796 }
797 None => ConflictTarget::Columns(vec![]),
798 };
799
800 let action = match conflict.action {
801 sp::OnConflictAction::DoNothing => ConflictAction::DoNothing,
802 sp::OnConflictAction::DoUpdate(update) => ConflictAction::DoUpdate {
803 assignments: update
804 .assignments
805 .into_iter()
806 .map(convert_assignment)
807 .collect::<ParseResult<Vec<_>>>()?,
808 where_clause: update.selection.map(convert_expr).transpose()?,
809 },
810 };
811
812 Ok(OnConflict { target, action })
813 }
814 _ => Err(ParseError::Unsupported("ON INSERT type".to_string())),
815 }
816}
817
818fn convert_assignment(assign: sp::Assignment) -> ParseResult<Assignment> {
820 let column = match assign.target {
822 sp::AssignmentTarget::ColumnName(names) => names
823 .0
824 .into_iter()
825 .next()
826 .map(convert_ident)
827 .ok_or_else(|| ParseError::MissingClause("assignment target".to_string()))?,
828 sp::AssignmentTarget::Tuple(_) => {
829 return Err(ParseError::Unsupported("tuple assignment target".to_string()));
830 }
831 };
832
833 Ok(Assignment { column, value: convert_expr(assign.value)? })
834}
835
836fn convert_update(
838 table: sp::TableWithJoins,
839 assignments: Vec<sp::Assignment>,
840 from: Option<Vec<sp::TableWithJoins>>,
841 selection: Option<sp::Expr>,
842 returning: Option<Vec<sp::SelectItem>>,
843) -> ParseResult<UpdateStatement> {
844 let table_ref = convert_table_with_joins(table)?;
845 let TableRef::Table { name: table_name, alias } = table_ref else {
846 return Err(ParseError::Unsupported("complex UPDATE target".to_string()));
847 };
848
849 let assignments =
850 assignments.into_iter().map(convert_assignment).collect::<ParseResult<Vec<_>>>()?;
851
852 let from_clause = from
853 .map(|f| f.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>())
854 .transpose()?
855 .unwrap_or_default();
856
857 let where_clause = selection.map(convert_expr).transpose()?;
858
859 let returning = returning
860 .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
861 .transpose()?
862 .unwrap_or_default();
863
864 Ok(UpdateStatement {
865 table: table_name,
866 alias,
867 assignments,
868 from: from_clause,
869 match_clause: None,
870 where_clause,
871 returning,
872 })
873}
874
875fn convert_delete(delete: sp::Delete) -> ParseResult<DeleteStatement> {
877 let from_table = match delete.from {
878 sp::FromTable::WithFromKeyword(tables) => tables
879 .into_iter()
880 .next()
881 .ok_or_else(|| ParseError::MissingClause("FROM".to_string()))?,
882 sp::FromTable::WithoutKeyword(tables) => tables
883 .into_iter()
884 .next()
885 .ok_or_else(|| ParseError::MissingClause("table".to_string()))?,
886 };
887
888 let table_ref = convert_table_with_joins(from_table)?;
889 let TableRef::Table { name: table_name, alias } = table_ref else {
890 return Err(ParseError::Unsupported("complex DELETE target".to_string()));
891 };
892
893 let using = delete
894 .using
895 .map(|u| u.into_iter().map(convert_table_with_joins).collect::<ParseResult<Vec<_>>>())
896 .transpose()?
897 .unwrap_or_default();
898
899 let where_clause = delete.selection.map(convert_expr).transpose()?;
900
901 let returning = delete
902 .returning
903 .map(|items| items.into_iter().map(convert_select_item).collect::<ParseResult<Vec<_>>>())
904 .transpose()?
905 .unwrap_or_default();
906
907 Ok(DeleteStatement {
908 table: table_name,
909 alias,
910 using,
911 match_clause: None,
912 where_clause,
913 returning,
914 })
915}
916
917fn convert_create_table(create: sp::CreateTable) -> ParseResult<CreateTableStatement> {
919 let columns =
920 create.columns.into_iter().map(convert_column_def).collect::<ParseResult<Vec<_>>>()?;
921
922 let constraints = create
923 .constraints
924 .into_iter()
925 .map(convert_table_constraint)
926 .collect::<ParseResult<Vec<_>>>()?;
927
928 Ok(CreateTableStatement {
929 if_not_exists: create.if_not_exists,
930 name: convert_object_name(create.name),
931 columns,
932 constraints,
933 })
934}
935
936fn convert_column_def(col: sp::ColumnDef) -> ParseResult<ColumnDef> {
938 let constraints =
939 col.options.into_iter().filter_map(|opt| convert_column_option(opt.option).ok()).collect();
940
941 Ok(ColumnDef {
942 name: convert_ident(col.name),
943 data_type: convert_data_type(col.data_type)?,
944 constraints,
945 })
946}
947
948fn convert_column_option(opt: sp::ColumnOption) -> ParseResult<ColumnConstraint> {
950 match opt {
951 sp::ColumnOption::Null => Ok(ColumnConstraint::Null),
952 sp::ColumnOption::NotNull => Ok(ColumnConstraint::NotNull),
953 sp::ColumnOption::Unique { is_primary, .. } => {
954 if is_primary {
955 Ok(ColumnConstraint::PrimaryKey)
956 } else {
957 Ok(ColumnConstraint::Unique)
958 }
959 }
960 sp::ColumnOption::ForeignKey { foreign_table, referred_columns, .. } => {
961 Ok(ColumnConstraint::References {
962 table: convert_object_name(foreign_table),
963 column: referred_columns.into_iter().next().map(convert_ident),
964 })
965 }
966 sp::ColumnOption::Check(expr) => Ok(ColumnConstraint::Check(convert_expr(expr)?)),
967 sp::ColumnOption::Default(expr) => Ok(ColumnConstraint::Default(convert_expr(expr)?)),
968 _ => Err(ParseError::Unsupported("column option".to_string())),
969 }
970}
971
972fn convert_table_constraint(constraint: sp::TableConstraint) -> ParseResult<TableConstraint> {
974 match constraint {
975 sp::TableConstraint::PrimaryKey { columns, name, .. } => Ok(TableConstraint::PrimaryKey {
976 name: name.map(convert_ident),
977 columns: columns.into_iter().map(convert_ident).collect(),
978 }),
979 sp::TableConstraint::Unique { columns, name, .. } => Ok(TableConstraint::Unique {
980 name: name.map(convert_ident),
981 columns: columns.into_iter().map(convert_ident).collect(),
982 }),
983 sp::TableConstraint::ForeignKey {
984 columns, foreign_table, referred_columns, name, ..
985 } => Ok(TableConstraint::ForeignKey {
986 name: name.map(convert_ident),
987 columns: columns.into_iter().map(convert_ident).collect(),
988 references_table: convert_object_name(foreign_table),
989 references_columns: referred_columns.into_iter().map(convert_ident).collect(),
990 }),
991 sp::TableConstraint::Check { name, expr } => {
992 Ok(TableConstraint::Check { name: name.map(convert_ident), expr: convert_expr(*expr)? })
993 }
994 _ => Err(ParseError::Unsupported("table constraint".to_string())),
995 }
996}
997
998fn convert_create_index(create: sp::CreateIndex) -> ParseResult<CreateIndexStatement> {
1000 let name = create
1001 .name
1002 .map(convert_object_name)
1003 .and_then(|n| n.parts.into_iter().next())
1004 .ok_or_else(|| ParseError::MissingClause("index name".to_string()))?;
1005
1006 let table = convert_object_name(create.table_name);
1007
1008 let columns = create
1009 .columns
1010 .into_iter()
1011 .map(|col| {
1012 Ok(IndexColumn {
1013 expr: convert_expr(col.expr)?,
1014 asc: col.asc,
1015 nulls_first: col.nulls_first,
1016 opclass: None,
1017 })
1018 })
1019 .collect::<ParseResult<Vec<_>>>()?;
1020
1021 Ok(CreateIndexStatement {
1022 unique: create.unique,
1023 if_not_exists: create.if_not_exists,
1024 name,
1025 table,
1026 columns,
1027 using: create.using.map(convert_ident).map(|i| i.name),
1028 with: vec![],
1029 where_clause: create.predicate.map(convert_expr).transpose()?,
1030 })
1031}
1032
1033#[allow(clippy::cast_possible_truncation)]
1035fn convert_data_type(dt: sp::DataType) -> ParseResult<DataType> {
1036 match dt {
1037 sp::DataType::Boolean | sp::DataType::Bool => Ok(DataType::Boolean),
1038 sp::DataType::SmallInt(_) | sp::DataType::Int2(_) => Ok(DataType::SmallInt),
1039 sp::DataType::Int(_) | sp::DataType::Integer(_) | sp::DataType::Int4(_) => {
1040 Ok(DataType::Integer)
1041 }
1042 sp::DataType::BigInt(_) | sp::DataType::Int8(_) => Ok(DataType::BigInt),
1043 sp::DataType::Real | sp::DataType::Float4 => Ok(DataType::Real),
1044 sp::DataType::DoublePrecision | sp::DataType::Double | sp::DataType::Float8 => {
1045 Ok(DataType::DoublePrecision)
1046 }
1047 sp::DataType::Numeric(info) | sp::DataType::Decimal(info) => {
1048 let (precision, scale) = match info {
1049 sp::ExactNumberInfo::None => (None, None),
1050 sp::ExactNumberInfo::Precision(p) => (Some(p as u32), None),
1051 sp::ExactNumberInfo::PrecisionAndScale(p, s) => (Some(p as u32), Some(s as u32)),
1052 };
1053 Ok(DataType::Numeric { precision, scale })
1054 }
1055 sp::DataType::Varchar(len) | sp::DataType::CharVarying(len) => {
1056 let len_val = len.and_then(|l| match l {
1057 sp::CharacterLength::IntegerLength { length, .. } => Some(length as u32),
1058 sp::CharacterLength::Max => None,
1059 });
1060 Ok(DataType::Varchar(len_val))
1061 }
1062 sp::DataType::Text => Ok(DataType::Text),
1063 sp::DataType::Bytea => Ok(DataType::Bytea),
1064 sp::DataType::Timestamp(_, _) => Ok(DataType::Timestamp),
1065 sp::DataType::Date => Ok(DataType::Date),
1066 sp::DataType::Time(_, _) => Ok(DataType::Time),
1067 sp::DataType::Interval => Ok(DataType::Interval),
1068 sp::DataType::JSON => Ok(DataType::Json),
1069 sp::DataType::Uuid => Ok(DataType::Uuid),
1070 sp::DataType::Array(elem) => match elem {
1071 sp::ArrayElemTypeDef::AngleBracket(inner)
1072 | sp::ArrayElemTypeDef::SquareBracket(inner, _) => {
1073 Ok(DataType::Array(Box::new(convert_data_type(*inner)?)))
1074 }
1075 sp::ArrayElemTypeDef::None => Err(ParseError::Unsupported("untyped array".to_string())),
1076 sp::ArrayElemTypeDef::Parenthesis(_) => {
1077 Err(ParseError::Unsupported("parenthesized array type".to_string()))
1078 }
1079 },
1080 sp::DataType::Custom(name, _) => {
1081 let name_str = name.0.iter().map(|p| p.value.clone()).collect::<Vec<_>>().join(".");
1082
1083 if name_str.eq_ignore_ascii_case("vector") {
1085 Ok(DataType::Vector(None))
1086 } else {
1087 Ok(DataType::Custom(name_str))
1088 }
1089 }
1090 _ => Err(ParseError::Unsupported(format!("data type: {dt:?}"))),
1091 }
1092}
1093
1094fn format_data_type(dt: &sp::DataType) -> String {
1096 format!("{dt}")
1097}
1098
1099fn convert_object_name(name: sp::ObjectName) -> QualifiedName {
1101 QualifiedName::new(name.0.into_iter().map(convert_ident).collect())
1102}
1103
1104fn convert_ident(ident: sp::Ident) -> Identifier {
1106 Identifier { name: ident.value, quote_style: ident.quote_style }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 use super::*;
1112
1113 #[test]
1114 fn parse_simple_select() {
1115 let stmt = parse_single_statement("SELECT * FROM users").unwrap();
1116 match stmt {
1117 Statement::Select(select) => {
1118 assert_eq!(select.projection.len(), 1);
1119 assert!(matches!(select.projection[0], SelectItem::Wildcard));
1120 }
1121 _ => panic!("expected SELECT"),
1122 }
1123 }
1124
1125 #[test]
1126 fn parse_select_with_where() {
1127 let stmt = parse_single_statement("SELECT id, name FROM users WHERE id = 1").unwrap();
1128 match stmt {
1129 Statement::Select(select) => {
1130 assert_eq!(select.projection.len(), 2);
1131 assert!(select.where_clause.is_some());
1132 }
1133 _ => panic!("expected SELECT"),
1134 }
1135 }
1136
1137 #[test]
1138 fn parse_insert() {
1139 let stmt =
1140 parse_single_statement("INSERT INTO users (name, age) VALUES ('Alice', 30)").unwrap();
1141 match stmt {
1142 Statement::Insert(insert) => {
1143 assert_eq!(insert.columns.len(), 2);
1144 match &insert.source {
1145 InsertSource::Values(rows) => {
1146 assert_eq!(rows.len(), 1);
1147 assert_eq!(rows[0].len(), 2);
1148 }
1149 _ => panic!("expected VALUES"),
1150 }
1151 }
1152 _ => panic!("expected INSERT"),
1153 }
1154 }
1155
1156 #[test]
1157 fn parse_update() {
1158 let stmt = parse_single_statement("UPDATE users SET name = 'Bob' WHERE id = 1").unwrap();
1159 match stmt {
1160 Statement::Update(update) => {
1161 assert_eq!(update.assignments.len(), 1);
1162 assert!(update.where_clause.is_some());
1163 }
1164 _ => panic!("expected UPDATE"),
1165 }
1166 }
1167
1168 #[test]
1169 fn parse_delete() {
1170 let stmt = parse_single_statement("DELETE FROM users WHERE id = 1").unwrap();
1171 match stmt {
1172 Statement::Delete(delete) => {
1173 assert!(delete.where_clause.is_some());
1174 }
1175 _ => panic!("expected DELETE"),
1176 }
1177 }
1178
1179 #[test]
1180 fn parse_create_table() {
1181 let stmt = parse_single_statement(
1182 "CREATE TABLE users (id BIGINT PRIMARY KEY, name VARCHAR(100) NOT NULL)",
1183 )
1184 .unwrap();
1185 match stmt {
1186 Statement::CreateTable(create) => {
1187 assert_eq!(create.columns.len(), 2);
1188 }
1189 _ => panic!("expected CREATE TABLE"),
1190 }
1191 }
1192
1193 #[test]
1194 fn parse_join() {
1195 let stmt = parse_single_statement(
1196 "SELECT u.name, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id",
1197 )
1198 .unwrap();
1199 match stmt {
1200 Statement::Select(select) => {
1201 assert_eq!(select.from.len(), 1);
1202 match &select.from[0] {
1203 TableRef::Join(join) => {
1204 assert_eq!(join.join_type, JoinType::Inner);
1205 }
1206 _ => panic!("expected JOIN"),
1207 }
1208 }
1209 _ => panic!("expected SELECT"),
1210 }
1211 }
1212
1213 #[test]
1214 fn parse_empty_query() {
1215 let result = parse_sql("");
1216 assert!(matches!(result, Err(ParseError::EmptyQuery)));
1217 }
1218
1219 #[test]
1220 fn parse_parameter() {
1221 let stmt = parse_single_statement("SELECT * FROM users WHERE id = $1").unwrap();
1222 match stmt {
1223 Statement::Select(select) => {
1224 if let Some(Expr::BinaryOp { right, .. }) = select.where_clause {
1225 match *right {
1226 Expr::Parameter(ParameterRef::Positional(1)) => {}
1227 _ => panic!("expected positional parameter"),
1228 }
1229 }
1230 }
1231 _ => panic!("expected SELECT"),
1232 }
1233 }
1234
1235 #[test]
1236 fn parse_vector_literal() {
1237 let stmt = parse_single_statement("SELECT [0.1, 0.2, 0.3]").unwrap();
1238 match stmt {
1239 Statement::Select(select) => {
1240 assert_eq!(select.projection.len(), 1);
1241 if let SelectItem::Expr { expr, .. } = &select.projection[0] {
1242 match expr {
1243 Expr::Literal(Literal::Vector(v)) => {
1244 assert_eq!(v.len(), 3);
1245 assert!((v[0] - 0.1).abs() < 0.001);
1246 assert!((v[1] - 0.2).abs() < 0.001);
1247 assert!((v[2] - 0.3).abs() < 0.001);
1248 }
1249 _ => panic!("expected Vector literal, got {:?}", expr),
1250 }
1251 } else {
1252 panic!("expected expression in projection");
1253 }
1254 }
1255 _ => panic!("expected SELECT"),
1256 }
1257 }
1258
1259 #[test]
1260 fn parse_multi_vector_literal() {
1261 let stmt = parse_single_statement("SELECT [[0.1, 0.2], [0.3, 0.4]]").unwrap();
1262 match stmt {
1263 Statement::Select(select) => {
1264 assert_eq!(select.projection.len(), 1);
1265 if let SelectItem::Expr { expr, .. } = &select.projection[0] {
1266 match expr {
1267 Expr::Literal(Literal::MultiVector(v)) => {
1268 assert_eq!(v.len(), 2);
1269 assert_eq!(v[0].len(), 2);
1270 assert_eq!(v[1].len(), 2);
1271 assert!((v[0][0] - 0.1).abs() < 0.001);
1272 assert!((v[0][1] - 0.2).abs() < 0.001);
1273 assert!((v[1][0] - 0.3).abs() < 0.001);
1274 assert!((v[1][1] - 0.4).abs() < 0.001);
1275 }
1276 _ => panic!("expected MultiVector literal, got {:?}", expr),
1277 }
1278 } else {
1279 panic!("expected expression in projection");
1280 }
1281 }
1282 _ => panic!("expected SELECT"),
1283 }
1284 }
1285
1286 #[test]
1287 fn parse_multi_vector_in_order_by() {
1288 let stmt = parse_single_statement(
1291 "SELECT * FROM docs ORDER BY embedding <-> [[0.1, 0.2], [0.3, 0.4]]",
1292 );
1293 assert!(stmt.is_err()); }
1297
1298 #[test]
1299 fn parse_insert_with_multi_vector() {
1300 let stmt = parse_single_statement(
1301 "INSERT INTO docs (id, embedding) VALUES (1, [[0.1, 0.2], [0.3, 0.4]])",
1302 )
1303 .unwrap();
1304 match stmt {
1305 Statement::Insert(insert) => {
1306 assert_eq!(insert.columns.len(), 2);
1307 match &insert.source {
1308 InsertSource::Values(rows) => {
1309 assert_eq!(rows.len(), 1);
1310 assert_eq!(rows[0].len(), 2);
1311 match &rows[0][1] {
1312 Expr::Literal(Literal::MultiVector(v)) => {
1313 assert_eq!(v.len(), 2);
1314 assert_eq!(v[0].len(), 2);
1315 }
1316 _ => panic!("expected MultiVector literal in insert"),
1317 }
1318 }
1319 _ => panic!("expected VALUES"),
1320 }
1321 }
1322 _ => panic!("expected INSERT"),
1323 }
1324 }
1325}