1use super::ast::*;
50use super::compatibility::SqlDialect;
51use super::error::{SqlError, SqlResult};
52use super::parser::Parser;
53use std::collections::HashMap;
54use sochdb_core::SochValue;
55
56#[derive(Debug, Clone)]
58pub enum ExecutionResult {
59 Rows {
61 columns: Vec<String>,
62 rows: Vec<HashMap<String, SochValue>>,
63 },
64 RowsAffected(usize),
66 Ok,
68 TransactionOk,
70}
71
72impl ExecutionResult {
73 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
75 match self {
76 ExecutionResult::Rows { rows, .. } => Some(rows),
77 _ => None,
78 }
79 }
80
81 pub fn columns(&self) -> Option<&Vec<String>> {
83 match self {
84 ExecutionResult::Rows { columns, .. } => Some(columns),
85 _ => None,
86 }
87 }
88
89 pub fn rows_affected(&self) -> usize {
91 match self {
92 ExecutionResult::RowsAffected(n) => *n,
93 ExecutionResult::Rows { rows, .. } => rows.len(),
94 _ => 0,
95 }
96 }
97}
98
99pub trait SqlConnection {
104 fn select(
106 &self,
107 table: &str,
108 columns: &[String],
109 where_clause: Option<&Expr>,
110 order_by: &[OrderByItem],
111 limit: Option<usize>,
112 offset: Option<usize>,
113 params: &[SochValue],
114 ) -> SqlResult<ExecutionResult>;
115
116 fn insert(
118 &mut self,
119 table: &str,
120 columns: Option<&[String]>,
121 rows: &[Vec<Expr>],
122 on_conflict: Option<&OnConflict>,
123 params: &[SochValue],
124 ) -> SqlResult<ExecutionResult>;
125
126 fn update(
128 &mut self,
129 table: &str,
130 assignments: &[Assignment],
131 where_clause: Option<&Expr>,
132 params: &[SochValue],
133 ) -> SqlResult<ExecutionResult>;
134
135 fn delete(
137 &mut self,
138 table: &str,
139 where_clause: Option<&Expr>,
140 params: &[SochValue],
141 ) -> SqlResult<ExecutionResult>;
142
143 fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
145
146 fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
148
149 fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
151
152 fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
154
155 fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
157
158 fn commit(&mut self) -> SqlResult<ExecutionResult>;
160
161 fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
163
164 fn table_exists(&self, table: &str) -> SqlResult<bool>;
166
167 fn index_exists(&self, index: &str) -> SqlResult<bool>;
169}
170
171pub struct SqlBridge<C: SqlConnection> {
173 conn: C,
174}
175
176impl<C: SqlConnection> SqlBridge<C> {
177 pub fn new(conn: C) -> Self {
179 Self { conn }
180 }
181
182 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
184 self.execute_with_params(sql, &[])
185 }
186
187 pub fn execute_with_params(
189 &mut self,
190 sql: &str,
191 params: &[SochValue],
192 ) -> SqlResult<ExecutionResult> {
193 let _dialect = SqlDialect::detect(sql);
195
196 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
198
199 let max_placeholder = self.find_max_placeholder(&stmt);
201 if max_placeholder as usize > params.len() {
202 return Err(SqlError::InvalidArgument(format!(
203 "Query contains {} placeholders but only {} parameters provided",
204 max_placeholder,
205 params.len()
206 )));
207 }
208
209 self.execute_statement(&stmt, params)
211 }
212
213 pub fn execute_statement(
215 &mut self,
216 stmt: &Statement,
217 params: &[SochValue],
218 ) -> SqlResult<ExecutionResult> {
219 match stmt {
220 Statement::Select(select) => self.execute_select(select, params),
221 Statement::Insert(insert) => self.execute_insert(insert, params),
222 Statement::Update(update) => self.execute_update(update, params),
223 Statement::Delete(delete) => self.execute_delete(delete, params),
224 Statement::CreateTable(create) => self.execute_create_table(create),
225 Statement::DropTable(drop) => self.execute_drop_table(drop),
226 Statement::CreateIndex(create) => self.execute_create_index(create),
227 Statement::DropIndex(drop) => self.execute_drop_index(drop),
228 Statement::AlterTable(_alter) => Err(SqlError::NotImplemented(
229 "ALTER TABLE not yet implemented".into(),
230 )),
231 Statement::Begin(begin) => self.conn.begin(begin),
232 Statement::Commit => self.conn.commit(),
233 Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
234 Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
235 "SAVEPOINT not yet implemented".into(),
236 )),
237 Statement::Release(_name) => Err(SqlError::NotImplemented(
238 "RELEASE SAVEPOINT not yet implemented".into(),
239 )),
240 Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
241 "EXPLAIN not yet implemented".into(),
242 )),
243 }
244 }
245
246 fn execute_select(
247 &self,
248 select: &SelectStmt,
249 params: &[SochValue],
250 ) -> SqlResult<ExecutionResult> {
251 let from = select
253 .from
254 .as_ref()
255 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
256
257 if from.tables.len() != 1 {
258 return Err(SqlError::NotImplemented(
259 "Multi-table queries not yet supported".into(),
260 ));
261 }
262
263 let table_name = match &from.tables[0] {
264 TableRef::Table { name, .. } => name.name().to_string(),
265 TableRef::Subquery { .. } => {
266 return Err(SqlError::NotImplemented(
267 "Subqueries not yet supported".into(),
268 ));
269 }
270 TableRef::Join { .. } => {
271 return Err(SqlError::NotImplemented(
272 "JOINs not yet supported".into(),
273 ));
274 }
275 TableRef::Function { .. } => {
276 return Err(SqlError::NotImplemented(
277 "Table functions not yet supported".into(),
278 ));
279 }
280 };
281
282 let columns = self.extract_select_columns(&select.columns)?;
284
285 let limit = self.extract_limit(&select.limit)?;
287 let offset = self.extract_limit(&select.offset)?;
288
289 self.conn.select(
290 &table_name,
291 &columns,
292 select.where_clause.as_ref(),
293 &select.order_by,
294 limit,
295 offset,
296 params,
297 )
298 }
299
300 fn execute_insert(
301 &mut self,
302 insert: &InsertStmt,
303 params: &[SochValue],
304 ) -> SqlResult<ExecutionResult> {
305 let table_name = insert.table.name();
306
307 let rows = match &insert.source {
308 InsertSource::Values(values) => values,
309 InsertSource::Query(_) => {
310 return Err(SqlError::NotImplemented(
311 "INSERT ... SELECT not yet supported".into(),
312 ));
313 }
314 InsertSource::Default => {
315 return Err(SqlError::NotImplemented(
316 "INSERT DEFAULT VALUES not yet supported".into(),
317 ));
318 }
319 };
320
321 self.conn.insert(
322 table_name,
323 insert.columns.as_deref(),
324 rows,
325 insert.on_conflict.as_ref(),
326 params,
327 )
328 }
329
330 fn execute_update(
331 &mut self,
332 update: &UpdateStmt,
333 params: &[SochValue],
334 ) -> SqlResult<ExecutionResult> {
335 let table_name = update.table.name();
336
337 self.conn.update(
338 table_name,
339 &update.assignments,
340 update.where_clause.as_ref(),
341 params,
342 )
343 }
344
345 fn execute_delete(
346 &mut self,
347 delete: &DeleteStmt,
348 params: &[SochValue],
349 ) -> SqlResult<ExecutionResult> {
350 let table_name = delete.table.name();
351
352 self.conn.delete(
353 table_name,
354 delete.where_clause.as_ref(),
355 params,
356 )
357 }
358
359 fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
360 if stmt.if_not_exists {
362 let table_name = stmt.name.name();
363 if self.conn.table_exists(table_name)? {
364 return Ok(ExecutionResult::Ok);
365 }
366 }
367
368 self.conn.create_table(stmt)
369 }
370
371 fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
372 if stmt.if_exists {
374 for name in &stmt.names {
375 if !self.conn.table_exists(name.name())? {
376 return Ok(ExecutionResult::Ok);
377 }
378 }
379 }
380
381 self.conn.drop_table(stmt)
382 }
383
384 fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
385 if stmt.if_not_exists {
387 if self.conn.index_exists(&stmt.name)? {
388 return Ok(ExecutionResult::Ok);
389 }
390 }
391
392 self.conn.create_index(stmt)
393 }
394
395 fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
396 if stmt.if_exists {
398 if !self.conn.index_exists(&stmt.name)? {
399 return Ok(ExecutionResult::Ok);
400 }
401 }
402
403 self.conn.drop_index(stmt)
404 }
405
406 fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
408 let mut columns = Vec::new();
409
410 for item in items {
411 match item {
412 SelectItem::Wildcard => columns.push("*".to_string()),
413 SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
414 SelectItem::Expr { expr, alias } => {
415 let name = alias.clone().unwrap_or_else(|| match expr {
416 Expr::Column(col) => col.column.clone(),
417 Expr::Function(func) => format!("{}()", func.name.name()),
418 _ => "?column?".to_string(),
419 });
420 columns.push(name);
421 }
422 }
423 }
424
425 Ok(columns)
426 }
427
428 fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
430 match expr {
431 Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
432 Some(_) => Err(SqlError::InvalidArgument(
433 "LIMIT/OFFSET must be an integer literal".into(),
434 )),
435 None => Ok(None),
436 }
437 }
438
439 fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
441 let mut visitor = PlaceholderVisitor::new();
442 visitor.visit_statement(stmt);
443 visitor.max_placeholder
444 }
445}
446
447struct PlaceholderVisitor {
449 max_placeholder: u32,
450}
451
452impl PlaceholderVisitor {
453 fn new() -> Self {
454 Self { max_placeholder: 0 }
455 }
456
457 fn visit_statement(&mut self, stmt: &Statement) {
458 match stmt {
459 Statement::Select(s) => self.visit_select(s),
460 Statement::Insert(i) => self.visit_insert(i),
461 Statement::Update(u) => self.visit_update(u),
462 Statement::Delete(d) => self.visit_delete(d),
463 _ => {}
464 }
465 }
466
467 fn visit_select(&mut self, select: &SelectStmt) {
468 for item in &select.columns {
469 if let SelectItem::Expr { expr, .. } = item {
470 self.visit_expr(expr);
471 }
472 }
473 if let Some(where_clause) = &select.where_clause {
474 self.visit_expr(where_clause);
475 }
476 if let Some(having) = &select.having {
477 self.visit_expr(having);
478 }
479 for order in &select.order_by {
480 self.visit_expr(&order.expr);
481 }
482 if let Some(limit) = &select.limit {
483 self.visit_expr(limit);
484 }
485 if let Some(offset) = &select.offset {
486 self.visit_expr(offset);
487 }
488 }
489
490 fn visit_insert(&mut self, insert: &InsertStmt) {
491 if let InsertSource::Values(rows) = &insert.source {
492 for row in rows {
493 for expr in row {
494 self.visit_expr(expr);
495 }
496 }
497 }
498 }
499
500 fn visit_update(&mut self, update: &UpdateStmt) {
501 for assign in &update.assignments {
502 self.visit_expr(&assign.value);
503 }
504 if let Some(where_clause) = &update.where_clause {
505 self.visit_expr(where_clause);
506 }
507 }
508
509 fn visit_delete(&mut self, delete: &DeleteStmt) {
510 if let Some(where_clause) = &delete.where_clause {
511 self.visit_expr(where_clause);
512 }
513 }
514
515 fn visit_expr(&mut self, expr: &Expr) {
516 match expr {
517 Expr::Placeholder(n) => {
518 self.max_placeholder = self.max_placeholder.max(*n);
519 }
520 Expr::BinaryOp { left, right, .. } => {
521 self.visit_expr(left);
522 self.visit_expr(right);
523 }
524 Expr::UnaryOp { expr, .. } => {
525 self.visit_expr(expr);
526 }
527 Expr::Function(func) => {
528 for arg in &func.args {
529 self.visit_expr(arg);
530 }
531 }
532 Expr::Case { operand, conditions, else_result } => {
533 if let Some(op) = operand {
534 self.visit_expr(op);
535 }
536 for (when, then) in conditions {
537 self.visit_expr(when);
538 self.visit_expr(then);
539 }
540 if let Some(else_expr) = else_result {
541 self.visit_expr(else_expr);
542 }
543 }
544 Expr::InList { expr, list, .. } => {
545 self.visit_expr(expr);
546 for item in list {
547 self.visit_expr(item);
548 }
549 }
550 Expr::Between { expr, low, high, .. } => {
551 self.visit_expr(expr);
552 self.visit_expr(low);
553 self.visit_expr(high);
554 }
555 Expr::Cast { expr, .. } => {
556 self.visit_expr(expr);
557 }
558 _ => {}
559 }
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_placeholder_visitor() {
569 let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
570 let mut visitor = PlaceholderVisitor::new();
571 visitor.visit_statement(&stmt);
572 assert_eq!(visitor.max_placeholder, 2);
573 }
574
575 #[test]
576 fn test_question_mark_placeholders() {
577 let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
578 let mut visitor = PlaceholderVisitor::new();
579 visitor.visit_statement(&stmt);
580 assert_eq!(visitor.max_placeholder, 2);
581 }
582
583 #[test]
584 fn test_dialect_detection() {
585 assert_eq!(SqlDialect::detect("SELECT * FROM users"), SqlDialect::Standard);
586 assert_eq!(
587 SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
588 SqlDialect::MySQL
589 );
590 assert_eq!(
591 SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
592 SqlDialect::SQLite
593 );
594 }
595}