Skip to main content

qail_core/
sanitize.rs

1//! AST structural sanitization for untrusted input.
2//!
3//! The text parser enforces identifier constraints (alphanumeric + `_` + `.`),
4//! but the binary/postcard path deserializes directly into `Qail` — an attacker
5//! can craft identifiers that inject SQL fragments.
6//!
7//! Call [`validate_ast`] on any `Qail` obtained from an untrusted source
8//! (binary endpoint, external API, etc.) before execution.
9
10use crate::ast::{Action, Expr, Qail, Value};
11use std::fmt;
12
13/// Error returned when AST structural validation fails.
14#[derive(Debug, Clone)]
15pub struct SanitizeError {
16    pub field: String,
17    pub value: String,
18    pub reason: String,
19}
20
21impl fmt::Display for SanitizeError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(
24            f,
25            "AST validation failed: {} '{}' — {}",
26            self.field, self.value, self.reason
27        )
28    }
29}
30
31impl std::error::Error for SanitizeError {}
32
33/// Maximum identifier length (PostgreSQL NAMEDATALEN - 1).
34const MAX_IDENT_LEN: usize = 63;
35
36/// Validate that an identifier matches the parser grammar: `[a-zA-Z0-9_.]`.
37///
38/// Also rejects empty identifiers and those exceeding PostgreSQL's 63-char limit.
39fn is_safe_identifier(s: &str) -> bool {
40    !s.is_empty()
41        && s.len() <= MAX_IDENT_LEN
42        && s.bytes()
43            .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
44}
45
46/// Validate an identifier, returning a `SanitizeError` if invalid.
47fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
48    if is_safe_identifier(value) {
49        Ok(())
50    } else {
51        Err(SanitizeError {
52            field: field.to_string(),
53            value: value.chars().take(40).collect(),
54            reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
55        })
56    }
57}
58
59/// Validate an `Expr` node for unsafe patterns.
60///
61/// - `Expr::Named` must be a safe identifier.
62/// - Recursive variants (Cast, Binary, etc.) are validated recursively.
63fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
64    match expr {
65        Expr::Star => Ok(()),
66        Expr::Named(name) => check_ident(field, name),
67        Expr::Aliased { name, alias } => {
68            check_ident(field, name)?;
69            check_ident(&format!("{field}.alias"), alias)
70        }
71        Expr::Aggregate {
72            col, alias, filter, ..
73        } => {
74            check_ident(field, col)?;
75            if let Some(a) = alias {
76                check_ident(&format!("{field}.alias"), a)?;
77            }
78            if let Some(conditions) = filter {
79                for cond in conditions {
80                    check_expr(&format!("{field}.filter"), &cond.left)?;
81                }
82            }
83            Ok(())
84        }
85        Expr::FunctionCall { name, args, alias } => {
86            check_ident(field, name)?;
87            if let Some(a) = alias {
88                check_ident(&format!("{field}.alias"), a)?;
89            }
90            for arg in args {
91                check_expr(&format!("{field}.arg"), arg)?;
92            }
93            Ok(())
94        }
95        Expr::Cast {
96            expr,
97            target_type,
98            alias,
99        } => {
100            check_expr(field, expr)?;
101            check_ident(&format!("{field}.cast_type"), target_type)?;
102            if let Some(a) = alias {
103                check_ident(&format!("{field}.alias"), a)?;
104            }
105            Ok(())
106        }
107        Expr::Binary {
108            left, right, alias, ..
109        } => {
110            check_expr(field, left)?;
111            check_expr(field, right)?;
112            if let Some(a) = alias {
113                check_ident(&format!("{field}.alias"), a)?;
114            }
115            Ok(())
116        }
117        Expr::Literal(_) => Ok(()),
118        Expr::JsonAccess {
119            column,
120            alias,
121            path_segments,
122            ..
123        } => {
124            check_ident(field, column)?;
125            for (key, _) in path_segments {
126                // Integer indices are fine; string keys must be safe identifiers
127                if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
128                    return Err(SanitizeError {
129                        field: format!("{field}.json_path"),
130                        value: key.chars().take(40).collect(),
131                        reason: "JSON path key must be a safe identifier or integer".to_string(),
132                    });
133                }
134            }
135            if let Some(a) = alias {
136                check_ident(&format!("{field}.alias"), a)?;
137            }
138            Ok(())
139        }
140        Expr::Subquery { query, alias } => {
141            validate_ast(query)?;
142            if let Some(a) = alias {
143                check_ident(&format!("{field}.alias"), a)?;
144            }
145            Ok(())
146        }
147        Expr::Exists { query, alias, .. } => {
148            validate_ast(query)?;
149            if let Some(a) = alias {
150                check_ident(&format!("{field}.alias"), a)?;
151            }
152            Ok(())
153        }
154        // For all other complex Expr variants, validate aliases where present
155        Expr::Window {
156            name,
157            func,
158            partition,
159            params,
160            order,
161            ..
162        } => {
163            if !name.is_empty() {
164                check_ident(&format!("{field}.window_alias"), name)?;
165            }
166            check_ident(&format!("{field}.window_func"), func)?;
167            for p in partition {
168                check_ident(&format!("{field}.partition"), p)?;
169            }
170            for p in params {
171                check_expr(&format!("{field}.window_param"), p)?;
172            }
173            for cage in order {
174                for cond in &cage.conditions {
175                    check_expr(&format!("{field}.window_order"), &cond.left)?;
176                    check_value(&format!("{field}.window_order"), &cond.value)?;
177                }
178            }
179            Ok(())
180        }
181        Expr::Case {
182            when_clauses,
183            else_value,
184            alias,
185        } => {
186            for (cond, val) in when_clauses {
187                check_expr(
188                    &format!("{field}.case_when"),
189                    &Expr::Named(cond.left.to_string()),
190                )?;
191                check_expr(&format!("{field}.case_then"), val)?;
192            }
193            if let Some(e) = else_value {
194                check_expr(&format!("{field}.case_else"), e)?;
195            }
196            if let Some(a) = alias {
197                check_ident(&format!("{field}.alias"), a)?;
198            }
199            Ok(())
200        }
201        Expr::SpecialFunction { args, alias, name } => {
202            check_ident(&format!("{field}.special_func"), name)?;
203            for (_, arg) in args {
204                check_expr(&format!("{field}.special_func_arg"), arg)?;
205            }
206            if let Some(a) = alias {
207                check_ident(&format!("{field}.alias"), a)?;
208            }
209            Ok(())
210        }
211        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
212            for elem in elements {
213                check_expr(&format!("{field}.element"), elem)?;
214            }
215            if let Some(a) = alias {
216                check_ident(&format!("{field}.alias"), a)?;
217            }
218            Ok(())
219        }
220        Expr::Subscript { expr, index, alias } => {
221            check_expr(&format!("{field}.subscript_expr"), expr)?;
222            check_expr(&format!("{field}.subscript_index"), index)?;
223            if let Some(a) = alias {
224                check_ident(&format!("{field}.alias"), a)?;
225            }
226            Ok(())
227        }
228        Expr::Collate {
229            expr,
230            collation,
231            alias,
232        } => {
233            check_expr(&format!("{field}.collate_expr"), expr)?;
234            check_ident(&format!("{field}.collation"), collation)?;
235            if let Some(a) = alias {
236                check_ident(&format!("{field}.alias"), a)?;
237            }
238            Ok(())
239        }
240        Expr::FieldAccess {
241            expr,
242            field: f,
243            alias,
244        } => {
245            check_expr(&format!("{field}.field_access_expr"), expr)?;
246            check_ident(&format!("{field}.field"), f)?;
247            if let Some(a) = alias {
248                check_ident(&format!("{field}.alias"), a)?;
249            }
250            Ok(())
251        }
252        Expr::Def { name, .. } => check_ident(field, name),
253        Expr::Mod { col, .. } => check_expr(field, col),
254    }
255}
256
257/// Check a `Value` for embedded subqueries.
258fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
259    match value {
260        Value::Subquery(q) => validate_ast(q),
261        Value::Array(vals) => {
262            for v in vals {
263                check_value(field, v)?;
264            }
265            Ok(())
266        }
267        Value::Expr(expr) => check_expr(field, expr),
268        _ => Ok(()),
269    }
270}
271
272/// Validate a `Qail` AST from an untrusted source.
273///
274/// Checks all identifier fields against the parser grammar (`[a-zA-Z0-9_.]`)
275/// and rejects dangerous procedural actions.
276///
277/// # Errors
278///
279/// Returns `SanitizeError` on the first violation found.
280pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
281    // ── Block dangerous actions from binary path ─────────────────────
282    match cmd.action {
283        Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
284            return Err(SanitizeError {
285                field: "action".to_string(),
286                value: format!("{:?}", cmd.action),
287                reason: "procedural/session actions are not allowed via binary AST".to_string(),
288            });
289        }
290        _ => {}
291    }
292
293    // ── Table name ───────────────────────────────────────────────────
294    if !cmd.table.is_empty() {
295        check_ident("table", &cmd.table)?;
296    }
297
298    // ── Columns ──────────────────────────────────────────────────────
299    for (i, col) in cmd.columns.iter().enumerate() {
300        check_expr(&format!("columns[{i}]"), col)?;
301    }
302
303    // ── Joins ────────────────────────────────────────────────────────
304    for (i, join) in cmd.joins.iter().enumerate() {
305        // Join table may include alias: "users u"
306        // Validate each space-separated token
307        for token in join.table.split_whitespace() {
308            check_ident(&format!("joins[{i}].table"), token)?;
309        }
310        if let Some(ref conditions) = join.on {
311            for cond in conditions {
312                check_expr(&format!("joins[{i}].on"), &cond.left)?;
313                check_value(&format!("joins[{i}].on"), &cond.value)?;
314            }
315        }
316    }
317
318    // ── Cages (filters, sorts, etc.) ─────────────────────────────────
319    for cage in &cmd.cages {
320        for cond in &cage.conditions {
321            check_expr("cage.condition.left", &cond.left)?;
322            check_value("cage.condition.value", &cond.value)?;
323        }
324    }
325
326    // ── CTEs ─────────────────────────────────────────────────────────
327    for cte in &cmd.ctes {
328        check_ident("cte.name", &cte.name)?;
329        for col in &cte.columns {
330            check_ident("cte.column", col)?;
331        }
332        validate_ast(&cte.base_query)?;
333        if let Some(ref rq) = cte.recursive_query {
334            validate_ast(rq)?;
335        }
336    }
337
338    // ── DISTINCT ON ──────────────────────────────────────────────────
339    for expr in &cmd.distinct_on {
340        check_expr("distinct_on", expr)?;
341    }
342
343    // ── RETURNING ────────────────────────────────────────────────────
344    if let Some(ref cols) = cmd.returning {
345        for col in cols {
346            check_expr("returning", col)?;
347        }
348    }
349
350    // ── ON CONFLICT ──────────────────────────────────────────────────
351    if let Some(ref oc) = cmd.on_conflict {
352        for col in &oc.columns {
353            check_ident("on_conflict.column", col)?;
354        }
355    }
356
357    // ── FROM / USING tables ──────────────────────────────────────────
358    for t in &cmd.from_tables {
359        check_ident("from_tables", t)?;
360    }
361    for t in &cmd.using_tables {
362        check_ident("using_tables", t)?;
363    }
364
365    // ── SET ops ──────────────────────────────────────────────────────
366    for (_, sub) in &cmd.set_ops {
367        validate_ast(sub)?;
368    }
369
370    // ── Source query (INSERT … SELECT) ───────────────────────────────
371    if let Some(ref sq) = cmd.source_query {
372        validate_ast(sq)?;
373    }
374
375    // ── HAVING ───────────────────────────────────────────────────────
376    for cond in &cmd.having {
377        check_expr("having", &cond.left)?;
378        check_value("having", &cond.value)?;
379    }
380
381    // ── Channel (LISTEN/NOTIFY) ──────────────────────────────────────
382    if let Some(ref ch) = cmd.channel {
383        check_ident("channel", ch)?;
384    }
385
386    Ok(())
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::ast::Qail;
393
394    #[test]
395    fn valid_simple_query_passes() {
396        let cmd = Qail::get("users").columns(["id", "name"]);
397        assert!(validate_ast(&cmd).is_ok());
398    }
399
400    #[test]
401    fn sql_injection_in_table_rejected() {
402        let cmd = Qail::get("users; DROP TABLE users; --");
403        let err = validate_ast(&cmd).unwrap_err();
404        assert_eq!(err.field, "table");
405    }
406
407    #[test]
408    fn call_action_rejected() {
409        let cmd = Qail {
410            action: Action::Call,
411            table: "my_proc()".to_string(),
412            ..Default::default()
413        };
414        let err = validate_ast(&cmd).unwrap_err();
415        assert_eq!(err.field, "action");
416    }
417
418    #[test]
419    fn do_action_rejected() {
420        let cmd = Qail {
421            action: Action::Do,
422            table: "plpgsql".to_string(),
423            ..Default::default()
424        };
425        let err = validate_ast(&cmd).unwrap_err();
426        assert_eq!(err.field, "action");
427    }
428
429    #[test]
430    fn valid_qualified_name_passes() {
431        let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
432        assert!(validate_ast(&cmd).is_ok());
433    }
434
435    #[test]
436    fn injection_in_join_table_rejected() {
437        use crate::ast::JoinKind;
438        let cmd = Qail::get("users").join(
439            JoinKind::Left,
440            "orders; DROP TABLE x",
441            "users.id",
442            "orders.user_id",
443        );
444        let err = validate_ast(&cmd).unwrap_err();
445        assert!(err.field.contains("joins"));
446    }
447
448    #[test]
449    fn injection_in_column_rejected() {
450        let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
451        let err = validate_ast(&cmd).unwrap_err();
452        assert!(err.field.contains("columns"));
453    }
454
455    #[test]
456    fn empty_table_name_passes() {
457        // Some actions like TxnStart have empty table
458        let cmd = Qail {
459            action: Action::TxnStart,
460            table: String::new(),
461            ..Default::default()
462        };
463        assert!(validate_ast(&cmd).is_ok());
464    }
465
466    #[test]
467    fn oversized_identifier_rejected() {
468        let long_name = "a".repeat(64);
469        let cmd = Qail::get(&long_name);
470        let err = validate_ast(&cmd).unwrap_err();
471        assert!(err.reason.contains("63"));
472    }
473}