1use crate::ast::{
11 Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15#[derive(Debug, Clone)]
17pub struct SanitizeError {
18 pub field: String,
19 pub value: String,
20 pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(
26 f,
27 "AST validation failed: {} '{}' — {}",
28 self.field, self.value, self.reason
29 )
30 }
31}
32
33impl std::error::Error for SanitizeError {}
34
35const MAX_IDENT_LEN: usize = 63;
37
38fn is_safe_identifier(s: &str) -> bool {
42 !s.is_empty()
43 && s.len() <= MAX_IDENT_LEN
44 && s.bytes()
45 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
46}
47
48fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
50 if is_safe_identifier(value) {
51 Ok(())
52 } else {
53 Err(SanitizeError {
54 field: field.to_string(),
55 value: value.chars().take(40).collect(),
56 reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
57 })
58 }
59}
60
61fn check_fk_action(field: &str, value: &str) -> Result<(), SanitizeError> {
62 let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
63 if matches!(
64 normalized.as_str(),
65 "cascade" | "restrict" | "no action" | "set null" | "set default"
66 ) {
67 Ok(())
68 } else {
69 Err(SanitizeError {
70 field: field.to_string(),
71 value: value.chars().take(40).collect(),
72 reason:
73 "foreign key action must be cascade, restrict, no_action, set_null, or set_default"
74 .to_string(),
75 })
76 }
77}
78
79fn check_fk_deferrable(field: &str, value: &str) -> Result<(), SanitizeError> {
80 let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
81 if matches!(
82 normalized.as_str(),
83 "deferrable"
84 | "initially deferred"
85 | "initially immediate"
86 | "deferrable initially deferred"
87 | "deferrable initially immediate"
88 ) {
89 Ok(())
90 } else {
91 Err(SanitizeError {
92 field: field.to_string(),
93 value: value.chars().take(40).collect(),
94 reason: "foreign key deferrable clause must be deferrable, initially_deferred, or initially_immediate".to_string(),
95 })
96 }
97}
98
99fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
104 match expr {
105 Expr::Star => Ok(()),
106 Expr::Named(name) => check_ident(field, name),
107 Expr::Aliased { name, alias } => {
108 check_ident(field, name)?;
109 check_ident(&format!("{field}.alias"), alias)
110 }
111 Expr::Aggregate {
112 col, alias, filter, ..
113 } => {
114 if col != "*" {
115 check_ident(field, col)?;
116 }
117 if let Some(a) = alias {
118 check_ident(&format!("{field}.alias"), a)?;
119 }
120 if let Some(conditions) = filter {
121 for cond in conditions {
122 check_expr(&format!("{field}.filter"), &cond.left)?;
123 check_value(&format!("{field}.filter"), &cond.value)?;
124 }
125 }
126 Ok(())
127 }
128 Expr::FunctionCall { name, args, alias } => {
129 check_ident(field, name)?;
130 if let Some(a) = alias {
131 check_ident(&format!("{field}.alias"), a)?;
132 }
133 for arg in args {
134 check_expr(&format!("{field}.arg"), arg)?;
135 }
136 Ok(())
137 }
138 Expr::Cast {
139 expr,
140 target_type,
141 alias,
142 } => {
143 check_expr(field, expr)?;
144 check_ident(&format!("{field}.cast_type"), target_type)?;
145 if let Some(a) = alias {
146 check_ident(&format!("{field}.alias"), a)?;
147 }
148 Ok(())
149 }
150 Expr::Binary {
151 left, right, alias, ..
152 } => {
153 check_expr(field, left)?;
154 check_expr(field, right)?;
155 if let Some(a) = alias {
156 check_ident(&format!("{field}.alias"), a)?;
157 }
158 Ok(())
159 }
160 Expr::Literal(_) => Ok(()),
161 Expr::JsonAccess {
162 column,
163 alias,
164 path_segments,
165 ..
166 } => {
167 check_ident(field, column)?;
168 for (key, _) in path_segments {
169 if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
171 return Err(SanitizeError {
172 field: format!("{field}.json_path"),
173 value: key.chars().take(40).collect(),
174 reason: "JSON path key must be a safe identifier or integer".to_string(),
175 });
176 }
177 }
178 if let Some(a) = alias {
179 check_ident(&format!("{field}.alias"), a)?;
180 }
181 Ok(())
182 }
183 Expr::Subquery { query, alias } => {
184 validate_ast(query)?;
185 if let Some(a) = alias {
186 check_ident(&format!("{field}.alias"), a)?;
187 }
188 Ok(())
189 }
190 Expr::Exists { query, alias, .. } => {
191 validate_ast(query)?;
192 if let Some(a) = alias {
193 check_ident(&format!("{field}.alias"), a)?;
194 }
195 Ok(())
196 }
197 Expr::Window {
199 name,
200 func,
201 partition,
202 params,
203 order,
204 ..
205 } => {
206 if !name.is_empty() {
207 check_ident(&format!("{field}.window_alias"), name)?;
208 }
209 check_ident(&format!("{field}.window_func"), func)?;
210 for p in partition {
211 check_ident(&format!("{field}.partition"), p)?;
212 }
213 for p in params {
214 check_expr(&format!("{field}.window_param"), p)?;
215 }
216 for cage in order {
217 for cond in &cage.conditions {
218 check_expr(&format!("{field}.window_order"), &cond.left)?;
219 check_value(&format!("{field}.window_order"), &cond.value)?;
220 }
221 }
222 Ok(())
223 }
224 Expr::Case {
225 when_clauses,
226 else_value,
227 alias,
228 } => {
229 for (cond, val) in when_clauses {
230 check_expr(&format!("{field}.case_when"), &cond.left)?;
231 check_value(&format!("{field}.case_when"), &cond.value)?;
232 check_expr(&format!("{field}.case_then"), val)?;
233 }
234 if let Some(e) = else_value {
235 check_expr(&format!("{field}.case_else"), e)?;
236 }
237 if let Some(a) = alias {
238 check_ident(&format!("{field}.alias"), a)?;
239 }
240 Ok(())
241 }
242 Expr::SpecialFunction { args, alias, name } => {
243 check_ident(&format!("{field}.special_func"), name)?;
244 for (_, arg) in args {
245 check_expr(&format!("{field}.special_func_arg"), arg)?;
246 }
247 if let Some(a) = alias {
248 check_ident(&format!("{field}.alias"), a)?;
249 }
250 Ok(())
251 }
252 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
253 for elem in elements {
254 check_expr(&format!("{field}.element"), elem)?;
255 }
256 if let Some(a) = alias {
257 check_ident(&format!("{field}.alias"), a)?;
258 }
259 Ok(())
260 }
261 Expr::Subscript { expr, index, alias } => {
262 check_expr(&format!("{field}.subscript_expr"), expr)?;
263 check_expr(&format!("{field}.subscript_index"), index)?;
264 if let Some(a) = alias {
265 check_ident(&format!("{field}.alias"), a)?;
266 }
267 Ok(())
268 }
269 Expr::Collate {
270 expr,
271 collation,
272 alias,
273 } => {
274 check_expr(&format!("{field}.collate_expr"), expr)?;
275 check_ident(&format!("{field}.collation"), collation)?;
276 if let Some(a) = alias {
277 check_ident(&format!("{field}.alias"), a)?;
278 }
279 Ok(())
280 }
281 Expr::FieldAccess {
282 expr,
283 field: f,
284 alias,
285 } => {
286 check_expr(&format!("{field}.field_access_expr"), expr)?;
287 check_ident(&format!("{field}.field"), f)?;
288 if let Some(a) = alias {
289 check_ident(&format!("{field}.alias"), a)?;
290 }
291 Ok(())
292 }
293 Expr::Def { name, .. } => check_ident(field, name),
294 Expr::Mod { col, .. } => check_expr(field, col),
295 }
296}
297
298fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
300 match value {
301 Value::Subquery(q) => validate_ast(q),
302 Value::Array(vals) => {
303 for v in vals {
304 check_value(field, v)?;
305 }
306 Ok(())
307 }
308 Value::Expr(expr) => check_expr(field, expr),
309 _ => Ok(()),
310 }
311}
312
313pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
322 match cmd.action {
324 Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
325 return Err(SanitizeError {
326 field: "action".to_string(),
327 value: format!("{:?}", cmd.action),
328 reason: "procedural/session actions are not allowed via binary AST".to_string(),
329 });
330 }
331 _ => {}
332 }
333
334 if !cmd.table.is_empty() {
336 check_ident("table", &cmd.table)?;
337 }
338
339 for (i, col) in cmd.columns.iter().enumerate() {
341 check_expr(&format!("columns[{i}]"), col)?;
342 }
343
344 for (i, constraint) in cmd.table_constraints.iter().enumerate() {
346 match constraint {
347 TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
348 for col in cols {
349 check_ident(&format!("table_constraints[{i}].column"), col)?;
350 }
351 }
352 TableConstraint::ForeignKey {
353 name,
354 columns,
355 ref_table,
356 ref_columns,
357 on_delete,
358 on_update,
359 deferrable,
360 } => {
361 if let Some(name) = name {
362 check_ident(&format!("table_constraints[{i}].name"), name)?;
363 }
364 for col in columns {
365 check_ident(&format!("table_constraints[{i}].column"), col)?;
366 }
367 check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
368 for col in ref_columns {
369 check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
370 }
371 if let Some(action) = on_delete {
372 check_fk_action(&format!("table_constraints[{i}].on_delete"), action)?;
373 }
374 if let Some(action) = on_update {
375 check_fk_action(&format!("table_constraints[{i}].on_update"), action)?;
376 }
377 if let Some(clause) = deferrable {
378 check_fk_deferrable(&format!("table_constraints[{i}].deferrable"), clause)?;
379 }
380 }
381 }
382 }
383
384 for (i, join) in cmd.joins.iter().enumerate() {
386 for token in join.table.split_whitespace() {
389 check_ident(&format!("joins[{i}].table"), token)?;
390 }
391 if let Some(ref conditions) = join.on {
392 for cond in conditions {
393 check_expr(&format!("joins[{i}].on"), &cond.left)?;
394 check_value(&format!("joins[{i}].on"), &cond.value)?;
395 }
396 }
397 }
398
399 for cage in &cmd.cages {
401 for cond in &cage.conditions {
402 check_expr("cage.condition.left", &cond.left)?;
403 check_value("cage.condition.value", &cond.value)?;
404 }
405 }
406
407 for cte in &cmd.ctes {
409 check_ident("cte.name", &cte.name)?;
410 for col in &cte.columns {
411 check_ident("cte.column", col)?;
412 }
413 validate_ast(&cte.base_query)?;
414 if let Some(ref rq) = cte.recursive_query {
415 validate_ast(rq)?;
416 }
417 }
418
419 for expr in &cmd.distinct_on {
421 check_expr("distinct_on", expr)?;
422 }
423
424 if let Some(ref cols) = cmd.returning {
426 for col in cols {
427 check_expr("returning", col)?;
428 }
429 }
430
431 if let Some(ref oc) = cmd.on_conflict {
433 for col in &oc.columns {
434 check_ident("on_conflict.column", col)?;
435 }
436 if let ConflictAction::DoUpdate { assignments } = &oc.action {
437 for (col, expr) in assignments {
438 check_ident("on_conflict.assignment.column", col)?;
439 check_expr("on_conflict.assignment.expr", expr)?;
440 }
441 }
442 }
443
444 if let Some(ref merge) = cmd.merge {
446 if let Some(alias) = &merge.target_alias {
447 check_ident("merge.target_alias", alias)?;
448 }
449 match &merge.source {
450 MergeSource::Table { name, alias } => {
451 check_ident("merge.source.table", name)?;
452 if let Some(alias) = alias {
453 check_ident("merge.source.alias", alias)?;
454 }
455 }
456 MergeSource::Query { query, alias } => {
457 validate_ast(query)?;
458 if let Some(alias) = alias {
459 check_ident("merge.source.alias", alias)?;
460 }
461 }
462 }
463 for cond in &merge.on {
464 check_expr("merge.on.left", &cond.left)?;
465 check_value("merge.on.value", &cond.value)?;
466 }
467 for clause in &merge.clauses {
468 for cond in &clause.condition {
469 check_expr("merge.clause.condition.left", &cond.left)?;
470 check_value("merge.clause.condition.value", &cond.value)?;
471 }
472 match &clause.action {
473 MergeAction::Update { assignments } => {
474 for (col, expr) in assignments {
475 check_ident("merge.update.column", col)?;
476 check_expr("merge.update.expr", expr)?;
477 }
478 }
479 MergeAction::Insert { columns, values } => {
480 for col in columns {
481 check_ident("merge.insert.column", col)?;
482 }
483 for expr in values {
484 check_expr("merge.insert.expr", expr)?;
485 }
486 }
487 MergeAction::Delete | MergeAction::DoNothing => {}
488 }
489 }
490 }
491
492 for t in &cmd.from_tables {
494 check_ident("from_tables", t)?;
495 }
496 for t in &cmd.using_tables {
497 check_ident("using_tables", t)?;
498 }
499
500 for (_, sub) in &cmd.set_ops {
502 validate_ast(sub)?;
503 }
504
505 if let Some(ref sq) = cmd.source_query {
507 validate_ast(sq)?;
508 }
509
510 for cond in &cmd.having {
512 check_expr("having", &cond.left)?;
513 check_value("having", &cond.value)?;
514 }
515
516 if let Some(ref ch) = cmd.channel {
518 check_ident("channel", ch)?;
519 }
520
521 Ok(())
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::ast::Qail;
528
529 #[test]
530 fn valid_simple_query_passes() {
531 let cmd = Qail::get("users").columns(["id", "name"]);
532 assert!(validate_ast(&cmd).is_ok());
533 }
534
535 #[test]
536 fn sql_injection_in_table_rejected() {
537 let cmd = Qail::get("users; DROP TABLE users; --");
538 let err = validate_ast(&cmd).unwrap_err();
539 assert_eq!(err.field, "table");
540 }
541
542 #[test]
543 fn call_action_rejected() {
544 let cmd = Qail {
545 action: Action::Call,
546 table: "my_proc()".to_string(),
547 ..Default::default()
548 };
549 let err = validate_ast(&cmd).unwrap_err();
550 assert_eq!(err.field, "action");
551 }
552
553 #[test]
554 fn do_action_rejected() {
555 let cmd = Qail {
556 action: Action::Do,
557 table: "plpgsql".to_string(),
558 ..Default::default()
559 };
560 let err = validate_ast(&cmd).unwrap_err();
561 assert_eq!(err.field, "action");
562 }
563
564 #[test]
565 fn valid_qualified_name_passes() {
566 let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
567 assert!(validate_ast(&cmd).is_ok());
568 }
569
570 #[test]
571 fn injection_in_join_table_rejected() {
572 use crate::ast::JoinKind;
573 let cmd = Qail::get("users").join(
574 JoinKind::Left,
575 "orders; DROP TABLE x",
576 "users.id",
577 "orders.user_id",
578 );
579 let err = validate_ast(&cmd).unwrap_err();
580 assert!(err.field.contains("joins"));
581 }
582
583 #[test]
584 fn injection_in_column_rejected() {
585 let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
586 let err = validate_ast(&cmd).unwrap_err();
587 assert!(err.field.contains("columns"));
588 }
589
590 #[test]
591 fn on_conflict_update_assignment_expression_injection_rejected() {
592 let cmd = Qail::add("users")
593 .set_value("id", 1)
594 .set_value("name", "Alice")
595 .on_conflict_update(
596 &["id"],
597 &[(
598 "name",
599 Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
600 )],
601 );
602
603 let err = validate_ast(&cmd).unwrap_err();
604
605 assert_eq!(err.field, "on_conflict.assignment.expr");
606 assert!(
607 err.value.contains("EXCLUDED.name"),
608 "unexpected rejected value: {}",
609 err.value
610 );
611 }
612
613 #[test]
614 fn on_conflict_update_assignment_column_injection_rejected() {
615 let cmd = Qail::add("users")
616 .set_value("id", 1)
617 .set_value("name", "Alice")
618 .on_conflict_update(
619 &["id"],
620 &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
621 );
622
623 let err = validate_ast(&cmd).unwrap_err();
624
625 assert_eq!(err.field, "on_conflict.assignment.column");
626 }
627
628 #[test]
629 fn aggregate_filter_value_expression_injection_rejected() {
630 use crate::ast::{AggregateFunc, Condition, Operator, Value};
631
632 let mut cmd = Qail::get("events");
633 cmd.columns.push(Expr::Aggregate {
634 col: "id".to_string(),
635 func: AggregateFunc::Count,
636 distinct: false,
637 filter: Some(vec![Condition {
638 left: Expr::Named("direction".to_string()),
639 op: Operator::Eq,
640 value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
641 is_array_unnest: false,
642 }]),
643 alias: None,
644 });
645
646 let err = validate_ast(&cmd).unwrap_err();
647 assert_eq!(err.field, "columns[0].filter");
648 }
649
650 #[test]
651 fn count_star_aggregate_passes_sanitizer() {
652 use crate::ast::AggregateFunc;
653
654 let mut cmd = Qail::get("events");
655 cmd.columns.push(Expr::Aggregate {
656 col: "*".to_string(),
657 func: AggregateFunc::Count,
658 distinct: false,
659 filter: None,
660 alias: Some("total".to_string()),
661 });
662
663 assert!(validate_ast(&cmd).is_ok());
664 }
665
666 #[test]
667 fn case_when_complex_condition_expression_passes_sanitizer() {
668 use crate::ast::{Condition, Operator, Value};
669
670 let mut cmd = Qail::get("users");
671 cmd.columns.push(Expr::Case {
672 when_clauses: vec![(
673 Condition {
674 left: Expr::Cast {
675 expr: Box::new(Expr::JsonAccess {
676 column: "profile".to_string(),
677 path_segments: vec![("active".to_string(), true)],
678 alias: None,
679 }),
680 target_type: "integer".to_string(),
681 alias: None,
682 },
683 op: Operator::Gt,
684 value: Value::Int(0),
685 is_array_unnest: false,
686 },
687 Box::new(Expr::Literal(Value::String("active".to_string()))),
688 )],
689 else_value: Some(Box::new(Expr::Literal(Value::String(
690 "inactive".to_string(),
691 )))),
692 alias: Some("status_label".to_string()),
693 });
694
695 assert!(validate_ast(&cmd).is_ok());
696 }
697
698 #[test]
699 fn empty_table_name_passes() {
700 let cmd = Qail {
702 action: Action::TxnStart,
703 table: String::new(),
704 ..Default::default()
705 };
706 assert!(validate_ast(&cmd).is_ok());
707 }
708
709 #[test]
710 fn oversized_identifier_rejected() {
711 let long_name = "a".repeat(64);
712 let cmd = Qail::get(&long_name);
713 let err = validate_ast(&cmd).unwrap_err();
714 assert!(err.reason.contains("63"));
715 }
716}