Skip to main content

qail_core/
sanitize.rs

1//! AST structural sanitization for untrusted input.
2//!
3//! The text parser enforces identifier constraints (alphanumeric + `_` + `.`),
4//! but externally provided ASTs (binary endpoints, APIs, generated payloads, etc.)
5//! may bypass parser-level identifier checks and still reach execution.
6//!
7//! Call [`validate_ast`](crate::sanitize::validate_ast) on any `Qail` obtained from an untrusted source
8//! (binary endpoint, external API, etc.) before execution.
9
10use crate::ast::{
11    Action, ConflictAction, Expr, MergeAction, MergeSource, Qail, TableConstraint, Value,
12};
13use std::fmt;
14
15/// Error returned when AST structural validation fails.
16#[derive(Debug, Clone)]
17pub struct SanitizeError {
18    pub field: String,
19    pub value: String,
20    pub reason: String,
21}
22
23impl fmt::Display for SanitizeError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        write!(
26            f,
27            "AST validation failed: {} '{}' — {}",
28            self.field, self.value, self.reason
29        )
30    }
31}
32
33impl std::error::Error for SanitizeError {}
34
35/// Maximum identifier length (PostgreSQL NAMEDATALEN - 1).
36const MAX_IDENT_LEN: usize = 63;
37
38/// Validate that an identifier matches the parser grammar: `[a-zA-Z0-9_.]`.
39///
40/// Also rejects empty identifiers and those exceeding PostgreSQL's 63-char limit.
41fn is_safe_identifier(s: &str) -> bool {
42    !s.is_empty()
43        && s.len() <= MAX_IDENT_LEN
44        && s.bytes()
45            .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
46}
47
48/// Validate an identifier, returning a `SanitizeError` if invalid.
49fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
50    if is_safe_identifier(value) {
51        Ok(())
52    } else {
53        Err(SanitizeError {
54            field: field.to_string(),
55            value: value.chars().take(40).collect(),
56            reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
57        })
58    }
59}
60
61/// Validate an `Expr` node for unsafe patterns.
62///
63/// - `Expr::Named` must be a safe identifier.
64/// - Recursive variants (Cast, Binary, etc.) are validated recursively.
65fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
66    match expr {
67        Expr::Star => Ok(()),
68        Expr::Named(name) => check_ident(field, name),
69        Expr::Aliased { name, alias } => {
70            check_ident(field, name)?;
71            check_ident(&format!("{field}.alias"), alias)
72        }
73        Expr::Aggregate {
74            col, alias, filter, ..
75        } => {
76            if col != "*" {
77                check_ident(field, col)?;
78            }
79            if let Some(a) = alias {
80                check_ident(&format!("{field}.alias"), a)?;
81            }
82            if let Some(conditions) = filter {
83                for cond in conditions {
84                    check_expr(&format!("{field}.filter"), &cond.left)?;
85                    check_value(&format!("{field}.filter"), &cond.value)?;
86                }
87            }
88            Ok(())
89        }
90        Expr::FunctionCall { name, args, alias } => {
91            check_ident(field, name)?;
92            if let Some(a) = alias {
93                check_ident(&format!("{field}.alias"), a)?;
94            }
95            for arg in args {
96                check_expr(&format!("{field}.arg"), arg)?;
97            }
98            Ok(())
99        }
100        Expr::Cast {
101            expr,
102            target_type,
103            alias,
104        } => {
105            check_expr(field, expr)?;
106            check_ident(&format!("{field}.cast_type"), target_type)?;
107            if let Some(a) = alias {
108                check_ident(&format!("{field}.alias"), a)?;
109            }
110            Ok(())
111        }
112        Expr::Binary {
113            left, right, alias, ..
114        } => {
115            check_expr(field, left)?;
116            check_expr(field, right)?;
117            if let Some(a) = alias {
118                check_ident(&format!("{field}.alias"), a)?;
119            }
120            Ok(())
121        }
122        Expr::Literal(_) => Ok(()),
123        Expr::JsonAccess {
124            column,
125            alias,
126            path_segments,
127            ..
128        } => {
129            check_ident(field, column)?;
130            for (key, _) in path_segments {
131                // Integer indices are fine; string keys must be safe identifiers
132                if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
133                    return Err(SanitizeError {
134                        field: format!("{field}.json_path"),
135                        value: key.chars().take(40).collect(),
136                        reason: "JSON path key must be a safe identifier or integer".to_string(),
137                    });
138                }
139            }
140            if let Some(a) = alias {
141                check_ident(&format!("{field}.alias"), a)?;
142            }
143            Ok(())
144        }
145        Expr::Subquery { query, alias } => {
146            validate_ast(query)?;
147            if let Some(a) = alias {
148                check_ident(&format!("{field}.alias"), a)?;
149            }
150            Ok(())
151        }
152        Expr::Exists { query, alias, .. } => {
153            validate_ast(query)?;
154            if let Some(a) = alias {
155                check_ident(&format!("{field}.alias"), a)?;
156            }
157            Ok(())
158        }
159        // For all other complex Expr variants, validate aliases where present
160        Expr::Window {
161            name,
162            func,
163            partition,
164            params,
165            order,
166            ..
167        } => {
168            if !name.is_empty() {
169                check_ident(&format!("{field}.window_alias"), name)?;
170            }
171            check_ident(&format!("{field}.window_func"), func)?;
172            for p in partition {
173                check_ident(&format!("{field}.partition"), p)?;
174            }
175            for p in params {
176                check_expr(&format!("{field}.window_param"), p)?;
177            }
178            for cage in order {
179                for cond in &cage.conditions {
180                    check_expr(&format!("{field}.window_order"), &cond.left)?;
181                    check_value(&format!("{field}.window_order"), &cond.value)?;
182                }
183            }
184            Ok(())
185        }
186        Expr::Case {
187            when_clauses,
188            else_value,
189            alias,
190        } => {
191            for (cond, val) in when_clauses {
192                check_expr(&format!("{field}.case_when"), &cond.left)?;
193                check_value(&format!("{field}.case_when"), &cond.value)?;
194                check_expr(&format!("{field}.case_then"), val)?;
195            }
196            if let Some(e) = else_value {
197                check_expr(&format!("{field}.case_else"), e)?;
198            }
199            if let Some(a) = alias {
200                check_ident(&format!("{field}.alias"), a)?;
201            }
202            Ok(())
203        }
204        Expr::SpecialFunction { args, alias, name } => {
205            check_ident(&format!("{field}.special_func"), name)?;
206            for (_, arg) in args {
207                check_expr(&format!("{field}.special_func_arg"), arg)?;
208            }
209            if let Some(a) = alias {
210                check_ident(&format!("{field}.alias"), a)?;
211            }
212            Ok(())
213        }
214        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
215            for elem in elements {
216                check_expr(&format!("{field}.element"), elem)?;
217            }
218            if let Some(a) = alias {
219                check_ident(&format!("{field}.alias"), a)?;
220            }
221            Ok(())
222        }
223        Expr::Subscript { expr, index, alias } => {
224            check_expr(&format!("{field}.subscript_expr"), expr)?;
225            check_expr(&format!("{field}.subscript_index"), index)?;
226            if let Some(a) = alias {
227                check_ident(&format!("{field}.alias"), a)?;
228            }
229            Ok(())
230        }
231        Expr::Collate {
232            expr,
233            collation,
234            alias,
235        } => {
236            check_expr(&format!("{field}.collate_expr"), expr)?;
237            check_ident(&format!("{field}.collation"), collation)?;
238            if let Some(a) = alias {
239                check_ident(&format!("{field}.alias"), a)?;
240            }
241            Ok(())
242        }
243        Expr::FieldAccess {
244            expr,
245            field: f,
246            alias,
247        } => {
248            check_expr(&format!("{field}.field_access_expr"), expr)?;
249            check_ident(&format!("{field}.field"), f)?;
250            if let Some(a) = alias {
251                check_ident(&format!("{field}.alias"), a)?;
252            }
253            Ok(())
254        }
255        Expr::Def { name, .. } => check_ident(field, name),
256        Expr::Mod { col, .. } => check_expr(field, col),
257    }
258}
259
260/// Check a `Value` for embedded subqueries.
261fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
262    match value {
263        Value::Subquery(q) => validate_ast(q),
264        Value::Array(vals) => {
265            for v in vals {
266                check_value(field, v)?;
267            }
268            Ok(())
269        }
270        Value::Expr(expr) => check_expr(field, expr),
271        _ => Ok(()),
272    }
273}
274
275/// Validate a `Qail` AST from an untrusted source.
276///
277/// Checks all identifier fields against the parser grammar (`[a-zA-Z0-9_.]`)
278/// and rejects dangerous procedural actions.
279///
280/// # Errors
281///
282/// Returns `SanitizeError` on the first violation found.
283pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
284    // ── Block dangerous actions from binary path ─────────────────────
285    match cmd.action {
286        Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
287            return Err(SanitizeError {
288                field: "action".to_string(),
289                value: format!("{:?}", cmd.action),
290                reason: "procedural/session actions are not allowed via binary AST".to_string(),
291            });
292        }
293        _ => {}
294    }
295
296    // ── Table name ───────────────────────────────────────────────────
297    if !cmd.table.is_empty() {
298        check_ident("table", &cmd.table)?;
299    }
300
301    // ── Columns ──────────────────────────────────────────────────────
302    for (i, col) in cmd.columns.iter().enumerate() {
303        check_expr(&format!("columns[{i}]"), col)?;
304    }
305
306    // ── Table Constraints ────────────────────────────────────────────
307    for (i, constraint) in cmd.table_constraints.iter().enumerate() {
308        match constraint {
309            TableConstraint::Unique(cols) | TableConstraint::PrimaryKey(cols) => {
310                for col in cols {
311                    check_ident(&format!("table_constraints[{i}].column"), col)?;
312                }
313            }
314            TableConstraint::ForeignKey {
315                name,
316                columns,
317                ref_table,
318                ref_columns,
319            } => {
320                if let Some(name) = name {
321                    check_ident(&format!("table_constraints[{i}].name"), name)?;
322                }
323                for col in columns {
324                    check_ident(&format!("table_constraints[{i}].column"), col)?;
325                }
326                check_ident(&format!("table_constraints[{i}].ref_table"), ref_table)?;
327                for col in ref_columns {
328                    check_ident(&format!("table_constraints[{i}].ref_column"), col)?;
329                }
330            }
331        }
332    }
333
334    // ── Joins ────────────────────────────────────────────────────────
335    for (i, join) in cmd.joins.iter().enumerate() {
336        // Join table may include alias: "users u"
337        // Validate each space-separated token
338        for token in join.table.split_whitespace() {
339            check_ident(&format!("joins[{i}].table"), token)?;
340        }
341        if let Some(ref conditions) = join.on {
342            for cond in conditions {
343                check_expr(&format!("joins[{i}].on"), &cond.left)?;
344                check_value(&format!("joins[{i}].on"), &cond.value)?;
345            }
346        }
347    }
348
349    // ── Cages (filters, sorts, etc.) ─────────────────────────────────
350    for cage in &cmd.cages {
351        for cond in &cage.conditions {
352            check_expr("cage.condition.left", &cond.left)?;
353            check_value("cage.condition.value", &cond.value)?;
354        }
355    }
356
357    // ── CTEs ─────────────────────────────────────────────────────────
358    for cte in &cmd.ctes {
359        check_ident("cte.name", &cte.name)?;
360        for col in &cte.columns {
361            check_ident("cte.column", col)?;
362        }
363        validate_ast(&cte.base_query)?;
364        if let Some(ref rq) = cte.recursive_query {
365            validate_ast(rq)?;
366        }
367    }
368
369    // ── DISTINCT ON ──────────────────────────────────────────────────
370    for expr in &cmd.distinct_on {
371        check_expr("distinct_on", expr)?;
372    }
373
374    // ── RETURNING ────────────────────────────────────────────────────
375    if let Some(ref cols) = cmd.returning {
376        for col in cols {
377            check_expr("returning", col)?;
378        }
379    }
380
381    // ── ON CONFLICT ──────────────────────────────────────────────────
382    if let Some(ref oc) = cmd.on_conflict {
383        for col in &oc.columns {
384            check_ident("on_conflict.column", col)?;
385        }
386        if let ConflictAction::DoUpdate { assignments } = &oc.action {
387            for (col, expr) in assignments {
388                check_ident("on_conflict.assignment.column", col)?;
389                check_expr("on_conflict.assignment.expr", expr)?;
390            }
391        }
392    }
393
394    // ── MERGE ────────────────────────────────────────────────────────
395    if let Some(ref merge) = cmd.merge {
396        if let Some(alias) = &merge.target_alias {
397            check_ident("merge.target_alias", alias)?;
398        }
399        match &merge.source {
400            MergeSource::Table { name, alias } => {
401                check_ident("merge.source.table", name)?;
402                if let Some(alias) = alias {
403                    check_ident("merge.source.alias", alias)?;
404                }
405            }
406            MergeSource::Query { query, alias } => {
407                validate_ast(query)?;
408                if let Some(alias) = alias {
409                    check_ident("merge.source.alias", alias)?;
410                }
411            }
412        }
413        for cond in &merge.on {
414            check_expr("merge.on.left", &cond.left)?;
415            check_value("merge.on.value", &cond.value)?;
416        }
417        for clause in &merge.clauses {
418            for cond in &clause.condition {
419                check_expr("merge.clause.condition.left", &cond.left)?;
420                check_value("merge.clause.condition.value", &cond.value)?;
421            }
422            match &clause.action {
423                MergeAction::Update { assignments } => {
424                    for (col, expr) in assignments {
425                        check_ident("merge.update.column", col)?;
426                        check_expr("merge.update.expr", expr)?;
427                    }
428                }
429                MergeAction::Insert { columns, values } => {
430                    for col in columns {
431                        check_ident("merge.insert.column", col)?;
432                    }
433                    for expr in values {
434                        check_expr("merge.insert.expr", expr)?;
435                    }
436                }
437                MergeAction::Delete | MergeAction::DoNothing => {}
438            }
439        }
440    }
441
442    // ── FROM / USING tables ──────────────────────────────────────────
443    for t in &cmd.from_tables {
444        check_ident("from_tables", t)?;
445    }
446    for t in &cmd.using_tables {
447        check_ident("using_tables", t)?;
448    }
449
450    // ── SET ops ──────────────────────────────────────────────────────
451    for (_, sub) in &cmd.set_ops {
452        validate_ast(sub)?;
453    }
454
455    // ── Source query (INSERT … SELECT) ───────────────────────────────
456    if let Some(ref sq) = cmd.source_query {
457        validate_ast(sq)?;
458    }
459
460    // ── HAVING ───────────────────────────────────────────────────────
461    for cond in &cmd.having {
462        check_expr("having", &cond.left)?;
463        check_value("having", &cond.value)?;
464    }
465
466    // ── Channel (LISTEN/NOTIFY) ──────────────────────────────────────
467    if let Some(ref ch) = cmd.channel {
468        check_ident("channel", ch)?;
469    }
470
471    Ok(())
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::ast::Qail;
478
479    #[test]
480    fn valid_simple_query_passes() {
481        let cmd = Qail::get("users").columns(["id", "name"]);
482        assert!(validate_ast(&cmd).is_ok());
483    }
484
485    #[test]
486    fn sql_injection_in_table_rejected() {
487        let cmd = Qail::get("users; DROP TABLE users; --");
488        let err = validate_ast(&cmd).unwrap_err();
489        assert_eq!(err.field, "table");
490    }
491
492    #[test]
493    fn call_action_rejected() {
494        let cmd = Qail {
495            action: Action::Call,
496            table: "my_proc()".to_string(),
497            ..Default::default()
498        };
499        let err = validate_ast(&cmd).unwrap_err();
500        assert_eq!(err.field, "action");
501    }
502
503    #[test]
504    fn do_action_rejected() {
505        let cmd = Qail {
506            action: Action::Do,
507            table: "plpgsql".to_string(),
508            ..Default::default()
509        };
510        let err = validate_ast(&cmd).unwrap_err();
511        assert_eq!(err.field, "action");
512    }
513
514    #[test]
515    fn valid_qualified_name_passes() {
516        let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
517        assert!(validate_ast(&cmd).is_ok());
518    }
519
520    #[test]
521    fn injection_in_join_table_rejected() {
522        use crate::ast::JoinKind;
523        let cmd = Qail::get("users").join(
524            JoinKind::Left,
525            "orders; DROP TABLE x",
526            "users.id",
527            "orders.user_id",
528        );
529        let err = validate_ast(&cmd).unwrap_err();
530        assert!(err.field.contains("joins"));
531    }
532
533    #[test]
534    fn injection_in_column_rejected() {
535        let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
536        let err = validate_ast(&cmd).unwrap_err();
537        assert!(err.field.contains("columns"));
538    }
539
540    #[test]
541    fn on_conflict_update_assignment_expression_injection_rejected() {
542        let cmd = Qail::add("users")
543            .set_value("id", 1)
544            .set_value("name", "Alice")
545            .on_conflict_update(
546                &["id"],
547                &[(
548                    "name",
549                    Expr::Named("EXCLUDED.name, is_admin = true".to_string()),
550                )],
551            );
552
553        let err = validate_ast(&cmd).unwrap_err();
554
555        assert_eq!(err.field, "on_conflict.assignment.expr");
556        assert!(
557            err.value.contains("EXCLUDED.name"),
558            "unexpected rejected value: {}",
559            err.value
560        );
561    }
562
563    #[test]
564    fn on_conflict_update_assignment_column_injection_rejected() {
565        let cmd = Qail::add("users")
566            .set_value("id", 1)
567            .set_value("name", "Alice")
568            .on_conflict_update(
569                &["id"],
570                &[("name, is_admin", Expr::Named("EXCLUDED.name".to_string()))],
571            );
572
573        let err = validate_ast(&cmd).unwrap_err();
574
575        assert_eq!(err.field, "on_conflict.assignment.column");
576    }
577
578    #[test]
579    fn aggregate_filter_value_expression_injection_rejected() {
580        use crate::ast::{AggregateFunc, Condition, Operator, Value};
581
582        let mut cmd = Qail::get("events");
583        cmd.columns.push(Expr::Aggregate {
584            col: "id".to_string(),
585            func: AggregateFunc::Count,
586            distinct: false,
587            filter: Some(vec![Condition {
588                left: Expr::Named("direction".to_string()),
589                op: Operator::Eq,
590                value: Value::Expr(Box::new(Expr::Named("bad;DROP".to_string()))),
591                is_array_unnest: false,
592            }]),
593            alias: None,
594        });
595
596        let err = validate_ast(&cmd).unwrap_err();
597        assert_eq!(err.field, "columns[0].filter");
598    }
599
600    #[test]
601    fn count_star_aggregate_passes_sanitizer() {
602        use crate::ast::AggregateFunc;
603
604        let mut cmd = Qail::get("events");
605        cmd.columns.push(Expr::Aggregate {
606            col: "*".to_string(),
607            func: AggregateFunc::Count,
608            distinct: false,
609            filter: None,
610            alias: Some("total".to_string()),
611        });
612
613        assert!(validate_ast(&cmd).is_ok());
614    }
615
616    #[test]
617    fn case_when_complex_condition_expression_passes_sanitizer() {
618        use crate::ast::{Condition, Operator, Value};
619
620        let mut cmd = Qail::get("users");
621        cmd.columns.push(Expr::Case {
622            when_clauses: vec![(
623                Condition {
624                    left: Expr::Cast {
625                        expr: Box::new(Expr::JsonAccess {
626                            column: "profile".to_string(),
627                            path_segments: vec![("active".to_string(), true)],
628                            alias: None,
629                        }),
630                        target_type: "integer".to_string(),
631                        alias: None,
632                    },
633                    op: Operator::Gt,
634                    value: Value::Int(0),
635                    is_array_unnest: false,
636                },
637                Box::new(Expr::Literal(Value::String("active".to_string()))),
638            )],
639            else_value: Some(Box::new(Expr::Literal(Value::String(
640                "inactive".to_string(),
641            )))),
642            alias: Some("status_label".to_string()),
643        });
644
645        assert!(validate_ast(&cmd).is_ok());
646    }
647
648    #[test]
649    fn empty_table_name_passes() {
650        // Some actions like TxnStart have empty table
651        let cmd = Qail {
652            action: Action::TxnStart,
653            table: String::new(),
654            ..Default::default()
655        };
656        assert!(validate_ast(&cmd).is_ok());
657    }
658
659    #[test]
660    fn oversized_identifier_rejected() {
661        let long_name = "a".repeat(64);
662        let cmd = Qail::get(&long_name);
663        let err = validate_ast(&cmd).unwrap_err();
664        assert!(err.reason.contains("63"));
665    }
666}