Skip to main content

qail_core/
sanitize.rs

1//! AST structural sanitization for untrusted input.
2//!
3//! The text parser enforces identifier constraints (alphanumeric + `_` + `.`),
4//! but externally provided ASTs (binary endpoints, APIs, generated payloads, etc.)
5//! may bypass parser-level identifier checks and still reach execution.
6//!
7//! Call [`validate_ast`](crate::sanitize::validate_ast) on any `Qail` obtained from an untrusted source
8//! (binary endpoint, external API, etc.) before execution.
9
10use crate::ast::{
11    Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15/// Error returned when AST structural validation fails.
16#[derive(Debug, Clone)]
17pub struct SanitizeError {
18    pub field: String,
19    pub value: String,
20    pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        write!(
26            f,
27            "AST validation failed: {} '{}' — {}",
28            self.field, self.value, self.reason
29        )
30    }
31}
32
33impl std::error::Error for SanitizeError {}
34
35/// Maximum identifier length (PostgreSQL NAMEDATALEN - 1).
36const MAX_IDENT_LEN: usize = 63;
37const MAX_RAW_FUNCTION_LEN: usize = 1024;
38
39/// Validate that an identifier matches the parser grammar: `[a-zA-Z0-9_.]`.
40///
41/// Also rejects empty identifiers and those exceeding PostgreSQL's 63-char limit.
42fn is_safe_identifier(s: &str) -> bool {
43    !s.is_empty()
44        && s.split('.').all(|part| {
45            !part.is_empty()
46                && part.len() <= MAX_IDENT_LEN
47                && part.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_')
48        })
49}
50
51/// Validate an identifier, returning a `SanitizeError` if invalid.
52fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
53    if is_safe_identifier(value) {
54        Ok(())
55    } else {
56        Err(SanitizeError {
57            field: field.to_string(),
58            value: value.chars().take(40).collect(),
59            reason: "identifier parts must match [a-zA-Z0-9_] and be ≤63 chars".to_string(),
60        })
61    }
62}
63
64fn check_named_param(field: &str, value: &str) -> Result<(), SanitizeError> {
65    let mut chars = value.chars();
66    let Some(first) = chars.next() else {
67        return Err(SanitizeError {
68            field: field.to_string(),
69            value: String::new(),
70            reason: "named parameter cannot be empty".to_string(),
71        });
72    };
73
74    if (first.is_ascii_alphabetic() || first == '_')
75        && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
76    {
77        Ok(())
78    } else {
79        Err(SanitizeError {
80            field: field.to_string(),
81            value: value.chars().take(40).collect(),
82            reason: "named parameters must match [a-zA-Z_][a-zA-Z0-9_]*".to_string(),
83        })
84    }
85}
86
87fn check_raw_function_value(field: &str, value: &str) -> Result<(), SanitizeError> {
88    if value.len() <= MAX_RAW_FUNCTION_LEN
89        && !value.contains('\0')
90        && !value.contains(';')
91        && !value.contains("--")
92        && !value.contains("/*")
93        && !value.contains("*/")
94    {
95        Ok(())
96    } else {
97        Err(SanitizeError {
98            field: field.to_string(),
99            value: value.chars().take(40).collect(),
100            reason: "raw function values cannot contain NUL, statement separators, or comments"
101                .to_string(),
102        })
103    }
104}
105
106fn contains_unquoted_statement_delimiter(value: &str) -> bool {
107    let bytes = value.as_bytes();
108    let mut i = 0;
109    let mut in_single = false;
110    let mut in_double = false;
111
112    while i < bytes.len() {
113        let b = bytes[i];
114        if b == 0 {
115            return true;
116        }
117
118        if in_single {
119            if b == b'\'' {
120                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
121                    i += 2;
122                    continue;
123                }
124                in_single = false;
125            }
126            i += 1;
127            continue;
128        }
129
130        if in_double {
131            if b == b'"' {
132                if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
133                    i += 2;
134                    continue;
135                }
136                in_double = false;
137            }
138            i += 1;
139            continue;
140        }
141
142        match b {
143            b'\'' => in_single = true,
144            b'"' => in_double = true,
145            b';' => return true,
146            b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => return true,
147            b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => return true,
148            _ => {}
149        }
150        i += 1;
151    }
152
153    false
154}
155
156fn check_sql_expr_fragment(field: &str, value: &str) -> Result<(), SanitizeError> {
157    let expr = value.trim();
158    if !expr.is_empty() && !contains_unquoted_statement_delimiter(expr) {
159        Ok(())
160    } else {
161        Err(SanitizeError {
162            field: field.to_string(),
163            value: value.chars().take(40).collect(),
164            reason: "SQL expression fragments cannot be empty or contain unquoted NUL, statement separators, or comments".to_string(),
165        })
166    }
167}
168
169fn table_ref_error(field: &str, value: &str) -> SanitizeError {
170    SanitizeError {
171        field: field.to_string(),
172        value: value.chars().take(40).collect(),
173        reason: "table references must be identifier or identifier [AS] alias".to_string(),
174    }
175}
176
177fn check_table_ref(field: &str, value: &str) -> Result<(), SanitizeError> {
178    let parts = value.split_whitespace().collect::<Vec<_>>();
179    match parts.as_slice() {
180        [table] => check_ident(field, table),
181        [table, alias] => {
182            check_ident(field, table)?;
183            check_ident(field, alias)
184        }
185        [table, as_keyword, alias] if as_keyword.eq_ignore_ascii_case("as") => {
186            check_ident(field, table)?;
187            check_ident(field, alias)
188        }
189        _ => Err(table_ref_error(field, value)),
190    }
191}
192
193fn action_allows_table_alias(action: Action) -> bool {
194    matches!(
195        action,
196        Action::Get
197            | Action::Cnt
198            | Action::Set
199            | Action::Del
200            | Action::Export
201            | Action::Explain
202            | Action::ExplainAnalyze
203            | Action::Over
204    )
205}
206
207fn check_fk_action(field: &str, value: &str) -> Result<(), SanitizeError> {
208    let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
209    if matches!(
210        normalized.as_str(),
211        "cascade" | "restrict" | "no action" | "set null" | "set default"
212    ) {
213        Ok(())
214    } else {
215        Err(SanitizeError {
216            field: field.to_string(),
217            value: value.chars().take(40).collect(),
218            reason:
219                "foreign key action must be cascade, restrict, no_action, set_null, or set_default"
220                    .to_string(),
221        })
222    }
223}
224
225fn check_fk_deferrable(field: &str, value: &str) -> Result<(), SanitizeError> {
226    let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
227    if matches!(
228        normalized.as_str(),
229        "deferrable"
230            | "initially deferred"
231            | "initially immediate"
232            | "deferrable initially deferred"
233            | "deferrable initially immediate"
234    ) {
235        Ok(())
236    } else {
237        Err(SanitizeError {
238            field: field.to_string(),
239            value: value.chars().take(40).collect(),
240            reason: "foreign key deferrable clause must be deferrable, initially_deferred, or initially_immediate".to_string(),
241        })
242    }
243}
244
245/// Validate an `Expr` node for unsafe patterns.
246///
247/// - `Expr::Named` must be a safe identifier.
248/// - Recursive variants (Cast, Binary, etc.) are validated recursively.
249fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
250    match expr {
251        Expr::Star => Ok(()),
252        Expr::Named(name) => check_ident(field, name),
253        Expr::Aliased { name, alias } => {
254            check_ident(field, name)?;
255            check_ident(&format!("{field}.alias"), alias)
256        }
257        Expr::Aggregate {
258            col, alias, filter, ..
259        } => {
260            if col != "*" {
261                check_ident(field, col)?;
262            }
263            if let Some(a) = alias {
264                check_ident(&format!("{field}.alias"), a)?;
265            }
266            if let Some(conditions) = filter {
267                for cond in conditions {
268                    check_expr(&format!("{field}.filter"), &cond.left)?;
269                    check_value(&format!("{field}.filter"), &cond.value)?;
270                }
271            }
272            Ok(())
273        }
274        Expr::FunctionCall { name, args, alias } => {
275            check_ident(field, name)?;
276            if let Some(a) = alias {
277                check_ident(&format!("{field}.alias"), a)?;
278            }
279            for arg in args {
280                check_expr(&format!("{field}.arg"), arg)?;
281            }
282            Ok(())
283        }
284        Expr::Cast {
285            expr,
286            target_type,
287            alias,
288        } => {
289            check_expr(field, expr)?;
290            check_ident(&format!("{field}.cast_type"), target_type)?;
291            if let Some(a) = alias {
292                check_ident(&format!("{field}.alias"), a)?;
293            }
294            Ok(())
295        }
296        Expr::Binary {
297            left, right, alias, ..
298        } => {
299            check_expr(field, left)?;
300            check_expr(field, right)?;
301            if let Some(a) = alias {
302                check_ident(&format!("{field}.alias"), a)?;
303            }
304            Ok(())
305        }
306        Expr::Literal(_) => Ok(()),
307        Expr::JsonAccess {
308            column,
309            alias,
310            path_segments,
311            ..
312        } => {
313            check_ident(field, column)?;
314            for (key, _) in path_segments {
315                // Integer indices are fine; string keys must be safe identifiers
316                if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
317                    return Err(SanitizeError {
318                        field: format!("{field}.json_path"),
319                        value: key.chars().take(40).collect(),
320                        reason: "JSON path key must be a safe identifier or integer".to_string(),
321                    });
322                }
323            }
324            if let Some(a) = alias {
325                check_ident(&format!("{field}.alias"), a)?;
326            }
327            Ok(())
328        }
329        Expr::Subquery { query, alias } => {
330            validate_ast(query)?;
331            if let Some(a) = alias {
332                check_ident(&format!("{field}.alias"), a)?;
333            }
334            Ok(())
335        }
336        Expr::Exists { query, alias, .. } => {
337            validate_ast(query)?;
338            if let Some(a) = alias {
339                check_ident(&format!("{field}.alias"), a)?;
340            }
341            Ok(())
342        }
343        // For all other complex Expr variants, validate aliases where present
344        Expr::Window {
345            name,
346            func,
347            partition,
348            params,
349            order,
350            ..
351        } => {
352            if !name.is_empty() {
353                check_ident(&format!("{field}.window_alias"), name)?;
354            }
355            check_ident(&format!("{field}.window_func"), func)?;
356            for p in partition {
357                check_ident(&format!("{field}.partition"), p)?;
358            }
359            for p in params {
360                check_expr(&format!("{field}.window_param"), p)?;
361            }
362            for cage in order {
363                for cond in &cage.conditions {
364                    check_expr(&format!("{field}.window_order"), &cond.left)?;
365                    check_value(&format!("{field}.window_order"), &cond.value)?;
366                }
367            }
368            Ok(())
369        }
370        Expr::Case {
371            when_clauses,
372            else_value,
373            alias,
374        } => {
375            for (cond, val) in when_clauses {
376                check_expr(&format!("{field}.case_when"), &cond.left)?;
377                check_value(&format!("{field}.case_when"), &cond.value)?;
378                check_expr(&format!("{field}.case_then"), val)?;
379            }
380            if let Some(e) = else_value {
381                check_expr(&format!("{field}.case_else"), e)?;
382            }
383            if let Some(a) = alias {
384                check_ident(&format!("{field}.alias"), a)?;
385            }
386            Ok(())
387        }
388        Expr::SpecialFunction { args, alias, name } => {
389            check_ident(&format!("{field}.special_func"), name)?;
390            for (_, arg) in args {
391                check_expr(&format!("{field}.special_func_arg"), arg)?;
392            }
393            if let Some(a) = alias {
394                check_ident(&format!("{field}.alias"), a)?;
395            }
396            Ok(())
397        }
398        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
399            for elem in elements {
400                check_expr(&format!("{field}.element"), elem)?;
401            }
402            if let Some(a) = alias {
403                check_ident(&format!("{field}.alias"), a)?;
404            }
405            Ok(())
406        }
407        Expr::Subscript { expr, index, alias } => {
408            check_expr(&format!("{field}.subscript_expr"), expr)?;
409            check_expr(&format!("{field}.subscript_index"), index)?;
410            if let Some(a) = alias {
411                check_ident(&format!("{field}.alias"), a)?;
412            }
413            Ok(())
414        }
415        Expr::Collate {
416            expr,
417            collation,
418            alias,
419        } => {
420            check_expr(&format!("{field}.collate_expr"), expr)?;
421            check_ident(&format!("{field}.collation"), collation)?;
422            if let Some(a) = alias {
423                check_ident(&format!("{field}.alias"), a)?;
424            }
425            Ok(())
426        }
427        Expr::FieldAccess {
428            expr,
429            field: f,
430            alias,
431        } => {
432            check_expr(&format!("{field}.field_access_expr"), expr)?;
433            check_ident(&format!("{field}.field"), f)?;
434            if let Some(a) = alias {
435                check_ident(&format!("{field}.alias"), a)?;
436            }
437            Ok(())
438        }
439        Expr::Def { name, .. } => check_ident(field, name),
440        Expr::Mod { col, .. } => check_expr(field, col),
441    }
442}
443
444/// Check a `Value` for embedded subqueries.
445fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
446    match value {
447        Value::Column(column) => check_ident(&format!("{field}.column"), column),
448        Value::NamedParam(name) => check_named_param(&format!("{field}.named_param"), name),
449        Value::Function(function) => {
450            check_raw_function_value(&format!("{field}.function"), function)
451        }
452        Value::Subquery(q) => validate_ast(q),
453        Value::Array(vals) => {
454            for v in vals {
455                check_value(field, v)?;
456            }
457            Ok(())
458        }
459        Value::Expr(expr) => check_expr(field, expr),
460        _ => Ok(()),
461    }
462}
463
464/// Validate a `Qail` AST from an untrusted source.
465///
466/// Checks all identifier fields against the parser grammar (`[a-zA-Z0-9_.]`)
467/// and rejects dangerous procedural actions.
468///
469/// # Errors
470///
471/// Returns `SanitizeError` on the first violation found.
472pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
473    // ── Block dangerous actions from binary path ─────────────────────
474    match cmd.action {
475        Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
476            return Err(SanitizeError {
477                field: "action".to_string(),
478                value: format!("{:?}", cmd.action),
479                reason: "procedural/session actions are not allowed via binary AST".to_string(),
480            });
481        }
482        _ => {}
483    }
484
485    // ── Table name ───────────────────────────────────────────────────
486    if !cmd.table.is_empty() {
487        if action_allows_table_alias(cmd.action) {
488            check_table_ref("table", &cmd.table)?;
489        } else {
490            check_ident("table", &cmd.table)?;
491        }
492    }
493
494    // ── Columns ──────────────────────────────────────────────────────
495    for (i, col) in cmd.columns.iter().enumerate() {
496        check_expr(&format!("columns[{i}]"), col)?;
497    }
498
499    // ── Table Constraints ────────────────────────────────────────────
500    for (i, constraint) in cmd.table_constraints.iter().enumerate() {
501        match constraint {
502            TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
503                for col in cols {
504                    check_ident(&format!("table_constraints[{i}].column"), col)?;
505                }
506            }
507            TableConstraint::ForeignKey {
508                name,
509                columns,
510                ref_table,
511                ref_columns,
512                on_delete,
513                on_update,
514                deferrable,
515            } => {
516                if let Some(name) = name {
517                    check_ident(&format!("table_constraints[{i}].name"), name)?;
518                }
519                for col in columns {
520                    check_ident(&format!("table_constraints[{i}].column"), col)?;
521                }
522                check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
523                for col in ref_columns {
524                    check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
525                }
526                if let Some(action) = on_delete {
527                    check_fk_action(&format!("table_constraints[{i}].on_delete"), action)?;
528                }
529                if let Some(action) = on_update {
530                    check_fk_action(&format!("table_constraints[{i}].on_update"), action)?;
531                }
532                if let Some(clause) = deferrable {
533                    check_fk_deferrable(&format!("table_constraints[{i}].deferrable"), clause)?;
534                }
535            }
536        }
537    }
538
539    // ── Joins ────────────────────────────────────────────────────────
540    for (i, join) in cmd.joins.iter().enumerate() {
541        check_table_ref(&format!("joins[{i}].table"), &join.table)?;
542        if let Some(ref conditions) = join.on {
543            for cond in conditions {
544                check_expr(&format!("joins[{i}].on"), &cond.left)?;
545                check_value(&format!("joins[{i}].on"), &cond.value)?;
546            }
547        }
548    }
549
550    // ── Cages (filters, sorts, etc.) ─────────────────────────────────
551    for cage in &cmd.cages {
552        for cond in &cage.conditions {
553            check_expr("cage.condition.left", &cond.left)?;
554            check_value("cage.condition.value", &cond.value)?;
555        }
556    }
557
558    // ── CTEs ─────────────────────────────────────────────────────────
559    for cte in &cmd.ctes {
560        check_ident("cte.name", &cte.name)?;
561        for col in &cte.columns {
562            check_ident("cte.column", col)?;
563        }
564        validate_ast(&cte.base_query)?;
565        if let Some(ref rq) = cte.recursive_query {
566            validate_ast(rq)?;
567        }
568    }
569
570    // ── DISTINCT ON ──────────────────────────────────────────────────
571    for expr in &cmd.distinct_on {
572        check_expr("distinct_on", expr)?;
573    }
574
575    // ── RETURNING ────────────────────────────────────────────────────
576    if let Some(ref cols) = cmd.returning {
577        for col in cols {
578            check_expr("returning", col)?;
579        }
580    }
581
582    // ── ON CONFLICT ──────────────────────────────────────────────────
583    if let Some(ref oc) = cmd.on_conflict {
584        for col in &oc.columns {
585            check_ident("on_conflict.column", col)?;
586        }
587        if let ConflictAction::DoUpdate { assignments } = &oc.action {
588            for (col, expr) in assignments {
589                check_ident("on_conflict.assignment.column", col)?;
590                check_expr("on_conflict.assignment.expr", expr)?;
591            }
592        }
593    }
594
595    // ── MERGE ────────────────────────────────────────────────────────
596    if let Some(ref merge) = cmd.merge {
597        if let Some(alias) = &merge.target_alias {
598            check_ident("merge.target_alias", alias)?;
599        }
600        match &merge.source {
601            MergeSource::Table { name, alias } => {
602                if let Some(alias) = alias {
603                    check_ident("merge.source.table", name)?;
604                    check_ident("merge.source.alias", alias)?;
605                } else {
606                    check_table_ref("merge.source.table", name)?;
607                }
608            }
609            MergeSource::Query { query, alias } => {
610                validate_ast(query)?;
611                if let Some(alias) = alias {
612                    check_ident("merge.source.alias", alias)?;
613                }
614            }
615        }
616        for cond in &merge.on {
617            check_expr("merge.on.left", &cond.left)?;
618            check_value("merge.on.value", &cond.value)?;
619        }
620        for clause in &merge.clauses {
621            for cond in &clause.condition {
622                check_expr("merge.clause.condition.left", &cond.left)?;
623                check_value("merge.clause.condition.value", &cond.value)?;
624            }
625            match &clause.action {
626                MergeAction::Update { assignments } => {
627                    for (col, expr) in assignments {
628                        check_ident("merge.update.column", col)?;
629                        check_expr("merge.update.expr", expr)?;
630                    }
631                }
632                MergeAction::Insert { columns, values } => {
633                    for col in columns {
634                        check_ident("merge.insert.column", col)?;
635                    }
636                    for expr in values {
637                        check_expr("merge.insert.expr", expr)?;
638                    }
639                }
640                MergeAction::Delete | MergeAction::DoNothing => {}
641            }
642        }
643    }
644
645    // ── FROM / USING tables ──────────────────────────────────────────
646    for t in &cmd.from_tables {
647        check_table_ref("from_tables", t)?;
648    }
649    for t in &cmd.using_tables {
650        check_table_ref("using_tables", t)?;
651    }
652
653    // ── SET ops ──────────────────────────────────────────────────────
654    for (_, sub) in &cmd.set_ops {
655        validate_ast(sub)?;
656    }
657
658    // ── Source query (INSERT … SELECT) ───────────────────────────────
659    if let Some(ref sq) = cmd.source_query {
660        validate_ast(sq)?;
661    }
662
663    // ── HAVING ───────────────────────────────────────────────────────
664    for cond in &cmd.having {
665        check_expr("having", &cond.left)?;
666        check_value("having", &cond.value)?;
667    }
668
669    // ── Channel (LISTEN/NOTIFY) ──────────────────────────────────────
670    match cmd.action {
671        Action::AlterAddConstraint | Action::AlterDropConstraint => {
672            let Some(ref name) = cmd.channel else {
673                return Err(SanitizeError {
674                    field: "channel".to_string(),
675                    value: String::new(),
676                    reason: "constraint actions require a constraint name".to_string(),
677                });
678            };
679            check_ident("channel", name)?;
680
681            if matches!(cmd.action, Action::AlterAddConstraint) {
682                let Some(ref expr) = cmd.payload else {
683                    return Err(SanitizeError {
684                        field: "payload".to_string(),
685                        value: String::new(),
686                        reason: "add constraint requires a check expression".to_string(),
687                    });
688                };
689                check_sql_expr_fragment("payload", expr)?;
690            }
691        }
692        _ => {}
693    }
694
695    if let Some(ref ch) = cmd.channel {
696        check_ident("channel", ch)?;
697    }
698
699    Ok(())
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705    use crate::ast::{Operator, Qail};
706
707    #[test]
708    fn valid_simple_query_passes() {
709        let cmd = Qail::get("users").columns(["id", "name"]);
710        assert!(validate_ast(&cmd).is_ok());
711    }
712
713    #[test]
714    fn sql_injection_in_table_rejected() {
715        let cmd = Qail::get("users; DROP TABLE users; --");
716        let err = validate_ast(&cmd).unwrap_err();
717        assert_eq!(err.field, "table");
718    }
719
720    #[test]
721    fn call_action_rejected() {
722        let cmd = Qail {
723            action: Action::Call,
724            table: "my_proc()".to_string(),
725            ..Default::default()
726        };
727        let err = validate_ast(&cmd).unwrap_err();
728        assert_eq!(err.field, "action");
729    }
730
731    #[test]
732    fn do_action_rejected() {
733        let cmd = Qail {
734            action: Action::Do,
735            table: "plpgsql".to_string(),
736            ..Default::default()
737        };
738        let err = validate_ast(&cmd).unwrap_err();
739        assert_eq!(err.field, "action");
740    }
741
742    #[test]
743    fn valid_qualified_name_passes() {
744        let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
745        assert!(validate_ast(&cmd).is_ok());
746    }
747
748    #[test]
749    fn valid_long_qualified_identifier_parts_pass() {
750        let schema = "s".repeat(MAX_IDENT_LEN);
751        let table = "t".repeat(MAX_IDENT_LEN);
752        let cmd = Qail::get(format!("{schema}.{table}")).columns(["id"]);
753
754        assert!(validate_ast(&cmd).is_ok());
755    }
756
757    #[test]
758    fn empty_qualified_identifier_part_is_rejected() {
759        let err = validate_ast(&Qail::get("public..users")).unwrap_err();
760
761        assert_eq!(err.field, "table");
762    }
763
764    #[test]
765    fn query_and_mutation_table_aliases_pass_sanitizer() {
766        assert!(validate_ast(&Qail::get("public.users u")).is_ok());
767        assert!(validate_ast(&Qail::set("public.users AS u").set_value("active", true)).is_ok());
768        assert!(validate_ast(&Qail::del("public.users u")).is_ok());
769    }
770
771    #[test]
772    fn ddl_table_alias_shape_is_rejected() {
773        let err = validate_ast(&Qail::make("public.users u")).unwrap_err();
774        assert_eq!(err.field, "table");
775    }
776
777    #[test]
778    fn alter_add_constraint_rejects_unsafe_payload() {
779        let cmd = Qail {
780            action: crate::ast::Action::AlterAddConstraint,
781            table: "users".to_string(),
782            channel: Some("users_active_check".to_string()),
783            payload: Some("active); DROP TABLE users; --".to_string()),
784            ..Default::default()
785        };
786
787        let err = validate_ast(&cmd).unwrap_err();
788        assert_eq!(err.field, "payload");
789    }
790
791    #[test]
792    fn alter_add_constraint_allows_quoted_delimiter_payload() {
793        let cmd = Qail {
794            action: crate::ast::Action::AlterAddConstraint,
795            table: "events".to_string(),
796            channel: Some("events_kind_check".to_string()),
797            payload: Some("kind <> 'semi;inside'".to_string()),
798            ..Default::default()
799        };
800
801        assert!(validate_ast(&cmd).is_ok());
802    }
803
804    #[test]
805    fn injection_in_join_table_rejected() {
806        use crate::ast::JoinKind;
807        let cmd = Qail::get("users").join(
808            JoinKind::Left,
809            "orders; DROP TABLE x",
810            "users.id",
811            "orders.user_id",
812        );
813        let err = validate_ast(&cmd).unwrap_err();
814        assert!(err.field.contains("joins"));
815    }
816
817    #[test]
818    fn malformed_join_table_reference_rejected() {
819        use crate::ast::JoinKind;
820        let cmd = Qail::get("users").join(
821            JoinKind::Left,
822            "orders DROP TABLE x",
823            "users.id",
824            "orders.user_id",
825        );
826        let err = validate_ast(&cmd).unwrap_err();
827        assert!(err.field.contains("joins"));
828    }
829
830    #[test]
831    fn update_from_and_delete_using_aliases_pass_sanitizer() {
832        let update = Qail::set("orders")
833            .set_value("status", "paid")
834            .update_from(["accounts a"]);
835        assert!(validate_ast(&update).is_ok());
836
837        let delete = Qail::del("orders").delete_using(["accounts a"]);
838        assert!(validate_ast(&delete).is_ok());
839    }
840
841    #[test]
842    fn merge_inline_source_alias_passes_sanitizer() {
843        let cmd = Qail::merge_into("orders")
844            .target_alias("o")
845            .using_table("stage_orders s")
846            .merge_on_column("o.id", Operator::Eq, "s.order_id")
847            .when_matched_do_nothing();
848
849        assert!(validate_ast(&cmd).is_ok());
850    }
851
852    #[test]
853    fn malformed_merge_source_table_reference_rejected() {
854        let cmd = Qail::merge_into("orders")
855            .using_table("stage_orders DROP TABLE x")
856            .merge_on_column("orders.id", Operator::Eq, "stage_orders.order_id")
857            .when_matched_do_nothing();
858
859        let err = validate_ast(&cmd).unwrap_err();
860        assert_eq!(err.field, "merge.source.table");
861    }
862
863    #[test]
864    fn injection_in_update_from_and_delete_using_rejected() {
865        let update = Qail::set("orders")
866            .set_value("status", "paid")
867            .update_from(["accounts; DROP TABLE accounts"]);
868        let err = validate_ast(&update).unwrap_err();
869        assert_eq!(err.field, "from_tables");
870
871        let delete = Qail::del("orders").delete_using(["accounts; DROP TABLE accounts"]);
872        let err = validate_ast(&delete).unwrap_err();
873        assert_eq!(err.field, "using_tables");
874    }
875
876    #[test]
877    fn injection_in_column_rejected() {
878        let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
879        let err = validate_ast(&cmd).unwrap_err();
880        assert!(err.field.contains("columns"));
881    }
882
883    #[test]
884    fn unsafe_column_value_rejected() {
885        use crate::ast::Value;
886
887        let cmd = Qail::get("orders").filter(
888            "user_id",
889            Operator::Eq,
890            Value::Column("users.id; DROP TABLE users; --".to_string()),
891        );
892
893        let err = validate_ast(&cmd).unwrap_err();
894        assert_eq!(err.field, "cage.condition.value.column");
895    }
896
897    #[test]
898    fn unsafe_named_parameter_rejected() {
899        use crate::ast::Value;
900
901        let cmd = Qail::get("users").filter(
902            "id",
903            Operator::Eq,
904            Value::NamedParam("id); DROP TABLE users; --".to_string()),
905        );
906
907        let err = validate_ast(&cmd).unwrap_err();
908        assert_eq!(err.field, "cage.condition.value.named_param");
909    }
910
911    #[test]
912    fn unsafe_raw_function_value_rejected() {
913        use crate::ast::Value;
914
915        let cmd = Qail::get("users").filter(
916            "updated_at",
917            Operator::Lt,
918            Value::Function("NOW(); DROP TABLE users; --".to_string()),
919        );
920
921        let err = validate_ast(&cmd).unwrap_err();
922        assert_eq!(err.field, "cage.condition.value.function");
923    }
924
925    #[test]
926    fn safe_raw_function_value_passes_sanitizer() {
927        use crate::ast::Value;
928
929        let cmd = Qail::get("users").filter(
930            "updated_at",
931            Operator::Lt,
932            Value::Function("NOW()".to_string()),
933        );
934
935        assert!(validate_ast(&cmd).is_ok());
936    }
937
938    #[test]
939    fn on_conflict_update_assignment_expression_injection_rejected() {
940        let cmd = Qail::add("users")
941            .set_value("id", 1)
942            .set_value("name", "Alice")
943            .on_conflict_update(
944                &["id"],
945                &[(
946                    "name",
947                    Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
948                )],
949            );
950
951        let err = validate_ast(&cmd).unwrap_err();
952
953        assert_eq!(err.field, "on_conflict.assignment.expr");
954        assert!(
955            err.value.contains("EXCLUDED.name"),
956            "unexpected rejected value: {}",
957            err.value
958        );
959    }
960
961    #[test]
962    fn on_conflict_update_assignment_column_injection_rejected() {
963        let cmd = Qail::add("users")
964            .set_value("id", 1)
965            .set_value("name", "Alice")
966            .on_conflict_update(
967                &["id"],
968                &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
969            );
970
971        let err = validate_ast(&cmd).unwrap_err();
972
973        assert_eq!(err.field, "on_conflict.assignment.column");
974    }
975
976    #[test]
977    fn aggregate_filter_value_expression_injection_rejected() {
978        use crate::ast::{AggregateFunc, Condition, Operator, Value};
979
980        let mut cmd = Qail::get("events");
981        cmd.columns.push(Expr::Aggregate {
982            col: "id".to_string(),
983            func: AggregateFunc::Count,
984            distinct: false,
985            filter: Some(vec![Condition {
986                left: Expr::Named("direction".to_string()),
987                op: Operator::Eq,
988                value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
989                is_array_unnest: false,
990            }]),
991            alias: None,
992        });
993
994        let err = validate_ast(&cmd).unwrap_err();
995        assert_eq!(err.field, "columns[0].filter");
996    }
997
998    #[test]
999    fn count_star_aggregate_passes_sanitizer() {
1000        use crate::ast::AggregateFunc;
1001
1002        let mut cmd = Qail::get("events");
1003        cmd.columns.push(Expr::Aggregate {
1004            col: "*".to_string(),
1005            func: AggregateFunc::Count,
1006            distinct: false,
1007            filter: None,
1008            alias: Some("total".to_string()),
1009        });
1010
1011        assert!(validate_ast(&cmd).is_ok());
1012    }
1013
1014    #[test]
1015    fn case_when_complex_condition_expression_passes_sanitizer() {
1016        use crate::ast::{Condition, Operator, Value};
1017
1018        let mut cmd = Qail::get("users");
1019        cmd.columns.push(Expr::Case {
1020            when_clauses: vec![(
1021                Condition {
1022                    left: Expr::Cast {
1023                        expr: Box::new(Expr::JsonAccess {
1024                            column: "profile".to_string(),
1025                            path_segments: vec![("active".to_string(), true)],
1026                            alias: None,
1027                        }),
1028                        target_type: "integer".to_string(),
1029                        alias: None,
1030                    },
1031                    op: Operator::Gt,
1032                    value: Value::Int(0),
1033                    is_array_unnest: false,
1034                },
1035                Box::new(Expr::Literal(Value::String("active".to_string()))),
1036            )],
1037            else_value: Some(Box::new(Expr::Literal(Value::String(
1038                "inactive".to_string(),
1039            )))),
1040            alias: Some("status_label".to_string()),
1041        });
1042
1043        assert!(validate_ast(&cmd).is_ok());
1044    }
1045
1046    #[test]
1047    fn empty_table_name_passes() {
1048        // Some actions like TxnStart have empty table
1049        let cmd = Qail {
1050            action: Action::TxnStart,
1051            table: String::new(),
1052            ..Default::default()
1053        };
1054        assert!(validate_ast(&cmd).is_ok());
1055    }
1056
1057    #[test]
1058    fn oversized_identifier_rejected() {
1059        let long_name = "a".repeat(64);
1060        let cmd = Qail::get(&long_name);
1061        let err = validate_ast(&cmd).unwrap_err();
1062        assert!(err.reason.contains("63"));
1063    }
1064}