1use crate::ast::{Action, Expr, Qail, Value};
11use std::fmt;
12
13#[derive(Debug, Clone)]
15pub struct SanitizeError {
16 pub field: String,
17 pub value: String,
18 pub reason: String,
19}
20
21impl fmt::Display for SanitizeError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(
24 f,
25 "AST validation failed: {} '{}' — {}",
26 self.field, self.value, self.reason
27 )
28 }
29}
30
31impl std::error::Error for SanitizeError {}
32
33const MAX_IDENT_LEN: usize = 63;
35
36fn is_safe_identifier(s: &str) -> bool {
40 !s.is_empty()
41 && s.len() <= MAX_IDENT_LEN
42 && s.bytes()
43 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
44}
45
46fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
48 if is_safe_identifier(value) {
49 Ok(())
50 } else {
51 Err(SanitizeError {
52 field: field.to_string(),
53 value: value.chars().take(40).collect(),
54 reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
55 })
56 }
57}
58
59fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
64 match expr {
65 Expr::Star => Ok(()),
66 Expr::Named(name) => check_ident(field, name),
67 Expr::Aliased { name, alias } => {
68 check_ident(field, name)?;
69 check_ident(&format!("{field}.alias"), alias)
70 }
71 Expr::Aggregate {
72 col, alias, filter, ..
73 } => {
74 check_ident(field, col)?;
75 if let Some(a) = alias {
76 check_ident(&format!("{field}.alias"), a)?;
77 }
78 if let Some(conditions) = filter {
79 for cond in conditions {
80 check_expr(&format!("{field}.filter"), &cond.left)?;
81 }
82 }
83 Ok(())
84 }
85 Expr::FunctionCall { name, args, alias } => {
86 check_ident(field, name)?;
87 if let Some(a) = alias {
88 check_ident(&format!("{field}.alias"), a)?;
89 }
90 for arg in args {
91 check_expr(&format!("{field}.arg"), arg)?;
92 }
93 Ok(())
94 }
95 Expr::Cast {
96 expr,
97 target_type,
98 alias,
99 } => {
100 check_expr(field, expr)?;
101 check_ident(&format!("{field}.cast_type"), target_type)?;
102 if let Some(a) = alias {
103 check_ident(&format!("{field}.alias"), a)?;
104 }
105 Ok(())
106 }
107 Expr::Binary {
108 left, right, alias, ..
109 } => {
110 check_expr(field, left)?;
111 check_expr(field, right)?;
112 if let Some(a) = alias {
113 check_ident(&format!("{field}.alias"), a)?;
114 }
115 Ok(())
116 }
117 Expr::Literal(_) => Ok(()),
118 Expr::JsonAccess {
119 column,
120 alias,
121 path_segments,
122 ..
123 } => {
124 check_ident(field, column)?;
125 for (key, _) in path_segments {
126 if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
128 return Err(SanitizeError {
129 field: format!("{field}.json_path"),
130 value: key.chars().take(40).collect(),
131 reason: "JSON path key must be a safe identifier or integer".to_string(),
132 });
133 }
134 }
135 if let Some(a) = alias {
136 check_ident(&format!("{field}.alias"), a)?;
137 }
138 Ok(())
139 }
140 Expr::Subquery { query, alias } => {
141 validate_ast(query)?;
142 if let Some(a) = alias {
143 check_ident(&format!("{field}.alias"), a)?;
144 }
145 Ok(())
146 }
147 Expr::Exists { query, alias, .. } => {
148 validate_ast(query)?;
149 if let Some(a) = alias {
150 check_ident(&format!("{field}.alias"), a)?;
151 }
152 Ok(())
153 }
154 Expr::Window {
156 name,
157 func,
158 partition,
159 params,
160 order,
161 ..
162 } => {
163 if !name.is_empty() {
164 check_ident(&format!("{field}.window_alias"), name)?;
165 }
166 check_ident(&format!("{field}.window_func"), func)?;
167 for p in partition {
168 check_ident(&format!("{field}.partition"), p)?;
169 }
170 for p in params {
171 check_expr(&format!("{field}.window_param"), p)?;
172 }
173 for cage in order {
174 for cond in &cage.conditions {
175 check_expr(&format!("{field}.window_order"), &cond.left)?;
176 check_value(&format!("{field}.window_order"), &cond.value)?;
177 }
178 }
179 Ok(())
180 }
181 Expr::Case {
182 when_clauses,
183 else_value,
184 alias,
185 } => {
186 for (cond, val) in when_clauses {
187 check_expr(
188 &format!("{field}.case_when"),
189 &Expr::Named(cond.left.to_string()),
190 )?;
191 check_expr(&format!("{field}.case_then"), val)?;
192 }
193 if let Some(e) = else_value {
194 check_expr(&format!("{field}.case_else"), e)?;
195 }
196 if let Some(a) = alias {
197 check_ident(&format!("{field}.alias"), a)?;
198 }
199 Ok(())
200 }
201 Expr::SpecialFunction { args, alias, name } => {
202 check_ident(&format!("{field}.special_func"), name)?;
203 for (_, arg) in args {
204 check_expr(&format!("{field}.special_func_arg"), arg)?;
205 }
206 if let Some(a) = alias {
207 check_ident(&format!("{field}.alias"), a)?;
208 }
209 Ok(())
210 }
211 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
212 for elem in elements {
213 check_expr(&format!("{field}.element"), elem)?;
214 }
215 if let Some(a) = alias {
216 check_ident(&format!("{field}.alias"), a)?;
217 }
218 Ok(())
219 }
220 Expr::Subscript { expr, index, alias } => {
221 check_expr(&format!("{field}.subscript_expr"), expr)?;
222 check_expr(&format!("{field}.subscript_index"), index)?;
223 if let Some(a) = alias {
224 check_ident(&format!("{field}.alias"), a)?;
225 }
226 Ok(())
227 }
228 Expr::Collate {
229 expr,
230 collation,
231 alias,
232 } => {
233 check_expr(&format!("{field}.collate_expr"), expr)?;
234 check_ident(&format!("{field}.collation"), collation)?;
235 if let Some(a) = alias {
236 check_ident(&format!("{field}.alias"), a)?;
237 }
238 Ok(())
239 }
240 Expr::FieldAccess {
241 expr,
242 field: f,
243 alias,
244 } => {
245 check_expr(&format!("{field}.field_access_expr"), expr)?;
246 check_ident(&format!("{field}.field"), f)?;
247 if let Some(a) = alias {
248 check_ident(&format!("{field}.alias"), a)?;
249 }
250 Ok(())
251 }
252 Expr::Def { name, .. } => check_ident(field, name),
253 Expr::Mod { col, .. } => check_expr(field, col),
254 }
255}
256
257fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
259 match value {
260 Value::Subquery(q) => validate_ast(q),
261 Value::Array(vals) => {
262 for v in vals {
263 check_value(field, v)?;
264 }
265 Ok(())
266 }
267 Value::Expr(expr) => check_expr(field, expr),
268 _ => Ok(()),
269 }
270}
271
272pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
281 match cmd.action {
283 Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
284 return Err(SanitizeError {
285 field: "action".to_string(),
286 value: format!("{:?}", cmd.action),
287 reason: "procedural/session actions are not allowed via binary AST".to_string(),
288 });
289 }
290 _ => {}
291 }
292
293 if !cmd.table.is_empty() {
295 check_ident("table", &cmd.table)?;
296 }
297
298 for (i, col) in cmd.columns.iter().enumerate() {
300 check_expr(&format!("columns[{i}]"), col)?;
301 }
302
303 for (i, join) in cmd.joins.iter().enumerate() {
305 for token in join.table.split_whitespace() {
308 check_ident(&format!("joins[{i}].table"), token)?;
309 }
310 if let Some(ref conditions) = join.on {
311 for cond in conditions {
312 check_expr(&format!("joins[{i}].on"), &cond.left)?;
313 check_value(&format!("joins[{i}].on"), &cond.value)?;
314 }
315 }
316 }
317
318 for cage in &cmd.cages {
320 for cond in &cage.conditions {
321 check_expr("cage.condition.left", &cond.left)?;
322 check_value("cage.condition.value", &cond.value)?;
323 }
324 }
325
326 for cte in &cmd.ctes {
328 check_ident("cte.name", &cte.name)?;
329 for col in &cte.columns {
330 check_ident("cte.column", col)?;
331 }
332 validate_ast(&cte.base_query)?;
333 if let Some(ref rq) = cte.recursive_query {
334 validate_ast(rq)?;
335 }
336 }
337
338 for expr in &cmd.distinct_on {
340 check_expr("distinct_on", expr)?;
341 }
342
343 if let Some(ref cols) = cmd.returning {
345 for col in cols {
346 check_expr("returning", col)?;
347 }
348 }
349
350 if let Some(ref oc) = cmd.on_conflict {
352 for col in &oc.columns {
353 check_ident("on_conflict.column", col)?;
354 }
355 }
356
357 for t in &cmd.from_tables {
359 check_ident("from_tables", t)?;
360 }
361 for t in &cmd.using_tables {
362 check_ident("using_tables", t)?;
363 }
364
365 for (_, sub) in &cmd.set_ops {
367 validate_ast(sub)?;
368 }
369
370 if let Some(ref sq) = cmd.source_query {
372 validate_ast(sq)?;
373 }
374
375 for cond in &cmd.having {
377 check_expr("having", &cond.left)?;
378 check_value("having", &cond.value)?;
379 }
380
381 if let Some(ref ch) = cmd.channel {
383 check_ident("channel", ch)?;
384 }
385
386 Ok(())
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::ast::Qail;
393
394 #[test]
395 fn valid_simple_query_passes() {
396 let cmd = Qail::get("users").columns(["id", "name"]);
397 assert!(validate_ast(&cmd).is_ok());
398 }
399
400 #[test]
401 fn sql_injection_in_table_rejected() {
402 let cmd = Qail::get("users; DROP TABLE users; --");
403 let err = validate_ast(&cmd).unwrap_err();
404 assert_eq!(err.field, "table");
405 }
406
407 #[test]
408 fn call_action_rejected() {
409 let cmd = Qail {
410 action: Action::Call,
411 table: "my_proc()".to_string(),
412 ..Default::default()
413 };
414 let err = validate_ast(&cmd).unwrap_err();
415 assert_eq!(err.field, "action");
416 }
417
418 #[test]
419 fn do_action_rejected() {
420 let cmd = Qail {
421 action: Action::Do,
422 table: "plpgsql".to_string(),
423 ..Default::default()
424 };
425 let err = validate_ast(&cmd).unwrap_err();
426 assert_eq!(err.field, "action");
427 }
428
429 #[test]
430 fn valid_qualified_name_passes() {
431 let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
432 assert!(validate_ast(&cmd).is_ok());
433 }
434
435 #[test]
436 fn injection_in_join_table_rejected() {
437 use crate::ast::JoinKind;
438 let cmd = Qail::get("users").join(
439 JoinKind::Left,
440 "orders; DROP TABLE x",
441 "users.id",
442 "orders.user_id",
443 );
444 let err = validate_ast(&cmd).unwrap_err();
445 assert!(err.field.contains("joins"));
446 }
447
448 #[test]
449 fn injection_in_column_rejected() {
450 let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
451 let err = validate_ast(&cmd).unwrap_err();
452 assert!(err.field.contains("columns"));
453 }
454
455 #[test]
456 fn empty_table_name_passes() {
457 let cmd = Qail {
459 action: Action::TxnStart,
460 table: String::new(),
461 ..Default::default()
462 };
463 assert!(validate_ast(&cmd).is_ok());
464 }
465
466 #[test]
467 fn oversized_identifier_rejected() {
468 let long_name = "a".repeat(64);
469 let cmd = Qail::get(&long_name);
470 let err = validate_ast(&cmd).unwrap_err();
471 assert!(err.reason.contains("63"));
472 }
473}