1use crate::ast::{Action, Expr, Qail, Value};
11use std::fmt;
12
13#[derive(Debug, Clone)]
15pub struct SanitizeError {
16 pub field: String,
17 pub value: String,
18 pub reason: String,
19}
20
21impl fmt::Display for SanitizeError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(
24 f,
25 "AST validation failed: {} '{}' — {}",
26 self.field, self.value, self.reason
27 )
28 }
29}
30
31impl std::error::Error for SanitizeError {}
32
33const MAX_IDENT_LEN: usize = 63;
35
36fn is_safe_identifier(s: &str) -> bool {
40 !s.is_empty()
41 && s.len() <= MAX_IDENT_LEN
42 && s.bytes()
43 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
44}
45
46fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
48 if is_safe_identifier(value) {
49 Ok(())
50 } else {
51 Err(SanitizeError {
52 field: field.to_string(),
53 value: value.chars().take(40).collect(),
54 reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
55 })
56 }
57}
58
59fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
65 match expr {
66 Expr::Star => Ok(()),
67 Expr::Named(name) => check_ident(field, name),
68 Expr::Aliased { name, alias } => {
69 check_ident(field, name)?;
70 check_ident(&format!("{field}.alias"), alias)
71 }
72 Expr::Aggregate {
73 col, alias, filter, ..
74 } => {
75 check_ident(field, col)?;
76 if let Some(a) = alias {
77 check_ident(&format!("{field}.alias"), a)?;
78 }
79 if let Some(conditions) = filter {
80 for cond in conditions {
81 check_expr(&format!("{field}.filter"), &cond.left)?;
82 }
83 }
84 Ok(())
85 }
86 Expr::FunctionCall { name, args, alias } => {
87 check_ident(field, name)?;
88 if let Some(a) = alias {
89 check_ident(&format!("{field}.alias"), a)?;
90 }
91 for arg in args {
92 check_expr(&format!("{field}.arg"), arg)?;
93 }
94 Ok(())
95 }
96 Expr::Cast {
97 expr,
98 target_type,
99 alias,
100 } => {
101 check_expr(field, expr)?;
102 check_ident(&format!("{field}.cast_type"), target_type)?;
103 if let Some(a) = alias {
104 check_ident(&format!("{field}.alias"), a)?;
105 }
106 Ok(())
107 }
108 Expr::Binary {
109 left, right, alias, ..
110 } => {
111 check_expr(field, left)?;
112 check_expr(field, right)?;
113 if let Some(a) = alias {
114 check_ident(&format!("{field}.alias"), a)?;
115 }
116 Ok(())
117 }
118 Expr::Raw(_) => Err(SanitizeError {
119 field: field.to_string(),
120 value: "(raw SQL)".to_string(),
121 reason: "Expr::Raw is not allowed in binary AST".to_string(),
122 }),
123 Expr::Literal(_) => Ok(()),
124 Expr::JsonAccess {
125 column,
126 alias,
127 path_segments,
128 ..
129 } => {
130 check_ident(field, column)?;
131 for (key, _) in path_segments {
132 if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
134 return Err(SanitizeError {
135 field: format!("{field}.json_path"),
136 value: key.chars().take(40).collect(),
137 reason: "JSON path key must be a safe identifier or integer".to_string(),
138 });
139 }
140 }
141 if let Some(a) = alias {
142 check_ident(&format!("{field}.alias"), a)?;
143 }
144 Ok(())
145 }
146 Expr::Subquery { query, alias } => {
147 validate_ast(query)?;
148 if let Some(a) = alias {
149 check_ident(&format!("{field}.alias"), a)?;
150 }
151 Ok(())
152 }
153 Expr::Exists { query, alias, .. } => {
154 validate_ast(query)?;
155 if let Some(a) = alias {
156 check_ident(&format!("{field}.alias"), a)?;
157 }
158 Ok(())
159 }
160 Expr::Window {
162 name,
163 func,
164 partition,
165 params,
166 order,
167 ..
168 } => {
169 if !name.is_empty() {
170 check_ident(&format!("{field}.window_alias"), name)?;
171 }
172 check_ident(&format!("{field}.window_func"), func)?;
173 for p in partition {
174 check_ident(&format!("{field}.partition"), p)?;
175 }
176 for p in params {
177 check_expr(&format!("{field}.window_param"), p)?;
178 }
179 for cage in order {
180 for cond in &cage.conditions {
181 check_expr(&format!("{field}.window_order"), &cond.left)?;
182 check_value(&format!("{field}.window_order"), &cond.value)?;
183 }
184 }
185 Ok(())
186 }
187 Expr::Case {
188 when_clauses,
189 else_value,
190 alias,
191 } => {
192 for (cond, val) in when_clauses {
193 check_expr(
194 &format!("{field}.case_when"),
195 &Expr::Named(cond.left.to_string()),
196 )?;
197 check_expr(&format!("{field}.case_then"), val)?;
198 }
199 if let Some(e) = else_value {
200 check_expr(&format!("{field}.case_else"), e)?;
201 }
202 if let Some(a) = alias {
203 check_ident(&format!("{field}.alias"), a)?;
204 }
205 Ok(())
206 }
207 Expr::SpecialFunction { args, alias, name } => {
208 check_ident(&format!("{field}.special_func"), name)?;
209 for (_, arg) in args {
210 check_expr(&format!("{field}.special_func_arg"), arg)?;
211 }
212 if let Some(a) = alias {
213 check_ident(&format!("{field}.alias"), a)?;
214 }
215 Ok(())
216 }
217 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
218 for elem in elements {
219 check_expr(&format!("{field}.element"), elem)?;
220 }
221 if let Some(a) = alias {
222 check_ident(&format!("{field}.alias"), a)?;
223 }
224 Ok(())
225 }
226 Expr::Subscript { expr, index, alias } => {
227 check_expr(&format!("{field}.subscript_expr"), expr)?;
228 check_expr(&format!("{field}.subscript_index"), index)?;
229 if let Some(a) = alias {
230 check_ident(&format!("{field}.alias"), a)?;
231 }
232 Ok(())
233 }
234 Expr::Collate {
235 expr,
236 collation,
237 alias,
238 } => {
239 check_expr(&format!("{field}.collate_expr"), expr)?;
240 check_ident(&format!("{field}.collation"), collation)?;
241 if let Some(a) = alias {
242 check_ident(&format!("{field}.alias"), a)?;
243 }
244 Ok(())
245 }
246 Expr::FieldAccess {
247 expr,
248 field: f,
249 alias,
250 } => {
251 check_expr(&format!("{field}.field_access_expr"), expr)?;
252 check_ident(&format!("{field}.field"), f)?;
253 if let Some(a) = alias {
254 check_ident(&format!("{field}.alias"), a)?;
255 }
256 Ok(())
257 }
258 Expr::Def { name, .. } => check_ident(field, name),
259 Expr::Mod { col, .. } => check_expr(field, col),
260 }
261}
262
263fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
265 match value {
266 Value::Subquery(q) => validate_ast(q),
267 Value::Array(vals) => {
268 for v in vals {
269 check_value(field, v)?;
270 }
271 Ok(())
272 }
273 Value::Expr(expr) => check_expr(field, expr),
274 _ => Ok(()),
275 }
276}
277
278pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
287 match cmd.action {
289 Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
290 return Err(SanitizeError {
291 field: "action".to_string(),
292 value: format!("{:?}", cmd.action),
293 reason: "procedural/session actions are not allowed via binary AST".to_string(),
294 });
295 }
296 _ => {}
297 }
298
299 if cmd.is_raw_sql() {
301 return Err(SanitizeError {
302 field: "table".to_string(),
303 value: "(raw SQL)".to_string(),
304 reason: "raw SQL pass-through is not allowed via binary AST".to_string(),
305 });
306 }
307
308 if !cmd.table.is_empty() {
310 check_ident("table", &cmd.table)?;
311 }
312
313 for (i, col) in cmd.columns.iter().enumerate() {
315 check_expr(&format!("columns[{i}]"), col)?;
316 }
317
318 for (i, join) in cmd.joins.iter().enumerate() {
320 for token in join.table.split_whitespace() {
323 check_ident(&format!("joins[{i}].table"), token)?;
324 }
325 if let Some(ref conditions) = join.on {
326 for cond in conditions {
327 check_expr(&format!("joins[{i}].on"), &cond.left)?;
328 check_value(&format!("joins[{i}].on"), &cond.value)?;
329 }
330 }
331 }
332
333 for cage in &cmd.cages {
335 for cond in &cage.conditions {
336 check_expr("cage.condition.left", &cond.left)?;
337 check_value("cage.condition.value", &cond.value)?;
338 }
339 }
340
341 for cte in &cmd.ctes {
343 check_ident("cte.name", &cte.name)?;
344 for col in &cte.columns {
345 check_ident("cte.column", col)?;
346 }
347 validate_ast(&cte.base_query)?;
348 if let Some(ref rq) = cte.recursive_query {
349 validate_ast(rq)?;
350 }
351 }
352
353 for expr in &cmd.distinct_on {
355 check_expr("distinct_on", expr)?;
356 }
357
358 if let Some(ref cols) = cmd.returning {
360 for col in cols {
361 check_expr("returning", col)?;
362 }
363 }
364
365 if let Some(ref oc) = cmd.on_conflict {
367 for col in &oc.columns {
368 check_ident("on_conflict.column", col)?;
369 }
370 }
371
372 for t in &cmd.from_tables {
374 check_ident("from_tables", t)?;
375 }
376 for t in &cmd.using_tables {
377 check_ident("using_tables", t)?;
378 }
379
380 for (_, sub) in &cmd.set_ops {
382 validate_ast(sub)?;
383 }
384
385 if let Some(ref sq) = cmd.source_query {
387 validate_ast(sq)?;
388 }
389
390 for cond in &cmd.having {
392 check_expr("having", &cond.left)?;
393 check_value("having", &cond.value)?;
394 }
395
396 if let Some(ref ch) = cmd.channel {
398 check_ident("channel", ch)?;
399 }
400
401 Ok(())
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use crate::ast::Qail;
408
409 #[test]
410 fn valid_simple_query_passes() {
411 let cmd = Qail::get("users").columns(["id", "name"]);
412 assert!(validate_ast(&cmd).is_ok());
413 }
414
415 #[test]
416 fn sql_injection_in_table_rejected() {
417 let cmd = Qail::get("users; DROP TABLE users; --");
418 let err = validate_ast(&cmd).unwrap_err();
419 assert_eq!(err.field, "table");
420 }
421
422 #[test]
423 fn raw_sql_rejected() {
424 let cmd = Qail::raw_sql("SELECT 1");
425 let err = validate_ast(&cmd).unwrap_err();
426 assert_eq!(err.field, "table");
427 }
428
429 #[test]
430 fn raw_expr_rejected() {
431 let cmd = Qail::get("users").columns_expr(vec![Expr::Raw("NOW()".to_string())]);
432 let err = validate_ast(&cmd).unwrap_err();
433 assert!(err.reason.contains("Raw"));
434 }
435
436 #[test]
437 fn call_action_rejected() {
438 let cmd = Qail {
439 action: Action::Call,
440 table: "my_proc()".to_string(),
441 ..Default::default()
442 };
443 let err = validate_ast(&cmd).unwrap_err();
444 assert_eq!(err.field, "action");
445 }
446
447 #[test]
448 fn do_action_rejected() {
449 let cmd = Qail {
450 action: Action::Do,
451 table: "plpgsql".to_string(),
452 ..Default::default()
453 };
454 let err = validate_ast(&cmd).unwrap_err();
455 assert_eq!(err.field, "action");
456 }
457
458 #[test]
459 fn valid_qualified_name_passes() {
460 let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
461 assert!(validate_ast(&cmd).is_ok());
462 }
463
464 #[test]
465 fn injection_in_join_table_rejected() {
466 use crate::ast::JoinKind;
467 let cmd = Qail::get("users").join(
468 JoinKind::Left,
469 "orders; DROP TABLE x",
470 "users.id",
471 "orders.user_id",
472 );
473 let err = validate_ast(&cmd).unwrap_err();
474 assert!(err.field.contains("joins"));
475 }
476
477 #[test]
478 fn injection_in_column_rejected() {
479 let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
480 let err = validate_ast(&cmd).unwrap_err();
481 assert!(err.field.contains("columns"));
482 }
483
484 #[test]
485 fn empty_table_name_passes() {
486 let cmd = Qail {
488 action: Action::TxnStart,
489 table: String::new(),
490 ..Default::default()
491 };
492 assert!(validate_ast(&cmd).is_ok());
493 }
494
495 #[test]
496 fn oversized_identifier_rejected() {
497 let long_name = "a".repeat(64);
498 let cmd = Qail::get(&long_name);
499 let err = validate_ast(&cmd).unwrap_err();
500 assert!(err.reason.contains("63"));
501 }
502}