Skip to main content

qail_core/
sanitize.rs

1//! AST structural sanitization for untrusted input.
2//!
3//! The text parser enforces identifier constraints (alphanumeric + `_` + `.`),
4//! but the binary/postcard path deserializes directly into `Qail` — an attacker
5//! can craft identifiers that inject SQL fragments.
6//!
7//! Call [`validate_ast`] on any `Qail` obtained from an untrusted source
8//! (binary endpoint, external API, etc.) before execution.
9
10use crate::ast::{Action, Expr, Qail, Value};
11use std::fmt;
12
13/// Error returned when AST structural validation fails.
14#[derive(Debug, Clone)]
15pub struct SanitizeError {
16    pub field: String,
17    pub value: String,
18    pub reason: String,
19}
20
21impl fmt::Display for SanitizeError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(
24            f,
25            "AST validation failed: {} '{}' — {}",
26            self.field, self.value, self.reason
27        )
28    }
29}
30
31impl std::error::Error for SanitizeError {}
32
33/// Maximum identifier length (PostgreSQL NAMEDATALEN - 1).
34const MAX_IDENT_LEN: usize = 63;
35
36/// Validate that an identifier matches the parser grammar: `[a-zA-Z0-9_.]`.
37///
38/// Also rejects empty identifiers and those exceeding PostgreSQL's 63-char limit.
39fn is_safe_identifier(s: &str) -> bool {
40    !s.is_empty()
41        && s.len() <= MAX_IDENT_LEN
42        && s.bytes()
43            .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
44}
45
46/// Validate an identifier, returning a `SanitizeError` if invalid.
47fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
48    if is_safe_identifier(value) {
49        Ok(())
50    } else {
51        Err(SanitizeError {
52            field: field.to_string(),
53            value: value.chars().take(40).collect(),
54            reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
55        })
56    }
57}
58
59/// Validate an `Expr` node for unsafe patterns.
60///
61/// - `Expr::Named` must be a safe identifier.
62/// - `Expr::Raw` is rejected outright (binary path must not carry raw SQL).
63/// - Recursive variants (Cast, Binary, etc.) are validated recursively.
64fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
65    match expr {
66        Expr::Star => Ok(()),
67        Expr::Named(name) => check_ident(field, name),
68        Expr::Aliased { name, alias } => {
69            check_ident(field, name)?;
70            check_ident(&format!("{field}.alias"), alias)
71        }
72        Expr::Aggregate {
73            col, alias, filter, ..
74        } => {
75            check_ident(field, col)?;
76            if let Some(a) = alias {
77                check_ident(&format!("{field}.alias"), a)?;
78            }
79            if let Some(conditions) = filter {
80                for cond in conditions {
81                    check_expr(&format!("{field}.filter"), &cond.left)?;
82                }
83            }
84            Ok(())
85        }
86        Expr::FunctionCall { name, args, alias } => {
87            check_ident(field, name)?;
88            if let Some(a) = alias {
89                check_ident(&format!("{field}.alias"), a)?;
90            }
91            for arg in args {
92                check_expr(&format!("{field}.arg"), arg)?;
93            }
94            Ok(())
95        }
96        Expr::Cast {
97            expr,
98            target_type,
99            alias,
100        } => {
101            check_expr(field, expr)?;
102            check_ident(&format!("{field}.cast_type"), target_type)?;
103            if let Some(a) = alias {
104                check_ident(&format!("{field}.alias"), a)?;
105            }
106            Ok(())
107        }
108        Expr::Binary {
109            left, right, alias, ..
110        } => {
111            check_expr(field, left)?;
112            check_expr(field, right)?;
113            if let Some(a) = alias {
114                check_ident(&format!("{field}.alias"), a)?;
115            }
116            Ok(())
117        }
118        Expr::Raw(_) => Err(SanitizeError {
119            field: field.to_string(),
120            value: "(raw SQL)".to_string(),
121            reason: "Expr::Raw is not allowed in binary AST".to_string(),
122        }),
123        Expr::Literal(_) => Ok(()),
124        Expr::JsonAccess {
125            column,
126            alias,
127            path_segments,
128            ..
129        } => {
130            check_ident(field, column)?;
131            for (key, _) in path_segments {
132                // Integer indices are fine; string keys must be safe identifiers
133                if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
134                    return Err(SanitizeError {
135                        field: format!("{field}.json_path"),
136                        value: key.chars().take(40).collect(),
137                        reason: "JSON path key must be a safe identifier or integer".to_string(),
138                    });
139                }
140            }
141            if let Some(a) = alias {
142                check_ident(&format!("{field}.alias"), a)?;
143            }
144            Ok(())
145        }
146        Expr::Subquery { query, alias } => {
147            validate_ast(query)?;
148            if let Some(a) = alias {
149                check_ident(&format!("{field}.alias"), a)?;
150            }
151            Ok(())
152        }
153        Expr::Exists { query, alias, .. } => {
154            validate_ast(query)?;
155            if let Some(a) = alias {
156                check_ident(&format!("{field}.alias"), a)?;
157            }
158            Ok(())
159        }
160        // For all other complex Expr variants, validate aliases where present
161        Expr::Window {
162            name,
163            func,
164            partition,
165            params,
166            order,
167            ..
168        } => {
169            if !name.is_empty() {
170                check_ident(&format!("{field}.window_alias"), name)?;
171            }
172            check_ident(&format!("{field}.window_func"), func)?;
173            for p in partition {
174                check_ident(&format!("{field}.partition"), p)?;
175            }
176            for p in params {
177                check_expr(&format!("{field}.window_param"), p)?;
178            }
179            for cage in order {
180                for cond in &cage.conditions {
181                    check_expr(&format!("{field}.window_order"), &cond.left)?;
182                    check_value(&format!("{field}.window_order"), &cond.value)?;
183                }
184            }
185            Ok(())
186        }
187        Expr::Case {
188            when_clauses,
189            else_value,
190            alias,
191        } => {
192            for (cond, val) in when_clauses {
193                check_expr(
194                    &format!("{field}.case_when"),
195                    &Expr::Named(cond.left.to_string()),
196                )?;
197                check_expr(&format!("{field}.case_then"), val)?;
198            }
199            if let Some(e) = else_value {
200                check_expr(&format!("{field}.case_else"), e)?;
201            }
202            if let Some(a) = alias {
203                check_ident(&format!("{field}.alias"), a)?;
204            }
205            Ok(())
206        }
207        Expr::SpecialFunction { args, alias, name } => {
208            check_ident(&format!("{field}.special_func"), name)?;
209            for (_, arg) in args {
210                check_expr(&format!("{field}.special_func_arg"), arg)?;
211            }
212            if let Some(a) = alias {
213                check_ident(&format!("{field}.alias"), a)?;
214            }
215            Ok(())
216        }
217        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
218            for elem in elements {
219                check_expr(&format!("{field}.element"), elem)?;
220            }
221            if let Some(a) = alias {
222                check_ident(&format!("{field}.alias"), a)?;
223            }
224            Ok(())
225        }
226        Expr::Subscript { expr, index, alias } => {
227            check_expr(&format!("{field}.subscript_expr"), expr)?;
228            check_expr(&format!("{field}.subscript_index"), index)?;
229            if let Some(a) = alias {
230                check_ident(&format!("{field}.alias"), a)?;
231            }
232            Ok(())
233        }
234        Expr::Collate {
235            expr,
236            collation,
237            alias,
238        } => {
239            check_expr(&format!("{field}.collate_expr"), expr)?;
240            check_ident(&format!("{field}.collation"), collation)?;
241            if let Some(a) = alias {
242                check_ident(&format!("{field}.alias"), a)?;
243            }
244            Ok(())
245        }
246        Expr::FieldAccess {
247            expr,
248            field: f,
249            alias,
250        } => {
251            check_expr(&format!("{field}.field_access_expr"), expr)?;
252            check_ident(&format!("{field}.field"), f)?;
253            if let Some(a) = alias {
254                check_ident(&format!("{field}.alias"), a)?;
255            }
256            Ok(())
257        }
258        Expr::Def { name, .. } => check_ident(field, name),
259        Expr::Mod { col, .. } => check_expr(field, col),
260    }
261}
262
263/// Check a `Value` for embedded subqueries.
264fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
265    match value {
266        Value::Subquery(q) => validate_ast(q),
267        Value::Array(vals) => {
268            for v in vals {
269                check_value(field, v)?;
270            }
271            Ok(())
272        }
273        Value::Expr(expr) => check_expr(field, expr),
274        _ => Ok(()),
275    }
276}
277
278/// Validate a `Qail` AST from an untrusted source.
279///
280/// Checks all identifier fields against the parser grammar (`[a-zA-Z0-9_.]`)
281/// and rejects dangerous constructs like `Expr::Raw` and procedural actions.
282///
283/// # Errors
284///
285/// Returns `SanitizeError` on the first violation found.
286pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
287    // ── Block dangerous actions from binary path ─────────────────────
288    match cmd.action {
289        Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
290            return Err(SanitizeError {
291                field: "action".to_string(),
292                value: format!("{:?}", cmd.action),
293                reason: "procedural/session actions are not allowed via binary AST".to_string(),
294            });
295        }
296        _ => {}
297    }
298
299    // ── Raw SQL pass-through ─────────────────────────────────────────
300    if cmd.is_raw_sql() {
301        return Err(SanitizeError {
302            field: "table".to_string(),
303            value: "(raw SQL)".to_string(),
304            reason: "raw SQL pass-through is not allowed via binary AST".to_string(),
305        });
306    }
307
308    // ── Table name ───────────────────────────────────────────────────
309    if !cmd.table.is_empty() {
310        check_ident("table", &cmd.table)?;
311    }
312
313    // ── Columns ──────────────────────────────────────────────────────
314    for (i, col) in cmd.columns.iter().enumerate() {
315        check_expr(&format!("columns[{i}]"), col)?;
316    }
317
318    // ── Joins ────────────────────────────────────────────────────────
319    for (i, join) in cmd.joins.iter().enumerate() {
320        // Join table may include alias: "users u"
321        // Validate each space-separated token
322        for token in join.table.split_whitespace() {
323            check_ident(&format!("joins[{i}].table"), token)?;
324        }
325        if let Some(ref conditions) = join.on {
326            for cond in conditions {
327                check_expr(&format!("joins[{i}].on"), &cond.left)?;
328                check_value(&format!("joins[{i}].on"), &cond.value)?;
329            }
330        }
331    }
332
333    // ── Cages (filters, sorts, etc.) ─────────────────────────────────
334    for cage in &cmd.cages {
335        for cond in &cage.conditions {
336            check_expr("cage.condition.left", &cond.left)?;
337            check_value("cage.condition.value", &cond.value)?;
338        }
339    }
340
341    // ── CTEs ─────────────────────────────────────────────────────────
342    for cte in &cmd.ctes {
343        check_ident("cte.name", &cte.name)?;
344        for col in &cte.columns {
345            check_ident("cte.column", col)?;
346        }
347        validate_ast(&cte.base_query)?;
348        if let Some(ref rq) = cte.recursive_query {
349            validate_ast(rq)?;
350        }
351    }
352
353    // ── DISTINCT ON ──────────────────────────────────────────────────
354    for expr in &cmd.distinct_on {
355        check_expr("distinct_on", expr)?;
356    }
357
358    // ── RETURNING ────────────────────────────────────────────────────
359    if let Some(ref cols) = cmd.returning {
360        for col in cols {
361            check_expr("returning", col)?;
362        }
363    }
364
365    // ── ON CONFLICT ──────────────────────────────────────────────────
366    if let Some(ref oc) = cmd.on_conflict {
367        for col in &oc.columns {
368            check_ident("on_conflict.column", col)?;
369        }
370    }
371
372    // ── FROM / USING tables ──────────────────────────────────────────
373    for t in &cmd.from_tables {
374        check_ident("from_tables", t)?;
375    }
376    for t in &cmd.using_tables {
377        check_ident("using_tables", t)?;
378    }
379
380    // ── SET ops ──────────────────────────────────────────────────────
381    for (_, sub) in &cmd.set_ops {
382        validate_ast(sub)?;
383    }
384
385    // ── Source query (INSERT … SELECT) ───────────────────────────────
386    if let Some(ref sq) = cmd.source_query {
387        validate_ast(sq)?;
388    }
389
390    // ── HAVING ───────────────────────────────────────────────────────
391    for cond in &cmd.having {
392        check_expr("having", &cond.left)?;
393        check_value("having", &cond.value)?;
394    }
395
396    // ── Channel (LISTEN/NOTIFY) ──────────────────────────────────────
397    if let Some(ref ch) = cmd.channel {
398        check_ident("channel", ch)?;
399    }
400
401    Ok(())
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use crate::ast::Qail;
408
409    #[test]
410    fn valid_simple_query_passes() {
411        let cmd = Qail::get("users").columns(["id", "name"]);
412        assert!(validate_ast(&cmd).is_ok());
413    }
414
415    #[test]
416    fn sql_injection_in_table_rejected() {
417        let cmd = Qail::get("users; DROP TABLE users; --");
418        let err = validate_ast(&cmd).unwrap_err();
419        assert_eq!(err.field, "table");
420    }
421
422    #[test]
423    fn raw_sql_rejected() {
424        let cmd = Qail::raw_sql("SELECT 1");
425        let err = validate_ast(&cmd).unwrap_err();
426        assert_eq!(err.field, "table");
427    }
428
429    #[test]
430    fn raw_expr_rejected() {
431        let cmd = Qail::get("users").columns_expr(vec![Expr::Raw("NOW()".to_string())]);
432        let err = validate_ast(&cmd).unwrap_err();
433        assert!(err.reason.contains("Raw"));
434    }
435
436    #[test]
437    fn call_action_rejected() {
438        let cmd = Qail {
439            action: Action::Call,
440            table: "my_proc()".to_string(),
441            ..Default::default()
442        };
443        let err = validate_ast(&cmd).unwrap_err();
444        assert_eq!(err.field, "action");
445    }
446
447    #[test]
448    fn do_action_rejected() {
449        let cmd = Qail {
450            action: Action::Do,
451            table: "plpgsql".to_string(),
452            ..Default::default()
453        };
454        let err = validate_ast(&cmd).unwrap_err();
455        assert_eq!(err.field, "action");
456    }
457
458    #[test]
459    fn valid_qualified_name_passes() {
460        let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
461        assert!(validate_ast(&cmd).is_ok());
462    }
463
464    #[test]
465    fn injection_in_join_table_rejected() {
466        use crate::ast::JoinKind;
467        let cmd = Qail::get("users").join(
468            JoinKind::Left,
469            "orders; DROP TABLE x",
470            "users.id",
471            "orders.user_id",
472        );
473        let err = validate_ast(&cmd).unwrap_err();
474        assert!(err.field.contains("joins"));
475    }
476
477    #[test]
478    fn injection_in_column_rejected() {
479        let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
480        let err = validate_ast(&cmd).unwrap_err();
481        assert!(err.field.contains("columns"));
482    }
483
484    #[test]
485    fn empty_table_name_passes() {
486        // Some actions like TxnStart have empty table
487        let cmd = Qail {
488            action: Action::TxnStart,
489            table: String::new(),
490            ..Default::default()
491        };
492        assert!(validate_ast(&cmd).is_ok());
493    }
494
495    #[test]
496    fn oversized_identifier_rejected() {
497        let long_name = "a".repeat(64);
498        let cmd = Qail::get(&long_name);
499        let err = validate_ast(&cmd).unwrap_err();
500        assert!(err.reason.contains("63"));
501    }
502}