1use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_CMD_TEXT_BATCH_COMMANDS: usize = MAX_AST_COLLECTION_LEN;
19const MAX_AST_STRING_LEN: usize = 32 * 1024;
20const MAX_AST_VECTOR_LEN: usize = 8_192;
21const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
22
23pub fn encode_cmd_text(cmd: &Qail) -> String {
25 let payload = cmd.to_string();
26 let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
27 out.push_str(CMD_TEXT_MAGIC);
28 out.push('\n');
29 out.push_str(&payload.len().to_string());
30 out.push('\n');
31 out.push_str(&payload);
32 out
33}
34
35pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
39 let bytes = input.as_bytes();
40 let mut idx = 0usize;
41
42 let Ok(magic) = read_line(bytes, &mut idx) else {
43 return crate::parse(input).map_err(|e| e.to_string());
44 };
45
46 if magic != CMD_TEXT_MAGIC {
47 return crate::parse(input).map_err(|e| e.to_string());
48 }
49
50 let len_line = read_line(bytes, &mut idx)?;
51 let payload_len = parse_usize("payload length", len_line)?;
52 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
53 if idx != bytes.len() {
54 return Err("trailing bytes after command payload".to_string());
55 }
56
57 crate::parse(payload).map_err(|e| e.to_string())
58}
59
60pub fn encode_cmds_text(cmds: &[Qail]) -> String {
62 let mut out = String::new();
63 out.push_str(CMDS_TEXT_MAGIC);
64 out.push('\n');
65 out.push_str(&cmds.len().to_string());
66 out.push('\n');
67
68 for cmd in cmds {
69 let payload = cmd.to_string();
70 out.push_str(&payload.len().to_string());
71 out.push('\n');
72 out.push_str(&payload);
73 }
74
75 out
76}
77
78pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
80 let bytes = input.as_bytes();
81 let mut idx = 0usize;
82
83 let magic = read_line(bytes, &mut idx)?;
84 if magic != CMDS_TEXT_MAGIC {
85 return Err(format!(
86 "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
87 ));
88 }
89
90 let count_line = read_line(bytes, &mut idx)?;
91 let count = parse_usize("command count", count_line)?;
92 if count > MAX_CMD_TEXT_BATCH_COMMANDS {
93 return Err(format!(
94 "command count exceeds limit: {count} > {MAX_CMD_TEXT_BATCH_COMMANDS}"
95 ));
96 }
97 let mut out = Vec::with_capacity(count);
98
99 for _ in 0..count {
100 let len_line = read_line(bytes, &mut idx)?;
101 let payload_len = parse_usize("payload length", len_line)?;
102 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
103 let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
104 out.push(cmd);
105 }
106
107 if idx != bytes.len() {
108 return Err("trailing bytes after batch payload".to_string());
109 }
110
111 Ok(out)
112}
113
114pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
116 let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
117 if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
118 return Err(format!(
119 "binary AST payload too large: {} bytes (max {})",
120 payload.len(),
121 MAX_CMD_BINARY_PAYLOAD_BYTES
122 ));
123 }
124
125 let payload_len = u32::try_from(payload.len())
126 .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
127 let mut out = Vec::with_capacity(8 + payload.len());
128 out.extend_from_slice(&CMD_BIN_MAGIC);
129 out.extend_from_slice(&payload_len.to_be_bytes());
130 out.extend_from_slice(&payload);
131 Ok(out)
132}
133
134pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
138 let payload = decode_cmd_binary_payload(input)?;
139 let mut deserializer = serde_json::Deserializer::from_slice(payload);
140 let cmd = serde::Deserialize::deserialize(&mut deserializer)
141 .map_err(|e| format!("binary AST decode failed: {e}"))?;
142 deserializer
143 .end()
144 .map_err(|_| "trailing bytes after AST payload".to_string())?;
145 validate_binary_ast_limits(&cmd)?;
146 Ok(cmd)
147}
148
149pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
153 if input.len() < 8 {
154 return Err("invalid wire header".to_string());
155 }
156 if input[0..4] != CMD_BIN_MAGIC {
157 if input[0..4] == CMD_BIN_LEGACY_MAGIC {
158 return Err(
159 "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
160 );
161 }
162 return Err("invalid wire header".to_string());
163 }
164
165 let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
166 if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
167 return Err(format!(
168 "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
169 ));
170 }
171 if input.len() != 8 + len {
172 return Err(format!(
173 "invalid payload length: header={len}, actual={}",
174 input.len().saturating_sub(8)
175 ));
176 }
177 Ok(&input[8..])
178}
179
180#[derive(Default)]
181struct AstLimitState {
182 nodes: usize,
183}
184
185impl AstLimitState {
186 fn bump(&mut self, kind: &str) -> Result<(), String> {
187 self.nodes = self
188 .nodes
189 .checked_add(1)
190 .ok_or_else(|| "AST node counter overflow".to_string())?;
191 if self.nodes > MAX_AST_NODES {
192 return Err(format!(
193 "AST node limit exceeded while walking {kind}: {} > {}",
194 self.nodes, MAX_AST_NODES
195 ));
196 }
197 Ok(())
198 }
199}
200
201fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
202 if depth > MAX_AST_DEPTH {
203 return Err(format!(
204 "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
205 ));
206 }
207 Ok(())
208}
209
210fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
211 if len > max {
212 return Err(format!("{kind} exceeds limit: {len} > {max}"));
213 }
214 Ok(())
215}
216
217fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
218 ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
219}
220
221fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
222 let mut state = AstLimitState::default();
223 validate_qail_limits(cmd, 0, &mut state)
224}
225
226fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
227 use crate::ast::GroupByMode;
228
229 ensure_depth(depth, "Qail")?;
230 state.bump("Qail")?;
231
232 ensure_str("qail.table", &cmd.table)?;
233 ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
234 for expr in &cmd.columns {
235 validate_expr_limits(expr, depth + 1, state)?;
236 }
237
238 ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
239 for join in &cmd.joins {
240 validate_join_limits(join, depth + 1, state)?;
241 }
242
243 ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
244 for cage in &cmd.cages {
245 validate_cage_limits(cage, depth + 1, state)?;
246 }
247
248 if let Some(index_def) = &cmd.index_def {
249 validate_index_def_limits(index_def)?;
250 }
251
252 ensure_len(
253 "qail.table_constraints",
254 cmd.table_constraints.len(),
255 MAX_AST_COLLECTION_LEN,
256 )?;
257 for constraint in &cmd.table_constraints {
258 match constraint {
259 crate::ast::TableConstraint::Unique(cols)
260 | crate::ast::TableConstraint::PrimaryKey(cols) => {
261 ensure_len(
262 "qail.table_constraint.columns",
263 cols.len(),
264 MAX_AST_COLLECTION_LEN,
265 )?;
266 for col in cols {
267 ensure_str("qail.table_constraint.column", col)?;
268 }
269 }
270 crate::ast::TableConstraint::ForeignKey {
271 name,
272 columns,
273 ref_table,
274 ref_columns,
275 on_delete,
276 on_update,
277 deferrable,
278 } => {
279 if let Some(name) = name {
280 ensure_str("qail.table_constraint.name", name)?;
281 }
282 ensure_len(
283 "qail.table_constraint.columns",
284 columns.len(),
285 MAX_AST_COLLECTION_LEN,
286 )?;
287 for col in columns {
288 ensure_str("qail.table_constraint.column", col)?;
289 }
290 ensure_str("qail.table_constraint.ref_table", ref_table)?;
291 ensure_len(
292 "qail.table_constraint.ref_columns",
293 ref_columns.len(),
294 MAX_AST_COLLECTION_LEN,
295 )?;
296 for col in ref_columns {
297 ensure_str("qail.table_constraint.ref_column", col)?;
298 }
299 if let Some(action) = on_delete {
300 ensure_str("qail.table_constraint.on_delete", action)?;
301 }
302 if let Some(action) = on_update {
303 ensure_str("qail.table_constraint.on_update", action)?;
304 }
305 if let Some(clause) = deferrable {
306 ensure_str("qail.table_constraint.deferrable", clause)?;
307 }
308 }
309 }
310 }
311
312 ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
313 for (_, rhs) in &cmd.set_ops {
314 validate_qail_limits(rhs, depth + 1, state)?;
315 }
316
317 ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
318 for cond in &cmd.having {
319 validate_condition_limits(cond, depth + 1, state)?;
320 }
321
322 if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
323 ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
324 for group in groups {
325 ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
326 for col in group {
327 ensure_str("qail.grouping_set.column", col)?;
328 }
329 }
330 }
331
332 ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
333 for cte in &cmd.ctes {
334 ensure_str("qail.cte.name", &cte.name)?;
335 ensure_len(
336 "qail.cte.columns",
337 cte.columns.len(),
338 MAX_AST_COLLECTION_LEN,
339 )?;
340 for col in &cte.columns {
341 ensure_str("qail.cte.column", col)?;
342 }
343 validate_qail_limits(&cte.base_query, depth + 1, state)?;
344 if let Some(recursive) = &cte.recursive_query {
345 validate_qail_limits(recursive, depth + 1, state)?;
346 }
347 if let Some(source_table) = &cte.source_table {
348 ensure_str("qail.cte.source_table", source_table)?;
349 }
350 }
351
352 ensure_len(
353 "qail.distinct_on",
354 cmd.distinct_on.len(),
355 MAX_AST_COLLECTION_LEN,
356 )?;
357 for expr in &cmd.distinct_on {
358 validate_expr_limits(expr, depth + 1, state)?;
359 }
360
361 if let Some(returning) = &cmd.returning {
362 ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
363 for expr in returning {
364 validate_expr_limits(expr, depth + 1, state)?;
365 }
366 }
367
368 if let Some(on_conflict) = &cmd.on_conflict {
369 ensure_len(
370 "qail.on_conflict.columns",
371 on_conflict.columns.len(),
372 MAX_AST_COLLECTION_LEN,
373 )?;
374 for col in &on_conflict.columns {
375 ensure_str("qail.on_conflict.column", col)?;
376 }
377 if let Some(assignments) = on_conflict.action.update_assignments() {
378 ensure_len(
379 "qail.on_conflict.assignments",
380 assignments.len(),
381 MAX_AST_COLLECTION_LEN,
382 )?;
383 for (col, expr) in assignments {
384 ensure_str("qail.on_conflict.assignment.column", col)?;
385 validate_expr_limits(expr, depth + 1, state)?;
386 }
387 }
388 }
389
390 if let Some(merge) = &cmd.merge {
391 if let Some(alias) = &merge.target_alias {
392 ensure_str("qail.merge.target_alias", alias)?;
393 }
394 match &merge.source {
395 crate::ast::MergeSource::Table { name, alias } => {
396 ensure_str("qail.merge.source.table", name)?;
397 if let Some(alias) = alias {
398 ensure_str("qail.merge.source.alias", alias)?;
399 }
400 }
401 crate::ast::MergeSource::Query { query, alias } => {
402 validate_qail_limits(query, depth + 1, state)?;
403 if let Some(alias) = alias {
404 ensure_str("qail.merge.source.alias", alias)?;
405 }
406 }
407 }
408 ensure_len("qail.merge.on", merge.on.len(), MAX_AST_COLLECTION_LEN)?;
409 for condition in &merge.on {
410 validate_condition_limits(condition, depth + 1, state)?;
411 }
412 ensure_len(
413 "qail.merge.clauses",
414 merge.clauses.len(),
415 MAX_AST_COLLECTION_LEN,
416 )?;
417 for clause in &merge.clauses {
418 ensure_len(
419 "qail.merge.clause.condition",
420 clause.condition.len(),
421 MAX_AST_COLLECTION_LEN,
422 )?;
423 for condition in &clause.condition {
424 validate_condition_limits(condition, depth + 1, state)?;
425 }
426 match &clause.action {
427 crate::ast::MergeAction::Update { assignments } => {
428 ensure_len(
429 "qail.merge.update.assignments",
430 assignments.len(),
431 MAX_AST_COLLECTION_LEN,
432 )?;
433 for (col, expr) in assignments {
434 ensure_str("qail.merge.update.column", col)?;
435 validate_expr_limits(expr, depth + 1, state)?;
436 }
437 }
438 crate::ast::MergeAction::Insert { columns, values } => {
439 ensure_len(
440 "qail.merge.insert.columns",
441 columns.len(),
442 MAX_AST_COLLECTION_LEN,
443 )?;
444 for col in columns {
445 ensure_str("qail.merge.insert.column", col)?;
446 }
447 ensure_len(
448 "qail.merge.insert.values",
449 values.len(),
450 MAX_AST_COLLECTION_LEN,
451 )?;
452 for expr in values {
453 validate_expr_limits(expr, depth + 1, state)?;
454 }
455 }
456 crate::ast::MergeAction::Delete | crate::ast::MergeAction::DoNothing => {}
457 }
458 }
459 }
460
461 if let Some(source_query) = &cmd.source_query {
462 validate_qail_limits(source_query, depth + 1, state)?;
463 }
464
465 if let Some(channel) = &cmd.channel {
466 ensure_str("qail.channel", channel)?;
467 }
468 if let Some(payload) = &cmd.payload {
469 ensure_str("qail.payload", payload)?;
470 }
471 if let Some(savepoint_name) = &cmd.savepoint_name {
472 ensure_str("qail.savepoint_name", savepoint_name)?;
473 }
474
475 ensure_len(
476 "qail.from_tables",
477 cmd.from_tables.len(),
478 MAX_AST_COLLECTION_LEN,
479 )?;
480 for table in &cmd.from_tables {
481 ensure_str("qail.from_table", table)?;
482 }
483
484 ensure_len(
485 "qail.using_tables",
486 cmd.using_tables.len(),
487 MAX_AST_COLLECTION_LEN,
488 )?;
489 for table in &cmd.using_tables {
490 ensure_str("qail.using_table", table)?;
491 }
492
493 if let Some((_, percent, _seed)) = cmd.sample
494 && !percent.is_finite()
495 {
496 return Err("qail.sample.percent must be finite".to_string());
497 }
498
499 if let Some(vector) = &cmd.vector {
500 ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
501 }
502 if let Some(vector_name) = &cmd.vector_name {
503 ensure_str("qail.vector_name", vector_name)?;
504 }
505 if let Some(function_def) = &cmd.function_def {
506 validate_function_def_limits(function_def)?;
507 }
508 if let Some(trigger_def) = &cmd.trigger_def {
509 validate_trigger_def_limits(trigger_def)?;
510 }
511 if let Some(policy_def) = &cmd.policy_def {
512 validate_policy_def_limits(policy_def, depth + 1, state)?;
513 }
514
515 Ok(())
516}
517
518fn validate_join_limits(
519 join: &crate::ast::Join,
520 depth: usize,
521 state: &mut AstLimitState,
522) -> Result<(), String> {
523 ensure_depth(depth, "Join")?;
524 state.bump("Join")?;
525 ensure_str("join.table", &join.table)?;
526 if let Some(on) = &join.on {
527 ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
528 for cond in on {
529 validate_condition_limits(cond, depth + 1, state)?;
530 }
531 }
532 Ok(())
533}
534
535fn validate_cage_limits(
536 cage: &crate::ast::Cage,
537 depth: usize,
538 state: &mut AstLimitState,
539) -> Result<(), String> {
540 use crate::ast::CageKind;
541
542 ensure_depth(depth, "Cage")?;
543 state.bump("Cage")?;
544 ensure_len(
545 "cage.conditions",
546 cage.conditions.len(),
547 MAX_AST_COLLECTION_LEN,
548 )?;
549 for cond in &cage.conditions {
550 validate_condition_limits(cond, depth + 1, state)?;
551 }
552 match cage.kind {
553 CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
554 ensure_len("cage.numeric", v, usize::MAX)?;
555 }
556 _ => {}
557 }
558 Ok(())
559}
560
561fn validate_condition_limits(
562 cond: &crate::ast::Condition,
563 depth: usize,
564 state: &mut AstLimitState,
565) -> Result<(), String> {
566 ensure_depth(depth, "Condition")?;
567 state.bump("Condition")?;
568 validate_expr_limits(&cond.left, depth + 1, state)?;
569 validate_value_limits(&cond.value, depth + 1, state)
570}
571
572fn validate_expr_limits(
573 expr: &crate::ast::Expr,
574 depth: usize,
575 state: &mut AstLimitState,
576) -> Result<(), String> {
577 use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
578
579 ensure_depth(depth, "Expr")?;
580 state.bump("Expr")?;
581
582 match expr {
583 Expr::Star => {}
584 Expr::Named(name) => ensure_str("expr.named", name)?,
585 Expr::Aliased { name, alias } => {
586 ensure_str("expr.aliased.name", name)?;
587 ensure_str("expr.aliased.alias", alias)?;
588 }
589 Expr::Aggregate {
590 col, filter, alias, ..
591 } => {
592 ensure_str("expr.aggregate.col", col)?;
593 if let Some(filters) = filter {
594 ensure_len(
595 "expr.aggregate.filter",
596 filters.len(),
597 MAX_AST_COLLECTION_LEN,
598 )?;
599 for cond in filters {
600 validate_condition_limits(cond, depth + 1, state)?;
601 }
602 }
603 if let Some(alias) = alias {
604 ensure_str("expr.aggregate.alias", alias)?;
605 }
606 }
607 Expr::Cast {
608 expr,
609 target_type,
610 alias,
611 } => {
612 validate_expr_limits(expr, depth + 1, state)?;
613 ensure_str("expr.cast.target_type", target_type)?;
614 if let Some(alias) = alias {
615 ensure_str("expr.cast.alias", alias)?;
616 }
617 }
618 Expr::Def {
619 name,
620 data_type,
621 constraints,
622 } => {
623 ensure_str("expr.def.name", name)?;
624 ensure_str("expr.def.data_type", data_type)?;
625 ensure_len(
626 "expr.def.constraints",
627 constraints.len(),
628 MAX_AST_COLLECTION_LEN,
629 )?;
630 for constraint in constraints {
631 match constraint {
632 Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
633 Constraint::Default(v) => ensure_str("expr.def.default", v)?,
634 Constraint::Check(values) => {
635 ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
636 for value in values {
637 ensure_str("expr.def.check.value", value)?;
638 }
639 }
640 Constraint::Comment(v) | Constraint::References(v) => {
641 ensure_str("expr.def.constraint", v)?;
642 }
643 Constraint::Generated(ColumnGeneration::Stored(v))
644 | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
645 ensure_str("expr.def.generated", v)?;
646 }
647 }
648 }
649 }
650 Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
651 Expr::Window {
652 name,
653 func,
654 params,
655 partition,
656 order,
657 frame,
658 } => {
659 ensure_str("expr.window.name", name)?;
660 ensure_str("expr.window.func", func)?;
661 ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
662 for param in params {
663 validate_expr_limits(param, depth + 1, state)?;
664 }
665 ensure_len(
666 "expr.window.partition",
667 partition.len(),
668 MAX_AST_COLLECTION_LEN,
669 )?;
670 for col in partition {
671 ensure_str("expr.window.partition.column", col)?;
672 }
673 ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
674 for cage in order {
675 validate_cage_limits(cage, depth + 1, state)?;
676 }
677 if let Some(frame) = frame {
678 match frame {
679 WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
680 }
681 }
682 }
683 Expr::Case {
684 when_clauses,
685 else_value,
686 alias,
687 } => {
688 ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
689 for (cond, then_expr) in when_clauses {
690 validate_condition_limits(cond, depth + 1, state)?;
691 validate_expr_limits(then_expr, depth + 1, state)?;
692 }
693 if let Some(else_expr) = else_value {
694 validate_expr_limits(else_expr, depth + 1, state)?;
695 }
696 if let Some(alias) = alias {
697 ensure_str("expr.case.alias", alias)?;
698 }
699 }
700 Expr::JsonAccess {
701 column,
702 path_segments,
703 alias,
704 } => {
705 ensure_str("expr.json_access.column", column)?;
706 ensure_len(
707 "expr.json_access.path_segments",
708 path_segments.len(),
709 MAX_AST_COLLECTION_LEN,
710 )?;
711 for (segment, _) in path_segments {
712 ensure_str("expr.json_access.segment", segment)?;
713 }
714 if let Some(alias) = alias {
715 ensure_str("expr.json_access.alias", alias)?;
716 }
717 }
718 Expr::FunctionCall { name, args, alias } => {
719 ensure_str("expr.function_call.name", name)?;
720 ensure_len(
721 "expr.function_call.args",
722 args.len(),
723 MAX_AST_COLLECTION_LEN,
724 )?;
725 for arg in args {
726 validate_expr_limits(arg, depth + 1, state)?;
727 }
728 if let Some(alias) = alias {
729 ensure_str("expr.function_call.alias", alias)?;
730 }
731 }
732 Expr::SpecialFunction { name, args, alias } => {
733 ensure_str("expr.special_function.name", name)?;
734 ensure_len(
735 "expr.special_function.args",
736 args.len(),
737 MAX_AST_COLLECTION_LEN,
738 )?;
739 for (keyword, arg) in args {
740 if let Some(keyword) = keyword {
741 ensure_str("expr.special_function.keyword", keyword)?;
742 }
743 validate_expr_limits(arg, depth + 1, state)?;
744 }
745 if let Some(alias) = alias {
746 ensure_str("expr.special_function.alias", alias)?;
747 }
748 }
749 Expr::Binary {
750 left, right, alias, ..
751 } => {
752 validate_expr_limits(left, depth + 1, state)?;
753 validate_expr_limits(right, depth + 1, state)?;
754 if let Some(alias) = alias {
755 ensure_str("expr.binary.alias", alias)?;
756 }
757 }
758 Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
759 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
760 ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
761 for el in elements {
762 validate_expr_limits(el, depth + 1, state)?;
763 }
764 if let Some(alias) = alias {
765 ensure_str("expr.elements.alias", alias)?;
766 }
767 }
768 Expr::Subscript { expr, index, alias } => {
769 validate_expr_limits(expr, depth + 1, state)?;
770 validate_expr_limits(index, depth + 1, state)?;
771 if let Some(alias) = alias {
772 ensure_str("expr.subscript.alias", alias)?;
773 }
774 }
775 Expr::Collate {
776 expr,
777 collation,
778 alias,
779 } => {
780 validate_expr_limits(expr, depth + 1, state)?;
781 ensure_str("expr.collate.collation", collation)?;
782 if let Some(alias) = alias {
783 ensure_str("expr.collate.alias", alias)?;
784 }
785 }
786 Expr::FieldAccess { expr, field, alias } => {
787 validate_expr_limits(expr, depth + 1, state)?;
788 ensure_str("expr.field_access.field", field)?;
789 if let Some(alias) = alias {
790 ensure_str("expr.field_access.alias", alias)?;
791 }
792 }
793 Expr::Subquery { query, alias } => {
794 validate_qail_limits(query, depth + 1, state)?;
795 if let Some(alias) = alias {
796 ensure_str("expr.subquery.alias", alias)?;
797 }
798 }
799 Expr::Exists { query, alias, .. } => {
800 validate_qail_limits(query, depth + 1, state)?;
801 if let Some(alias) = alias {
802 ensure_str("expr.exists.alias", alias)?;
803 }
804 }
805 }
806
807 Ok(())
808}
809
810fn validate_value_limits(
811 value: &crate::ast::Value,
812 depth: usize,
813 state: &mut AstLimitState,
814) -> Result<(), String> {
815 use crate::ast::Value;
816
817 ensure_depth(depth, "Value")?;
818 state.bump("Value")?;
819
820 match value {
821 Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
822 Value::String(v)
823 | Value::NamedParam(v)
824 | Value::Function(v)
825 | Value::Column(v)
826 | Value::Timestamp(v)
827 | Value::Json(v) => ensure_str("value.string", v)?,
828 Value::Array(values) => {
829 ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
830 for v in values {
831 validate_value_limits(v, depth + 1, state)?;
832 }
833 }
834 Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
835 Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
836 Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
837 Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
838 Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
839 }
840
841 Ok(())
842}
843
844fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
845 ensure_str("index_def.name", &index_def.name)?;
846 ensure_str("index_def.table", &index_def.table)?;
847 ensure_len(
848 "index_def.columns",
849 index_def.columns.len(),
850 MAX_AST_COLLECTION_LEN,
851 )?;
852 for col in &index_def.columns {
853 ensure_str("index_def.column", col)?;
854 }
855 if let Some(index_type) = &index_def.index_type {
856 ensure_str("index_def.index_type", index_type)?;
857 }
858 if let Some(where_clause) = &index_def.where_clause {
859 ensure_str("index_def.where_clause", where_clause)?;
860 }
861 Ok(())
862}
863
864fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
865 ensure_str("function_def.name", &function_def.name)?;
866 ensure_len(
867 "function_def.args",
868 function_def.args.len(),
869 MAX_AST_COLLECTION_LEN,
870 )?;
871 for arg in &function_def.args {
872 ensure_str("function_def.arg", arg)?;
873 }
874 ensure_str("function_def.returns", &function_def.returns)?;
875 ensure_str("function_def.body", &function_def.body)?;
876 if let Some(language) = &function_def.language {
877 ensure_str("function_def.language", language)?;
878 }
879 if let Some(volatility) = &function_def.volatility {
880 ensure_str("function_def.volatility", volatility)?;
881 }
882 Ok(())
883}
884
885fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
886 ensure_str("trigger_def.name", &trigger_def.name)?;
887 ensure_str("trigger_def.table", &trigger_def.table)?;
888 ensure_len(
889 "trigger_def.events",
890 trigger_def.events.len(),
891 MAX_AST_COLLECTION_LEN,
892 )?;
893 ensure_len(
894 "trigger_def.update_columns",
895 trigger_def.update_columns.len(),
896 MAX_AST_COLLECTION_LEN,
897 )?;
898 for col in &trigger_def.update_columns {
899 ensure_str("trigger_def.update_column", col)?;
900 }
901 ensure_str(
902 "trigger_def.execute_function",
903 &trigger_def.execute_function,
904 )?;
905 Ok(())
906}
907
908fn validate_policy_def_limits(
909 policy_def: &crate::migrate::policy::RlsPolicy,
910 depth: usize,
911 state: &mut AstLimitState,
912) -> Result<(), String> {
913 ensure_str("policy_def.name", &policy_def.name)?;
914 ensure_str("policy_def.table", &policy_def.table)?;
915 if let Some(using_expr) = &policy_def.using {
916 validate_expr_limits(using_expr, depth + 1, state)?;
917 }
918 if let Some(with_check_expr) = &policy_def.with_check {
919 validate_expr_limits(with_check_expr, depth + 1, state)?;
920 }
921 if let Some(role) = &policy_def.role {
922 ensure_str("policy_def.role", role)?;
923 }
924 Ok(())
925}
926
927fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
928 if *idx >= bytes.len() {
929 return Err("unexpected EOF".to_string());
930 }
931
932 let start = *idx;
933 while *idx < bytes.len() && bytes[*idx] != b'\n' {
934 *idx += 1;
935 }
936
937 if *idx >= bytes.len() {
938 return Err("unterminated header line".to_string());
939 }
940
941 let line =
942 std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
943 *idx += 1; Ok(line)
945}
946
947fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
948 line.parse::<usize>()
949 .map_err(|_| format!("invalid {field}: {line}"))
950}
951
952fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
953 let end = (*idx)
954 .checked_add(len)
955 .ok_or_else(|| "payload length overflow".to_string())?;
956 if end > bytes.len() {
957 return Err("payload truncated".to_string());
958 }
959 let start = *idx;
960 *idx = end;
961 std::str::from_utf8(&bytes[start..end]).map_err(|_| "payload is not UTF-8".to_string())
962}
963
964#[cfg(test)]
965mod tests {
966 use super::*;
967 use proptest::prelude::*;
968
969 #[test]
970 fn cmd_text_roundtrip() {
971 let cmd = crate::ast::Qail::get("users")
972 .columns(["id", "email"])
973 .where_eq("active", true)
974 .limit(10);
975
976 let encoded = encode_cmd_text(&cmd);
977 let decoded = decode_cmd_text(&encoded).unwrap();
978 assert_eq!(decoded.to_string(), cmd.to_string());
979 }
980
981 #[test]
982 fn cmd_binary_roundtrip() {
983 let cmd = crate::ast::Qail::set("users")
984 .set_value("active", true)
985 .where_eq("id", 7);
986
987 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
988 let decoded = decode_cmd_binary(&encoded).unwrap();
989 assert_eq!(decoded.to_string(), cmd.to_string());
990 }
991
992 #[test]
993 fn cmd_binary_roundtrip_preserves_merge_ast() {
994 let cmd = crate::ast::Qail::merge_into("users")
995 .target_alias("u")
996 .using_table_as("staging_users", "s")
997 .merge_on_column("u.id", crate::ast::Operator::Eq, "s.id")
998 .when_matched_update(&[("name", crate::ast::Expr::Named("s.name".to_string()))])
999 .when_not_matched_insert(
1000 &["id", "name"],
1001 &[
1002 crate::ast::Expr::Named("s.id".to_string()),
1003 crate::ast::Expr::Named("s.name".to_string()),
1004 ],
1005 );
1006
1007 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1008 let decoded = decode_cmd_binary(&encoded).unwrap();
1009 assert_eq!(decoded, cmd);
1010 }
1011
1012 #[test]
1013 fn cmd_binary_payload_rejects_oversized_header() {
1014 let mut payload = Vec::new();
1015 payload.extend_from_slice(&CMD_BIN_MAGIC);
1016 payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
1017 payload.extend_from_slice(&[]);
1018
1019 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1020 assert!(err.contains("binary AST payload too large"));
1021 }
1022
1023 #[test]
1024 fn cmd_binary_payload_roundtrip() {
1025 let cmd = crate::ast::Qail::get("users").limit(3);
1026 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1027 let payload = decode_cmd_binary_payload(&encoded).unwrap();
1028 let mut deserializer = serde_json::Deserializer::from_slice(payload);
1029 let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
1030 deserializer.end().unwrap();
1031 assert_eq!(decoded.to_string(), cmd.to_string());
1032 }
1033
1034 #[test]
1035 fn cmd_binary_payload_rejects_legacy_qwb1() {
1036 let legacy_text = b"get users limit 1";
1037 let mut payload = Vec::new();
1038 payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
1039 payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
1040 payload.extend_from_slice(legacy_text);
1041
1042 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1043 assert!(err.contains("legacy QWB1"));
1044 }
1045
1046 #[test]
1047 fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
1048 let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
1049 assert!(err.contains("invalid wire header"));
1050 }
1051
1052 #[test]
1053 fn cmd_binary_decode_rejects_trailing_bytes() {
1054 let cmd = crate::ast::Qail::get("users").limit(1);
1055 let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
1056 encoded.extend_from_slice(&[0xAA, 0xBB]);
1057 let err = decode_cmd_binary(&encoded).unwrap_err();
1058 assert!(err.contains("invalid payload length"));
1059 }
1060
1061 #[test]
1062 fn cmd_binary_decode_enforces_depth_limits() {
1063 let mut nested = crate::ast::Qail::get("users").limit(1);
1064 for _ in 0..(MAX_AST_DEPTH + 2) {
1065 nested = crate::ast::Qail {
1066 action: crate::ast::Action::Get,
1067 table: "users".to_string(),
1068 columns: vec![crate::ast::Expr::Subquery {
1069 query: Box::new(nested),
1070 alias: None,
1071 }],
1072 ..crate::ast::Qail::default()
1073 };
1074 }
1075
1076 let encoded = encode_cmd_binary(&nested).expect("binary encode");
1077 let err = decode_cmd_binary(&encoded).unwrap_err();
1078 assert!(
1079 err.contains("AST depth limit exceeded")
1080 || err.contains("binary AST decode failed")
1081 || err.contains("recursion limit exceeded")
1082 );
1083 }
1084
1085 #[test]
1086 fn cmd_binary_decode_bitflip_corpus_no_panic() {
1087 let seeds = vec![
1088 encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
1089 encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
1090 .expect("binary encode"),
1091 vec![],
1092 b"QWB2garbage".to_vec(),
1093 vec![0u8; 32],
1094 ];
1095
1096 for seed in seeds {
1097 for i in 0..seed.len().min(128) {
1098 for bit in 0..8u8 {
1099 let mut mutated = seed.clone();
1100 mutated[i] ^= 1 << bit;
1101 let _ = decode_cmd_binary(&mutated);
1102 }
1103 }
1104 let _ = decode_cmd_binary(&seed);
1105 }
1106 }
1107
1108 proptest! {
1109 #[test]
1110 fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
1111 let _ = decode_cmd_binary(&data);
1112 }
1113 }
1114
1115 #[test]
1116 fn cmds_text_roundtrip() {
1117 let cmds = vec![
1118 crate::ast::Qail::get("users").columns(["id", "email"]),
1119 crate::ast::Qail::get("users").limit(1),
1120 crate::ast::Qail::del("users").where_eq("id", 99),
1121 ];
1122
1123 let encoded = encode_cmds_text(&cmds);
1124 let decoded = decode_cmds_text(&encoded).unwrap();
1125 assert_eq!(decoded.len(), cmds.len());
1126 for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
1127 assert_eq!(lhs.to_string(), rhs.to_string());
1128 }
1129 }
1130
1131 #[test]
1132 fn decode_cmd_text_falls_back_to_raw_qail() {
1133 let decoded = decode_cmd_text("get users limit 1").unwrap();
1134 assert_eq!(decoded.action, crate::ast::Action::Get);
1135 assert_eq!(decoded.table, "users");
1136 assert!(
1137 decoded
1138 .cages
1139 .iter()
1140 .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1141 );
1142 }
1143
1144 #[test]
1145 fn cmd_text_rejects_overflowing_payload_length_without_panic() {
1146 let input = format!("{CMD_TEXT_MAGIC}\n{}\n", usize::MAX);
1147
1148 let err = decode_cmd_text(&input).expect_err("oversized length should fail closed");
1149 assert!(err.contains("payload length overflow") || err.contains("payload truncated"));
1150 }
1151
1152 #[test]
1153 fn cmds_text_rejects_oversized_command_count_without_allocation() {
1154 let input = format!("{CMDS_TEXT_MAGIC}\n{}\n", MAX_CMD_TEXT_BATCH_COMMANDS + 1);
1155
1156 let err = decode_cmds_text(&input).expect_err("oversized command count should fail closed");
1157 assert!(err.contains("command count exceeds limit"));
1158 }
1159}