1use crate::ast::{
11 Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15#[derive(Debug, Clone)]
17pub struct SanitizeError {
18 pub field: String,
19 pub value: String,
20 pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 write!(
26 f,
27 "AST validation failed: {} '{}' — {}",
28 self.field, self.value, self.reason
29 )
30 }
31}
32
33impl std::error::Error for SanitizeError {}
34
35const MAX_IDENT_LEN: usize = 63;
37
38fn is_safe_identifier(s: &str) -> bool {
42 !s.is_empty()
43 && s.len() <= MAX_IDENT_LEN
44 && s.bytes()
45 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
46}
47
48fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
50 if is_safe_identifier(value) {
51 Ok(())
52 } else {
53 Err(SanitizeError {
54 field: field.to_string(),
55 value: value.chars().take(40).collect(),
56 reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
57 })
58 }
59}
60
61fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
66 match expr {
67 Expr::Star => Ok(()),
68 Expr::Named(name) => check_ident(field, name),
69 Expr::Aliased { name, alias } => {
70 check_ident(field, name)?;
71 check_ident(&format!("{field}.alias"), alias)
72 }
73 Expr::Aggregate {
74 col, alias, filter, ..
75 } => {
76 if col != "*" {
77 check_ident(field, col)?;
78 }
79 if let Some(a) = alias {
80 check_ident(&format!("{field}.alias"), a)?;
81 }
82 if let Some(conditions) = filter {
83 for cond in conditions {
84 check_expr(&format!("{field}.filter"), &cond.left)?;
85 check_value(&format!("{field}.filter"), &cond.value)?;
86 }
87 }
88 Ok(())
89 }
90 Expr::FunctionCall { name, args, alias } => {
91 check_ident(field, name)?;
92 if let Some(a) = alias {
93 check_ident(&format!("{field}.alias"), a)?;
94 }
95 for arg in args {
96 check_expr(&format!("{field}.arg"), arg)?;
97 }
98 Ok(())
99 }
100 Expr::Cast {
101 expr,
102 target_type,
103 alias,
104 } => {
105 check_expr(field, expr)?;
106 check_ident(&format!("{field}.cast_type"), target_type)?;
107 if let Some(a) = alias {
108 check_ident(&format!("{field}.alias"), a)?;
109 }
110 Ok(())
111 }
112 Expr::Binary {
113 left, right, alias, ..
114 } => {
115 check_expr(field, left)?;
116 check_expr(field, right)?;
117 if let Some(a) = alias {
118 check_ident(&format!("{field}.alias"), a)?;
119 }
120 Ok(())
121 }
122 Expr::Literal(_) => Ok(()),
123 Expr::JsonAccess {
124 column,
125 alias,
126 path_segments,
127 ..
128 } => {
129 check_ident(field, column)?;
130 for (key, _) in path_segments {
131 if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
133 return Err(SanitizeError {
134 field: format!("{field}.json_path"),
135 value: key.chars().take(40).collect(),
136 reason: "JSON path key must be a safe identifier or integer".to_string(),
137 });
138 }
139 }
140 if let Some(a) = alias {
141 check_ident(&format!("{field}.alias"), a)?;
142 }
143 Ok(())
144 }
145 Expr::Subquery { query, alias } => {
146 validate_ast(query)?;
147 if let Some(a) = alias {
148 check_ident(&format!("{field}.alias"), a)?;
149 }
150 Ok(())
151 }
152 Expr::Exists { query, alias, .. } => {
153 validate_ast(query)?;
154 if let Some(a) = alias {
155 check_ident(&format!("{field}.alias"), a)?;
156 }
157 Ok(())
158 }
159 Expr::Window {
161 name,
162 func,
163 partition,
164 params,
165 order,
166 ..
167 } => {
168 if !name.is_empty() {
169 check_ident(&format!("{field}.window_alias"), name)?;
170 }
171 check_ident(&format!("{field}.window_func"), func)?;
172 for p in partition {
173 check_ident(&format!("{field}.partition"), p)?;
174 }
175 for p in params {
176 check_expr(&format!("{field}.window_param"), p)?;
177 }
178 for cage in order {
179 for cond in &cage.conditions {
180 check_expr(&format!("{field}.window_order"), &cond.left)?;
181 check_value(&format!("{field}.window_order"), &cond.value)?;
182 }
183 }
184 Ok(())
185 }
186 Expr::Case {
187 when_clauses,
188 else_value,
189 alias,
190 } => {
191 for (cond, val) in when_clauses {
192 check_expr(&format!("{field}.case_when"), &cond.left)?;
193 check_value(&format!("{field}.case_when"), &cond.value)?;
194 check_expr(&format!("{field}.case_then"), val)?;
195 }
196 if let Some(e) = else_value {
197 check_expr(&format!("{field}.case_else"), e)?;
198 }
199 if let Some(a) = alias {
200 check_ident(&format!("{field}.alias"), a)?;
201 }
202 Ok(())
203 }
204 Expr::SpecialFunction { args, alias, name } => {
205 check_ident(&format!("{field}.special_func"), name)?;
206 for (_, arg) in args {
207 check_expr(&format!("{field}.special_func_arg"), arg)?;
208 }
209 if let Some(a) = alias {
210 check_ident(&format!("{field}.alias"), a)?;
211 }
212 Ok(())
213 }
214 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
215 for elem in elements {
216 check_expr(&format!("{field}.element"), elem)?;
217 }
218 if let Some(a) = alias {
219 check_ident(&format!("{field}.alias"), a)?;
220 }
221 Ok(())
222 }
223 Expr::Subscript { expr, index, alias } => {
224 check_expr(&format!("{field}.subscript_expr"), expr)?;
225 check_expr(&format!("{field}.subscript_index"), index)?;
226 if let Some(a) = alias {
227 check_ident(&format!("{field}.alias"), a)?;
228 }
229 Ok(())
230 }
231 Expr::Collate {
232 expr,
233 collation,
234 alias,
235 } => {
236 check_expr(&format!("{field}.collate_expr"), expr)?;
237 check_ident(&format!("{field}.collation"), collation)?;
238 if let Some(a) = alias {
239 check_ident(&format!("{field}.alias"), a)?;
240 }
241 Ok(())
242 }
243 Expr::FieldAccess {
244 expr,
245 field: f,
246 alias,
247 } => {
248 check_expr(&format!("{field}.field_access_expr"), expr)?;
249 check_ident(&format!("{field}.field"), f)?;
250 if let Some(a) = alias {
251 check_ident(&format!("{field}.alias"), a)?;
252 }
253 Ok(())
254 }
255 Expr::Def { name, .. } => check_ident(field, name),
256 Expr::Mod { col, .. } => check_expr(field, col),
257 }
258}
259
260fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
262 match value {
263 Value::Subquery(q) => validate_ast(q),
264 Value::Array(vals) => {
265 for v in vals {
266 check_value(field, v)?;
267 }
268 Ok(())
269 }
270 Value::Expr(expr) => check_expr(field, expr),
271 _ => Ok(()),
272 }
273}
274
275pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
284 match cmd.action {
286 Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
287 return Err(SanitizeError {
288 field: "action".to_string(),
289 value: format!("{:?}", cmd.action),
290 reason: "procedural/session actions are not allowed via binary AST".to_string(),
291 });
292 }
293 _ => {}
294 }
295
296 if !cmd.table.is_empty() {
298 check_ident("table", &cmd.table)?;
299 }
300
301 for (i, col) in cmd.columns.iter().enumerate() {
303 check_expr(&format!("columns[{i}]"), col)?;
304 }
305
306 for (i, constraint) in cmd.table_constraints.iter().enumerate() {
308 match constraint {
309 TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
310 for col in cols {
311 check_ident(&format!("table_constraints[{i}].column"), col)?;
312 }
313 }
314 TableConstraint::ForeignKey {
315 name,
316 columns,
317 ref_table,
318 ref_columns,
319 } => {
320 if let Some(name) = name {
321 check_ident(&format!("table_constraints[{i}].name"), name)?;
322 }
323 for col in columns {
324 check_ident(&format!("table_constraints[{i}].column"), col)?;
325 }
326 check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
327 for col in ref_columns {
328 check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
329 }
330 }
331 }
332 }
333
334 for (i, join) in cmd.joins.iter().enumerate() {
336 for token in join.table.split_whitespace() {
339 check_ident(&format!("joins[{i}].table"), token)?;
340 }
341 if let Some(ref conditions) = join.on {
342 for cond in conditions {
343 check_expr(&format!("joins[{i}].on"), &cond.left)?;
344 check_value(&format!("joins[{i}].on"), &cond.value)?;
345 }
346 }
347 }
348
349 for cage in &cmd.cages {
351 for cond in &cage.conditions {
352 check_expr("cage.condition.left", &cond.left)?;
353 check_value("cage.condition.value", &cond.value)?;
354 }
355 }
356
357 for cte in &cmd.ctes {
359 check_ident("cte.name", &cte.name)?;
360 for col in &cte.columns {
361 check_ident("cte.column", col)?;
362 }
363 validate_ast(&cte.base_query)?;
364 if let Some(ref rq) = cte.recursive_query {
365 validate_ast(rq)?;
366 }
367 }
368
369 for expr in &cmd.distinct_on {
371 check_expr("distinct_on", expr)?;
372 }
373
374 if let Some(ref cols) = cmd.returning {
376 for col in cols {
377 check_expr("returning", col)?;
378 }
379 }
380
381 if let Some(ref oc) = cmd.on_conflict {
383 for col in &oc.columns {
384 check_ident("on_conflict.column", col)?;
385 }
386 if let ConflictAction::DoUpdate { assignments } = &oc.action {
387 for (col, expr) in assignments {
388 check_ident("on_conflict.assignment.column", col)?;
389 check_expr("on_conflict.assignment.expr", expr)?;
390 }
391 }
392 }
393
394 if let Some(ref merge) = cmd.merge {
396 if let Some(alias) = &merge.target_alias {
397 check_ident("merge.target_alias", alias)?;
398 }
399 match &merge.source {
400 MergeSource::Table { name, alias } => {
401 check_ident("merge.source.table", name)?;
402 if let Some(alias) = alias {
403 check_ident("merge.source.alias", alias)?;
404 }
405 }
406 MergeSource::Query { query, alias } => {
407 validate_ast(query)?;
408 if let Some(alias) = alias {
409 check_ident("merge.source.alias", alias)?;
410 }
411 }
412 }
413 for cond in &merge.on {
414 check_expr("merge.on.left", &cond.left)?;
415 check_value("merge.on.value", &cond.value)?;
416 }
417 for clause in &merge.clauses {
418 for cond in &clause.condition {
419 check_expr("merge.clause.condition.left", &cond.left)?;
420 check_value("merge.clause.condition.value", &cond.value)?;
421 }
422 match &clause.action {
423 MergeAction::Update { assignments } => {
424 for (col, expr) in assignments {
425 check_ident("merge.update.column", col)?;
426 check_expr("merge.update.expr", expr)?;
427 }
428 }
429 MergeAction::Insert { columns, values } => {
430 for col in columns {
431 check_ident("merge.insert.column", col)?;
432 }
433 for expr in values {
434 check_expr("merge.insert.expr", expr)?;
435 }
436 }
437 MergeAction::Delete | MergeAction::DoNothing => {}
438 }
439 }
440 }
441
442 for t in &cmd.from_tables {
444 check_ident("from_tables", t)?;
445 }
446 for t in &cmd.using_tables {
447 check_ident("using_tables", t)?;
448 }
449
450 for (_, sub) in &cmd.set_ops {
452 validate_ast(sub)?;
453 }
454
455 if let Some(ref sq) = cmd.source_query {
457 validate_ast(sq)?;
458 }
459
460 for cond in &cmd.having {
462 check_expr("having", &cond.left)?;
463 check_value("having", &cond.value)?;
464 }
465
466 if let Some(ref ch) = cmd.channel {
468 check_ident("channel", ch)?;
469 }
470
471 Ok(())
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::ast::Qail;
478
479 #[test]
480 fn valid_simple_query_passes() {
481 let cmd = Qail::get("users").columns(["id", "name"]);
482 assert!(validate_ast(&cmd).is_ok());
483 }
484
485 #[test]
486 fn sql_injection_in_table_rejected() {
487 let cmd = Qail::get("users; DROP TABLE users; --");
488 let err = validate_ast(&cmd).unwrap_err();
489 assert_eq!(err.field, "table");
490 }
491
492 #[test]
493 fn call_action_rejected() {
494 let cmd = Qail {
495 action: Action::Call,
496 table: "my_proc()".to_string(),
497 ..Default::default()
498 };
499 let err = validate_ast(&cmd).unwrap_err();
500 assert_eq!(err.field, "action");
501 }
502
503 #[test]
504 fn do_action_rejected() {
505 let cmd = Qail {
506 action: Action::Do,
507 table: "plpgsql".to_string(),
508 ..Default::default()
509 };
510 let err = validate_ast(&cmd).unwrap_err();
511 assert_eq!(err.field, "action");
512 }
513
514 #[test]
515 fn valid_qualified_name_passes() {
516 let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
517 assert!(validate_ast(&cmd).is_ok());
518 }
519
520 #[test]
521 fn injection_in_join_table_rejected() {
522 use crate::ast::JoinKind;
523 let cmd = Qail::get("users").join(
524 JoinKind::Left,
525 "orders; DROP TABLE x",
526 "users.id",
527 "orders.user_id",
528 );
529 let err = validate_ast(&cmd).unwrap_err();
530 assert!(err.field.contains("joins"));
531 }
532
533 #[test]
534 fn injection_in_column_rejected() {
535 let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
536 let err = validate_ast(&cmd).unwrap_err();
537 assert!(err.field.contains("columns"));
538 }
539
540 #[test]
541 fn on_conflict_update_assignment_expression_injection_rejected() {
542 let cmd = Qail::add("users")
543 .set_value("id", 1)
544 .set_value("name", "Alice")
545 .on_conflict_update(
546 &["id"],
547 &[(
548 "name",
549 Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
550 )],
551 );
552
553 let err = validate_ast(&cmd).unwrap_err();
554
555 assert_eq!(err.field, "on_conflict.assignment.expr");
556 assert!(
557 err.value.contains("EXCLUDED.name"),
558 "unexpected rejected value: {}",
559 err.value
560 );
561 }
562
563 #[test]
564 fn on_conflict_update_assignment_column_injection_rejected() {
565 let cmd = Qail::add("users")
566 .set_value("id", 1)
567 .set_value("name", "Alice")
568 .on_conflict_update(
569 &["id"],
570 &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
571 );
572
573 let err = validate_ast(&cmd).unwrap_err();
574
575 assert_eq!(err.field, "on_conflict.assignment.column");
576 }
577
578 #[test]
579 fn aggregate_filter_value_expression_injection_rejected() {
580 use crate::ast::{AggregateFunc, Condition, Operator, Value};
581
582 let mut cmd = Qail::get("events");
583 cmd.columns.push(Expr::Aggregate {
584 col: "id".to_string(),
585 func: AggregateFunc::Count,
586 distinct: false,
587 filter: Some(vec![Condition {
588 left: Expr::Named("direction".to_string()),
589 op: Operator::Eq,
590 value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
591 is_array_unnest: false,
592 }]),
593 alias: None,
594 });
595
596 let err = validate_ast(&cmd).unwrap_err();
597 assert_eq!(err.field, "columns[0].filter");
598 }
599
600 #[test]
601 fn count_star_aggregate_passes_sanitizer() {
602 use crate::ast::AggregateFunc;
603
604 let mut cmd = Qail::get("events");
605 cmd.columns.push(Expr::Aggregate {
606 col: "*".to_string(),
607 func: AggregateFunc::Count,
608 distinct: false,
609 filter: None,
610 alias: Some("total".to_string()),
611 });
612
613 assert!(validate_ast(&cmd).is_ok());
614 }
615
616 #[test]
617 fn case_when_complex_condition_expression_passes_sanitizer() {
618 use crate::ast::{Condition, Operator, Value};
619
620 let mut cmd = Qail::get("users");
621 cmd.columns.push(Expr::Case {
622 when_clauses: vec![(
623 Condition {
624 left: Expr::Cast {
625 expr: Box::new(Expr::JsonAccess {
626 column: "profile".to_string(),
627 path_segments: vec![("active".to_string(), true)],
628 alias: None,
629 }),
630 target_type: "integer".to_string(),
631 alias: None,
632 },
633 op: Operator::Gt,
634 value: Value::Int(0),
635 is_array_unnest: false,
636 },
637 Box::new(Expr::Literal(Value::String("active".to_string()))),
638 )],
639 else_value: Some(Box::new(Expr::Literal(Value::String(
640 "inactive".to_string(),
641 )))),
642 alias: Some("status_label".to_string()),
643 });
644
645 assert!(validate_ast(&cmd).is_ok());
646 }
647
648 #[test]
649 fn empty_table_name_passes() {
650 let cmd = Qail {
652 action: Action::TxnStart,
653 table: String::new(),
654 ..Default::default()
655 };
656 assert!(validate_ast(&cmd).is_ok());
657 }
658
659 #[test]
660 fn oversized_identifier_rejected() {
661 let long_name = "a".repeat(64);
662 let cmd = Qail::get(&long_name);
663 let err = validate_ast(&cmd).unwrap_err();
664 assert!(err.reason.contains("63"));
665 }
666}