Skip to main content

qail_core/
sanitize.rs

1//! AST structural sanitization for untrusted input.
2//!
3//! The text parser enforces identifier constraints (alphanumeric + `_` + `.`),
4//! but externally provided ASTs (binary endpoints, APIs, generated payloads, etc.)
5//! may bypass parser-level identifier checks and still reach execution.
6//!
7//! Call [`validate_ast`](crate::sanitize::validate_ast) on any `Qail` obtained from an untrusted source
8//! (binary endpoint, external API, etc.) before execution.
9
10use crate::ast::{
11    Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15/// Error returned when AST structural validation fails.
16#[derive(Debug, Clone)]
17pub struct SanitizeError {
18    pub field: String,
19    pub value: String,
20    pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        write!(
26            f,
27            "AST validation failed: {} '{}' — {}",
28            self.field, self.value, self.reason
29        )
30    }
31}
32
33impl std::error::Error for SanitizeError {}
34
35/// Maximum identifier length (PostgreSQL NAMEDATALEN - 1).
36const MAX_IDENT_LEN: usize = 63;
37
38/// Validate that an identifier matches the parser grammar: `[a-zA-Z0-9_.]`.
39///
40/// Also rejects empty identifiers and those exceeding PostgreSQL's 63-char limit.
41fn is_safe_identifier(s: &str) -> bool {
42    !s.is_empty()
43        && s.len() <= MAX_IDENT_LEN
44        && s.bytes()
45            .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
46}
47
48/// Validate an identifier, returning a `SanitizeError` if invalid.
49fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
50    if is_safe_identifier(value) {
51        Ok(())
52    } else {
53        Err(SanitizeError {
54            field: field.to_string(),
55            value: value.chars().take(40).collect(),
56            reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
57        })
58    }
59}
60
61fn check_fk_action(field: &str, value: &str) -> Result<(), SanitizeError> {
62    let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
63    if matches!(
64        normalized.as_str(),
65        "cascade" | "restrict" | "no action" | "set null" | "set default"
66    ) {
67        Ok(())
68    } else {
69        Err(SanitizeError {
70            field: field.to_string(),
71            value: value.chars().take(40).collect(),
72            reason:
73                "foreign key action must be cascade, restrict, no_action, set_null, or set_default"
74                    .to_string(),
75        })
76    }
77}
78
79fn check_fk_deferrable(field: &str, value: &str) -> Result<(), SanitizeError> {
80    let normalized = value.trim().to_ascii_lowercase().replace('_', " ");
81    if matches!(
82        normalized.as_str(),
83        "deferrable"
84            | "initially deferred"
85            | "initially immediate"
86            | "deferrable initially deferred"
87            | "deferrable initially immediate"
88    ) {
89        Ok(())
90    } else {
91        Err(SanitizeError {
92            field: field.to_string(),
93            value: value.chars().take(40).collect(),
94            reason: "foreign key deferrable clause must be deferrable, initially_deferred, or initially_immediate".to_string(),
95        })
96    }
97}
98
99/// Validate an `Expr` node for unsafe patterns.
100///
101/// - `Expr::Named` must be a safe identifier.
102/// - Recursive variants (Cast, Binary, etc.) are validated recursively.
103fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
104    match expr {
105        Expr::Star => Ok(()),
106        Expr::Named(name) => check_ident(field, name),
107        Expr::Aliased { name, alias } => {
108            check_ident(field, name)?;
109            check_ident(&format!("{field}.alias"), alias)
110        }
111        Expr::Aggregate {
112            col, alias, filter, ..
113        } => {
114            if col != "*" {
115                check_ident(field, col)?;
116            }
117            if let Some(a) = alias {
118                check_ident(&format!("{field}.alias"), a)?;
119            }
120            if let Some(conditions) = filter {
121                for cond in conditions {
122                    check_expr(&format!("{field}.filter"), &cond.left)?;
123                    check_value(&format!("{field}.filter"), &cond.value)?;
124                }
125            }
126            Ok(())
127        }
128        Expr::FunctionCall { name, args, alias } => {
129            check_ident(field, name)?;
130            if let Some(a) = alias {
131                check_ident(&format!("{field}.alias"), a)?;
132            }
133            for arg in args {
134                check_expr(&format!("{field}.arg"), arg)?;
135            }
136            Ok(())
137        }
138        Expr::Cast {
139            expr,
140            target_type,
141            alias,
142        } => {
143            check_expr(field, expr)?;
144            check_ident(&format!("{field}.cast_type"), target_type)?;
145            if let Some(a) = alias {
146                check_ident(&format!("{field}.alias"), a)?;
147            }
148            Ok(())
149        }
150        Expr::Binary {
151            left, right, alias, ..
152        } => {
153            check_expr(field, left)?;
154            check_expr(field, right)?;
155            if let Some(a) = alias {
156                check_ident(&format!("{field}.alias"), a)?;
157            }
158            Ok(())
159        }
160        Expr::Literal(_) => Ok(()),
161        Expr::JsonAccess {
162            column,
163            alias,
164            path_segments,
165            ..
166        } => {
167            check_ident(field, column)?;
168            for (key, _) in path_segments {
169                // Integer indices are fine; string keys must be safe identifiers
170                if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
171                    return Err(SanitizeError {
172                        field: format!("{field}.json_path"),
173                        value: key.chars().take(40).collect(),
174                        reason: "JSON path key must be a safe identifier or integer".to_string(),
175                    });
176                }
177            }
178            if let Some(a) = alias {
179                check_ident(&format!("{field}.alias"), a)?;
180            }
181            Ok(())
182        }
183        Expr::Subquery { query, alias } => {
184            validate_ast(query)?;
185            if let Some(a) = alias {
186                check_ident(&format!("{field}.alias"), a)?;
187            }
188            Ok(())
189        }
190        Expr::Exists { query, alias, .. } => {
191            validate_ast(query)?;
192            if let Some(a) = alias {
193                check_ident(&format!("{field}.alias"), a)?;
194            }
195            Ok(())
196        }
197        // For all other complex Expr variants, validate aliases where present
198        Expr::Window {
199            name,
200            func,
201            partition,
202            params,
203            order,
204            ..
205        } => {
206            if !name.is_empty() {
207                check_ident(&format!("{field}.window_alias"), name)?;
208            }
209            check_ident(&format!("{field}.window_func"), func)?;
210            for p in partition {
211                check_ident(&format!("{field}.partition"), p)?;
212            }
213            for p in params {
214                check_expr(&format!("{field}.window_param"), p)?;
215            }
216            for cage in order {
217                for cond in &cage.conditions {
218                    check_expr(&format!("{field}.window_order"), &cond.left)?;
219                    check_value(&format!("{field}.window_order"), &cond.value)?;
220                }
221            }
222            Ok(())
223        }
224        Expr::Case {
225            when_clauses,
226            else_value,
227            alias,
228        } => {
229            for (cond, val) in when_clauses {
230                check_expr(&format!("{field}.case_when"), &cond.left)?;
231                check_value(&format!("{field}.case_when"), &cond.value)?;
232                check_expr(&format!("{field}.case_then"), val)?;
233            }
234            if let Some(e) = else_value {
235                check_expr(&format!("{field}.case_else"), e)?;
236            }
237            if let Some(a) = alias {
238                check_ident(&format!("{field}.alias"), a)?;
239            }
240            Ok(())
241        }
242        Expr::SpecialFunction { args, alias, name } => {
243            check_ident(&format!("{field}.special_func"), name)?;
244            for (_, arg) in args {
245                check_expr(&format!("{field}.special_func_arg"), arg)?;
246            }
247            if let Some(a) = alias {
248                check_ident(&format!("{field}.alias"), a)?;
249            }
250            Ok(())
251        }
252        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
253            for elem in elements {
254                check_expr(&format!("{field}.element"), elem)?;
255            }
256            if let Some(a) = alias {
257                check_ident(&format!("{field}.alias"), a)?;
258            }
259            Ok(())
260        }
261        Expr::Subscript { expr, index, alias } => {
262            check_expr(&format!("{field}.subscript_expr"), expr)?;
263            check_expr(&format!("{field}.subscript_index"), index)?;
264            if let Some(a) = alias {
265                check_ident(&format!("{field}.alias"), a)?;
266            }
267            Ok(())
268        }
269        Expr::Collate {
270            expr,
271            collation,
272            alias,
273        } => {
274            check_expr(&format!("{field}.collate_expr"), expr)?;
275            check_ident(&format!("{field}.collation"), collation)?;
276            if let Some(a) = alias {
277                check_ident(&format!("{field}.alias"), a)?;
278            }
279            Ok(())
280        }
281        Expr::FieldAccess {
282            expr,
283            field: f,
284            alias,
285        } => {
286            check_expr(&format!("{field}.field_access_expr"), expr)?;
287            check_ident(&format!("{field}.field"), f)?;
288            if let Some(a) = alias {
289                check_ident(&format!("{field}.alias"), a)?;
290            }
291            Ok(())
292        }
293        Expr::Def { name, .. } => check_ident(field, name),
294        Expr::Mod { col, .. } => check_expr(field, col),
295    }
296}
297
298/// Check a `Value` for embedded subqueries.
299fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
300    match value {
301        Value::Subquery(q) => validate_ast(q),
302        Value::Array(vals) => {
303            for v in vals {
304                check_value(field, v)?;
305            }
306            Ok(())
307        }
308        Value::Expr(expr) => check_expr(field, expr),
309        _ => Ok(()),
310    }
311}
312
313/// Validate a `Qail` AST from an untrusted source.
314///
315/// Checks all identifier fields against the parser grammar (`[a-zA-Z0-9_.]`)
316/// and rejects dangerous procedural actions.
317///
318/// # Errors
319///
320/// Returns `SanitizeError` on the first violation found.
321pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
322    // ── Block dangerous actions from binary path ─────────────────────
323    match cmd.action {
324        Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
325            return Err(SanitizeError {
326                field: "action".to_string(),
327                value: format!("{:?}", cmd.action),
328                reason: "procedural/session actions are not allowed via binary AST".to_string(),
329            });
330        }
331        _ => {}
332    }
333
334    // ── Table name ───────────────────────────────────────────────────
335    if !cmd.table.is_empty() {
336        check_ident("table", &cmd.table)?;
337    }
338
339    // ── Columns ──────────────────────────────────────────────────────
340    for (i, col) in cmd.columns.iter().enumerate() {
341        check_expr(&format!("columns[{i}]"), col)?;
342    }
343
344    // ── Table Constraints ────────────────────────────────────────────
345    for (i, constraint) in cmd.table_constraints.iter().enumerate() {
346        match constraint {
347            TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
348                for col in cols {
349                    check_ident(&format!("table_constraints[{i}].column"), col)?;
350                }
351            }
352            TableConstraint::ForeignKey {
353                name,
354                columns,
355                ref_table,
356                ref_columns,
357                on_delete,
358                on_update,
359                deferrable,
360            } => {
361                if let Some(name) = name {
362                    check_ident(&format!("table_constraints[{i}].name"), name)?;
363                }
364                for col in columns {
365                    check_ident(&format!("table_constraints[{i}].column"), col)?;
366                }
367                check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
368                for col in ref_columns {
369                    check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
370                }
371                if let Some(action) = on_delete {
372                    check_fk_action(&format!("table_constraints[{i}].on_delete"), action)?;
373                }
374                if let Some(action) = on_update {
375                    check_fk_action(&format!("table_constraints[{i}].on_update"), action)?;
376                }
377                if let Some(clause) = deferrable {
378                    check_fk_deferrable(&format!("table_constraints[{i}].deferrable"), clause)?;
379                }
380            }
381        }
382    }
383
384    // ── Joins ────────────────────────────────────────────────────────
385    for (i, join) in cmd.joins.iter().enumerate() {
386        // Join table may include alias: "users u"
387        // Validate each space-separated token
388        for token in join.table.split_whitespace() {
389            check_ident(&format!("joins[{i}].table"), token)?;
390        }
391        if let Some(ref conditions) = join.on {
392            for cond in conditions {
393                check_expr(&format!("joins[{i}].on"), &cond.left)?;
394                check_value(&format!("joins[{i}].on"), &cond.value)?;
395            }
396        }
397    }
398
399    // ── Cages (filters, sorts, etc.) ─────────────────────────────────
400    for cage in &cmd.cages {
401        for cond in &cage.conditions {
402            check_expr("cage.condition.left", &cond.left)?;
403            check_value("cage.condition.value", &cond.value)?;
404        }
405    }
406
407    // ── CTEs ─────────────────────────────────────────────────────────
408    for cte in &cmd.ctes {
409        check_ident("cte.name", &cte.name)?;
410        for col in &cte.columns {
411            check_ident("cte.column", col)?;
412        }
413        validate_ast(&cte.base_query)?;
414        if let Some(ref rq) = cte.recursive_query {
415            validate_ast(rq)?;
416        }
417    }
418
419    // ── DISTINCT ON ──────────────────────────────────────────────────
420    for expr in &cmd.distinct_on {
421        check_expr("distinct_on", expr)?;
422    }
423
424    // ── RETURNING ────────────────────────────────────────────────────
425    if let Some(ref cols) = cmd.returning {
426        for col in cols {
427            check_expr("returning", col)?;
428        }
429    }
430
431    // ── ON CONFLICT ──────────────────────────────────────────────────
432    if let Some(ref oc) = cmd.on_conflict {
433        for col in &oc.columns {
434            check_ident("on_conflict.column", col)?;
435        }
436        if let ConflictAction::DoUpdate { assignments } = &oc.action {
437            for (col, expr) in assignments {
438                check_ident("on_conflict.assignment.column", col)?;
439                check_expr("on_conflict.assignment.expr", expr)?;
440            }
441        }
442    }
443
444    // ── MERGE ────────────────────────────────────────────────────────
445    if let Some(ref merge) = cmd.merge {
446        if let Some(alias) = &merge.target_alias {
447            check_ident("merge.target_alias", alias)?;
448        }
449        match &merge.source {
450            MergeSource::Table { name, alias } => {
451                check_ident("merge.source.table", name)?;
452                if let Some(alias) = alias {
453                    check_ident("merge.source.alias", alias)?;
454                }
455            }
456            MergeSource::Query { query, alias } => {
457                validate_ast(query)?;
458                if let Some(alias) = alias {
459                    check_ident("merge.source.alias", alias)?;
460                }
461            }
462        }
463        for cond in &merge.on {
464            check_expr("merge.on.left", &cond.left)?;
465            check_value("merge.on.value", &cond.value)?;
466        }
467        for clause in &merge.clauses {
468            for cond in &clause.condition {
469                check_expr("merge.clause.condition.left", &cond.left)?;
470                check_value("merge.clause.condition.value", &cond.value)?;
471            }
472            match &clause.action {
473                MergeAction::Update { assignments } => {
474                    for (col, expr) in assignments {
475                        check_ident("merge.update.column", col)?;
476                        check_expr("merge.update.expr", expr)?;
477                    }
478                }
479                MergeAction::Insert { columns, values } => {
480                    for col in columns {
481                        check_ident("merge.insert.column", col)?;
482                    }
483                    for expr in values {
484                        check_expr("merge.insert.expr", expr)?;
485                    }
486                }
487                MergeAction::Delete | MergeAction::DoNothing => {}
488            }
489        }
490    }
491
492    // ── FROM / USING tables ──────────────────────────────────────────
493    for t in &cmd.from_tables {
494        check_ident("from_tables", t)?;
495    }
496    for t in &cmd.using_tables {
497        check_ident("using_tables", t)?;
498    }
499
500    // ── SET ops ──────────────────────────────────────────────────────
501    for (_, sub) in &cmd.set_ops {
502        validate_ast(sub)?;
503    }
504
505    // ── Source query (INSERT … SELECT) ───────────────────────────────
506    if let Some(ref sq) = cmd.source_query {
507        validate_ast(sq)?;
508    }
509
510    // ── HAVING ───────────────────────────────────────────────────────
511    for cond in &cmd.having {
512        check_expr("having", &cond.left)?;
513        check_value("having", &cond.value)?;
514    }
515
516    // ── Channel (LISTEN/NOTIFY) ──────────────────────────────────────
517    if let Some(ref ch) = cmd.channel {
518        check_ident("channel", ch)?;
519    }
520
521    Ok(())
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::ast::Qail;
528
529    #[test]
530    fn valid_simple_query_passes() {
531        let cmd = Qail::get("users").columns(["id", "name"]);
532        assert!(validate_ast(&cmd).is_ok());
533    }
534
535    #[test]
536    fn sql_injection_in_table_rejected() {
537        let cmd = Qail::get("users; DROP TABLE users; --");
538        let err = validate_ast(&cmd).unwrap_err();
539        assert_eq!(err.field, "table");
540    }
541
542    #[test]
543    fn call_action_rejected() {
544        let cmd = Qail {
545            action: Action::Call,
546            table: "my_proc()".to_string(),
547            ..Default::default()
548        };
549        let err = validate_ast(&cmd).unwrap_err();
550        assert_eq!(err.field, "action");
551    }
552
553    #[test]
554    fn do_action_rejected() {
555        let cmd = Qail {
556            action: Action::Do,
557            table: "plpgsql".to_string(),
558            ..Default::default()
559        };
560        let err = validate_ast(&cmd).unwrap_err();
561        assert_eq!(err.field, "action");
562    }
563
564    #[test]
565    fn valid_qualified_name_passes() {
566        let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
567        assert!(validate_ast(&cmd).is_ok());
568    }
569
570    #[test]
571    fn injection_in_join_table_rejected() {
572        use crate::ast::JoinKind;
573        let cmd = Qail::get("users").join(
574            JoinKind::Left,
575            "orders; DROP TABLE x",
576            "users.id",
577            "orders.user_id",
578        );
579        let err = validate_ast(&cmd).unwrap_err();
580        assert!(err.field.contains("joins"));
581    }
582
583    #[test]
584    fn injection_in_column_rejected() {
585        let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
586        let err = validate_ast(&cmd).unwrap_err();
587        assert!(err.field.contains("columns"));
588    }
589
590    #[test]
591    fn on_conflict_update_assignment_expression_injection_rejected() {
592        let cmd = Qail::add("users")
593            .set_value("id", 1)
594            .set_value("name", "Alice")
595            .on_conflict_update(
596                &["id"],
597                &[(
598                    "name",
599                    Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
600                )],
601            );
602
603        let err = validate_ast(&cmd).unwrap_err();
604
605        assert_eq!(err.field, "on_conflict.assignment.expr");
606        assert!(
607            err.value.contains("EXCLUDED.name"),
608            "unexpected rejected value: {}",
609            err.value
610        );
611    }
612
613    #[test]
614    fn on_conflict_update_assignment_column_injection_rejected() {
615        let cmd = Qail::add("users")
616            .set_value("id", 1)
617            .set_value("name", "Alice")
618            .on_conflict_update(
619                &["id"],
620                &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
621            );
622
623        let err = validate_ast(&cmd).unwrap_err();
624
625        assert_eq!(err.field, "on_conflict.assignment.column");
626    }
627
628    #[test]
629    fn aggregate_filter_value_expression_injection_rejected() {
630        use crate::ast::{AggregateFunc, Condition, Operator, Value};
631
632        let mut cmd = Qail::get("events");
633        cmd.columns.push(Expr::Aggregate {
634            col: "id".to_string(),
635            func: AggregateFunc::Count,
636            distinct: false,
637            filter: Some(vec![Condition {
638                left: Expr::Named("direction".to_string()),
639                op: Operator::Eq,
640                value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
641                is_array_unnest: false,
642            }]),
643            alias: None,
644        });
645
646        let err = validate_ast(&cmd).unwrap_err();
647        assert_eq!(err.field, "columns[0].filter");
648    }
649
650    #[test]
651    fn count_star_aggregate_passes_sanitizer() {
652        use crate::ast::AggregateFunc;
653
654        let mut cmd = Qail::get("events");
655        cmd.columns.push(Expr::Aggregate {
656            col: "*".to_string(),
657            func: AggregateFunc::Count,
658            distinct: false,
659            filter: None,
660            alias: Some("total".to_string()),
661        });
662
663        assert!(validate_ast(&cmd).is_ok());
664    }
665
666    #[test]
667    fn case_when_complex_condition_expression_passes_sanitizer() {
668        use crate::ast::{Condition, Operator, Value};
669
670        let mut cmd = Qail::get("users");
671        cmd.columns.push(Expr::Case {
672            when_clauses: vec![(
673                Condition {
674                    left: Expr::Cast {
675                        expr: Box::new(Expr::JsonAccess {
676                            column: "profile".to_string(),
677                            path_segments: vec![("active".to_string(), true)],
678                            alias: None,
679                        }),
680                        target_type: "integer".to_string(),
681                        alias: None,
682                    },
683                    op: Operator::Gt,
684                    value: Value::Int(0),
685                    is_array_unnest: false,
686                },
687                Box::new(Expr::Literal(Value::String("active".to_string()))),
688            )],
689            else_value: Some(Box::new(Expr::Literal(Value::String(
690                "inactive".to_string(),
691            )))),
692            alias: Some("status_label".to_string()),
693        });
694
695        assert!(validate_ast(&cmd).is_ok());
696    }
697
698    #[test]
699    fn empty_table_name_passes() {
700        // Some actions like TxnStart have empty table
701        let cmd = Qail {
702            action: Action::TxnStart,
703            table: String::new(),
704            ..Default::default()
705        };
706        assert!(validate_ast(&cmd).is_ok());
707    }
708
709    #[test]
710    fn oversized_identifier_rejected() {
711        let long_name = "a".repeat(64);
712        let cmd = Qail::get(&long_name);
713        let err = validate_ast(&cmd).unwrap_err();
714        assert!(err.reason.contains("63"));
715    }
716}