1use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_CMD_TEXT_BATCH_COMMANDS: usize = MAX_AST_COLLECTION_LEN;
19const MAX_AST_STRING_LEN: usize = 32 * 1024;
20const MAX_AST_VECTOR_LEN: usize = 8_192;
21const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
22
23pub fn encode_cmd_text(cmd: &Qail) -> String {
25 let payload = cmd.to_string();
26 let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
27 out.push_str(CMD_TEXT_MAGIC);
28 out.push('\n');
29 out.push_str(&payload.len().to_string());
30 out.push('\n');
31 out.push_str(&payload);
32 out
33}
34
35pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
39 let bytes = input.as_bytes();
40 let mut idx = 0usize;
41
42 let Ok(magic) = read_line(bytes, &mut idx) else {
43 return crate::parse(input).map_err(|e| e.to_string());
44 };
45
46 if magic != CMD_TEXT_MAGIC {
47 return crate::parse(input).map_err(|e| e.to_string());
48 }
49
50 let len_line = read_line(bytes, &mut idx)?;
51 let payload_len = parse_usize("payload length", len_line)?;
52 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
53 if idx != bytes.len() {
54 return Err("trailing bytes after command payload".to_string());
55 }
56
57 crate::parse(payload).map_err(|e| e.to_string())
58}
59
60pub fn encode_cmds_text(cmds: &[Qail]) -> String {
62 let mut out = String::new();
63 out.push_str(CMDS_TEXT_MAGIC);
64 out.push('\n');
65 out.push_str(&cmds.len().to_string());
66 out.push('\n');
67
68 for cmd in cmds {
69 let payload = cmd.to_string();
70 out.push_str(&payload.len().to_string());
71 out.push('\n');
72 out.push_str(&payload);
73 }
74
75 out
76}
77
78pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
80 let bytes = input.as_bytes();
81 let mut idx = 0usize;
82
83 let magic = read_line(bytes, &mut idx)?;
84 if magic != CMDS_TEXT_MAGIC {
85 return Err(format!(
86 "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
87 ));
88 }
89
90 let count_line = read_line(bytes, &mut idx)?;
91 let count = parse_usize("command count", count_line)?;
92 if count > MAX_CMD_TEXT_BATCH_COMMANDS {
93 return Err(format!(
94 "command count exceeds limit: {count} > {MAX_CMD_TEXT_BATCH_COMMANDS}"
95 ));
96 }
97 let mut out = Vec::with_capacity(count);
98
99 for _ in 0..count {
100 let len_line = read_line(bytes, &mut idx)?;
101 let payload_len = parse_usize("payload length", len_line)?;
102 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
103 let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
104 out.push(cmd);
105 }
106
107 if idx != bytes.len() {
108 return Err("trailing bytes after batch payload".to_string());
109 }
110
111 Ok(out)
112}
113
114pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
116 validate_binary_ast_limits(cmd)?;
117 crate::sanitize::validate_ast(cmd).map_err(|e| e.to_string())?;
118
119 let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
120 if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
121 return Err(format!(
122 "binary AST payload too large: {} bytes (max {})",
123 payload.len(),
124 MAX_CMD_BINARY_PAYLOAD_BYTES
125 ));
126 }
127
128 let payload_len = u32::try_from(payload.len())
129 .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
130 let mut out = Vec::with_capacity(8 + payload.len());
131 out.extend_from_slice(&CMD_BIN_MAGIC);
132 out.extend_from_slice(&payload_len.to_be_bytes());
133 out.extend_from_slice(&payload);
134 Ok(out)
135}
136
137pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
141 let payload = decode_cmd_binary_payload(input)?;
142 let mut deserializer = serde_json::Deserializer::from_slice(payload);
143 let cmd = serde::Deserialize::deserialize(&mut deserializer)
144 .map_err(|e| format!("binary AST decode failed: {e}"))?;
145 deserializer
146 .end()
147 .map_err(|_| "trailing bytes after AST payload".to_string())?;
148 validate_binary_ast_limits(&cmd)?;
149 crate::sanitize::validate_ast(&cmd).map_err(|e| e.to_string())?;
150 Ok(cmd)
151}
152
153pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
157 if input.len() < 8 {
158 return Err("invalid wire header".to_string());
159 }
160 if input[0..4] != CMD_BIN_MAGIC {
161 if input[0..4] == CMD_BIN_LEGACY_MAGIC {
162 return Err(
163 "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
164 );
165 }
166 return Err("invalid wire header".to_string());
167 }
168
169 let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
170 if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
171 return Err(format!(
172 "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
173 ));
174 }
175 if input.len() != 8 + len {
176 return Err(format!(
177 "invalid payload length: header={len}, actual={}",
178 input.len().saturating_sub(8)
179 ));
180 }
181 Ok(&input[8..])
182}
183
184#[derive(Default)]
185struct AstLimitState {
186 nodes: usize,
187}
188
189impl AstLimitState {
190 fn bump(&mut self, kind: &str) -> Result<(), String> {
191 self.nodes = self
192 .nodes
193 .checked_add(1)
194 .ok_or_else(|| "AST node counter overflow".to_string())?;
195 if self.nodes > MAX_AST_NODES {
196 return Err(format!(
197 "AST node limit exceeded while walking {kind}: {} > {}",
198 self.nodes, MAX_AST_NODES
199 ));
200 }
201 Ok(())
202 }
203}
204
205fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
206 if depth > MAX_AST_DEPTH {
207 return Err(format!(
208 "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
209 ));
210 }
211 Ok(())
212}
213
214fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
215 if len > max {
216 return Err(format!("{kind} exceeds limit: {len} > {max}"));
217 }
218 Ok(())
219}
220
221fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
222 ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
223}
224
225fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
226 let mut state = AstLimitState::default();
227 validate_qail_limits(cmd, 0, &mut state)
228}
229
230fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
231 use crate::ast::GroupByMode;
232
233 ensure_depth(depth, "Qail")?;
234 state.bump("Qail")?;
235
236 ensure_str("qail.table", &cmd.table)?;
237 ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
238 for expr in &cmd.columns {
239 validate_expr_limits(expr, depth + 1, state)?;
240 }
241
242 ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
243 for join in &cmd.joins {
244 validate_join_limits(join, depth + 1, state)?;
245 }
246
247 ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
248 for cage in &cmd.cages {
249 validate_cage_limits(cage, depth + 1, state)?;
250 }
251
252 if let Some(index_def) = &cmd.index_def {
253 validate_index_def_limits(index_def)?;
254 }
255
256 ensure_len(
257 "qail.table_constraints",
258 cmd.table_constraints.len(),
259 MAX_AST_COLLECTION_LEN,
260 )?;
261 for constraint in &cmd.table_constraints {
262 match constraint {
263 crate::ast::TableConstraint::Unique(cols)
264 | crate::ast::TableConstraint::PrimaryKey(cols) => {
265 ensure_len(
266 "qail.table_constraint.columns",
267 cols.len(),
268 MAX_AST_COLLECTION_LEN,
269 )?;
270 for col in cols {
271 ensure_str("qail.table_constraint.column", col)?;
272 }
273 }
274 crate::ast::TableConstraint::ForeignKey {
275 name,
276 columns,
277 ref_table,
278 ref_columns,
279 on_delete,
280 on_update,
281 deferrable,
282 } => {
283 if let Some(name) = name {
284 ensure_str("qail.table_constraint.name", name)?;
285 }
286 ensure_len(
287 "qail.table_constraint.columns",
288 columns.len(),
289 MAX_AST_COLLECTION_LEN,
290 )?;
291 for col in columns {
292 ensure_str("qail.table_constraint.column", col)?;
293 }
294 ensure_str("qail.table_constraint.ref_table", ref_table)?;
295 ensure_len(
296 "qail.table_constraint.ref_columns",
297 ref_columns.len(),
298 MAX_AST_COLLECTION_LEN,
299 )?;
300 for col in ref_columns {
301 ensure_str("qail.table_constraint.ref_column", col)?;
302 }
303 if let Some(action) = on_delete {
304 ensure_str("qail.table_constraint.on_delete", action)?;
305 }
306 if let Some(action) = on_update {
307 ensure_str("qail.table_constraint.on_update", action)?;
308 }
309 if let Some(clause) = deferrable {
310 ensure_str("qail.table_constraint.deferrable", clause)?;
311 }
312 }
313 }
314 }
315
316 ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
317 for (_, rhs) in &cmd.set_ops {
318 validate_qail_limits(rhs, depth + 1, state)?;
319 }
320
321 ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
322 for cond in &cmd.having {
323 validate_condition_limits(cond, depth + 1, state)?;
324 }
325
326 if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
327 ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
328 for group in groups {
329 ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
330 for col in group {
331 ensure_str("qail.grouping_set.column", col)?;
332 }
333 }
334 }
335
336 ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
337 for cte in &cmd.ctes {
338 ensure_str("qail.cte.name", &cte.name)?;
339 ensure_len(
340 "qail.cte.columns",
341 cte.columns.len(),
342 MAX_AST_COLLECTION_LEN,
343 )?;
344 for col in &cte.columns {
345 ensure_str("qail.cte.column", col)?;
346 }
347 validate_qail_limits(&cte.base_query, depth + 1, state)?;
348 if let Some(recursive) = &cte.recursive_query {
349 validate_qail_limits(recursive, depth + 1, state)?;
350 }
351 if let Some(source_table) = &cte.source_table {
352 ensure_str("qail.cte.source_table", source_table)?;
353 }
354 }
355
356 ensure_len(
357 "qail.distinct_on",
358 cmd.distinct_on.len(),
359 MAX_AST_COLLECTION_LEN,
360 )?;
361 for expr in &cmd.distinct_on {
362 validate_expr_limits(expr, depth + 1, state)?;
363 }
364
365 if let Some(returning) = &cmd.returning {
366 ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
367 for expr in returning {
368 validate_expr_limits(expr, depth + 1, state)?;
369 }
370 }
371
372 if let Some(on_conflict) = &cmd.on_conflict {
373 ensure_len(
374 "qail.on_conflict.columns",
375 on_conflict.columns.len(),
376 MAX_AST_COLLECTION_LEN,
377 )?;
378 for col in &on_conflict.columns {
379 ensure_str("qail.on_conflict.column", col)?;
380 }
381 if let Some(assignments) = on_conflict.action.update_assignments() {
382 ensure_len(
383 "qail.on_conflict.assignments",
384 assignments.len(),
385 MAX_AST_COLLECTION_LEN,
386 )?;
387 for (col, expr) in assignments {
388 ensure_str("qail.on_conflict.assignment.column", col)?;
389 validate_expr_limits(expr, depth + 1, state)?;
390 }
391 }
392 }
393
394 if let Some(merge) = &cmd.merge {
395 if let Some(alias) = &merge.target_alias {
396 ensure_str("qail.merge.target_alias", alias)?;
397 }
398 match &merge.source {
399 crate::ast::MergeSource::Table { name, alias } => {
400 ensure_str("qail.merge.source.table", name)?;
401 if let Some(alias) = alias {
402 ensure_str("qail.merge.source.alias", alias)?;
403 }
404 }
405 crate::ast::MergeSource::Query { query, alias } => {
406 validate_qail_limits(query, depth + 1, state)?;
407 if let Some(alias) = alias {
408 ensure_str("qail.merge.source.alias", alias)?;
409 }
410 }
411 }
412 ensure_len("qail.merge.on", merge.on.len(), MAX_AST_COLLECTION_LEN)?;
413 for condition in &merge.on {
414 validate_condition_limits(condition, depth + 1, state)?;
415 }
416 ensure_len(
417 "qail.merge.clauses",
418 merge.clauses.len(),
419 MAX_AST_COLLECTION_LEN,
420 )?;
421 for clause in &merge.clauses {
422 ensure_len(
423 "qail.merge.clause.condition",
424 clause.condition.len(),
425 MAX_AST_COLLECTION_LEN,
426 )?;
427 for condition in &clause.condition {
428 validate_condition_limits(condition, depth + 1, state)?;
429 }
430 match &clause.action {
431 crate::ast::MergeAction::Update { assignments } => {
432 ensure_len(
433 "qail.merge.update.assignments",
434 assignments.len(),
435 MAX_AST_COLLECTION_LEN,
436 )?;
437 for (col, expr) in assignments {
438 ensure_str("qail.merge.update.column", col)?;
439 validate_expr_limits(expr, depth + 1, state)?;
440 }
441 }
442 crate::ast::MergeAction::Insert { columns, values } => {
443 ensure_len(
444 "qail.merge.insert.columns",
445 columns.len(),
446 MAX_AST_COLLECTION_LEN,
447 )?;
448 for col in columns {
449 ensure_str("qail.merge.insert.column", col)?;
450 }
451 ensure_len(
452 "qail.merge.insert.values",
453 values.len(),
454 MAX_AST_COLLECTION_LEN,
455 )?;
456 for expr in values {
457 validate_expr_limits(expr, depth + 1, state)?;
458 }
459 }
460 crate::ast::MergeAction::Delete | crate::ast::MergeAction::DoNothing => {}
461 }
462 }
463 }
464
465 if let Some(source_query) = &cmd.source_query {
466 validate_qail_limits(source_query, depth + 1, state)?;
467 }
468
469 if let Some(channel) = &cmd.channel {
470 ensure_str("qail.channel", channel)?;
471 }
472 if let Some(payload) = &cmd.payload {
473 ensure_str("qail.payload", payload)?;
474 }
475 if let Some(savepoint_name) = &cmd.savepoint_name {
476 ensure_str("qail.savepoint_name", savepoint_name)?;
477 }
478
479 ensure_len(
480 "qail.from_tables",
481 cmd.from_tables.len(),
482 MAX_AST_COLLECTION_LEN,
483 )?;
484 for table in &cmd.from_tables {
485 ensure_str("qail.from_table", table)?;
486 }
487
488 ensure_len(
489 "qail.using_tables",
490 cmd.using_tables.len(),
491 MAX_AST_COLLECTION_LEN,
492 )?;
493 for table in &cmd.using_tables {
494 ensure_str("qail.using_table", table)?;
495 }
496
497 if let Some((_, percent, _seed)) = cmd.sample
498 && !percent.is_finite()
499 {
500 return Err("qail.sample.percent must be finite".to_string());
501 }
502
503 if let Some(vector) = &cmd.vector {
504 ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
505 }
506 if let Some(vector_name) = &cmd.vector_name {
507 ensure_str("qail.vector_name", vector_name)?;
508 }
509 if let Some(function_def) = &cmd.function_def {
510 validate_function_def_limits(function_def)?;
511 }
512 if let Some(trigger_def) = &cmd.trigger_def {
513 validate_trigger_def_limits(trigger_def)?;
514 }
515 if let Some(policy_def) = &cmd.policy_def {
516 validate_policy_def_limits(policy_def, depth + 1, state)?;
517 }
518
519 Ok(())
520}
521
522fn validate_join_limits(
523 join: &crate::ast::Join,
524 depth: usize,
525 state: &mut AstLimitState,
526) -> Result<(), String> {
527 ensure_depth(depth, "Join")?;
528 state.bump("Join")?;
529 ensure_str("join.table", &join.table)?;
530 if let Some(on) = &join.on {
531 ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
532 for cond in on {
533 validate_condition_limits(cond, depth + 1, state)?;
534 }
535 }
536 Ok(())
537}
538
539fn validate_cage_limits(
540 cage: &crate::ast::Cage,
541 depth: usize,
542 state: &mut AstLimitState,
543) -> Result<(), String> {
544 use crate::ast::CageKind;
545
546 ensure_depth(depth, "Cage")?;
547 state.bump("Cage")?;
548 ensure_len(
549 "cage.conditions",
550 cage.conditions.len(),
551 MAX_AST_COLLECTION_LEN,
552 )?;
553 for cond in &cage.conditions {
554 validate_condition_limits(cond, depth + 1, state)?;
555 }
556 match cage.kind {
557 CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
558 ensure_len("cage.numeric", v, usize::MAX)?;
559 }
560 _ => {}
561 }
562 Ok(())
563}
564
565fn validate_condition_limits(
566 cond: &crate::ast::Condition,
567 depth: usize,
568 state: &mut AstLimitState,
569) -> Result<(), String> {
570 ensure_depth(depth, "Condition")?;
571 state.bump("Condition")?;
572 validate_expr_limits(&cond.left, depth + 1, state)?;
573 validate_value_limits(&cond.value, depth + 1, state)
574}
575
576fn validate_expr_limits(
577 expr: &crate::ast::Expr,
578 depth: usize,
579 state: &mut AstLimitState,
580) -> Result<(), String> {
581 use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
582
583 ensure_depth(depth, "Expr")?;
584 state.bump("Expr")?;
585
586 match expr {
587 Expr::Star => {}
588 Expr::Named(name) => ensure_str("expr.named", name)?,
589 Expr::Aliased { name, alias } => {
590 ensure_str("expr.aliased.name", name)?;
591 ensure_str("expr.aliased.alias", alias)?;
592 }
593 Expr::Aggregate {
594 col, filter, alias, ..
595 } => {
596 ensure_str("expr.aggregate.col", col)?;
597 if let Some(filters) = filter {
598 ensure_len(
599 "expr.aggregate.filter",
600 filters.len(),
601 MAX_AST_COLLECTION_LEN,
602 )?;
603 for cond in filters {
604 validate_condition_limits(cond, depth + 1, state)?;
605 }
606 }
607 if let Some(alias) = alias {
608 ensure_str("expr.aggregate.alias", alias)?;
609 }
610 }
611 Expr::Cast {
612 expr,
613 target_type,
614 alias,
615 } => {
616 validate_expr_limits(expr, depth + 1, state)?;
617 ensure_str("expr.cast.target_type", target_type)?;
618 if let Some(alias) = alias {
619 ensure_str("expr.cast.alias", alias)?;
620 }
621 }
622 Expr::Def {
623 name,
624 data_type,
625 constraints,
626 } => {
627 ensure_str("expr.def.name", name)?;
628 ensure_str("expr.def.data_type", data_type)?;
629 ensure_len(
630 "expr.def.constraints",
631 constraints.len(),
632 MAX_AST_COLLECTION_LEN,
633 )?;
634 for constraint in constraints {
635 match constraint {
636 Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
637 Constraint::Default(v) => ensure_str("expr.def.default", v)?,
638 Constraint::Check(values) => {
639 ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
640 for value in values {
641 ensure_str("expr.def.check.value", value)?;
642 }
643 }
644 Constraint::Comment(v) | Constraint::References(v) => {
645 ensure_str("expr.def.constraint", v)?;
646 }
647 Constraint::Generated(ColumnGeneration::Stored(v))
648 | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
649 ensure_str("expr.def.generated", v)?;
650 }
651 }
652 }
653 }
654 Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
655 Expr::Window {
656 name,
657 func,
658 params,
659 partition,
660 order,
661 frame,
662 } => {
663 ensure_str("expr.window.name", name)?;
664 ensure_str("expr.window.func", func)?;
665 ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
666 for param in params {
667 validate_expr_limits(param, depth + 1, state)?;
668 }
669 ensure_len(
670 "expr.window.partition",
671 partition.len(),
672 MAX_AST_COLLECTION_LEN,
673 )?;
674 for col in partition {
675 ensure_str("expr.window.partition.column", col)?;
676 }
677 ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
678 for cage in order {
679 validate_cage_limits(cage, depth + 1, state)?;
680 }
681 if let Some(frame) = frame {
682 match frame {
683 WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
684 }
685 }
686 }
687 Expr::Case {
688 when_clauses,
689 else_value,
690 alias,
691 } => {
692 ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
693 for (cond, then_expr) in when_clauses {
694 validate_condition_limits(cond, depth + 1, state)?;
695 validate_expr_limits(then_expr, depth + 1, state)?;
696 }
697 if let Some(else_expr) = else_value {
698 validate_expr_limits(else_expr, depth + 1, state)?;
699 }
700 if let Some(alias) = alias {
701 ensure_str("expr.case.alias", alias)?;
702 }
703 }
704 Expr::JsonAccess {
705 column,
706 path_segments,
707 alias,
708 } => {
709 ensure_str("expr.json_access.column", column)?;
710 ensure_len(
711 "expr.json_access.path_segments",
712 path_segments.len(),
713 MAX_AST_COLLECTION_LEN,
714 )?;
715 for (segment, _) in path_segments {
716 ensure_str("expr.json_access.segment", segment)?;
717 }
718 if let Some(alias) = alias {
719 ensure_str("expr.json_access.alias", alias)?;
720 }
721 }
722 Expr::FunctionCall { name, args, alias } => {
723 ensure_str("expr.function_call.name", name)?;
724 ensure_len(
725 "expr.function_call.args",
726 args.len(),
727 MAX_AST_COLLECTION_LEN,
728 )?;
729 for arg in args {
730 validate_expr_limits(arg, depth + 1, state)?;
731 }
732 if let Some(alias) = alias {
733 ensure_str("expr.function_call.alias", alias)?;
734 }
735 }
736 Expr::SpecialFunction { name, args, alias } => {
737 ensure_str("expr.special_function.name", name)?;
738 ensure_len(
739 "expr.special_function.args",
740 args.len(),
741 MAX_AST_COLLECTION_LEN,
742 )?;
743 for (keyword, arg) in args {
744 if let Some(keyword) = keyword {
745 ensure_str("expr.special_function.keyword", keyword)?;
746 }
747 validate_expr_limits(arg, depth + 1, state)?;
748 }
749 if let Some(alias) = alias {
750 ensure_str("expr.special_function.alias", alias)?;
751 }
752 }
753 Expr::Binary {
754 left, right, alias, ..
755 } => {
756 validate_expr_limits(left, depth + 1, state)?;
757 validate_expr_limits(right, depth + 1, state)?;
758 if let Some(alias) = alias {
759 ensure_str("expr.binary.alias", alias)?;
760 }
761 }
762 Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
763 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
764 ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
765 for el in elements {
766 validate_expr_limits(el, depth + 1, state)?;
767 }
768 if let Some(alias) = alias {
769 ensure_str("expr.elements.alias", alias)?;
770 }
771 }
772 Expr::Subscript { expr, index, alias } => {
773 validate_expr_limits(expr, depth + 1, state)?;
774 validate_expr_limits(index, depth + 1, state)?;
775 if let Some(alias) = alias {
776 ensure_str("expr.subscript.alias", alias)?;
777 }
778 }
779 Expr::Collate {
780 expr,
781 collation,
782 alias,
783 } => {
784 validate_expr_limits(expr, depth + 1, state)?;
785 ensure_str("expr.collate.collation", collation)?;
786 if let Some(alias) = alias {
787 ensure_str("expr.collate.alias", alias)?;
788 }
789 }
790 Expr::FieldAccess { expr, field, alias } => {
791 validate_expr_limits(expr, depth + 1, state)?;
792 ensure_str("expr.field_access.field", field)?;
793 if let Some(alias) = alias {
794 ensure_str("expr.field_access.alias", alias)?;
795 }
796 }
797 Expr::Subquery { query, alias } => {
798 validate_qail_limits(query, depth + 1, state)?;
799 if let Some(alias) = alias {
800 ensure_str("expr.subquery.alias", alias)?;
801 }
802 }
803 Expr::Exists { query, alias, .. } => {
804 validate_qail_limits(query, depth + 1, state)?;
805 if let Some(alias) = alias {
806 ensure_str("expr.exists.alias", alias)?;
807 }
808 }
809 }
810
811 Ok(())
812}
813
814fn validate_value_limits(
815 value: &crate::ast::Value,
816 depth: usize,
817 state: &mut AstLimitState,
818) -> Result<(), String> {
819 use crate::ast::Value;
820
821 ensure_depth(depth, "Value")?;
822 state.bump("Value")?;
823
824 match value {
825 Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
826 Value::String(v)
827 | Value::NamedParam(v)
828 | Value::Function(v)
829 | Value::Column(v)
830 | Value::Timestamp(v)
831 | Value::Json(v) => ensure_str("value.string", v)?,
832 Value::Array(values) => {
833 ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
834 for v in values {
835 validate_value_limits(v, depth + 1, state)?;
836 }
837 }
838 Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
839 Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
840 Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
841 Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
842 Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
843 }
844
845 Ok(())
846}
847
848fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
849 ensure_str("index_def.name", &index_def.name)?;
850 ensure_str("index_def.table", &index_def.table)?;
851 ensure_len(
852 "index_def.columns",
853 index_def.columns.len(),
854 MAX_AST_COLLECTION_LEN,
855 )?;
856 for col in &index_def.columns {
857 ensure_str("index_def.column", col)?;
858 }
859 if let Some(index_type) = &index_def.index_type {
860 ensure_str("index_def.index_type", index_type)?;
861 }
862 if let Some(where_clause) = &index_def.where_clause {
863 ensure_str("index_def.where_clause", where_clause)?;
864 }
865 Ok(())
866}
867
868fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
869 ensure_str("function_def.name", &function_def.name)?;
870 ensure_len(
871 "function_def.args",
872 function_def.args.len(),
873 MAX_AST_COLLECTION_LEN,
874 )?;
875 for arg in &function_def.args {
876 ensure_str("function_def.arg", arg)?;
877 }
878 ensure_str("function_def.returns", &function_def.returns)?;
879 ensure_str("function_def.body", &function_def.body)?;
880 if let Some(language) = &function_def.language {
881 ensure_str("function_def.language", language)?;
882 }
883 if let Some(volatility) = &function_def.volatility {
884 ensure_str("function_def.volatility", volatility)?;
885 }
886 Ok(())
887}
888
889fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
890 ensure_str("trigger_def.name", &trigger_def.name)?;
891 ensure_str("trigger_def.table", &trigger_def.table)?;
892 ensure_len(
893 "trigger_def.events",
894 trigger_def.events.len(),
895 MAX_AST_COLLECTION_LEN,
896 )?;
897 ensure_len(
898 "trigger_def.update_columns",
899 trigger_def.update_columns.len(),
900 MAX_AST_COLLECTION_LEN,
901 )?;
902 for col in &trigger_def.update_columns {
903 ensure_str("trigger_def.update_column", col)?;
904 }
905 ensure_str(
906 "trigger_def.execute_function",
907 &trigger_def.execute_function,
908 )?;
909 Ok(())
910}
911
912fn validate_policy_def_limits(
913 policy_def: &crate::migrate::policy::RlsPolicy,
914 depth: usize,
915 state: &mut AstLimitState,
916) -> Result<(), String> {
917 ensure_str("policy_def.name", &policy_def.name)?;
918 ensure_str("policy_def.table", &policy_def.table)?;
919 if let Some(using_expr) = &policy_def.using {
920 validate_expr_limits(using_expr, depth + 1, state)?;
921 }
922 if let Some(with_check_expr) = &policy_def.with_check {
923 validate_expr_limits(with_check_expr, depth + 1, state)?;
924 }
925 if let Some(role) = &policy_def.role {
926 ensure_str("policy_def.role", role)?;
927 }
928 Ok(())
929}
930
931fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
932 if *idx >= bytes.len() {
933 return Err("unexpected EOF".to_string());
934 }
935
936 let start = *idx;
937 while *idx < bytes.len() && bytes[*idx] != b'\n' {
938 *idx += 1;
939 }
940
941 if *idx >= bytes.len() {
942 return Err("unterminated header line".to_string());
943 }
944
945 let line =
946 std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
947 *idx += 1; Ok(line)
949}
950
951fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
952 line.parse::<usize>()
953 .map_err(|_| format!("invalid {field}: {line}"))
954}
955
956fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
957 let end = (*idx)
958 .checked_add(len)
959 .ok_or_else(|| "payload length overflow".to_string())?;
960 if end > bytes.len() {
961 return Err("payload truncated".to_string());
962 }
963 let start = *idx;
964 *idx = end;
965 std::str::from_utf8(&bytes[start..end]).map_err(|_| "payload is not UTF-8".to_string())
966}
967
968#[cfg(test)]
969mod tests {
970 use super::*;
971 use proptest::prelude::*;
972
973 fn encode_cmd_binary_unchecked_for_test(cmd: &Qail) -> Vec<u8> {
974 let payload = serde_json::to_vec(cmd).expect("test binary AST encode");
975 let mut out = Vec::with_capacity(8 + payload.len());
976 out.extend_from_slice(&CMD_BIN_MAGIC);
977 out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
978 out.extend_from_slice(&payload);
979 out
980 }
981
982 #[test]
983 fn cmd_text_roundtrip() {
984 let cmd = crate::ast::Qail::get("users")
985 .columns(["id", "email"])
986 .where_eq("active", true)
987 .limit(10);
988
989 let encoded = encode_cmd_text(&cmd);
990 let decoded = decode_cmd_text(&encoded).unwrap();
991 assert_eq!(decoded.to_string(), cmd.to_string());
992 }
993
994 #[test]
995 fn cmd_binary_roundtrip() {
996 let cmd = crate::ast::Qail::set("users")
997 .set_value("active", true)
998 .where_eq("id", 7);
999
1000 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1001 let decoded = decode_cmd_binary(&encoded).unwrap();
1002 assert_eq!(decoded.to_string(), cmd.to_string());
1003 }
1004
1005 #[test]
1006 fn cmd_binary_roundtrip_preserves_merge_ast() {
1007 let cmd = crate::ast::Qail::merge_into("users")
1008 .target_alias("u")
1009 .using_table_as("staging_users", "s")
1010 .merge_on_column("u.id", crate::ast::Operator::Eq, "s.id")
1011 .when_matched_update(&[("name", crate::ast::Expr::Named("s.name".to_string()))])
1012 .when_not_matched_insert(
1013 &["id", "name"],
1014 &[
1015 crate::ast::Expr::Named("s.id".to_string()),
1016 crate::ast::Expr::Named("s.name".to_string()),
1017 ],
1018 );
1019
1020 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1021 let decoded = decode_cmd_binary(&encoded).unwrap();
1022 assert_eq!(decoded, cmd);
1023 }
1024
1025 #[test]
1026 fn cmd_binary_payload_rejects_oversized_header() {
1027 let mut payload = Vec::new();
1028 payload.extend_from_slice(&CMD_BIN_MAGIC);
1029 payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
1030 payload.extend_from_slice(&[]);
1031
1032 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1033 assert!(err.contains("binary AST payload too large"));
1034 }
1035
1036 #[test]
1037 fn cmd_binary_payload_roundtrip() {
1038 let cmd = crate::ast::Qail::get("users").limit(3);
1039 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1040 let payload = decode_cmd_binary_payload(&encoded).unwrap();
1041 let mut deserializer = serde_json::Deserializer::from_slice(payload);
1042 let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
1043 deserializer.end().unwrap();
1044 assert_eq!(decoded.to_string(), cmd.to_string());
1045 }
1046
1047 #[test]
1048 fn cmd_binary_payload_rejects_legacy_qwb1() {
1049 let legacy_text = b"get users limit 1";
1050 let mut payload = Vec::new();
1051 payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
1052 payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
1053 payload.extend_from_slice(legacy_text);
1054
1055 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1056 assert!(err.contains("legacy QWB1"));
1057 }
1058
1059 #[test]
1060 fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
1061 let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
1062 assert!(err.contains("invalid wire header"));
1063 }
1064
1065 #[test]
1066 fn cmd_binary_decode_rejects_trailing_bytes() {
1067 let cmd = crate::ast::Qail::get("users").limit(1);
1068 let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
1069 encoded.extend_from_slice(&[0xAA, 0xBB]);
1070 let err = decode_cmd_binary(&encoded).unwrap_err();
1071 assert!(err.contains("invalid payload length"));
1072 }
1073
1074 #[test]
1075 fn cmd_binary_decode_rejects_unsafe_identifiers() {
1076 let cmd = crate::ast::Qail::get("users; DROP TABLE users; --").limit(1);
1077 let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1078
1079 let err = decode_cmd_binary(&encoded).unwrap_err();
1080
1081 assert!(err.contains("AST validation failed"));
1082 assert!(err.contains("table"));
1083 }
1084
1085 #[test]
1086 fn cmd_binary_encode_rejects_unsafe_identifiers() {
1087 let cmd = crate::ast::Qail::get("users; DROP TABLE users; --").limit(1);
1088
1089 let err = encode_cmd_binary(&cmd).unwrap_err();
1090
1091 assert!(err.contains("AST validation failed"));
1092 assert!(err.contains("table"));
1093 }
1094
1095 #[test]
1096 fn cmd_binary_decode_rejects_procedural_actions() {
1097 let cmd = crate::ast::Qail::call("refresh_materialized_views()");
1098 let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1099
1100 let err = decode_cmd_binary(&encoded).unwrap_err();
1101
1102 assert!(err.contains("AST validation failed"));
1103 assert!(err.contains("procedural/session actions"));
1104 }
1105
1106 #[test]
1107 fn cmd_binary_encode_rejects_procedural_actions() {
1108 let cmd = crate::ast::Qail::call("refresh_materialized_views()");
1109
1110 let err = encode_cmd_binary(&cmd).unwrap_err();
1111
1112 assert!(err.contains("AST validation failed"));
1113 assert!(err.contains("procedural/session actions"));
1114 }
1115
1116 #[test]
1117 fn cmd_binary_decode_rejects_unsafe_raw_function_values() {
1118 let cmd = crate::ast::Qail::get("users").filter(
1119 "updated_at",
1120 crate::ast::Operator::Lt,
1121 crate::ast::Value::Function("NOW(); DROP TABLE users; --".to_string()),
1122 );
1123 let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1124
1125 let err = decode_cmd_binary(&encoded).unwrap_err();
1126
1127 assert!(err.contains("AST validation failed"));
1128 assert!(err.contains("raw function values"));
1129 }
1130
1131 #[test]
1132 fn cmd_binary_encode_rejects_unsafe_raw_function_values() {
1133 let cmd = crate::ast::Qail::get("users").filter(
1134 "updated_at",
1135 crate::ast::Operator::Lt,
1136 crate::ast::Value::Function("NOW(); DROP TABLE users; --".to_string()),
1137 );
1138
1139 let err = encode_cmd_binary(&cmd).unwrap_err();
1140
1141 assert!(err.contains("AST validation failed"));
1142 assert!(err.contains("raw function values"));
1143 }
1144
1145 #[test]
1146 fn cmd_binary_encode_rejects_unsafe_check_constraint_payload() {
1147 let cmd = crate::ast::Qail {
1148 action: crate::ast::Action::AlterAddConstraint,
1149 table: "users".to_string(),
1150 channel: Some("users_active_check".to_string()),
1151 payload: Some("active); DROP TABLE users; --".to_string()),
1152 ..Default::default()
1153 };
1154
1155 let err = encode_cmd_binary(&cmd).unwrap_err();
1156
1157 assert!(err.contains("AST validation failed"));
1158 assert!(err.contains("SQL expression fragments"));
1159 }
1160
1161 #[test]
1162 fn cmd_binary_decode_enforces_depth_limits() {
1163 let mut nested = crate::ast::Qail::get("users").limit(1);
1164 for _ in 0..(MAX_AST_DEPTH + 2) {
1165 nested = crate::ast::Qail {
1166 action: crate::ast::Action::Get,
1167 table: "users".to_string(),
1168 columns: vec![crate::ast::Expr::Subquery {
1169 query: Box::new(nested),
1170 alias: None,
1171 }],
1172 ..crate::ast::Qail::default()
1173 };
1174 }
1175
1176 let encoded = encode_cmd_binary_unchecked_for_test(&nested);
1177 let err = decode_cmd_binary(&encoded).unwrap_err();
1178 assert!(
1179 err.contains("AST depth limit exceeded")
1180 || err.contains("binary AST decode failed")
1181 || err.contains("recursion limit exceeded")
1182 );
1183 }
1184
1185 #[test]
1186 fn cmd_binary_decode_bitflip_corpus_no_panic() {
1187 let seeds = vec![
1188 encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
1189 encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
1190 .expect("binary encode"),
1191 vec![],
1192 b"QWB2garbage".to_vec(),
1193 vec![0u8; 32],
1194 ];
1195
1196 for seed in seeds {
1197 for i in 0..seed.len().min(128) {
1198 for bit in 0..8u8 {
1199 let mut mutated = seed.clone();
1200 mutated[i] ^= 1 << bit;
1201 let _ = decode_cmd_binary(&mutated);
1202 }
1203 }
1204 let _ = decode_cmd_binary(&seed);
1205 }
1206 }
1207
1208 proptest! {
1209 #[test]
1210 fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
1211 let _ = decode_cmd_binary(&data);
1212 }
1213 }
1214
1215 #[test]
1216 fn cmds_text_roundtrip() {
1217 let cmds = vec![
1218 crate::ast::Qail::get("users").columns(["id", "email"]),
1219 crate::ast::Qail::get("users").limit(1),
1220 crate::ast::Qail::del("users").where_eq("id", 99),
1221 ];
1222
1223 let encoded = encode_cmds_text(&cmds);
1224 let decoded = decode_cmds_text(&encoded).unwrap();
1225 assert_eq!(decoded.len(), cmds.len());
1226 for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
1227 assert_eq!(lhs.to_string(), rhs.to_string());
1228 }
1229 }
1230
1231 #[test]
1232 fn decode_cmd_text_falls_back_to_raw_qail() {
1233 let decoded = decode_cmd_text("get users limit 1").unwrap();
1234 assert_eq!(decoded.action, crate::ast::Action::Get);
1235 assert_eq!(decoded.table, "users");
1236 assert!(
1237 decoded
1238 .cages
1239 .iter()
1240 .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1241 );
1242 }
1243
1244 #[test]
1245 fn cmd_text_rejects_overflowing_payload_length_without_panic() {
1246 let input = format!("{CMD_TEXT_MAGIC}\n{}\n", usize::MAX);
1247
1248 let err = decode_cmd_text(&input).expect_err("oversized length should fail closed");
1249 assert!(err.contains("payload length overflow") || err.contains("payload truncated"));
1250 }
1251
1252 #[test]
1253 fn cmds_text_rejects_oversized_command_count_without_allocation() {
1254 let input = format!("{CMDS_TEXT_MAGIC}\n{}\n", MAX_CMD_TEXT_BATCH_COMMANDS + 1);
1255
1256 let err = decode_cmds_text(&input).expect_err("oversized command count should fail closed");
1257 assert!(err.contains("command count exceeds limit"));
1258 }
1259}