1use crate::ast::{
11 Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15#[derive(Debug, Clone)]
17pub struct SanitizeError {
18 pub field: String,
19 pub value: String,
20 pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(
26 f,
27 "AST validation failed: {} '{}' — {}",
28 self.field, self.value, self.reason
29 )
30 }
31}
32
33impl std::error::Error for SanitizeError {}
34
35const MAX_IDENT_LEN: usize = 63;
37const MAX_RAW_FUNCTION_LEN: usize = 1024;
38
39fn is_safe_identifier(s: &str) -> bool {
43 !s.is_empty()
44 && s.split('.').all(|part| {
45 !part.is_empty()
46 && part.len() <= MAX_IDENT_LEN
47 && part.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_')
48 })
49}
50
51fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
53 if is_safe_identifier(value) {
54 Ok(())
55 } else {
56 Err(SanitizeError {
57 field: field.to_string(),
58 value: value.chars().take(40).collect(),
59 reason: "identifier parts must match [a-zA-Z0-9_] and be ≤63 chars".to_string(),
60 })
61 }
62}
63
64fn check_named_param(field: &str, value: &str) -> Result<(), SanitizeError> {
65 let mut chars = value.chars();
66 let Some(first) = chars.next() else {
67 return Err(SanitizeError {
68 field: field.to_string(),
69 value: String::new(),
70 reason: "named parameter cannot be empty".to_string(),
71 });
72 };
73
74 if (first.is_ascii_alphabetic() || first == '_')
75 && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
76 {
77 Ok(())
78 } else {
79 Err(SanitizeError {
80 field: field.to_string(),
81 value: value.chars().take(40).collect(),
82 reason: "named parameters must match [a-zA-Z_][a-zA-Z0-9_]*".to_string(),
83 })
84 }
85}
86
87fn check_raw_function_value(field: &str, value: &str) -> Result<(), SanitizeError> {
88 if value.len() <= MAX_RAW_FUNCTION_LEN
89 && !value.contains('\0')
90 && !value.contains(';')
91 && !value.contains("--")
92 && !value.contains("/*")
93 && !value.contains("*/")
94 {
95 Ok(())
96 } else {
97 Err(SanitizeError {
98 field: field.to_string(),
99 value: value.chars().take(40).collect(),
100 reason: "raw function values cannot contain NUL, statement separators, or comments"
101 .to_string(),
102 })
103 }
104}
105
106fn contains_unquoted_statement_delimiter(value: &str) -> bool {
107 let bytes = value.as_bytes();
108 let mut i = 0;
109 let mut in_single = false;
110 let mut in_double = false;
111
112 while i < bytes.len() {
113 let b = bytes[i];
114 if b == 0 {
115 return true;
116 }
117
118 if in_single {
119 if b == b'\'' {
120 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
121 i += 2;
122 continue;
123 }
124 in_single = false;
125 }
126 i += 1;
127 continue;
128 }
129
130 if in_double {
131 if b == b'"' {
132 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
133 i += 2;
134 continue;
135 }
136 in_double = false;
137 }
138 i += 1;
139 continue;
140 }
141
142 match b {
143 b'\'' => in_single = true,
144 b'"' => in_double = true,
145 b';' => return true,
146 b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => return true,
147 b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => return true,
148 _ => {}
149 }
150 i += 1;
151 }
152
153 false
154}
155
156fn check_sql_expr_fragment(field: &str, value: &str) -> Result<(), SanitizeError> {
157 let expr = value.trim();
158 if !expr.is_empty() && !contains_unquoted_statement_delimiter(expr) {
159 Ok(())
160 } else {
161 Err(SanitizeError {
162 field: field.to_string(),
163 value: value.chars().take(40).collect(),
164 reason: "SQL expression fragments cannot be empty or contain unquoted NUL, statement separators, or comments".to_string(),
165 })
166 }
167}
168
169fn table_ref_error(field: &str, value: &str) -> SanitizeError {
170 SanitizeError {
171 field: field.to_string(),
172 value: value.chars().take(40).collect(),
173 reason: "table references must be identifier or identifier [AS] alias".to_string(),
174 }
175}
176
177fn check_table_ref(field: &str, value: &str) -> Result<(), SanitizeError> {
178 let parts = value.split_whitespace().collect::<Vec<_>>();
179 match parts.as_slice() {
180 [table] => check_ident(field, table),
181 [table, alias] => {
182 check_ident(field, table)?;
183 check_ident(field, alias)
184 }
185 [table, as_keyword, alias] if as_keyword.eq_ignore_ascii_case("as") => {
186 check_ident(field, table)?;
187 check_ident(field, alias)
188 }
189 _ => Err(table_ref_error(field, value)),
190 }
191}
192
193fn action_allows_table_alias(action: Action) -> bool {
194 matches!(
195 action,
196 Action::Get
197 | Action::Cnt
198 | Action::Set
199 | Action::Del
200 | Action::Export
201 | Action::Explain
202 | Action::ExplainAnalyze
203 | Action::Over
204 )
205}
206
207fn check_fk_action(field: &str, value: &str) -> Result<(), SanitizeError> {
208 let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
209 if matches!(
210 normalized.as_str(),
211 "cascade" | "restrict" | "no action" | "set null" | "set default"
212 ) {
213 Ok(())
214 } else {
215 Err(SanitizeError {
216 field: field.to_string(),
217 value: value.chars().take(40).collect(),
218 reason:
219 "foreign key action must be cascade, restrict, no_action, set_null, or set_default"
220 .to_string(),
221 })
222 }
223}
224
225fn check_fk_deferrable(field: &str, value: &str) -> Result<(), SanitizeError> {
226 let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
227 if matches!(
228 normalized.as_str(),
229 "deferrable"
230 | "initially deferred"
231 | "initially immediate"
232 | "deferrable initially deferred"
233 | "deferrable initially immediate"
234 ) {
235 Ok(())
236 } else {
237 Err(SanitizeError {
238 field: field.to_string(),
239 value: value.chars().take(40).collect(),
240 reason: "foreign key deferrable clause must be deferrable, initially_deferred, or initially_immediate".to_string(),
241 })
242 }
243}
244
245fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
250 match expr {
251 Expr::Star => Ok(()),
252 Expr::Named(name) => check_ident(field, name),
253 Expr::Aliased { name, alias } => {
254 check_ident(field, name)?;
255 check_ident(&format!("{field}.alias"), alias)
256 }
257 Expr::Aggregate {
258 col, alias, filter, ..
259 } => {
260 if col != "*" {
261 check_ident(field, col)?;
262 }
263 if let Some(a) = alias {
264 check_ident(&format!("{field}.alias"), a)?;
265 }
266 if let Some(conditions) = filter {
267 for cond in conditions {
268 check_expr(&format!("{field}.filter"), &cond.left)?;
269 check_value(&format!("{field}.filter"), &cond.value)?;
270 }
271 }
272 Ok(())
273 }
274 Expr::FunctionCall { name, args, alias } => {
275 check_ident(field, name)?;
276 if let Some(a) = alias {
277 check_ident(&format!("{field}.alias"), a)?;
278 }
279 for arg in args {
280 check_expr(&format!("{field}.arg"), arg)?;
281 }
282 Ok(())
283 }
284 Expr::Cast {
285 expr,
286 target_type,
287 alias,
288 } => {
289 check_expr(field, expr)?;
290 check_ident(&format!("{field}.cast_type"), target_type)?;
291 if let Some(a) = alias {
292 check_ident(&format!("{field}.alias"), a)?;
293 }
294 Ok(())
295 }
296 Expr::Binary {
297 left, right, alias, ..
298 } => {
299 check_expr(field, left)?;
300 check_expr(field, right)?;
301 if let Some(a) = alias {
302 check_ident(&format!("{field}.alias"), a)?;
303 }
304 Ok(())
305 }
306 Expr::Literal(_) => Ok(()),
307 Expr::JsonAccess {
308 column,
309 alias,
310 path_segments,
311 ..
312 } => {
313 check_ident(field, column)?;
314 for (key, _) in path_segments {
315 if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
317 return Err(SanitizeError {
318 field: format!("{field}.json_path"),
319 value: key.chars().take(40).collect(),
320 reason: "JSON path key must be a safe identifier or integer".to_string(),
321 });
322 }
323 }
324 if let Some(a) = alias {
325 check_ident(&format!("{field}.alias"), a)?;
326 }
327 Ok(())
328 }
329 Expr::Subquery { query, alias } => {
330 validate_ast(query)?;
331 if let Some(a) = alias {
332 check_ident(&format!("{field}.alias"), a)?;
333 }
334 Ok(())
335 }
336 Expr::Exists { query, alias, .. } => {
337 validate_ast(query)?;
338 if let Some(a) = alias {
339 check_ident(&format!("{field}.alias"), a)?;
340 }
341 Ok(())
342 }
343 Expr::Window {
345 name,
346 func,
347 partition,
348 params,
349 order,
350 ..
351 } => {
352 if !name.is_empty() {
353 check_ident(&format!("{field}.window_alias"), name)?;
354 }
355 check_ident(&format!("{field}.window_func"), func)?;
356 for p in partition {
357 check_ident(&format!("{field}.partition"), p)?;
358 }
359 for p in params {
360 check_expr(&format!("{field}.window_param"), p)?;
361 }
362 for cage in order {
363 for cond in &cage.conditions {
364 check_expr(&format!("{field}.window_order"), &cond.left)?;
365 check_value(&format!("{field}.window_order"), &cond.value)?;
366 }
367 }
368 Ok(())
369 }
370 Expr::Case {
371 when_clauses,
372 else_value,
373 alias,
374 } => {
375 for (cond, val) in when_clauses {
376 check_expr(&format!("{field}.case_when"), &cond.left)?;
377 check_value(&format!("{field}.case_when"), &cond.value)?;
378 check_expr(&format!("{field}.case_then"), val)?;
379 }
380 if let Some(e) = else_value {
381 check_expr(&format!("{field}.case_else"), e)?;
382 }
383 if let Some(a) = alias {
384 check_ident(&format!("{field}.alias"), a)?;
385 }
386 Ok(())
387 }
388 Expr::SpecialFunction { args, alias, name } => {
389 check_ident(&format!("{field}.special_func"), name)?;
390 for (_, arg) in args {
391 check_expr(&format!("{field}.special_func_arg"), arg)?;
392 }
393 if let Some(a) = alias {
394 check_ident(&format!("{field}.alias"), a)?;
395 }
396 Ok(())
397 }
398 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
399 for elem in elements {
400 check_expr(&format!("{field}.element"), elem)?;
401 }
402 if let Some(a) = alias {
403 check_ident(&format!("{field}.alias"), a)?;
404 }
405 Ok(())
406 }
407 Expr::Subscript { expr, index, alias } => {
408 check_expr(&format!("{field}.subscript_expr"), expr)?;
409 check_expr(&format!("{field}.subscript_index"), index)?;
410 if let Some(a) = alias {
411 check_ident(&format!("{field}.alias"), a)?;
412 }
413 Ok(())
414 }
415 Expr::Collate {
416 expr,
417 collation,
418 alias,
419 } => {
420 check_expr(&format!("{field}.collate_expr"), expr)?;
421 check_ident(&format!("{field}.collation"), collation)?;
422 if let Some(a) = alias {
423 check_ident(&format!("{field}.alias"), a)?;
424 }
425 Ok(())
426 }
427 Expr::FieldAccess {
428 expr,
429 field: f,
430 alias,
431 } => {
432 check_expr(&format!("{field}.field_access_expr"), expr)?;
433 check_ident(&format!("{field}.field"), f)?;
434 if let Some(a) = alias {
435 check_ident(&format!("{field}.alias"), a)?;
436 }
437 Ok(())
438 }
439 Expr::Def { name, .. } => check_ident(field, name),
440 Expr::Mod { col, .. } => check_expr(field, col),
441 }
442}
443
444fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
446 match value {
447 Value::Column(column) => check_ident(&format!("{field}.column"), column),
448 Value::NamedParam(name) => check_named_param(&format!("{field}.named_param"), name),
449 Value::Function(function) => {
450 check_raw_function_value(&format!("{field}.function"), function)
451 }
452 Value::Subquery(q) => validate_ast(q),
453 Value::Array(vals) => {
454 for v in vals {
455 check_value(field, v)?;
456 }
457 Ok(())
458 }
459 Value::Expr(expr) => check_expr(field, expr),
460 _ => Ok(()),
461 }
462}
463
464pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
473 match cmd.action {
475 Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
476 return Err(SanitizeError {
477 field: "action".to_string(),
478 value: format!("{:?}", cmd.action),
479 reason: "procedural/session actions are not allowed via binary AST".to_string(),
480 });
481 }
482 _ => {}
483 }
484
485 if !cmd.table.is_empty() {
487 if action_allows_table_alias(cmd.action) {
488 check_table_ref("table", &cmd.table)?;
489 } else {
490 check_ident("table", &cmd.table)?;
491 }
492 }
493
494 for (i, col) in cmd.columns.iter().enumerate() {
496 check_expr(&format!("columns[{i}]"), col)?;
497 }
498
499 for (i, constraint) in cmd.table_constraints.iter().enumerate() {
501 match constraint {
502 TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
503 for col in cols {
504 check_ident(&format!("table_constraints[{i}].column"), col)?;
505 }
506 }
507 TableConstraint::ForeignKey {
508 name,
509 columns,
510 ref_table,
511 ref_columns,
512 on_delete,
513 on_update,
514 deferrable,
515 } => {
516 if let Some(name) = name {
517 check_ident(&format!("table_constraints[{i}].name"), name)?;
518 }
519 for col in columns {
520 check_ident(&format!("table_constraints[{i}].column"), col)?;
521 }
522 check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
523 for col in ref_columns {
524 check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
525 }
526 if let Some(action) = on_delete {
527 check_fk_action(&format!("table_constraints[{i}].on_delete"), action)?;
528 }
529 if let Some(action) = on_update {
530 check_fk_action(&format!("table_constraints[{i}].on_update"), action)?;
531 }
532 if let Some(clause) = deferrable {
533 check_fk_deferrable(&format!("table_constraints[{i}].deferrable"), clause)?;
534 }
535 }
536 }
537 }
538
539 for (i, join) in cmd.joins.iter().enumerate() {
541 check_table_ref(&format!("joins[{i}].table"), &join.table)?;
542 if let Some(ref conditions) = join.on {
543 for cond in conditions {
544 check_expr(&format!("joins[{i}].on"), &cond.left)?;
545 check_value(&format!("joins[{i}].on"), &cond.value)?;
546 }
547 }
548 }
549
550 for cage in &cmd.cages {
552 for cond in &cage.conditions {
553 check_expr("cage.condition.left", &cond.left)?;
554 check_value("cage.condition.value", &cond.value)?;
555 }
556 }
557
558 for cte in &cmd.ctes {
560 check_ident("cte.name", &cte.name)?;
561 for col in &cte.columns {
562 check_ident("cte.column", col)?;
563 }
564 validate_ast(&cte.base_query)?;
565 if let Some(ref rq) = cte.recursive_query {
566 validate_ast(rq)?;
567 }
568 }
569
570 for expr in &cmd.distinct_on {
572 check_expr("distinct_on", expr)?;
573 }
574
575 if let Some(ref cols) = cmd.returning {
577 for col in cols {
578 check_expr("returning", col)?;
579 }
580 }
581
582 if let Some(ref oc) = cmd.on_conflict {
584 for col in &oc.columns {
585 check_ident("on_conflict.column", col)?;
586 }
587 if let ConflictAction::DoUpdate { assignments } = &oc.action {
588 for (col, expr) in assignments {
589 check_ident("on_conflict.assignment.column", col)?;
590 check_expr("on_conflict.assignment.expr", expr)?;
591 }
592 }
593 }
594
595 if let Some(ref merge) = cmd.merge {
597 if let Some(alias) = &merge.target_alias {
598 check_ident("merge.target_alias", alias)?;
599 }
600 match &merge.source {
601 MergeSource::Table { name, alias } => {
602 if let Some(alias) = alias {
603 check_ident("merge.source.table", name)?;
604 check_ident("merge.source.alias", alias)?;
605 } else {
606 check_table_ref("merge.source.table", name)?;
607 }
608 }
609 MergeSource::Query { query, alias } => {
610 validate_ast(query)?;
611 if let Some(alias) = alias {
612 check_ident("merge.source.alias", alias)?;
613 }
614 }
615 }
616 for cond in &merge.on {
617 check_expr("merge.on.left", &cond.left)?;
618 check_value("merge.on.value", &cond.value)?;
619 }
620 for clause in &merge.clauses {
621 for cond in &clause.condition {
622 check_expr("merge.clause.condition.left", &cond.left)?;
623 check_value("merge.clause.condition.value", &cond.value)?;
624 }
625 match &clause.action {
626 MergeAction::Update { assignments } => {
627 for (col, expr) in assignments {
628 check_ident("merge.update.column", col)?;
629 check_expr("merge.update.expr", expr)?;
630 }
631 }
632 MergeAction::Insert { columns, values } => {
633 for col in columns {
634 check_ident("merge.insert.column", col)?;
635 }
636 for expr in values {
637 check_expr("merge.insert.expr", expr)?;
638 }
639 }
640 MergeAction::Delete | MergeAction::DoNothing => {}
641 }
642 }
643 }
644
645 for t in &cmd.from_tables {
647 check_table_ref("from_tables", t)?;
648 }
649 for t in &cmd.using_tables {
650 check_table_ref("using_tables", t)?;
651 }
652
653 for (_, sub) in &cmd.set_ops {
655 validate_ast(sub)?;
656 }
657
658 if let Some(ref sq) = cmd.source_query {
660 validate_ast(sq)?;
661 }
662
663 for cond in &cmd.having {
665 check_expr("having", &cond.left)?;
666 check_value("having", &cond.value)?;
667 }
668
669 match cmd.action {
671 Action::AlterAddConstraint | Action::AlterDropConstraint => {
672 let Some(ref name) = cmd.channel else {
673 return Err(SanitizeError {
674 field: "channel".to_string(),
675 value: String::new(),
676 reason: "constraint actions require a constraint name".to_string(),
677 });
678 };
679 check_ident("channel", name)?;
680
681 if matches!(cmd.action, Action::AlterAddConstraint) {
682 let Some(ref expr) = cmd.payload else {
683 return Err(SanitizeError {
684 field: "payload".to_string(),
685 value: String::new(),
686 reason: "add constraint requires a check expression".to_string(),
687 });
688 };
689 check_sql_expr_fragment("payload", expr)?;
690 }
691 }
692 _ => {}
693 }
694
695 if let Some(ref ch) = cmd.channel {
696 check_ident("channel", ch)?;
697 }
698
699 Ok(())
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use crate::ast::{Operator, Qail};
706
707 #[test]
708 fn valid_simple_query_passes() {
709 let cmd = Qail::get("users").columns(["id", "name"]);
710 assert!(validate_ast(&cmd).is_ok());
711 }
712
713 #[test]
714 fn sql_injection_in_table_rejected() {
715 let cmd = Qail::get("users; DROP TABLE users; --");
716 let err = validate_ast(&cmd).unwrap_err();
717 assert_eq!(err.field, "table");
718 }
719
720 #[test]
721 fn call_action_rejected() {
722 let cmd = Qail {
723 action: Action::Call,
724 table: "my_proc()".to_string(),
725 ..Default::default()
726 };
727 let err = validate_ast(&cmd).unwrap_err();
728 assert_eq!(err.field, "action");
729 }
730
731 #[test]
732 fn do_action_rejected() {
733 let cmd = Qail {
734 action: Action::Do,
735 table: "plpgsql".to_string(),
736 ..Default::default()
737 };
738 let err = validate_ast(&cmd).unwrap_err();
739 assert_eq!(err.field, "action");
740 }
741
742 #[test]
743 fn valid_qualified_name_passes() {
744 let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
745 assert!(validate_ast(&cmd).is_ok());
746 }
747
748 #[test]
749 fn valid_long_qualified_identifier_parts_pass() {
750 let schema = "s".repeat(MAX_IDENT_LEN);
751 let table = "t".repeat(MAX_IDENT_LEN);
752 let cmd = Qail::get(format!("{schema}.{table}")).columns(["id"]);
753
754 assert!(validate_ast(&cmd).is_ok());
755 }
756
757 #[test]
758 fn empty_qualified_identifier_part_is_rejected() {
759 let err = validate_ast(&Qail::get("public..users")).unwrap_err();
760
761 assert_eq!(err.field, "table");
762 }
763
764 #[test]
765 fn query_and_mutation_table_aliases_pass_sanitizer() {
766 assert!(validate_ast(&Qail::get("public.users u")).is_ok());
767 assert!(validate_ast(&Qail::set("public.users AS u").set_value("active", true)).is_ok());
768 assert!(validate_ast(&Qail::del("public.users u")).is_ok());
769 }
770
771 #[test]
772 fn ddl_table_alias_shape_is_rejected() {
773 let err = validate_ast(&Qail::make("public.users u")).unwrap_err();
774 assert_eq!(err.field, "table");
775 }
776
777 #[test]
778 fn alter_add_constraint_rejects_unsafe_payload() {
779 let cmd = Qail {
780 action: crate::ast::Action::AlterAddConstraint,
781 table: "users".to_string(),
782 channel: Some("users_active_check".to_string()),
783 payload: Some("active); DROP TABLE users; --".to_string()),
784 ..Default::default()
785 };
786
787 let err = validate_ast(&cmd).unwrap_err();
788 assert_eq!(err.field, "payload");
789 }
790
791 #[test]
792 fn alter_add_constraint_allows_quoted_delimiter_payload() {
793 let cmd = Qail {
794 action: crate::ast::Action::AlterAddConstraint,
795 table: "events".to_string(),
796 channel: Some("events_kind_check".to_string()),
797 payload: Some("kind <> 'semi;inside'".to_string()),
798 ..Default::default()
799 };
800
801 assert!(validate_ast(&cmd).is_ok());
802 }
803
804 #[test]
805 fn injection_in_join_table_rejected() {
806 use crate::ast::JoinKind;
807 let cmd = Qail::get("users").join(
808 JoinKind::Left,
809 "orders; DROP TABLE x",
810 "users.id",
811 "orders.user_id",
812 );
813 let err = validate_ast(&cmd).unwrap_err();
814 assert!(err.field.contains("joins"));
815 }
816
817 #[test]
818 fn malformed_join_table_reference_rejected() {
819 use crate::ast::JoinKind;
820 let cmd = Qail::get("users").join(
821 JoinKind::Left,
822 "orders DROP TABLE x",
823 "users.id",
824 "orders.user_id",
825 );
826 let err = validate_ast(&cmd).unwrap_err();
827 assert!(err.field.contains("joins"));
828 }
829
830 #[test]
831 fn update_from_and_delete_using_aliases_pass_sanitizer() {
832 let update = Qail::set("orders")
833 .set_value("status", "paid")
834 .update_from(["accounts a"]);
835 assert!(validate_ast(&update).is_ok());
836
837 let delete = Qail::del("orders").delete_using(["accounts a"]);
838 assert!(validate_ast(&delete).is_ok());
839 }
840
841 #[test]
842 fn merge_inline_source_alias_passes_sanitizer() {
843 let cmd = Qail::merge_into("orders")
844 .target_alias("o")
845 .using_table("stage_orders s")
846 .merge_on_column("o.id", Operator::Eq, "s.order_id")
847 .when_matched_do_nothing();
848
849 assert!(validate_ast(&cmd).is_ok());
850 }
851
852 #[test]
853 fn malformed_merge_source_table_reference_rejected() {
854 let cmd = Qail::merge_into("orders")
855 .using_table("stage_orders DROP TABLE x")
856 .merge_on_column("orders.id", Operator::Eq, "stage_orders.order_id")
857 .when_matched_do_nothing();
858
859 let err = validate_ast(&cmd).unwrap_err();
860 assert_eq!(err.field, "merge.source.table");
861 }
862
863 #[test]
864 fn injection_in_update_from_and_delete_using_rejected() {
865 let update = Qail::set("orders")
866 .set_value("status", "paid")
867 .update_from(["accounts; DROP TABLE accounts"]);
868 let err = validate_ast(&update).unwrap_err();
869 assert_eq!(err.field, "from_tables");
870
871 let delete = Qail::del("orders").delete_using(["accounts; DROP TABLE accounts"]);
872 let err = validate_ast(&delete).unwrap_err();
873 assert_eq!(err.field, "using_tables");
874 }
875
876 #[test]
877 fn injection_in_column_rejected() {
878 let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
879 let err = validate_ast(&cmd).unwrap_err();
880 assert!(err.field.contains("columns"));
881 }
882
883 #[test]
884 fn unsafe_column_value_rejected() {
885 use crate::ast::Value;
886
887 let cmd = Qail::get("orders").filter(
888 "user_id",
889 Operator::Eq,
890 Value::Column("users.id; DROP TABLE users; --".to_string()),
891 );
892
893 let err = validate_ast(&cmd).unwrap_err();
894 assert_eq!(err.field, "cage.condition.value.column");
895 }
896
897 #[test]
898 fn unsafe_named_parameter_rejected() {
899 use crate::ast::Value;
900
901 let cmd = Qail::get("users").filter(
902 "id",
903 Operator::Eq,
904 Value::NamedParam("id); DROP TABLE users; --".to_string()),
905 );
906
907 let err = validate_ast(&cmd).unwrap_err();
908 assert_eq!(err.field, "cage.condition.value.named_param");
909 }
910
911 #[test]
912 fn unsafe_raw_function_value_rejected() {
913 use crate::ast::Value;
914
915 let cmd = Qail::get("users").filter(
916 "updated_at",
917 Operator::Lt,
918 Value::Function("NOW(); DROP TABLE users; --".to_string()),
919 );
920
921 let err = validate_ast(&cmd).unwrap_err();
922 assert_eq!(err.field, "cage.condition.value.function");
923 }
924
925 #[test]
926 fn safe_raw_function_value_passes_sanitizer() {
927 use crate::ast::Value;
928
929 let cmd = Qail::get("users").filter(
930 "updated_at",
931 Operator::Lt,
932 Value::Function("NOW()".to_string()),
933 );
934
935 assert!(validate_ast(&cmd).is_ok());
936 }
937
938 #[test]
939 fn on_conflict_update_assignment_expression_injection_rejected() {
940 let cmd = Qail::add("users")
941 .set_value("id", 1)
942 .set_value("name", "Alice")
943 .on_conflict_update(
944 &["id"],
945 &[(
946 "name",
947 Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
948 )],
949 );
950
951 let err = validate_ast(&cmd).unwrap_err();
952
953 assert_eq!(err.field, "on_conflict.assignment.expr");
954 assert!(
955 err.value.contains("EXCLUDED.name"),
956 "unexpected rejected value: {}",
957 err.value
958 );
959 }
960
961 #[test]
962 fn on_conflict_update_assignment_column_injection_rejected() {
963 let cmd = Qail::add("users")
964 .set_value("id", 1)
965 .set_value("name", "Alice")
966 .on_conflict_update(
967 &["id"],
968 &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
969 );
970
971 let err = validate_ast(&cmd).unwrap_err();
972
973 assert_eq!(err.field, "on_conflict.assignment.column");
974 }
975
976 #[test]
977 fn aggregate_filter_value_expression_injection_rejected() {
978 use crate::ast::{AggregateFunc, Condition, Operator, Value};
979
980 let mut cmd = Qail::get("events");
981 cmd.columns.push(Expr::Aggregate {
982 col: "id".to_string(),
983 func: AggregateFunc::Count,
984 distinct: false,
985 filter: Some(vec![Condition {
986 left: Expr::Named("direction".to_string()),
987 op: Operator::Eq,
988 value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
989 is_array_unnest: false,
990 }]),
991 alias: None,
992 });
993
994 let err = validate_ast(&cmd).unwrap_err();
995 assert_eq!(err.field, "columns[0].filter");
996 }
997
998 #[test]
999 fn count_star_aggregate_passes_sanitizer() {
1000 use crate::ast::AggregateFunc;
1001
1002 let mut cmd = Qail::get("events");
1003 cmd.columns.push(Expr::Aggregate {
1004 col: "*".to_string(),
1005 func: AggregateFunc::Count,
1006 distinct: false,
1007 filter: None,
1008 alias: Some("total".to_string()),
1009 });
1010
1011 assert!(validate_ast(&cmd).is_ok());
1012 }
1013
1014 #[test]
1015 fn case_when_complex_condition_expression_passes_sanitizer() {
1016 use crate::ast::{Condition, Operator, Value};
1017
1018 let mut cmd = Qail::get("users");
1019 cmd.columns.push(Expr::Case {
1020 when_clauses: vec![(
1021 Condition {
1022 left: Expr::Cast {
1023 expr: Box::new(Expr::JsonAccess {
1024 column: "profile".to_string(),
1025 path_segments: vec![("active".to_string(), true)],
1026 alias: None,
1027 }),
1028 target_type: "integer".to_string(),
1029 alias: None,
1030 },
1031 op: Operator::Gt,
1032 value: Value::Int(0),
1033 is_array_unnest: false,
1034 },
1035 Box::new(Expr::Literal(Value::String("active".to_string()))),
1036 )],
1037 else_value: Some(Box::new(Expr::Literal(Value::String(
1038 "inactive".to_string(),
1039 )))),
1040 alias: Some("status_label".to_string()),
1041 });
1042
1043 assert!(validate_ast(&cmd).is_ok());
1044 }
1045
1046 #[test]
1047 fn empty_table_name_passes() {
1048 let cmd = Qail {
1050 action: Action::TxnStart,
1051 table: String::new(),
1052 ..Default::default()
1053 };
1054 assert!(validate_ast(&cmd).is_ok());
1055 }
1056
1057 #[test]
1058 fn oversized_identifier_rejected() {
1059 let long_name = "a".repeat(64);
1060 let cmd = Qail::get(&long_name);
1061 let err = validate_ast(&cmd).unwrap_err();
1062 assert!(err.reason.contains("63"));
1063 }
1064}