1use super::ast::*;
50use super::compatibility::SqlDialect;
51use super::error::{SqlError, SqlResult};
52use super::parser::Parser;
53use std::collections::HashMap;
54use sochdb_core::SochValue;
55
56#[derive(Debug, Clone)]
58pub enum ExecutionResult {
59 Rows {
61 columns: Vec<String>,
62 rows: Vec<HashMap<String, SochValue>>,
63 },
64 RowsAffected(usize),
66 Ok,
68 TransactionOk,
70}
71
72impl ExecutionResult {
73 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
75 match self {
76 ExecutionResult::Rows { rows, .. } => Some(rows),
77 _ => None,
78 }
79 }
80
81 pub fn columns(&self) -> Option<&Vec<String>> {
83 match self {
84 ExecutionResult::Rows { columns, .. } => Some(columns),
85 _ => None,
86 }
87 }
88
89 pub fn rows_affected(&self) -> usize {
91 match self {
92 ExecutionResult::RowsAffected(n) => *n,
93 ExecutionResult::Rows { rows, .. } => rows.len(),
94 _ => 0,
95 }
96 }
97}
98
99pub trait SqlConnection {
104 fn select(
106 &self,
107 table: &str,
108 columns: &[String],
109 where_clause: Option<&Expr>,
110 order_by: &[OrderByItem],
111 limit: Option<usize>,
112 offset: Option<usize>,
113 params: &[SochValue],
114 ) -> SqlResult<ExecutionResult>;
115
116 fn insert(
118 &mut self,
119 table: &str,
120 columns: Option<&[String]>,
121 rows: &[Vec<Expr>],
122 on_conflict: Option<&OnConflict>,
123 params: &[SochValue],
124 ) -> SqlResult<ExecutionResult>;
125
126 fn update(
128 &mut self,
129 table: &str,
130 assignments: &[Assignment],
131 where_clause: Option<&Expr>,
132 params: &[SochValue],
133 ) -> SqlResult<ExecutionResult>;
134
135 fn delete(
137 &mut self,
138 table: &str,
139 where_clause: Option<&Expr>,
140 params: &[SochValue],
141 ) -> SqlResult<ExecutionResult>;
142
143 fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
145
146 fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
148
149 fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
151
152 fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
154
155 fn alter_table(&mut self, stmt: &AlterTableStmt) -> SqlResult<ExecutionResult>;
157
158 fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
160
161 fn commit(&mut self) -> SqlResult<ExecutionResult>;
163
164 fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
166
167 fn table_exists(&self, table: &str) -> SqlResult<bool>;
169
170 fn index_exists(&self, index: &str) -> SqlResult<bool>;
172
173 fn scan_all(
177 &self,
178 table: &str,
179 columns: &[String],
180 ) -> SqlResult<Vec<HashMap<String, SochValue>>>;
181
182 fn eval_join_predicate(
185 &self,
186 expr: &Expr,
187 row: &HashMap<String, SochValue>,
188 params: &[SochValue],
189 ) -> Option<bool>;
190}
191
192#[derive(Debug, Clone)]
194pub struct ScopeDefinition {
195 pub name: String,
197 pub session_duration_secs: Option<u64>,
199 pub signin: Option<Box<Expr>>,
201 pub signup: Option<Box<Expr>>,
203}
204
205#[derive(Debug, Clone)]
207pub struct StoredTablePermissions {
208 pub table: String,
210 pub permissions: Vec<TablePermission>,
212}
213
214pub struct SqlBridge<C: SqlConnection> {
216 conn: C,
217 scope_definitions: HashMap<String, ScopeDefinition>,
219 table_permissions: HashMap<String, StoredTablePermissions>,
221}
222
223impl<C: SqlConnection> SqlBridge<C> {
224 pub fn new(conn: C) -> Self {
226 Self {
227 conn,
228 scope_definitions: HashMap::new(),
229 table_permissions: HashMap::new(),
230 }
231 }
232
233 pub fn get_scope(&self, name: &str) -> Option<&ScopeDefinition> {
235 self.scope_definitions.get(name)
236 }
237
238 pub fn get_table_permissions(&self, table: &str) -> Option<&StoredTablePermissions> {
240 self.table_permissions.get(table)
241 }
242
243 pub fn check_table_permission(&self, table: &str, op: PermissionOp) -> SqlResult<()> {
247 if let Some(perms) = self.table_permissions.get(table) {
248 let rule = perms.permissions.iter().find(|p| p.operation == op);
250 match rule {
251 Some(perm) => {
252 match &perm.condition {
256 Expr::Literal(Literal::Boolean(true)) => Ok(()),
257 Expr::Literal(Literal::Boolean(false)) => {
258 Err(SqlError::PermissionDenied(format!(
259 "{:?} denied on table '{}' by table permission rule",
260 op, table
261 )))
262 }
263 _ => Ok(()),
267 }
268 }
269 None => Err(SqlError::PermissionDenied(format!(
272 "{:?} not permitted on table '{}' (no matching permission rule)",
273 op, table
274 ))),
275 }
276 } else {
277 Ok(())
279 }
280 }
281
282 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
284 self.execute_with_params(sql, &[])
285 }
286
287 pub fn execute_with_params(
289 &mut self,
290 sql: &str,
291 params: &[SochValue],
292 ) -> SqlResult<ExecutionResult> {
293 let _dialect = SqlDialect::detect(sql);
295
296 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
298
299 let max_placeholder = self.find_max_placeholder(&stmt);
301 if max_placeholder as usize > params.len() {
302 return Err(SqlError::InvalidArgument(format!(
303 "Query contains {} placeholders but only {} parameters provided",
304 max_placeholder,
305 params.len()
306 )));
307 }
308
309 self.execute_statement(&stmt, params)
311 }
312
313 pub fn execute_statement(
315 &mut self,
316 stmt: &Statement,
317 params: &[SochValue],
318 ) -> SqlResult<ExecutionResult> {
319 match stmt {
320 Statement::Select(select) => self.execute_select(select, params),
321 Statement::Insert(insert) => self.execute_insert(insert, params),
322 Statement::Update(update) => self.execute_update(update, params),
323 Statement::Delete(delete) => self.execute_delete(delete, params),
324 Statement::CreateTable(create) => self.execute_create_table(create),
325 Statement::DropTable(drop) => self.execute_drop_table(drop),
326 Statement::CreateIndex(create) => self.execute_create_index(create),
327 Statement::DropIndex(drop) => self.execute_drop_index(drop),
328 Statement::AlterTable(alter) => self.execute_alter_table(alter),
329 Statement::Begin(begin) => self.conn.begin(begin),
330 Statement::Commit => self.conn.commit(),
331 Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
332 Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
333 "SAVEPOINT not yet implemented".into(),
334 )),
335 Statement::Release(_name) => Err(SqlError::NotImplemented(
336 "RELEASE SAVEPOINT not yet implemented".into(),
337 )),
338 Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
339 "EXPLAIN not yet implemented".into(),
340 )),
341 Statement::DefineScope(def) => {
342 self.scope_definitions.insert(def.name.clone(), ScopeDefinition {
343 name: def.name.clone(),
344 session_duration_secs: def.session_duration_secs,
345 signin: def.signin.clone(),
346 signup: def.signup.clone(),
347 });
348 Ok(ExecutionResult::Ok)
349 },
350 Statement::DefineTablePermissions(def) => {
351 let table_name = def.table.name().to_string();
352 self.table_permissions.insert(table_name.clone(), StoredTablePermissions {
353 table: table_name,
354 permissions: def.permissions.clone(),
355 });
356 Ok(ExecutionResult::Ok)
357 },
358 Statement::RemoveScope(name) => {
359 self.scope_definitions.remove(name);
360 Ok(ExecutionResult::Ok)
361 },
362 Statement::Relate(_) => Err(SqlError::NotImplemented(
363 "RELATE not yet implemented — graph execution engine required".into(),
364 )),
365 Statement::LiveSelect(_) => Err(SqlError::NotImplemented(
366 "LIVE SELECT not yet implemented — CDC subscription engine required".into(),
367 )),
368 Statement::DefineEvent(_) => Err(SqlError::NotImplemented(
369 "DEFINE EVENT not yet implemented — event trigger engine required".into(),
370 )),
371 }
372 }
373
374 fn execute_select(
375 &self,
376 select: &SelectStmt,
377 params: &[SochValue],
378 ) -> SqlResult<ExecutionResult> {
379 let from = select
381 .from
382 .as_ref()
383 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
384
385 if from.tables.len() != 1 {
386 return Err(SqlError::NotImplemented(
387 "Multi-table queries (comma-separated) not yet supported".into(),
388 ));
389 }
390
391 let table_ref = &from.tables[0];
393 if self.contains_join(table_ref) {
394 return self.execute_join_select(select, table_ref, params);
395 }
396
397 let table_name = match table_ref {
399 TableRef::Table { name, .. } => name.name().to_string(),
400 TableRef::Subquery { .. } => {
401 return Err(SqlError::NotImplemented(
402 "Subqueries not yet supported".into(),
403 ));
404 }
405 TableRef::Function { .. } => {
406 return Err(SqlError::NotImplemented(
407 "Table functions not yet supported".into(),
408 ));
409 }
410 TableRef::Join { .. } => unreachable!("handled above"),
411 };
412
413 self.check_table_permission(&table_name, PermissionOp::Select)?;
415
416 let columns = self.extract_select_columns(&select.columns)?;
418
419 let limit = self.extract_limit(&select.limit)?;
421 let offset = self.extract_limit(&select.offset)?;
422
423 self.conn.select(
424 &table_name,
425 &columns,
426 select.where_clause.as_ref(),
427 &select.order_by,
428 limit,
429 offset,
430 params,
431 )
432 }
433
434 fn contains_join(&self, table_ref: &TableRef) -> bool {
436 matches!(table_ref, TableRef::Join { .. })
437 }
438
439 fn execute_join_select(
444 &self,
445 select: &SelectStmt,
446 table_ref: &TableRef,
447 params: &[SochValue],
448 ) -> SqlResult<ExecutionResult> {
449 let mut rows = self.resolve_table_ref(table_ref, params)?;
453
454 if let Some(ref expr) = select.where_clause {
456 rows.retain(|row| {
457 self.conn
458 .eval_join_predicate(expr, row, params)
459 .unwrap_or(false)
460 });
461 }
462
463 if !select.order_by.is_empty() {
465 rows.sort_by(|a, b| {
466 for item in &select.order_by {
467 let col = Self::extract_order_column(&item.expr);
468 let va = a.get(&col);
469 let vb = b.get(&col);
470 let cmp = Self::compare_optional_values(va, vb);
471 let cmp = if !item.asc { cmp.reverse() } else { cmp };
472 if cmp != std::cmp::Ordering::Equal {
473 return cmp;
474 }
475 }
476 std::cmp::Ordering::Equal
477 });
478 }
479
480 let offset = self.extract_limit(&select.offset)?;
482 if let Some(off) = offset {
483 rows = rows.into_iter().skip(off).collect();
484 }
485
486 let limit = self.extract_limit(&select.limit)?;
488 if let Some(lim) = limit {
489 rows.truncate(lim);
490 }
491
492 let select_columns = self.extract_select_columns(&select.columns)?;
494 let (result_columns, projected_rows) =
495 self.project_join_rows(&select_columns, &rows)?;
496
497 Ok(ExecutionResult::Rows {
498 columns: result_columns,
499 rows: projected_rows,
500 })
501 }
502
503 fn resolve_table_ref(
508 &self,
509 table_ref: &TableRef,
510 params: &[SochValue],
511 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
512 match table_ref {
513 TableRef::Table { name, alias } => {
514 let table_name = name.name().to_string();
515 let prefix = alias.as_deref().unwrap_or(&table_name);
516 let raw_rows = self.conn.scan_all(&table_name, &[])?;
517
518 let mut result = Vec::with_capacity(raw_rows.len());
520 for row in raw_rows {
521 let mut merged = HashMap::new();
522 for (k, v) in &row {
523 merged.insert(format!("{}.{}", prefix, k), v.clone());
524 merged.insert(k.clone(), v.clone());
527 }
528 result.push(merged);
529 }
530 Ok(result)
531 }
532 TableRef::Join {
533 left,
534 join_type,
535 right,
536 condition,
537 } => {
538 let left_rows = self.resolve_table_ref(left, params)?;
539 let right_rows = self.resolve_table_ref(right, params)?;
540 self.execute_join(&left_rows, &right_rows, *join_type, condition.as_ref(), params)
541 }
542 TableRef::Subquery { .. } => Err(SqlError::NotImplemented(
543 "Subqueries in FROM not yet supported".into(),
544 )),
545 TableRef::Function { .. } => Err(SqlError::NotImplemented(
546 "Table functions not yet supported".into(),
547 )),
548 }
549 }
550
551 fn execute_join(
555 &self,
556 left_rows: &[HashMap<String, SochValue>],
557 right_rows: &[HashMap<String, SochValue>],
558 join_type: JoinType,
559 condition: Option<&JoinCondition>,
560 params: &[SochValue],
561 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
562 let (on_expr, using_cols) = match condition {
564 Some(JoinCondition::On(expr)) => (Some(expr), None),
565 Some(JoinCondition::Using(cols)) => (None, Some(cols.as_slice())),
566 Some(JoinCondition::Natural) => {
567 return Err(SqlError::NotImplemented(
568 "NATURAL JOIN not yet supported".into(),
569 ))
570 }
571 None => (None, None), };
573
574 if let Some(expr) = on_expr {
576 if let Some((left_key, right_key)) = Self::extract_equi_join_keys(expr) {
577 return self.hash_join(
578 left_rows, right_rows, &left_key, &right_key, join_type, params,
579 );
580 }
581 }
582
583 let mut result = Vec::new();
585 let null_right: HashMap<String, SochValue> = Self::null_row(right_rows);
586 let null_left: HashMap<String, SochValue> = Self::null_row(left_rows);
587
588 let mut right_matched = vec![false; right_rows.len()];
589
590 for left in left_rows {
591 let mut found_match = false;
592
593 for (ri, right) in right_rows.iter().enumerate() {
594 let merged = Self::merge_rows(left, right);
595 let matches = match (on_expr, using_cols) {
596 (Some(expr), _) => self
597 .conn
598 .eval_join_predicate(expr, &merged, params)
599 .unwrap_or(false),
600 (_, Some(cols)) => Self::using_matches(left, right, cols),
601 (None, None) => true, };
603
604 if matches {
605 result.push(merged);
606 found_match = true;
607 right_matched[ri] = true;
608 }
609 }
610
611 if !found_match
613 && matches!(join_type, JoinType::Left | JoinType::Full)
614 {
615 result.push(Self::merge_rows(left, &null_right));
616 }
617 }
618
619 if matches!(join_type, JoinType::Right | JoinType::Full) {
621 for (ri, right) in right_rows.iter().enumerate() {
622 if !right_matched[ri] {
623 result.push(Self::merge_rows(&null_left, right));
624 }
625 }
626 }
627
628 Ok(result)
629 }
630
631 fn hash_join(
633 &self,
634 left_rows: &[HashMap<String, SochValue>],
635 right_rows: &[HashMap<String, SochValue>],
636 left_key: &str,
637 right_key: &str,
638 join_type: JoinType,
639 _params: &[SochValue],
640 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
641 let mut hash_table: HashMap<String, Vec<usize>> = HashMap::new();
643 for (i, row) in right_rows.iter().enumerate() {
644 if let Some(val) = row.get(right_key) {
645 let key = Self::value_to_hash_key(val);
646 hash_table.entry(key).or_default().push(i);
647 }
648 }
649
650 let null_right = Self::null_row(right_rows);
651 let null_left = Self::null_row(left_rows);
652 let mut right_matched = vec![false; right_rows.len()];
653 let mut result = Vec::new();
654
655 for left in left_rows {
657 let mut found_match = false;
658 if let Some(val) = left.get(left_key) {
659 let key = Self::value_to_hash_key(val);
660 if let Some(indices) = hash_table.get(&key) {
661 for &ri in indices {
662 result.push(Self::merge_rows(left, &right_rows[ri]));
663 found_match = true;
664 right_matched[ri] = true;
665 }
666 }
667 }
668 if !found_match && matches!(join_type, JoinType::Left | JoinType::Full) {
669 result.push(Self::merge_rows(left, &null_right));
670 }
671 }
672
673 if matches!(join_type, JoinType::Right | JoinType::Full) {
674 for (ri, right) in right_rows.iter().enumerate() {
675 if !right_matched[ri] {
676 result.push(Self::merge_rows(&null_left, right));
677 }
678 }
679 }
680
681 Ok(result)
682 }
683
684 fn extract_equi_join_keys(expr: &Expr) -> Option<(String, String)> {
687 if let Expr::BinaryOp { left, op, right } = expr {
688 if *op == BinaryOperator::Eq {
689 if let (Expr::Column(l), Expr::Column(r)) = (left.as_ref(), right.as_ref()) {
690 let lk = if let Some(ref t) = l.table {
691 format!("{}.{}", t, l.column)
692 } else {
693 l.column.clone()
694 };
695 let rk = if let Some(ref t) = r.table {
696 format!("{}.{}", t, r.column)
697 } else {
698 r.column.clone()
699 };
700 return Some((lk, rk));
701 }
702 }
703 }
704 None
705 }
706
707 fn merge_rows(
709 left: &HashMap<String, SochValue>,
710 right: &HashMap<String, SochValue>,
711 ) -> HashMap<String, SochValue> {
712 let mut merged = left.clone();
713 for (k, v) in right {
714 if !merged.contains_key(k) || k.contains('.') {
717 merged.insert(k.clone(), v.clone());
718 }
719 }
720 merged
721 }
722
723 fn null_row(rows: &[HashMap<String, SochValue>]) -> HashMap<String, SochValue> {
725 if let Some(sample) = rows.first() {
726 sample
727 .keys()
728 .map(|k| (k.clone(), SochValue::Null))
729 .collect()
730 } else {
731 HashMap::new()
732 }
733 }
734
735 fn using_matches(
737 left: &HashMap<String, SochValue>,
738 right: &HashMap<String, SochValue>,
739 cols: &[String],
740 ) -> bool {
741 cols.iter().all(|col| {
742 let lv = left.get(col);
743 let rv = right.get(col);
744 match (lv, rv) {
745 (Some(l), Some(r)) => l == r,
746 _ => false,
747 }
748 })
749 }
750
751 fn value_to_hash_key(val: &SochValue) -> String {
753 format!("{:?}", val)
754 }
755
756 fn extract_order_column(expr: &Expr) -> String {
758 match expr {
759 Expr::Column(col) => {
760 if let Some(ref t) = col.table {
761 format!("{}.{}", t, col.column)
762 } else {
763 col.column.clone()
764 }
765 }
766 _ => String::new(),
767 }
768 }
769
770 fn compare_optional_values(
772 a: Option<&SochValue>,
773 b: Option<&SochValue>,
774 ) -> std::cmp::Ordering {
775 match (a, b) {
776 (None, None) => std::cmp::Ordering::Equal,
777 (None, Some(_)) => std::cmp::Ordering::Less,
778 (Some(_), None) => std::cmp::Ordering::Greater,
779 (Some(va), Some(vb)) => Self::compare_values(va, vb),
780 }
781 }
782
783 fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
785 match (a, b) {
786 (SochValue::Int(a), SochValue::Int(b)) => a.cmp(b),
787 (SochValue::Float(a), SochValue::Float(b)) => {
788 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
789 }
790 (SochValue::Text(a), SochValue::Text(b)) => a.cmp(b),
791 (SochValue::Bool(a), SochValue::Bool(b)) => a.cmp(b),
792 (SochValue::Null, SochValue::Null) => std::cmp::Ordering::Equal,
793 (SochValue::Null, _) => std::cmp::Ordering::Less,
794 (_, SochValue::Null) => std::cmp::Ordering::Greater,
795 _ => std::cmp::Ordering::Equal,
796 }
797 }
798
799 fn project_join_rows(
801 &self,
802 select_columns: &[String],
803 rows: &[HashMap<String, SochValue>],
804 ) -> SqlResult<(Vec<String>, Vec<HashMap<String, SochValue>>)> {
805 if select_columns.is_empty() || select_columns.iter().any(|c| c == "*") {
807 let all_cols: Vec<String> = rows
808 .first()
809 .map(|r| {
810 let mut cols: Vec<String> = r
812 .keys()
813 .filter(|k| k.contains('.'))
814 .cloned()
815 .collect();
816 cols.sort();
817 if cols.is_empty() {
818 cols = r.keys().cloned().collect();
820 cols.sort();
821 }
822 cols
823 })
824 .unwrap_or_default();
825
826 let projected: Vec<HashMap<String, SochValue>> = rows
827 .iter()
828 .map(|row| {
829 all_cols
830 .iter()
831 .map(|c| {
832 let short = c.rsplit('.').next().unwrap_or(c);
833 (short.to_string(), row.get(c).cloned().unwrap_or(SochValue::Null))
834 })
835 .collect()
836 })
837 .collect();
838 let short_cols: Vec<String> = all_cols
839 .iter()
840 .map(|c| c.rsplit('.').next().unwrap_or(c).to_string())
841 .collect();
842 return Ok((short_cols, projected));
843 }
844
845 let mut result_rows = Vec::with_capacity(rows.len());
847 for row in rows {
848 let mut projected = HashMap::new();
849 for col in select_columns {
850 let val = row
852 .get(col)
853 .or_else(|| {
854 row.iter()
856 .find(|(k, _)| {
857 k.ends_with(&format!(".{}", col))
858 || k.as_str() == col
859 })
860 .map(|(_, v)| v)
861 })
862 .cloned()
863 .unwrap_or(SochValue::Null);
864 projected.insert(col.clone(), val);
865 }
866 result_rows.push(projected);
867 }
868
869 Ok((select_columns.to_vec(), result_rows))
870 }
871
872 fn execute_insert(
873 &mut self,
874 insert: &InsertStmt,
875 params: &[SochValue],
876 ) -> SqlResult<ExecutionResult> {
877 let table_name = insert.table.name();
878
879 self.check_table_permission(table_name, PermissionOp::Create)?;
881
882 let rows = match &insert.source {
883 InsertSource::Values(values) => values,
884 InsertSource::Query(_) => {
885 return Err(SqlError::NotImplemented(
886 "INSERT ... SELECT not yet supported".into(),
887 ));
888 }
889 InsertSource::Default => {
890 return Err(SqlError::NotImplemented(
891 "INSERT DEFAULT VALUES not yet supported".into(),
892 ));
893 }
894 };
895
896 self.conn.insert(
897 table_name,
898 insert.columns.as_deref(),
899 rows,
900 insert.on_conflict.as_ref(),
901 params,
902 )
903 }
904
905 fn execute_update(
906 &mut self,
907 update: &UpdateStmt,
908 params: &[SochValue],
909 ) -> SqlResult<ExecutionResult> {
910 let table_name = update.table.name();
911
912 self.check_table_permission(table_name, PermissionOp::Update)?;
914
915 self.conn.update(
916 table_name,
917 &update.assignments,
918 update.where_clause.as_ref(),
919 params,
920 )
921 }
922
923 fn execute_delete(
924 &mut self,
925 delete: &DeleteStmt,
926 params: &[SochValue],
927 ) -> SqlResult<ExecutionResult> {
928 let table_name = delete.table.name();
929
930 self.check_table_permission(table_name, PermissionOp::Delete)?;
932
933 self.conn.delete(
934 table_name,
935 delete.where_clause.as_ref(),
936 params,
937 )
938 }
939
940 fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
941 if stmt.if_not_exists {
943 let table_name = stmt.name.name();
944 if self.conn.table_exists(table_name)? {
945 return Ok(ExecutionResult::Ok);
946 }
947 }
948
949 self.conn.create_table(stmt)
950 }
951
952 fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
953 if stmt.if_exists {
955 for name in &stmt.names {
956 if !self.conn.table_exists(name.name())? {
957 return Ok(ExecutionResult::Ok);
958 }
959 }
960 }
961
962 self.conn.drop_table(stmt)
963 }
964
965 fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
966 if stmt.if_not_exists {
968 if self.conn.index_exists(&stmt.name)? {
969 return Ok(ExecutionResult::Ok);
970 }
971 }
972
973 self.conn.create_index(stmt)
974 }
975
976 fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
977 if stmt.if_exists {
979 if !self.conn.index_exists(&stmt.name)? {
980 return Ok(ExecutionResult::Ok);
981 }
982 }
983
984 self.conn.drop_index(stmt)
985 }
986
987 fn execute_alter_table(&mut self, stmt: &AlterTableStmt) -> SqlResult<ExecutionResult> {
988 self.conn.alter_table(stmt)
989 }
990
991 fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
993 let mut columns = Vec::new();
994
995 for item in items {
996 match item {
997 SelectItem::Wildcard => columns.push("*".to_string()),
998 SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
999 SelectItem::Expr { expr, alias } => {
1000 let name = alias.clone().unwrap_or_else(|| match expr {
1001 Expr::Column(col) => col.column.clone(),
1002 Expr::Function(func) => format!("{}()", func.name.name()),
1003 _ => "?column?".to_string(),
1004 });
1005 columns.push(name);
1006 }
1007 }
1008 }
1009
1010 Ok(columns)
1011 }
1012
1013 fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
1015 match expr {
1016 Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
1017 Some(_) => Err(SqlError::InvalidArgument(
1018 "LIMIT/OFFSET must be an integer literal".into(),
1019 )),
1020 None => Ok(None),
1021 }
1022 }
1023
1024 fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
1026 let mut visitor = PlaceholderVisitor::new();
1027 visitor.visit_statement(stmt);
1028 visitor.max_placeholder
1029 }
1030}
1031
1032struct PlaceholderVisitor {
1034 max_placeholder: u32,
1035}
1036
1037impl PlaceholderVisitor {
1038 fn new() -> Self {
1039 Self { max_placeholder: 0 }
1040 }
1041
1042 fn visit_statement(&mut self, stmt: &Statement) {
1043 match stmt {
1044 Statement::Select(s) => self.visit_select(s),
1045 Statement::Insert(i) => self.visit_insert(i),
1046 Statement::Update(u) => self.visit_update(u),
1047 Statement::Delete(d) => self.visit_delete(d),
1048 _ => {}
1049 }
1050 }
1051
1052 fn visit_select(&mut self, select: &SelectStmt) {
1053 for item in &select.columns {
1054 if let SelectItem::Expr { expr, .. } = item {
1055 self.visit_expr(expr);
1056 }
1057 }
1058 if let Some(where_clause) = &select.where_clause {
1059 self.visit_expr(where_clause);
1060 }
1061 if let Some(having) = &select.having {
1062 self.visit_expr(having);
1063 }
1064 for order in &select.order_by {
1065 self.visit_expr(&order.expr);
1066 }
1067 if let Some(limit) = &select.limit {
1068 self.visit_expr(limit);
1069 }
1070 if let Some(offset) = &select.offset {
1071 self.visit_expr(offset);
1072 }
1073 }
1074
1075 fn visit_insert(&mut self, insert: &InsertStmt) {
1076 if let InsertSource::Values(rows) = &insert.source {
1077 for row in rows {
1078 for expr in row {
1079 self.visit_expr(expr);
1080 }
1081 }
1082 }
1083 }
1084
1085 fn visit_update(&mut self, update: &UpdateStmt) {
1086 for assign in &update.assignments {
1087 self.visit_expr(&assign.value);
1088 }
1089 if let Some(where_clause) = &update.where_clause {
1090 self.visit_expr(where_clause);
1091 }
1092 }
1093
1094 fn visit_delete(&mut self, delete: &DeleteStmt) {
1095 if let Some(where_clause) = &delete.where_clause {
1096 self.visit_expr(where_clause);
1097 }
1098 }
1099
1100 fn visit_expr(&mut self, expr: &Expr) {
1101 match expr {
1102 Expr::Placeholder(n) => {
1103 self.max_placeholder = self.max_placeholder.max(*n);
1104 }
1105 Expr::BinaryOp { left, right, .. } => {
1106 self.visit_expr(left);
1107 self.visit_expr(right);
1108 }
1109 Expr::UnaryOp { expr, .. } => {
1110 self.visit_expr(expr);
1111 }
1112 Expr::Function(func) => {
1113 for arg in &func.args {
1114 self.visit_expr(arg);
1115 }
1116 }
1117 Expr::Case { operand, conditions, else_result } => {
1118 if let Some(op) = operand {
1119 self.visit_expr(op);
1120 }
1121 for (when, then) in conditions {
1122 self.visit_expr(when);
1123 self.visit_expr(then);
1124 }
1125 if let Some(else_expr) = else_result {
1126 self.visit_expr(else_expr);
1127 }
1128 }
1129 Expr::InList { expr, list, .. } => {
1130 self.visit_expr(expr);
1131 for item in list {
1132 self.visit_expr(item);
1133 }
1134 }
1135 Expr::Between { expr, low, high, .. } => {
1136 self.visit_expr(expr);
1137 self.visit_expr(low);
1138 self.visit_expr(high);
1139 }
1140 Expr::Cast { expr, .. } => {
1141 self.visit_expr(expr);
1142 }
1143 _ => {}
1144 }
1145 }
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150 use super::*;
1151
1152 #[test]
1153 fn test_placeholder_visitor() {
1154 let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
1155 let mut visitor = PlaceholderVisitor::new();
1156 visitor.visit_statement(&stmt);
1157 assert_eq!(visitor.max_placeholder, 2);
1158 }
1159
1160 #[test]
1161 fn test_question_mark_placeholders() {
1162 let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
1163 let mut visitor = PlaceholderVisitor::new();
1164 visitor.visit_statement(&stmt);
1165 assert_eq!(visitor.max_placeholder, 2);
1166 }
1167
1168 #[test]
1169 fn test_dialect_detection() {
1170 assert_eq!(SqlDialect::detect("SELECT * FROM users"), SqlDialect::Standard);
1171 assert_eq!(
1172 SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
1173 SqlDialect::MySQL
1174 );
1175 assert_eq!(
1176 SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
1177 SqlDialect::SQLite
1178 );
1179 }
1180
1181 #[test]
1182 fn test_define_scope_stores_definition() {
1183 use crate::sql::bridge::tests::make_mock_bridge;
1184 let mut bridge = make_mock_bridge();
1185 let result = bridge.execute("DEFINE SCOPE user_scope SESSION 86400");
1187 result.unwrap();
1188 let scope = bridge.get_scope("user_scope");
1189 assert!(scope.is_some(), "Scope not stored");
1190 let scope = scope.unwrap();
1191 assert_eq!(scope.name, "user_scope");
1192 assert_eq!(scope.session_duration_secs, Some(86400));
1193 }
1194
1195 #[test]
1196 fn test_remove_scope_deletes_definition() {
1197 use crate::sql::bridge::tests::make_mock_bridge;
1198 let mut bridge = make_mock_bridge();
1199 bridge.execute("DEFINE SCOPE temp_scope SESSION 3600").unwrap();
1200 assert!(bridge.get_scope("temp_scope").is_some());
1201 bridge.execute("REMOVE SCOPE temp_scope").unwrap();
1202 assert!(bridge.get_scope("temp_scope").is_none());
1203 }
1204
1205 #[test]
1206 fn test_define_table_permissions_stores_rules() {
1207 use crate::sql::bridge::tests::make_mock_bridge;
1208 let mut bridge = make_mock_bridge();
1209 let result = bridge.execute(
1210 "DEFINE TABLE posts PERMISSIONS FOR select WHERE true FOR delete WHERE false"
1211 );
1212 assert!(result.is_ok());
1213 let perms = bridge.get_table_permissions("posts");
1214 assert!(perms.is_some());
1215 assert_eq!(perms.unwrap().permissions.len(), 2);
1216 }
1217
1218 #[test]
1219 fn test_table_permission_check_allows_matching_true() {
1220 use crate::sql::bridge::tests::make_mock_bridge;
1221 let mut bridge = make_mock_bridge();
1222 bridge.execute(
1223 "DEFINE TABLE docs PERMISSIONS FOR select WHERE true FOR insert WHERE true FOR update WHERE true FOR delete WHERE true"
1224 ).unwrap();
1225 assert!(bridge.check_table_permission("docs", PermissionOp::Select).is_ok());
1226 assert!(bridge.check_table_permission("docs", PermissionOp::Create).is_ok());
1227 assert!(bridge.check_table_permission("docs", PermissionOp::Update).is_ok());
1228 assert!(bridge.check_table_permission("docs", PermissionOp::Delete).is_ok());
1229 }
1230
1231 #[test]
1232 fn test_table_permission_check_denies_matching_false() {
1233 use crate::sql::bridge::tests::make_mock_bridge;
1234 let mut bridge = make_mock_bridge();
1235 bridge.execute(
1236 "DEFINE TABLE secrets PERMISSIONS FOR select WHERE false FOR delete WHERE false"
1237 ).unwrap();
1238 let err = bridge.check_table_permission("secrets", PermissionOp::Select);
1239 assert!(err.is_err());
1240 assert!(format!("{}", err.unwrap_err()).contains("Permission denied"));
1241 }
1242
1243 #[test]
1244 fn test_table_permission_denies_undefined_op_when_rules_exist() {
1245 use crate::sql::bridge::tests::make_mock_bridge;
1246 let mut bridge = make_mock_bridge();
1247 bridge.execute(
1249 "DEFINE TABLE restricted PERMISSIONS FOR select WHERE true"
1250 ).unwrap();
1251 assert!(bridge.check_table_permission("restricted", PermissionOp::Select).is_ok());
1252 let err = bridge.check_table_permission("restricted", PermissionOp::Update);
1253 assert!(err.is_err());
1254 }
1255
1256 #[test]
1257 fn test_no_permissions_allows_all() {
1258 use crate::sql::bridge::tests::make_mock_bridge;
1259 let bridge = make_mock_bridge();
1260 assert!(bridge.check_table_permission("any_table", PermissionOp::Select).is_ok());
1262 assert!(bridge.check_table_permission("any_table", PermissionOp::Delete).is_ok());
1263 }
1264
1265 fn make_mock_bridge() -> SqlBridge<MockPermConn> {
1267 SqlBridge::new(MockPermConn)
1268 }
1269
1270 struct MockPermConn;
1272
1273 impl SqlConnection for MockPermConn {
1274 fn select(&self, _: &str, _: &[String], _: Option<&Expr>, _: &[OrderByItem], _: Option<usize>, _: Option<usize>, _: &[SochValue]) -> SqlResult<ExecutionResult> {
1275 Ok(ExecutionResult::Rows { columns: vec![], rows: vec![] })
1276 }
1277 fn insert(&mut self, _: &str, _: Option<&[String]>, _: &[Vec<Expr>], _: Option<&OnConflict>, _: &[SochValue]) -> SqlResult<ExecutionResult> {
1278 Ok(ExecutionResult::RowsAffected(0))
1279 }
1280 fn update(&mut self, _: &str, _: &[Assignment], _: Option<&Expr>, _: &[SochValue]) -> SqlResult<ExecutionResult> {
1281 Ok(ExecutionResult::RowsAffected(0))
1282 }
1283 fn delete(&mut self, _: &str, _: Option<&Expr>, _: &[SochValue]) -> SqlResult<ExecutionResult> {
1284 Ok(ExecutionResult::RowsAffected(0))
1285 }
1286 fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::Ok) }
1287 fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::Ok) }
1288 fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::Ok) }
1289 fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::Ok) }
1290 fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::Ok) }
1291 fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::TransactionOk) }
1292 fn commit(&mut self) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::TransactionOk) }
1293 fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> { Ok(ExecutionResult::TransactionOk) }
1294 fn table_exists(&self, _: &str) -> SqlResult<bool> { Ok(true) }
1295 fn index_exists(&self, _: &str) -> SqlResult<bool> { Ok(false) }
1296 fn scan_all(&self, _: &str, _: &[String]) -> SqlResult<Vec<HashMap<String, SochValue>>> { Ok(vec![]) }
1297 fn eval_join_predicate(&self, _: &Expr, _: &HashMap<String, SochValue>, _: &[SochValue]) -> Option<bool> { Some(true) }
1298 }
1299}