1use super::ast::*;
50use super::compatibility::SqlDialect;
51use super::error::{SqlError, SqlResult};
52use super::parser::Parser;
53use sochdb_core::SochValue;
54use std::collections::HashMap;
55
56#[derive(Debug, Clone)]
58pub enum ExecutionResult {
59 Rows {
61 columns: Vec<String>,
62 rows: Vec<HashMap<String, SochValue>>,
63 },
64 RowsAffected(usize),
66 Ok,
68 TransactionOk,
70}
71
72impl ExecutionResult {
73 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
75 match self {
76 ExecutionResult::Rows { rows, .. } => Some(rows),
77 _ => None,
78 }
79 }
80
81 pub fn columns(&self) -> Option<&Vec<String>> {
83 match self {
84 ExecutionResult::Rows { columns, .. } => Some(columns),
85 _ => None,
86 }
87 }
88
89 pub fn rows_affected(&self) -> usize {
91 match self {
92 ExecutionResult::RowsAffected(n) => *n,
93 ExecutionResult::Rows { rows, .. } => rows.len(),
94 _ => 0,
95 }
96 }
97}
98
99pub trait SqlConnection {
104 fn select(
106 &self,
107 table: &str,
108 columns: &[String],
109 where_clause: Option<&Expr>,
110 order_by: &[OrderByItem],
111 limit: Option<usize>,
112 offset: Option<usize>,
113 params: &[SochValue],
114 ) -> SqlResult<ExecutionResult>;
115
116 fn insert(
118 &mut self,
119 table: &str,
120 columns: Option<&[String]>,
121 rows: &[Vec<Expr>],
122 on_conflict: Option<&OnConflict>,
123 params: &[SochValue],
124 ) -> SqlResult<ExecutionResult>;
125
126 fn update(
128 &mut self,
129 table: &str,
130 assignments: &[Assignment],
131 where_clause: Option<&Expr>,
132 params: &[SochValue],
133 ) -> SqlResult<ExecutionResult>;
134
135 fn delete(
137 &mut self,
138 table: &str,
139 where_clause: Option<&Expr>,
140 params: &[SochValue],
141 ) -> SqlResult<ExecutionResult>;
142
143 fn create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult>;
145
146 fn drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult>;
148
149 fn create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult>;
151
152 fn drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult>;
154
155 fn alter_table(&mut self, stmt: &AlterTableStmt) -> SqlResult<ExecutionResult>;
157
158 fn begin(&mut self, stmt: &BeginStmt) -> SqlResult<ExecutionResult>;
160
161 fn commit(&mut self) -> SqlResult<ExecutionResult>;
163
164 fn rollback(&mut self, savepoint: Option<&str>) -> SqlResult<ExecutionResult>;
166
167 fn table_exists(&self, table: &str) -> SqlResult<bool>;
169
170 fn index_exists(&self, index: &str) -> SqlResult<bool>;
172
173 fn scan_all(
177 &self,
178 table: &str,
179 columns: &[String],
180 ) -> SqlResult<Vec<HashMap<String, SochValue>>>;
181
182 fn eval_join_predicate(
185 &self,
186 expr: &Expr,
187 row: &HashMap<String, SochValue>,
188 params: &[SochValue],
189 ) -> Option<bool>;
190}
191
192#[derive(Debug, Clone)]
194pub struct ScopeDefinition {
195 pub name: String,
197 pub session_duration_secs: Option<u64>,
199 pub signin: Option<Box<Expr>>,
201 pub signup: Option<Box<Expr>>,
203}
204
205#[derive(Debug, Clone)]
207pub struct StoredTablePermissions {
208 pub table: String,
210 pub permissions: Vec<TablePermission>,
212}
213
214pub struct SqlBridge<C: SqlConnection> {
216 conn: C,
217 scope_definitions: HashMap<String, ScopeDefinition>,
219 table_permissions: HashMap<String, StoredTablePermissions>,
221}
222
223impl<C: SqlConnection> SqlBridge<C> {
224 pub fn new(conn: C) -> Self {
226 Self {
227 conn,
228 scope_definitions: HashMap::new(),
229 table_permissions: HashMap::new(),
230 }
231 }
232
233 pub fn get_scope(&self, name: &str) -> Option<&ScopeDefinition> {
235 self.scope_definitions.get(name)
236 }
237
238 pub fn get_table_permissions(&self, table: &str) -> Option<&StoredTablePermissions> {
240 self.table_permissions.get(table)
241 }
242
243 pub fn check_table_permission(&self, table: &str, op: PermissionOp) -> SqlResult<()> {
247 if let Some(perms) = self.table_permissions.get(table) {
248 let rule = perms.permissions.iter().find(|p| p.operation == op);
250 match rule {
251 Some(perm) => {
252 match &perm.condition {
256 Expr::Literal(Literal::Boolean(true)) => Ok(()),
257 Expr::Literal(Literal::Boolean(false)) => {
258 Err(SqlError::PermissionDenied(format!(
259 "{:?} denied on table '{}' by table permission rule",
260 op, table
261 )))
262 }
263 _ => Ok(()),
267 }
268 }
269 None => Err(SqlError::PermissionDenied(format!(
272 "{:?} not permitted on table '{}' (no matching permission rule)",
273 op, table
274 ))),
275 }
276 } else {
277 Ok(())
279 }
280 }
281
282 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
284 self.execute_with_params(sql, &[])
285 }
286
287 pub fn execute_with_params(
289 &mut self,
290 sql: &str,
291 params: &[SochValue],
292 ) -> SqlResult<ExecutionResult> {
293 let _dialect = SqlDialect::detect(sql);
295
296 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
298
299 let max_placeholder = self.find_max_placeholder(&stmt);
301 if max_placeholder as usize > params.len() {
302 return Err(SqlError::InvalidArgument(format!(
303 "Query contains {} placeholders but only {} parameters provided",
304 max_placeholder,
305 params.len()
306 )));
307 }
308
309 self.execute_statement(&stmt, params)
311 }
312
313 pub fn execute_statement(
315 &mut self,
316 stmt: &Statement,
317 params: &[SochValue],
318 ) -> SqlResult<ExecutionResult> {
319 match stmt {
320 Statement::Select(select) => self.execute_select(select, params),
321 Statement::Insert(insert) => self.execute_insert(insert, params),
322 Statement::Update(update) => self.execute_update(update, params),
323 Statement::Delete(delete) => self.execute_delete(delete, params),
324 Statement::CreateTable(create) => self.execute_create_table(create),
325 Statement::DropTable(drop) => self.execute_drop_table(drop),
326 Statement::CreateIndex(create) => self.execute_create_index(create),
327 Statement::DropIndex(drop) => self.execute_drop_index(drop),
328 Statement::AlterTable(alter) => self.execute_alter_table(alter),
329 Statement::Begin(begin) => self.conn.begin(begin),
330 Statement::Commit => self.conn.commit(),
331 Statement::Rollback(savepoint) => self.conn.rollback(savepoint.as_deref()),
332 Statement::Savepoint(_name) => Err(SqlError::NotImplemented(
333 "SAVEPOINT not yet implemented".into(),
334 )),
335 Statement::Release(_name) => Err(SqlError::NotImplemented(
336 "RELEASE SAVEPOINT not yet implemented".into(),
337 )),
338 Statement::Explain(_stmt) => Err(SqlError::NotImplemented(
339 "EXPLAIN not yet implemented".into(),
340 )),
341 Statement::DefineScope(def) => {
342 self.scope_definitions.insert(
343 def.name.clone(),
344 ScopeDefinition {
345 name: def.name.clone(),
346 session_duration_secs: def.session_duration_secs,
347 signin: def.signin.clone(),
348 signup: def.signup.clone(),
349 },
350 );
351 Ok(ExecutionResult::Ok)
352 }
353 Statement::DefineTablePermissions(def) => {
354 let table_name = def.table.name().to_string();
355 self.table_permissions.insert(
356 table_name.clone(),
357 StoredTablePermissions {
358 table: table_name,
359 permissions: def.permissions.clone(),
360 },
361 );
362 Ok(ExecutionResult::Ok)
363 }
364 Statement::RemoveScope(name) => {
365 self.scope_definitions.remove(name);
366 Ok(ExecutionResult::Ok)
367 }
368 Statement::Relate(_) => Err(SqlError::NotImplemented(
369 "RELATE not yet implemented — graph execution engine required".into(),
370 )),
371 Statement::LiveSelect(_) => Err(SqlError::NotImplemented(
372 "LIVE SELECT not yet implemented — CDC subscription engine required".into(),
373 )),
374 Statement::DefineEvent(_) => Err(SqlError::NotImplemented(
375 "DEFINE EVENT not yet implemented — event trigger engine required".into(),
376 )),
377 }
378 }
379
380 fn execute_select(
381 &self,
382 select: &SelectStmt,
383 params: &[SochValue],
384 ) -> SqlResult<ExecutionResult> {
385 let from = select
387 .from
388 .as_ref()
389 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
390
391 if from.tables.len() != 1 {
392 return Err(SqlError::NotImplemented(
393 "Multi-table queries (comma-separated) not yet supported".into(),
394 ));
395 }
396
397 let table_ref = &from.tables[0];
399 if self.contains_join(table_ref) {
400 return self.execute_join_select(select, table_ref, params);
401 }
402
403 let table_name = match table_ref {
405 TableRef::Table { name, .. } => name.name().to_string(),
406 TableRef::Subquery { .. } => {
407 return Err(SqlError::NotImplemented(
408 "Subqueries not yet supported".into(),
409 ));
410 }
411 TableRef::Function { .. } => {
412 return Err(SqlError::NotImplemented(
413 "Table functions not yet supported".into(),
414 ));
415 }
416 TableRef::Join { .. } => unreachable!("handled above"),
417 };
418
419 self.check_table_permission(&table_name, PermissionOp::Select)?;
421
422 let limit = self.extract_limit(&select.limit)?;
424 let offset = self.extract_limit(&select.offset)?;
425
426 if super::aggregate::is_aggregate_query(select) {
430 let input = self.conn.select(
431 &table_name,
432 &[],
433 select.where_clause.as_ref(),
434 &[],
435 None,
436 None,
437 params,
438 )?;
439 let rows = match input {
440 ExecutionResult::Rows { rows, .. } => rows,
441 _ => Vec::new(),
442 };
443 return super::aggregate::execute_aggregate(select, &rows, params, limit, offset);
444 }
445
446 let columns = self.extract_select_columns(&select.columns)?;
448
449 self.conn.select(
450 &table_name,
451 &columns,
452 select.where_clause.as_ref(),
453 &select.order_by,
454 limit,
455 offset,
456 params,
457 )
458 }
459
460 fn contains_join(&self, table_ref: &TableRef) -> bool {
462 matches!(table_ref, TableRef::Join { .. })
463 }
464
465 fn execute_join_select(
470 &self,
471 select: &SelectStmt,
472 table_ref: &TableRef,
473 params: &[SochValue],
474 ) -> SqlResult<ExecutionResult> {
475 let mut rows = self.resolve_table_ref(table_ref, params)?;
479
480 if let Some(ref expr) = select.where_clause {
482 rows.retain(|row| {
483 self.conn
484 .eval_join_predicate(expr, row, params)
485 .unwrap_or(false)
486 });
487 }
488
489 if super::aggregate::is_aggregate_query(select) {
491 let limit = self.extract_limit(&select.limit)?;
492 let offset = self.extract_limit(&select.offset)?;
493 return super::aggregate::execute_aggregate(select, &rows, params, limit, offset);
494 }
495
496 if !select.order_by.is_empty() {
498 rows.sort_by(|a, b| {
499 for item in &select.order_by {
500 let col = Self::extract_order_column(&item.expr);
501 let va = a.get(&col);
502 let vb = b.get(&col);
503 let cmp = Self::compare_optional_values(va, vb);
504 let cmp = if !item.asc { cmp.reverse() } else { cmp };
505 if cmp != std::cmp::Ordering::Equal {
506 return cmp;
507 }
508 }
509 std::cmp::Ordering::Equal
510 });
511 }
512
513 let offset = self.extract_limit(&select.offset)?;
515 if let Some(off) = offset {
516 rows = rows.into_iter().skip(off).collect();
517 }
518
519 let limit = self.extract_limit(&select.limit)?;
521 if let Some(lim) = limit {
522 rows.truncate(lim);
523 }
524
525 let select_columns = self.extract_select_columns(&select.columns)?;
527 let (result_columns, projected_rows) = self.project_join_rows(&select_columns, &rows)?;
528
529 Ok(ExecutionResult::Rows {
530 columns: result_columns,
531 rows: projected_rows,
532 })
533 }
534
535 fn resolve_table_ref(
540 &self,
541 table_ref: &TableRef,
542 params: &[SochValue],
543 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
544 match table_ref {
545 TableRef::Table { name, alias } => {
546 let table_name = name.name().to_string();
547 let prefix = alias.as_deref().unwrap_or(&table_name);
548 let raw_rows = self.conn.scan_all(&table_name, &[])?;
549
550 let mut result = Vec::with_capacity(raw_rows.len());
552 for row in raw_rows {
553 let mut merged = HashMap::new();
554 for (k, v) in &row {
555 merged.insert(format!("{}.{}", prefix, k), v.clone());
556 merged.insert(k.clone(), v.clone());
559 }
560 result.push(merged);
561 }
562 Ok(result)
563 }
564 TableRef::Join {
565 left,
566 join_type,
567 right,
568 condition,
569 } => {
570 let left_rows = self.resolve_table_ref(left, params)?;
571 let right_rows = self.resolve_table_ref(right, params)?;
572 self.execute_join(
573 &left_rows,
574 &right_rows,
575 *join_type,
576 condition.as_ref(),
577 params,
578 )
579 }
580 TableRef::Subquery { .. } => Err(SqlError::NotImplemented(
581 "Subqueries in FROM not yet supported".into(),
582 )),
583 TableRef::Function { .. } => Err(SqlError::NotImplemented(
584 "Table functions not yet supported".into(),
585 )),
586 }
587 }
588
589 fn execute_join(
593 &self,
594 left_rows: &[HashMap<String, SochValue>],
595 right_rows: &[HashMap<String, SochValue>],
596 join_type: JoinType,
597 condition: Option<&JoinCondition>,
598 params: &[SochValue],
599 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
600 let (on_expr, using_cols) = match condition {
602 Some(JoinCondition::On(expr)) => (Some(expr), None),
603 Some(JoinCondition::Using(cols)) => (None, Some(cols.as_slice())),
604 Some(JoinCondition::Natural) => {
605 return Err(SqlError::NotImplemented(
606 "NATURAL JOIN not yet supported".into(),
607 ));
608 }
609 None => (None, None), };
611
612 if let Some(expr) = on_expr {
614 if let Some((left_key, right_key)) = Self::extract_equi_join_keys(expr) {
615 return self.hash_join(
616 left_rows, right_rows, &left_key, &right_key, join_type, params,
617 );
618 }
619 }
620
621 let mut result = Vec::new();
623 let null_right: HashMap<String, SochValue> = Self::null_row(right_rows);
624 let null_left: HashMap<String, SochValue> = Self::null_row(left_rows);
625
626 let mut right_matched = vec![false; right_rows.len()];
627
628 for left in left_rows {
629 let mut found_match = false;
630
631 for (ri, right) in right_rows.iter().enumerate() {
632 let merged = Self::merge_rows(left, right);
633 let matches = match (on_expr, using_cols) {
634 (Some(expr), _) => self
635 .conn
636 .eval_join_predicate(expr, &merged, params)
637 .unwrap_or(false),
638 (_, Some(cols)) => Self::using_matches(left, right, cols),
639 (None, None) => true, };
641
642 if matches {
643 result.push(merged);
644 found_match = true;
645 right_matched[ri] = true;
646 }
647 }
648
649 if !found_match && matches!(join_type, JoinType::Left | JoinType::Full) {
651 result.push(Self::merge_rows(left, &null_right));
652 }
653 }
654
655 if matches!(join_type, JoinType::Right | JoinType::Full) {
657 for (ri, right) in right_rows.iter().enumerate() {
658 if !right_matched[ri] {
659 result.push(Self::merge_rows(&null_left, right));
660 }
661 }
662 }
663
664 Ok(result)
665 }
666
667 fn hash_join(
669 &self,
670 left_rows: &[HashMap<String, SochValue>],
671 right_rows: &[HashMap<String, SochValue>],
672 left_key: &str,
673 right_key: &str,
674 join_type: JoinType,
675 _params: &[SochValue],
676 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
677 let mut hash_table: HashMap<String, Vec<usize>> = HashMap::new();
679 for (i, row) in right_rows.iter().enumerate() {
680 if let Some(val) = row.get(right_key) {
681 let key = Self::value_to_hash_key(val);
682 hash_table.entry(key).or_default().push(i);
683 }
684 }
685
686 let null_right = Self::null_row(right_rows);
687 let null_left = Self::null_row(left_rows);
688 let mut right_matched = vec![false; right_rows.len()];
689 let mut result = Vec::new();
690
691 for left in left_rows {
693 let mut found_match = false;
694 if let Some(val) = left.get(left_key) {
695 let key = Self::value_to_hash_key(val);
696 if let Some(indices) = hash_table.get(&key) {
697 for &ri in indices {
698 result.push(Self::merge_rows(left, &right_rows[ri]));
699 found_match = true;
700 right_matched[ri] = true;
701 }
702 }
703 }
704 if !found_match && matches!(join_type, JoinType::Left | JoinType::Full) {
705 result.push(Self::merge_rows(left, &null_right));
706 }
707 }
708
709 if matches!(join_type, JoinType::Right | JoinType::Full) {
710 for (ri, right) in right_rows.iter().enumerate() {
711 if !right_matched[ri] {
712 result.push(Self::merge_rows(&null_left, right));
713 }
714 }
715 }
716
717 Ok(result)
718 }
719
720 fn extract_equi_join_keys(expr: &Expr) -> Option<(String, String)> {
723 if let Expr::BinaryOp { left, op, right } = expr {
724 if *op == BinaryOperator::Eq {
725 if let (Expr::Column(l), Expr::Column(r)) = (left.as_ref(), right.as_ref()) {
726 let lk = if let Some(ref t) = l.table {
727 format!("{}.{}", t, l.column)
728 } else {
729 l.column.clone()
730 };
731 let rk = if let Some(ref t) = r.table {
732 format!("{}.{}", t, r.column)
733 } else {
734 r.column.clone()
735 };
736 return Some((lk, rk));
737 }
738 }
739 }
740 None
741 }
742
743 fn merge_rows(
745 left: &HashMap<String, SochValue>,
746 right: &HashMap<String, SochValue>,
747 ) -> HashMap<String, SochValue> {
748 let mut merged = left.clone();
749 for (k, v) in right {
750 if !merged.contains_key(k) || k.contains('.') {
753 merged.insert(k.clone(), v.clone());
754 }
755 }
756 merged
757 }
758
759 fn null_row(rows: &[HashMap<String, SochValue>]) -> HashMap<String, SochValue> {
761 if let Some(sample) = rows.first() {
762 sample
763 .keys()
764 .map(|k| (k.clone(), SochValue::Null))
765 .collect()
766 } else {
767 HashMap::new()
768 }
769 }
770
771 fn using_matches(
773 left: &HashMap<String, SochValue>,
774 right: &HashMap<String, SochValue>,
775 cols: &[String],
776 ) -> bool {
777 cols.iter().all(|col| {
778 let lv = left.get(col);
779 let rv = right.get(col);
780 match (lv, rv) {
781 (Some(l), Some(r)) => l == r,
782 _ => false,
783 }
784 })
785 }
786
787 fn value_to_hash_key(val: &SochValue) -> String {
789 format!("{:?}", val)
790 }
791
792 fn extract_order_column(expr: &Expr) -> String {
794 match expr {
795 Expr::Column(col) => {
796 if let Some(ref t) = col.table {
797 format!("{}.{}", t, col.column)
798 } else {
799 col.column.clone()
800 }
801 }
802 _ => String::new(),
803 }
804 }
805
806 fn compare_optional_values(a: Option<&SochValue>, b: Option<&SochValue>) -> std::cmp::Ordering {
808 match (a, b) {
809 (None, None) => std::cmp::Ordering::Equal,
810 (None, Some(_)) => std::cmp::Ordering::Less,
811 (Some(_), None) => std::cmp::Ordering::Greater,
812 (Some(va), Some(vb)) => Self::compare_values(va, vb),
813 }
814 }
815
816 fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
818 match (a, b) {
819 (SochValue::Int(a), SochValue::Int(b)) => a.cmp(b),
820 (SochValue::Float(a), SochValue::Float(b)) => {
821 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
822 }
823 (SochValue::Text(a), SochValue::Text(b)) => a.cmp(b),
824 (SochValue::Bool(a), SochValue::Bool(b)) => a.cmp(b),
825 (SochValue::Null, SochValue::Null) => std::cmp::Ordering::Equal,
826 (SochValue::Null, _) => std::cmp::Ordering::Less,
827 (_, SochValue::Null) => std::cmp::Ordering::Greater,
828 _ => std::cmp::Ordering::Equal,
829 }
830 }
831
832 fn project_join_rows(
834 &self,
835 select_columns: &[String],
836 rows: &[HashMap<String, SochValue>],
837 ) -> SqlResult<(Vec<String>, Vec<HashMap<String, SochValue>>)> {
838 if select_columns.is_empty() || select_columns.iter().any(|c| c == "*") {
840 let all_cols: Vec<String> = rows
841 .first()
842 .map(|r| {
843 let mut cols: Vec<String> =
845 r.keys().filter(|k| k.contains('.')).cloned().collect();
846 cols.sort();
847 if cols.is_empty() {
848 cols = r.keys().cloned().collect();
850 cols.sort();
851 }
852 cols
853 })
854 .unwrap_or_default();
855
856 let projected: Vec<HashMap<String, SochValue>> = rows
857 .iter()
858 .map(|row| {
859 all_cols
860 .iter()
861 .map(|c| {
862 let short = c.rsplit('.').next().unwrap_or(c);
863 (
864 short.to_string(),
865 row.get(c).cloned().unwrap_or(SochValue::Null),
866 )
867 })
868 .collect()
869 })
870 .collect();
871 let short_cols: Vec<String> = all_cols
872 .iter()
873 .map(|c| c.rsplit('.').next().unwrap_or(c).to_string())
874 .collect();
875 return Ok((short_cols, projected));
876 }
877
878 let mut result_rows = Vec::with_capacity(rows.len());
880 for row in rows {
881 let mut projected = HashMap::new();
882 for col in select_columns {
883 let val = row
885 .get(col)
886 .or_else(|| {
887 row.iter()
889 .find(|(k, _)| k.ends_with(&format!(".{}", col)) || k.as_str() == col)
890 .map(|(_, v)| v)
891 })
892 .cloned()
893 .unwrap_or(SochValue::Null);
894 projected.insert(col.clone(), val);
895 }
896 result_rows.push(projected);
897 }
898
899 Ok((select_columns.to_vec(), result_rows))
900 }
901
902 fn execute_insert(
903 &mut self,
904 insert: &InsertStmt,
905 params: &[SochValue],
906 ) -> SqlResult<ExecutionResult> {
907 let table_name = insert.table.name();
908
909 self.check_table_permission(table_name, PermissionOp::Create)?;
911
912 let rows = match &insert.source {
913 InsertSource::Values(values) => values,
914 InsertSource::Query(_) => {
915 return Err(SqlError::NotImplemented(
916 "INSERT ... SELECT not yet supported".into(),
917 ));
918 }
919 InsertSource::Default => {
920 return Err(SqlError::NotImplemented(
921 "INSERT DEFAULT VALUES not yet supported".into(),
922 ));
923 }
924 };
925
926 self.conn.insert(
927 table_name,
928 insert.columns.as_deref(),
929 rows,
930 insert.on_conflict.as_ref(),
931 params,
932 )
933 }
934
935 fn execute_update(
936 &mut self,
937 update: &UpdateStmt,
938 params: &[SochValue],
939 ) -> SqlResult<ExecutionResult> {
940 let table_name = update.table.name();
941
942 self.check_table_permission(table_name, PermissionOp::Update)?;
944
945 self.conn.update(
946 table_name,
947 &update.assignments,
948 update.where_clause.as_ref(),
949 params,
950 )
951 }
952
953 fn execute_delete(
954 &mut self,
955 delete: &DeleteStmt,
956 params: &[SochValue],
957 ) -> SqlResult<ExecutionResult> {
958 let table_name = delete.table.name();
959
960 self.check_table_permission(table_name, PermissionOp::Delete)?;
962
963 self.conn
964 .delete(table_name, delete.where_clause.as_ref(), params)
965 }
966
967 fn execute_create_table(&mut self, stmt: &CreateTableStmt) -> SqlResult<ExecutionResult> {
968 if stmt.if_not_exists {
970 let table_name = stmt.name.name();
971 if self.conn.table_exists(table_name)? {
972 return Ok(ExecutionResult::Ok);
973 }
974 }
975
976 self.conn.create_table(stmt)
977 }
978
979 fn execute_drop_table(&mut self, stmt: &DropTableStmt) -> SqlResult<ExecutionResult> {
980 if stmt.if_exists {
982 for name in &stmt.names {
983 if !self.conn.table_exists(name.name())? {
984 return Ok(ExecutionResult::Ok);
985 }
986 }
987 }
988
989 self.conn.drop_table(stmt)
990 }
991
992 fn execute_create_index(&mut self, stmt: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
993 if stmt.if_not_exists {
995 if self.conn.index_exists(&stmt.name)? {
996 return Ok(ExecutionResult::Ok);
997 }
998 }
999
1000 self.conn.create_index(stmt)
1001 }
1002
1003 fn execute_drop_index(&mut self, stmt: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1004 if stmt.if_exists {
1006 if !self.conn.index_exists(&stmt.name)? {
1007 return Ok(ExecutionResult::Ok);
1008 }
1009 }
1010
1011 self.conn.drop_index(stmt)
1012 }
1013
1014 fn execute_alter_table(&mut self, stmt: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1015 self.conn.alter_table(stmt)
1016 }
1017
1018 fn extract_select_columns(&self, items: &[SelectItem]) -> SqlResult<Vec<String>> {
1020 let mut columns = Vec::new();
1021
1022 for item in items {
1023 match item {
1024 SelectItem::Wildcard => columns.push("*".to_string()),
1025 SelectItem::QualifiedWildcard(table) => columns.push(format!("{}.*", table)),
1026 SelectItem::Expr { expr, alias } => {
1027 let name = alias.clone().unwrap_or_else(|| match expr {
1028 Expr::Column(col) => col.column.clone(),
1029 Expr::Function(func) => format!("{}()", func.name.name()),
1030 _ => "?column?".to_string(),
1031 });
1032 columns.push(name);
1033 }
1034 }
1035 }
1036
1037 Ok(columns)
1038 }
1039
1040 fn extract_limit(&self, expr: &Option<Expr>) -> SqlResult<Option<usize>> {
1042 match expr {
1043 Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
1044 Some(_) => Err(SqlError::InvalidArgument(
1045 "LIMIT/OFFSET must be an integer literal".into(),
1046 )),
1047 None => Ok(None),
1048 }
1049 }
1050
1051 fn find_max_placeholder(&self, stmt: &Statement) -> u32 {
1053 let mut visitor = PlaceholderVisitor::new();
1054 visitor.visit_statement(stmt);
1055 visitor.max_placeholder
1056 }
1057}
1058
1059struct PlaceholderVisitor {
1061 max_placeholder: u32,
1062}
1063
1064impl PlaceholderVisitor {
1065 fn new() -> Self {
1066 Self { max_placeholder: 0 }
1067 }
1068
1069 fn visit_statement(&mut self, stmt: &Statement) {
1070 match stmt {
1071 Statement::Select(s) => self.visit_select(s),
1072 Statement::Insert(i) => self.visit_insert(i),
1073 Statement::Update(u) => self.visit_update(u),
1074 Statement::Delete(d) => self.visit_delete(d),
1075 _ => {}
1076 }
1077 }
1078
1079 fn visit_select(&mut self, select: &SelectStmt) {
1080 for item in &select.columns {
1081 if let SelectItem::Expr { expr, .. } = item {
1082 self.visit_expr(expr);
1083 }
1084 }
1085 if let Some(where_clause) = &select.where_clause {
1086 self.visit_expr(where_clause);
1087 }
1088 if let Some(having) = &select.having {
1089 self.visit_expr(having);
1090 }
1091 for order in &select.order_by {
1092 self.visit_expr(&order.expr);
1093 }
1094 if let Some(limit) = &select.limit {
1095 self.visit_expr(limit);
1096 }
1097 if let Some(offset) = &select.offset {
1098 self.visit_expr(offset);
1099 }
1100 }
1101
1102 fn visit_insert(&mut self, insert: &InsertStmt) {
1103 if let InsertSource::Values(rows) = &insert.source {
1104 for row in rows {
1105 for expr in row {
1106 self.visit_expr(expr);
1107 }
1108 }
1109 }
1110 }
1111
1112 fn visit_update(&mut self, update: &UpdateStmt) {
1113 for assign in &update.assignments {
1114 self.visit_expr(&assign.value);
1115 }
1116 if let Some(where_clause) = &update.where_clause {
1117 self.visit_expr(where_clause);
1118 }
1119 }
1120
1121 fn visit_delete(&mut self, delete: &DeleteStmt) {
1122 if let Some(where_clause) = &delete.where_clause {
1123 self.visit_expr(where_clause);
1124 }
1125 }
1126
1127 fn visit_expr(&mut self, expr: &Expr) {
1128 match expr {
1129 Expr::Placeholder(n) => {
1130 self.max_placeholder = self.max_placeholder.max(*n);
1131 }
1132 Expr::BinaryOp { left, right, .. } => {
1133 self.visit_expr(left);
1134 self.visit_expr(right);
1135 }
1136 Expr::UnaryOp { expr, .. } => {
1137 self.visit_expr(expr);
1138 }
1139 Expr::Function(func) => {
1140 for arg in &func.args {
1141 self.visit_expr(arg);
1142 }
1143 }
1144 Expr::Case {
1145 operand,
1146 conditions,
1147 else_result,
1148 } => {
1149 if let Some(op) = operand {
1150 self.visit_expr(op);
1151 }
1152 for (when, then) in conditions {
1153 self.visit_expr(when);
1154 self.visit_expr(then);
1155 }
1156 if let Some(else_expr) = else_result {
1157 self.visit_expr(else_expr);
1158 }
1159 }
1160 Expr::InList { expr, list, .. } => {
1161 self.visit_expr(expr);
1162 for item in list {
1163 self.visit_expr(item);
1164 }
1165 }
1166 Expr::Between {
1167 expr, low, high, ..
1168 } => {
1169 self.visit_expr(expr);
1170 self.visit_expr(low);
1171 self.visit_expr(high);
1172 }
1173 Expr::Cast { expr, .. } => {
1174 self.visit_expr(expr);
1175 }
1176 _ => {}
1177 }
1178 }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183 use super::*;
1184
1185 #[test]
1186 fn test_placeholder_visitor() {
1187 let stmt = Parser::parse("SELECT * FROM users WHERE id = $1 AND name = $2").unwrap();
1188 let mut visitor = PlaceholderVisitor::new();
1189 visitor.visit_statement(&stmt);
1190 assert_eq!(visitor.max_placeholder, 2);
1191 }
1192
1193 #[test]
1194 fn test_question_mark_placeholders() {
1195 let stmt = Parser::parse("SELECT * FROM users WHERE id = ? AND name = ?").unwrap();
1196 let mut visitor = PlaceholderVisitor::new();
1197 visitor.visit_statement(&stmt);
1198 assert_eq!(visitor.max_placeholder, 2);
1199 }
1200
1201 #[test]
1202 fn test_dialect_detection() {
1203 assert_eq!(
1204 SqlDialect::detect("SELECT * FROM users"),
1205 SqlDialect::Standard
1206 );
1207 assert_eq!(
1208 SqlDialect::detect("INSERT IGNORE INTO users VALUES (1)"),
1209 SqlDialect::MySQL
1210 );
1211 assert_eq!(
1212 SqlDialect::detect("INSERT OR IGNORE INTO users VALUES (1)"),
1213 SqlDialect::SQLite
1214 );
1215 }
1216
1217 #[test]
1218 fn test_define_scope_stores_definition() {
1219 use crate::sql::bridge::tests::make_mock_bridge;
1220 let mut bridge = make_mock_bridge();
1221 let result = bridge.execute("DEFINE SCOPE user_scope SESSION 86400");
1223 result.unwrap();
1224 let scope = bridge.get_scope("user_scope");
1225 assert!(scope.is_some(), "Scope not stored");
1226 let scope = scope.unwrap();
1227 assert_eq!(scope.name, "user_scope");
1228 assert_eq!(scope.session_duration_secs, Some(86400));
1229 }
1230
1231 #[test]
1232 fn test_remove_scope_deletes_definition() {
1233 use crate::sql::bridge::tests::make_mock_bridge;
1234 let mut bridge = make_mock_bridge();
1235 bridge
1236 .execute("DEFINE SCOPE temp_scope SESSION 3600")
1237 .unwrap();
1238 assert!(bridge.get_scope("temp_scope").is_some());
1239 bridge.execute("REMOVE SCOPE temp_scope").unwrap();
1240 assert!(bridge.get_scope("temp_scope").is_none());
1241 }
1242
1243 #[test]
1244 fn test_define_table_permissions_stores_rules() {
1245 use crate::sql::bridge::tests::make_mock_bridge;
1246 let mut bridge = make_mock_bridge();
1247 let result = bridge
1248 .execute("DEFINE TABLE posts PERMISSIONS FOR select WHERE true FOR delete WHERE false");
1249 assert!(result.is_ok());
1250 let perms = bridge.get_table_permissions("posts");
1251 assert!(perms.is_some());
1252 assert_eq!(perms.unwrap().permissions.len(), 2);
1253 }
1254
1255 #[test]
1256 fn test_table_permission_check_allows_matching_true() {
1257 use crate::sql::bridge::tests::make_mock_bridge;
1258 let mut bridge = make_mock_bridge();
1259 bridge.execute(
1260 "DEFINE TABLE docs PERMISSIONS FOR select WHERE true FOR insert WHERE true FOR update WHERE true FOR delete WHERE true"
1261 ).unwrap();
1262 assert!(
1263 bridge
1264 .check_table_permission("docs", PermissionOp::Select)
1265 .is_ok()
1266 );
1267 assert!(
1268 bridge
1269 .check_table_permission("docs", PermissionOp::Create)
1270 .is_ok()
1271 );
1272 assert!(
1273 bridge
1274 .check_table_permission("docs", PermissionOp::Update)
1275 .is_ok()
1276 );
1277 assert!(
1278 bridge
1279 .check_table_permission("docs", PermissionOp::Delete)
1280 .is_ok()
1281 );
1282 }
1283
1284 #[test]
1285 fn test_table_permission_check_denies_matching_false() {
1286 use crate::sql::bridge::tests::make_mock_bridge;
1287 let mut bridge = make_mock_bridge();
1288 bridge
1289 .execute(
1290 "DEFINE TABLE secrets PERMISSIONS FOR select WHERE false FOR delete WHERE false",
1291 )
1292 .unwrap();
1293 let err = bridge.check_table_permission("secrets", PermissionOp::Select);
1294 assert!(err.is_err());
1295 assert!(format!("{}", err.unwrap_err()).contains("Permission denied"));
1296 }
1297
1298 #[test]
1299 fn test_table_permission_denies_undefined_op_when_rules_exist() {
1300 use crate::sql::bridge::tests::make_mock_bridge;
1301 let mut bridge = make_mock_bridge();
1302 bridge
1304 .execute("DEFINE TABLE restricted PERMISSIONS FOR select WHERE true")
1305 .unwrap();
1306 assert!(
1307 bridge
1308 .check_table_permission("restricted", PermissionOp::Select)
1309 .is_ok()
1310 );
1311 let err = bridge.check_table_permission("restricted", PermissionOp::Update);
1312 assert!(err.is_err());
1313 }
1314
1315 #[test]
1316 fn test_no_permissions_allows_all() {
1317 use crate::sql::bridge::tests::make_mock_bridge;
1318 let bridge = make_mock_bridge();
1319 assert!(
1321 bridge
1322 .check_table_permission("any_table", PermissionOp::Select)
1323 .is_ok()
1324 );
1325 assert!(
1326 bridge
1327 .check_table_permission("any_table", PermissionOp::Delete)
1328 .is_ok()
1329 );
1330 }
1331
1332 fn make_mock_bridge() -> SqlBridge<MockPermConn> {
1334 SqlBridge::new(MockPermConn)
1335 }
1336
1337 struct MockPermConn;
1339
1340 impl SqlConnection for MockPermConn {
1341 fn select(
1342 &self,
1343 _: &str,
1344 _: &[String],
1345 _: Option<&Expr>,
1346 _: &[OrderByItem],
1347 _: Option<usize>,
1348 _: Option<usize>,
1349 _: &[SochValue],
1350 ) -> SqlResult<ExecutionResult> {
1351 Ok(ExecutionResult::Rows {
1352 columns: vec![],
1353 rows: vec![],
1354 })
1355 }
1356 fn insert(
1357 &mut self,
1358 _: &str,
1359 _: Option<&[String]>,
1360 _: &[Vec<Expr>],
1361 _: Option<&OnConflict>,
1362 _: &[SochValue],
1363 ) -> SqlResult<ExecutionResult> {
1364 Ok(ExecutionResult::RowsAffected(0))
1365 }
1366 fn update(
1367 &mut self,
1368 _: &str,
1369 _: &[Assignment],
1370 _: Option<&Expr>,
1371 _: &[SochValue],
1372 ) -> SqlResult<ExecutionResult> {
1373 Ok(ExecutionResult::RowsAffected(0))
1374 }
1375 fn delete(
1376 &mut self,
1377 _: &str,
1378 _: Option<&Expr>,
1379 _: &[SochValue],
1380 ) -> SqlResult<ExecutionResult> {
1381 Ok(ExecutionResult::RowsAffected(0))
1382 }
1383 fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> {
1384 Ok(ExecutionResult::Ok)
1385 }
1386 fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> {
1387 Ok(ExecutionResult::Ok)
1388 }
1389 fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
1390 Ok(ExecutionResult::Ok)
1391 }
1392 fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1393 Ok(ExecutionResult::Ok)
1394 }
1395 fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1396 Ok(ExecutionResult::Ok)
1397 }
1398 fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> {
1399 Ok(ExecutionResult::TransactionOk)
1400 }
1401 fn commit(&mut self) -> SqlResult<ExecutionResult> {
1402 Ok(ExecutionResult::TransactionOk)
1403 }
1404 fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> {
1405 Ok(ExecutionResult::TransactionOk)
1406 }
1407 fn table_exists(&self, _: &str) -> SqlResult<bool> {
1408 Ok(true)
1409 }
1410 fn index_exists(&self, _: &str) -> SqlResult<bool> {
1411 Ok(false)
1412 }
1413 fn scan_all(&self, _: &str, _: &[String]) -> SqlResult<Vec<HashMap<String, SochValue>>> {
1414 Ok(vec![])
1415 }
1416 fn eval_join_predicate(
1417 &self,
1418 _: &Expr,
1419 _: &HashMap<String, SochValue>,
1420 _: &[SochValue],
1421 ) -> Option<bool> {
1422 Some(true)
1423 }
1424 }
1425}