1use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_CMD_TEXT_BATCH_COMMANDS: usize = MAX_AST_COLLECTION_LEN;
19const MAX_AST_STRING_LEN: usize = 32 * 1024;
20const MAX_AST_VECTOR_LEN: usize = 8_192;
21const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
22
23pub fn encode_cmd_text(cmd: &Qail) -> String {
25 let payload = cmd.to_string();
26 let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
27 out.push_str(CMD_TEXT_MAGIC);
28 out.push('\n');
29 out.push_str(&payload.len().to_string());
30 out.push('\n');
31 out.push_str(&payload);
32 out
33}
34
35pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
39 let bytes = input.as_bytes();
40 let mut idx = 0usize;
41
42 let Ok(magic) = read_line(bytes, &mut idx) else {
43 return crate::parse(input).map_err(|e| e.to_string());
44 };
45
46 if magic != CMD_TEXT_MAGIC {
47 return crate::parse(input).map_err(|e| e.to_string());
48 }
49
50 let len_line = read_line(bytes, &mut idx)?;
51 let payload_len = parse_usize("payload length", len_line)?;
52 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
53 if idx != bytes.len() {
54 return Err("trailing bytes after command payload".to_string());
55 }
56
57 crate::parse(payload).map_err(|e| e.to_string())
58}
59
60pub fn encode_cmds_text(cmds: &[Qail]) -> String {
62 let mut out = String::new();
63 out.push_str(CMDS_TEXT_MAGIC);
64 out.push('\n');
65 out.push_str(&cmds.len().to_string());
66 out.push('\n');
67
68 for cmd in cmds {
69 let payload = cmd.to_string();
70 out.push_str(&payload.len().to_string());
71 out.push('\n');
72 out.push_str(&payload);
73 }
74
75 out
76}
77
78pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
80 let bytes = input.as_bytes();
81 let mut idx = 0usize;
82
83 let magic = read_line(bytes, &mut idx)?;
84 if magic != CMDS_TEXT_MAGIC {
85 return Err(format!(
86 "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
87 ));
88 }
89
90 let count_line = read_line(bytes, &mut idx)?;
91 let count = parse_usize("command count", count_line)?;
92 if count > MAX_CMD_TEXT_BATCH_COMMANDS {
93 return Err(format!(
94 "command count exceeds limit: {count} > {MAX_CMD_TEXT_BATCH_COMMANDS}"
95 ));
96 }
97 let mut out = Vec::with_capacity(count);
98
99 for _ in 0..count {
100 let len_line = read_line(bytes, &mut idx)?;
101 let payload_len = parse_usize("payload length", len_line)?;
102 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
103 let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
104 out.push(cmd);
105 }
106
107 if idx != bytes.len() {
108 return Err("trailing bytes after batch payload".to_string());
109 }
110
111 Ok(out)
112}
113
114pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
116 let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
117 if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
118 return Err(format!(
119 "binary AST payload too large: {} bytes (max {})",
120 payload.len(),
121 MAX_CMD_BINARY_PAYLOAD_BYTES
122 ));
123 }
124
125 let payload_len = u32::try_from(payload.len())
126 .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
127 let mut out = Vec::with_capacity(8 + payload.len());
128 out.extend_from_slice(&CMD_BIN_MAGIC);
129 out.extend_from_slice(&payload_len.to_be_bytes());
130 out.extend_from_slice(&payload);
131 Ok(out)
132}
133
134pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
138 let payload = decode_cmd_binary_payload(input)?;
139 let mut deserializer = serde_json::Deserializer::from_slice(payload);
140 let cmd = serde::Deserialize::deserialize(&mut deserializer)
141 .map_err(|e| format!("binary AST decode failed: {e}"))?;
142 deserializer
143 .end()
144 .map_err(|_| "trailing bytes after AST payload".to_string())?;
145 validate_binary_ast_limits(&cmd)?;
146 Ok(cmd)
147}
148
149pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
153 if input.len() < 8 {
154 return Err("invalid wire header".to_string());
155 }
156 if input[0..4] != CMD_BIN_MAGIC {
157 if input[0..4] == CMD_BIN_LEGACY_MAGIC {
158 return Err(
159 "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
160 );
161 }
162 return Err("invalid wire header".to_string());
163 }
164
165 let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
166 if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
167 return Err(format!(
168 "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
169 ));
170 }
171 if input.len() != 8 + len {
172 return Err(format!(
173 "invalid payload length: header={len}, actual={}",
174 input.len().saturating_sub(8)
175 ));
176 }
177 Ok(&input[8..])
178}
179
180#[derive(Default)]
181struct AstLimitState {
182 nodes: usize,
183}
184
185impl AstLimitState {
186 fn bump(&mut self, kind: &str) -> Result<(), String> {
187 self.nodes = self
188 .nodes
189 .checked_add(1)
190 .ok_or_else(|| "AST node counter overflow".to_string())?;
191 if self.nodes > MAX_AST_NODES {
192 return Err(format!(
193 "AST node limit exceeded while walking {kind}: {} > {}",
194 self.nodes, MAX_AST_NODES
195 ));
196 }
197 Ok(())
198 }
199}
200
201fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
202 if depth > MAX_AST_DEPTH {
203 return Err(format!(
204 "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
205 ));
206 }
207 Ok(())
208}
209
210fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
211 if len > max {
212 return Err(format!("{kind} exceeds limit: {len} > {max}"));
213 }
214 Ok(())
215}
216
217fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
218 ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
219}
220
221fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
222 let mut state = AstLimitState::default();
223 validate_qail_limits(cmd, 0, &mut state)
224}
225
226fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
227 use crate::ast::GroupByMode;
228
229 ensure_depth(depth, "Qail")?;
230 state.bump("Qail")?;
231
232 ensure_str("qail.table", &cmd.table)?;
233 ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
234 for expr in &cmd.columns {
235 validate_expr_limits(expr, depth + 1, state)?;
236 }
237
238 ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
239 for join in &cmd.joins {
240 validate_join_limits(join, depth + 1, state)?;
241 }
242
243 ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
244 for cage in &cmd.cages {
245 validate_cage_limits(cage, depth + 1, state)?;
246 }
247
248 if let Some(index_def) = &cmd.index_def {
249 validate_index_def_limits(index_def)?;
250 }
251
252 ensure_len(
253 "qail.table_constraints",
254 cmd.table_constraints.len(),
255 MAX_AST_COLLECTION_LEN,
256 )?;
257 for constraint in &cmd.table_constraints {
258 match constraint {
259 crate::ast::TableConstraint::Unique(cols)
260 | crate::ast::TableConstraint::PrimaryKey(cols) => {
261 ensure_len(
262 "qail.table_constraint.columns",
263 cols.len(),
264 MAX_AST_COLLECTION_LEN,
265 )?;
266 for col in cols {
267 ensure_str("qail.table_constraint.column", col)?;
268 }
269 }
270 crate::ast::TableConstraint::ForeignKey {
271 name,
272 columns,
273 ref_table,
274 ref_columns,
275 } => {
276 if let Some(name) = name {
277 ensure_str("qail.table_constraint.name", name)?;
278 }
279 ensure_len(
280 "qail.table_constraint.columns",
281 columns.len(),
282 MAX_AST_COLLECTION_LEN,
283 )?;
284 for col in columns {
285 ensure_str("qail.table_constraint.column", col)?;
286 }
287 ensure_str("qail.table_constraint.ref_table", ref_table)?;
288 ensure_len(
289 "qail.table_constraint.ref_columns",
290 ref_columns.len(),
291 MAX_AST_COLLECTION_LEN,
292 )?;
293 for col in ref_columns {
294 ensure_str("qail.table_constraint.ref_column", col)?;
295 }
296 }
297 }
298 }
299
300 ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
301 for (_, rhs) in &cmd.set_ops {
302 validate_qail_limits(rhs, depth + 1, state)?;
303 }
304
305 ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
306 for cond in &cmd.having {
307 validate_condition_limits(cond, depth + 1, state)?;
308 }
309
310 if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
311 ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
312 for group in groups {
313 ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
314 for col in group {
315 ensure_str("qail.grouping_set.column", col)?;
316 }
317 }
318 }
319
320 ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
321 for cte in &cmd.ctes {
322 ensure_str("qail.cte.name", &cte.name)?;
323 ensure_len(
324 "qail.cte.columns",
325 cte.columns.len(),
326 MAX_AST_COLLECTION_LEN,
327 )?;
328 for col in &cte.columns {
329 ensure_str("qail.cte.column", col)?;
330 }
331 validate_qail_limits(&cte.base_query, depth + 1, state)?;
332 if let Some(recursive) = &cte.recursive_query {
333 validate_qail_limits(recursive, depth + 1, state)?;
334 }
335 if let Some(source_table) = &cte.source_table {
336 ensure_str("qail.cte.source_table", source_table)?;
337 }
338 }
339
340 ensure_len(
341 "qail.distinct_on",
342 cmd.distinct_on.len(),
343 MAX_AST_COLLECTION_LEN,
344 )?;
345 for expr in &cmd.distinct_on {
346 validate_expr_limits(expr, depth + 1, state)?;
347 }
348
349 if let Some(returning) = &cmd.returning {
350 ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
351 for expr in returning {
352 validate_expr_limits(expr, depth + 1, state)?;
353 }
354 }
355
356 if let Some(on_conflict) = &cmd.on_conflict {
357 ensure_len(
358 "qail.on_conflict.columns",
359 on_conflict.columns.len(),
360 MAX_AST_COLLECTION_LEN,
361 )?;
362 for col in &on_conflict.columns {
363 ensure_str("qail.on_conflict.column", col)?;
364 }
365 if let Some(assignments) = on_conflict.action.update_assignments() {
366 ensure_len(
367 "qail.on_conflict.assignments",
368 assignments.len(),
369 MAX_AST_COLLECTION_LEN,
370 )?;
371 for (col, expr) in assignments {
372 ensure_str("qail.on_conflict.assignment.column", col)?;
373 validate_expr_limits(expr, depth + 1, state)?;
374 }
375 }
376 }
377
378 if let Some(merge) = &cmd.merge {
379 if let Some(alias) = &merge.target_alias {
380 ensure_str("qail.merge.target_alias", alias)?;
381 }
382 match &merge.source {
383 crate::ast::MergeSource::Table { name, alias } => {
384 ensure_str("qail.merge.source.table", name)?;
385 if let Some(alias) = alias {
386 ensure_str("qail.merge.source.alias", alias)?;
387 }
388 }
389 crate::ast::MergeSource::Query { query, alias } => {
390 validate_qail_limits(query, depth + 1, state)?;
391 if let Some(alias) = alias {
392 ensure_str("qail.merge.source.alias", alias)?;
393 }
394 }
395 }
396 ensure_len("qail.merge.on", merge.on.len(), MAX_AST_COLLECTION_LEN)?;
397 for condition in &merge.on {
398 validate_condition_limits(condition, depth + 1, state)?;
399 }
400 ensure_len(
401 "qail.merge.clauses",
402 merge.clauses.len(),
403 MAX_AST_COLLECTION_LEN,
404 )?;
405 for clause in &merge.clauses {
406 ensure_len(
407 "qail.merge.clause.condition",
408 clause.condition.len(),
409 MAX_AST_COLLECTION_LEN,
410 )?;
411 for condition in &clause.condition {
412 validate_condition_limits(condition, depth + 1, state)?;
413 }
414 match &clause.action {
415 crate::ast::MergeAction::Update { assignments } => {
416 ensure_len(
417 "qail.merge.update.assignments",
418 assignments.len(),
419 MAX_AST_COLLECTION_LEN,
420 )?;
421 for (col, expr) in assignments {
422 ensure_str("qail.merge.update.column", col)?;
423 validate_expr_limits(expr, depth + 1, state)?;
424 }
425 }
426 crate::ast::MergeAction::Insert { columns, values } => {
427 ensure_len(
428 "qail.merge.insert.columns",
429 columns.len(),
430 MAX_AST_COLLECTION_LEN,
431 )?;
432 for col in columns {
433 ensure_str("qail.merge.insert.column", col)?;
434 }
435 ensure_len(
436 "qail.merge.insert.values",
437 values.len(),
438 MAX_AST_COLLECTION_LEN,
439 )?;
440 for expr in values {
441 validate_expr_limits(expr, depth + 1, state)?;
442 }
443 }
444 crate::ast::MergeAction::Delete | crate::ast::MergeAction::DoNothing => {}
445 }
446 }
447 }
448
449 if let Some(source_query) = &cmd.source_query {
450 validate_qail_limits(source_query, depth + 1, state)?;
451 }
452
453 if let Some(channel) = &cmd.channel {
454 ensure_str("qail.channel", channel)?;
455 }
456 if let Some(payload) = &cmd.payload {
457 ensure_str("qail.payload", payload)?;
458 }
459 if let Some(savepoint_name) = &cmd.savepoint_name {
460 ensure_str("qail.savepoint_name", savepoint_name)?;
461 }
462
463 ensure_len(
464 "qail.from_tables",
465 cmd.from_tables.len(),
466 MAX_AST_COLLECTION_LEN,
467 )?;
468 for table in &cmd.from_tables {
469 ensure_str("qail.from_table", table)?;
470 }
471
472 ensure_len(
473 "qail.using_tables",
474 cmd.using_tables.len(),
475 MAX_AST_COLLECTION_LEN,
476 )?;
477 for table in &cmd.using_tables {
478 ensure_str("qail.using_table", table)?;
479 }
480
481 if let Some((_, percent, _seed)) = cmd.sample
482 && !percent.is_finite()
483 {
484 return Err("qail.sample.percent must be finite".to_string());
485 }
486
487 if let Some(vector) = &cmd.vector {
488 ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
489 }
490 if let Some(vector_name) = &cmd.vector_name {
491 ensure_str("qail.vector_name", vector_name)?;
492 }
493 if let Some(function_def) = &cmd.function_def {
494 validate_function_def_limits(function_def)?;
495 }
496 if let Some(trigger_def) = &cmd.trigger_def {
497 validate_trigger_def_limits(trigger_def)?;
498 }
499 if let Some(policy_def) = &cmd.policy_def {
500 validate_policy_def_limits(policy_def, depth + 1, state)?;
501 }
502
503 Ok(())
504}
505
506fn validate_join_limits(
507 join: &crate::ast::Join,
508 depth: usize,
509 state: &mut AstLimitState,
510) -> Result<(), String> {
511 ensure_depth(depth, "Join")?;
512 state.bump("Join")?;
513 ensure_str("join.table", &join.table)?;
514 if let Some(on) = &join.on {
515 ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
516 for cond in on {
517 validate_condition_limits(cond, depth + 1, state)?;
518 }
519 }
520 Ok(())
521}
522
523fn validate_cage_limits(
524 cage: &crate::ast::Cage,
525 depth: usize,
526 state: &mut AstLimitState,
527) -> Result<(), String> {
528 use crate::ast::CageKind;
529
530 ensure_depth(depth, "Cage")?;
531 state.bump("Cage")?;
532 ensure_len(
533 "cage.conditions",
534 cage.conditions.len(),
535 MAX_AST_COLLECTION_LEN,
536 )?;
537 for cond in &cage.conditions {
538 validate_condition_limits(cond, depth + 1, state)?;
539 }
540 match cage.kind {
541 CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
542 ensure_len("cage.numeric", v, usize::MAX)?;
543 }
544 _ => {}
545 }
546 Ok(())
547}
548
549fn validate_condition_limits(
550 cond: &crate::ast::Condition,
551 depth: usize,
552 state: &mut AstLimitState,
553) -> Result<(), String> {
554 ensure_depth(depth, "Condition")?;
555 state.bump("Condition")?;
556 validate_expr_limits(&cond.left, depth + 1, state)?;
557 validate_value_limits(&cond.value, depth + 1, state)
558}
559
560fn validate_expr_limits(
561 expr: &crate::ast::Expr,
562 depth: usize,
563 state: &mut AstLimitState,
564) -> Result<(), String> {
565 use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
566
567 ensure_depth(depth, "Expr")?;
568 state.bump("Expr")?;
569
570 match expr {
571 Expr::Star => {}
572 Expr::Named(name) => ensure_str("expr.named", name)?,
573 Expr::Aliased { name, alias } => {
574 ensure_str("expr.aliased.name", name)?;
575 ensure_str("expr.aliased.alias", alias)?;
576 }
577 Expr::Aggregate {
578 col, filter, alias, ..
579 } => {
580 ensure_str("expr.aggregate.col", col)?;
581 if let Some(filters) = filter {
582 ensure_len(
583 "expr.aggregate.filter",
584 filters.len(),
585 MAX_AST_COLLECTION_LEN,
586 )?;
587 for cond in filters {
588 validate_condition_limits(cond, depth + 1, state)?;
589 }
590 }
591 if let Some(alias) = alias {
592 ensure_str("expr.aggregate.alias", alias)?;
593 }
594 }
595 Expr::Cast {
596 expr,
597 target_type,
598 alias,
599 } => {
600 validate_expr_limits(expr, depth + 1, state)?;
601 ensure_str("expr.cast.target_type", target_type)?;
602 if let Some(alias) = alias {
603 ensure_str("expr.cast.alias", alias)?;
604 }
605 }
606 Expr::Def {
607 name,
608 data_type,
609 constraints,
610 } => {
611 ensure_str("expr.def.name", name)?;
612 ensure_str("expr.def.data_type", data_type)?;
613 ensure_len(
614 "expr.def.constraints",
615 constraints.len(),
616 MAX_AST_COLLECTION_LEN,
617 )?;
618 for constraint in constraints {
619 match constraint {
620 Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
621 Constraint::Default(v) => ensure_str("expr.def.default", v)?,
622 Constraint::Check(values) => {
623 ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
624 for value in values {
625 ensure_str("expr.def.check.value", value)?;
626 }
627 }
628 Constraint::Comment(v) | Constraint::References(v) => {
629 ensure_str("expr.def.constraint", v)?;
630 }
631 Constraint::Generated(ColumnGeneration::Stored(v))
632 | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
633 ensure_str("expr.def.generated", v)?;
634 }
635 }
636 }
637 }
638 Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
639 Expr::Window {
640 name,
641 func,
642 params,
643 partition,
644 order,
645 frame,
646 } => {
647 ensure_str("expr.window.name", name)?;
648 ensure_str("expr.window.func", func)?;
649 ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
650 for param in params {
651 validate_expr_limits(param, depth + 1, state)?;
652 }
653 ensure_len(
654 "expr.window.partition",
655 partition.len(),
656 MAX_AST_COLLECTION_LEN,
657 )?;
658 for col in partition {
659 ensure_str("expr.window.partition.column", col)?;
660 }
661 ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
662 for cage in order {
663 validate_cage_limits(cage, depth + 1, state)?;
664 }
665 if let Some(frame) = frame {
666 match frame {
667 WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
668 }
669 }
670 }
671 Expr::Case {
672 when_clauses,
673 else_value,
674 alias,
675 } => {
676 ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
677 for (cond, then_expr) in when_clauses {
678 validate_condition_limits(cond, depth + 1, state)?;
679 validate_expr_limits(then_expr, depth + 1, state)?;
680 }
681 if let Some(else_expr) = else_value {
682 validate_expr_limits(else_expr, depth + 1, state)?;
683 }
684 if let Some(alias) = alias {
685 ensure_str("expr.case.alias", alias)?;
686 }
687 }
688 Expr::JsonAccess {
689 column,
690 path_segments,
691 alias,
692 } => {
693 ensure_str("expr.json_access.column", column)?;
694 ensure_len(
695 "expr.json_access.path_segments",
696 path_segments.len(),
697 MAX_AST_COLLECTION_LEN,
698 )?;
699 for (segment, _) in path_segments {
700 ensure_str("expr.json_access.segment", segment)?;
701 }
702 if let Some(alias) = alias {
703 ensure_str("expr.json_access.alias", alias)?;
704 }
705 }
706 Expr::FunctionCall { name, args, alias } => {
707 ensure_str("expr.function_call.name", name)?;
708 ensure_len(
709 "expr.function_call.args",
710 args.len(),
711 MAX_AST_COLLECTION_LEN,
712 )?;
713 for arg in args {
714 validate_expr_limits(arg, depth + 1, state)?;
715 }
716 if let Some(alias) = alias {
717 ensure_str("expr.function_call.alias", alias)?;
718 }
719 }
720 Expr::SpecialFunction { name, args, alias } => {
721 ensure_str("expr.special_function.name", name)?;
722 ensure_len(
723 "expr.special_function.args",
724 args.len(),
725 MAX_AST_COLLECTION_LEN,
726 )?;
727 for (keyword, arg) in args {
728 if let Some(keyword) = keyword {
729 ensure_str("expr.special_function.keyword", keyword)?;
730 }
731 validate_expr_limits(arg, depth + 1, state)?;
732 }
733 if let Some(alias) = alias {
734 ensure_str("expr.special_function.alias", alias)?;
735 }
736 }
737 Expr::Binary {
738 left, right, alias, ..
739 } => {
740 validate_expr_limits(left, depth + 1, state)?;
741 validate_expr_limits(right, depth + 1, state)?;
742 if let Some(alias) = alias {
743 ensure_str("expr.binary.alias", alias)?;
744 }
745 }
746 Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
747 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
748 ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
749 for el in elements {
750 validate_expr_limits(el, depth + 1, state)?;
751 }
752 if let Some(alias) = alias {
753 ensure_str("expr.elements.alias", alias)?;
754 }
755 }
756 Expr::Subscript { expr, index, alias } => {
757 validate_expr_limits(expr, depth + 1, state)?;
758 validate_expr_limits(index, depth + 1, state)?;
759 if let Some(alias) = alias {
760 ensure_str("expr.subscript.alias", alias)?;
761 }
762 }
763 Expr::Collate {
764 expr,
765 collation,
766 alias,
767 } => {
768 validate_expr_limits(expr, depth + 1, state)?;
769 ensure_str("expr.collate.collation", collation)?;
770 if let Some(alias) = alias {
771 ensure_str("expr.collate.alias", alias)?;
772 }
773 }
774 Expr::FieldAccess { expr, field, alias } => {
775 validate_expr_limits(expr, depth + 1, state)?;
776 ensure_str("expr.field_access.field", field)?;
777 if let Some(alias) = alias {
778 ensure_str("expr.field_access.alias", alias)?;
779 }
780 }
781 Expr::Subquery { query, alias } => {
782 validate_qail_limits(query, depth + 1, state)?;
783 if let Some(alias) = alias {
784 ensure_str("expr.subquery.alias", alias)?;
785 }
786 }
787 Expr::Exists { query, alias, .. } => {
788 validate_qail_limits(query, depth + 1, state)?;
789 if let Some(alias) = alias {
790 ensure_str("expr.exists.alias", alias)?;
791 }
792 }
793 }
794
795 Ok(())
796}
797
798fn validate_value_limits(
799 value: &crate::ast::Value,
800 depth: usize,
801 state: &mut AstLimitState,
802) -> Result<(), String> {
803 use crate::ast::Value;
804
805 ensure_depth(depth, "Value")?;
806 state.bump("Value")?;
807
808 match value {
809 Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
810 Value::String(v)
811 | Value::NamedParam(v)
812 | Value::Function(v)
813 | Value::Column(v)
814 | Value::Timestamp(v)
815 | Value::Json(v) => ensure_str("value.string", v)?,
816 Value::Array(values) => {
817 ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
818 for v in values {
819 validate_value_limits(v, depth + 1, state)?;
820 }
821 }
822 Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
823 Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
824 Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
825 Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
826 Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
827 }
828
829 Ok(())
830}
831
832fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
833 ensure_str("index_def.name", &index_def.name)?;
834 ensure_str("index_def.table", &index_def.table)?;
835 ensure_len(
836 "index_def.columns",
837 index_def.columns.len(),
838 MAX_AST_COLLECTION_LEN,
839 )?;
840 for col in &index_def.columns {
841 ensure_str("index_def.column", col)?;
842 }
843 if let Some(index_type) = &index_def.index_type {
844 ensure_str("index_def.index_type", index_type)?;
845 }
846 if let Some(where_clause) = &index_def.where_clause {
847 ensure_str("index_def.where_clause", where_clause)?;
848 }
849 Ok(())
850}
851
852fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
853 ensure_str("function_def.name", &function_def.name)?;
854 ensure_len(
855 "function_def.args",
856 function_def.args.len(),
857 MAX_AST_COLLECTION_LEN,
858 )?;
859 for arg in &function_def.args {
860 ensure_str("function_def.arg", arg)?;
861 }
862 ensure_str("function_def.returns", &function_def.returns)?;
863 ensure_str("function_def.body", &function_def.body)?;
864 if let Some(language) = &function_def.language {
865 ensure_str("function_def.language", language)?;
866 }
867 if let Some(volatility) = &function_def.volatility {
868 ensure_str("function_def.volatility", volatility)?;
869 }
870 Ok(())
871}
872
873fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
874 ensure_str("trigger_def.name", &trigger_def.name)?;
875 ensure_str("trigger_def.table", &trigger_def.table)?;
876 ensure_len(
877 "trigger_def.events",
878 trigger_def.events.len(),
879 MAX_AST_COLLECTION_LEN,
880 )?;
881 ensure_len(
882 "trigger_def.update_columns",
883 trigger_def.update_columns.len(),
884 MAX_AST_COLLECTION_LEN,
885 )?;
886 for col in &trigger_def.update_columns {
887 ensure_str("trigger_def.update_column", col)?;
888 }
889 ensure_str(
890 "trigger_def.execute_function",
891 &trigger_def.execute_function,
892 )?;
893 Ok(())
894}
895
896fn validate_policy_def_limits(
897 policy_def: &crate::migrate::policy::RlsPolicy,
898 depth: usize,
899 state: &mut AstLimitState,
900) -> Result<(), String> {
901 ensure_str("policy_def.name", &policy_def.name)?;
902 ensure_str("policy_def.table", &policy_def.table)?;
903 if let Some(using_expr) = &policy_def.using {
904 validate_expr_limits(using_expr, depth + 1, state)?;
905 }
906 if let Some(with_check_expr) = &policy_def.with_check {
907 validate_expr_limits(with_check_expr, depth + 1, state)?;
908 }
909 if let Some(role) = &policy_def.role {
910 ensure_str("policy_def.role", role)?;
911 }
912 Ok(())
913}
914
915fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
916 if *idx >= bytes.len() {
917 return Err("unexpected EOF".to_string());
918 }
919
920 let start = *idx;
921 while *idx < bytes.len() && bytes[*idx] != b'\n' {
922 *idx += 1;
923 }
924
925 if *idx >= bytes.len() {
926 return Err("unterminated header line".to_string());
927 }
928
929 let line =
930 std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
931 *idx += 1; Ok(line)
933}
934
935fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
936 line.parse::<usize>()
937 .map_err(|_| format!("invalid {field}: {line}"))
938}
939
940fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
941 let end = (*idx)
942 .checked_add(len)
943 .ok_or_else(|| "payload length overflow".to_string())?;
944 if end > bytes.len() {
945 return Err("payload truncated".to_string());
946 }
947 let start = *idx;
948 *idx = end;
949 std::str::from_utf8(&bytes[start..end]).map_err(|_| "payload is not UTF-8".to_string())
950}
951
952#[cfg(test)]
953mod tests {
954 use super::*;
955 use proptest::prelude::*;
956
957 #[test]
958 fn cmd_text_roundtrip() {
959 let cmd = crate::ast::Qail::get("users")
960 .columns(["id", "email"])
961 .where_eq("active", true)
962 .limit(10);
963
964 let encoded = encode_cmd_text(&cmd);
965 let decoded = decode_cmd_text(&encoded).unwrap();
966 assert_eq!(decoded.to_string(), cmd.to_string());
967 }
968
969 #[test]
970 fn cmd_binary_roundtrip() {
971 let cmd = crate::ast::Qail::set("users")
972 .set_value("active", true)
973 .where_eq("id", 7);
974
975 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
976 let decoded = decode_cmd_binary(&encoded).unwrap();
977 assert_eq!(decoded.to_string(), cmd.to_string());
978 }
979
980 #[test]
981 fn cmd_binary_roundtrip_preserves_merge_ast() {
982 let cmd = crate::ast::Qail::merge_into("users")
983 .target_alias("u")
984 .using_table_as("staging_users", "s")
985 .merge_on_column("u.id", crate::ast::Operator::Eq, "s.id")
986 .when_matched_update(&[("name", crate::ast::Expr::Named("s.name".to_string()))])
987 .when_not_matched_insert(
988 &["id", "name"],
989 &[
990 crate::ast::Expr::Named("s.id".to_string()),
991 crate::ast::Expr::Named("s.name".to_string()),
992 ],
993 );
994
995 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
996 let decoded = decode_cmd_binary(&encoded).unwrap();
997 assert_eq!(decoded, cmd);
998 }
999
1000 #[test]
1001 fn cmd_binary_payload_rejects_oversized_header() {
1002 let mut payload = Vec::new();
1003 payload.extend_from_slice(&CMD_BIN_MAGIC);
1004 payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
1005 payload.extend_from_slice(&[]);
1006
1007 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1008 assert!(err.contains("binary AST payload too large"));
1009 }
1010
1011 #[test]
1012 fn cmd_binary_payload_roundtrip() {
1013 let cmd = crate::ast::Qail::get("users").limit(3);
1014 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1015 let payload = decode_cmd_binary_payload(&encoded).unwrap();
1016 let mut deserializer = serde_json::Deserializer::from_slice(payload);
1017 let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
1018 deserializer.end().unwrap();
1019 assert_eq!(decoded.to_string(), cmd.to_string());
1020 }
1021
1022 #[test]
1023 fn cmd_binary_payload_rejects_legacy_qwb1() {
1024 let legacy_text = b"get users limit 1";
1025 let mut payload = Vec::new();
1026 payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
1027 payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
1028 payload.extend_from_slice(legacy_text);
1029
1030 let err = decode_cmd_binary_payload(&payload).unwrap_err();
1031 assert!(err.contains("legacy QWB1"));
1032 }
1033
1034 #[test]
1035 fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
1036 let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
1037 assert!(err.contains("invalid wire header"));
1038 }
1039
1040 #[test]
1041 fn cmd_binary_decode_rejects_trailing_bytes() {
1042 let cmd = crate::ast::Qail::get("users").limit(1);
1043 let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
1044 encoded.extend_from_slice(&[0xAA, 0xBB]);
1045 let err = decode_cmd_binary(&encoded).unwrap_err();
1046 assert!(err.contains("invalid payload length"));
1047 }
1048
1049 #[test]
1050 fn cmd_binary_decode_enforces_depth_limits() {
1051 let mut nested = crate::ast::Qail::get("users").limit(1);
1052 for _ in 0..(MAX_AST_DEPTH + 2) {
1053 nested = crate::ast::Qail {
1054 action: crate::ast::Action::Get,
1055 table: "users".to_string(),
1056 columns: vec![crate::ast::Expr::Subquery {
1057 query: Box::new(nested),
1058 alias: None,
1059 }],
1060 ..crate::ast::Qail::default()
1061 };
1062 }
1063
1064 let encoded = encode_cmd_binary(&nested).expect("binary encode");
1065 let err = decode_cmd_binary(&encoded).unwrap_err();
1066 assert!(
1067 err.contains("AST depth limit exceeded")
1068 || err.contains("binary AST decode failed")
1069 || err.contains("recursion limit exceeded")
1070 );
1071 }
1072
1073 #[test]
1074 fn cmd_binary_decode_bitflip_corpus_no_panic() {
1075 let seeds = vec![
1076 encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
1077 encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
1078 .expect("binary encode"),
1079 vec![],
1080 b"QWB2garbage".to_vec(),
1081 vec![0u8; 32],
1082 ];
1083
1084 for seed in seeds {
1085 for i in 0..seed.len().min(128) {
1086 for bit in 0..8u8 {
1087 let mut mutated = seed.clone();
1088 mutated[i] ^= 1 << bit;
1089 let _ = decode_cmd_binary(&mutated);
1090 }
1091 }
1092 let _ = decode_cmd_binary(&seed);
1093 }
1094 }
1095
1096 proptest! {
1097 #[test]
1098 fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
1099 let _ = decode_cmd_binary(&data);
1100 }
1101 }
1102
1103 #[test]
1104 fn cmds_text_roundtrip() {
1105 let cmds = vec![
1106 crate::ast::Qail::get("users").columns(["id", "email"]),
1107 crate::ast::Qail::get("users").limit(1),
1108 crate::ast::Qail::del("users").where_eq("id", 99),
1109 ];
1110
1111 let encoded = encode_cmds_text(&cmds);
1112 let decoded = decode_cmds_text(&encoded).unwrap();
1113 assert_eq!(decoded.len(), cmds.len());
1114 for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
1115 assert_eq!(lhs.to_string(), rhs.to_string());
1116 }
1117 }
1118
1119 #[test]
1120 fn decode_cmd_text_falls_back_to_raw_qail() {
1121 let decoded = decode_cmd_text("get users limit 1").unwrap();
1122 assert_eq!(decoded.action, crate::ast::Action::Get);
1123 assert_eq!(decoded.table, "users");
1124 assert!(
1125 decoded
1126 .cages
1127 .iter()
1128 .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1129 );
1130 }
1131
1132 #[test]
1133 fn cmd_text_rejects_overflowing_payload_length_without_panic() {
1134 let input = format!("{CMD_TEXT_MAGIC}\n{}\n", usize::MAX);
1135
1136 let err = decode_cmd_text(&input).expect_err("oversized length should fail closed");
1137 assert!(err.contains("payload length overflow") || err.contains("payload truncated"));
1138 }
1139
1140 #[test]
1141 fn cmds_text_rejects_oversized_command_count_without_allocation() {
1142 let input = format!("{CMDS_TEXT_MAGIC}\n{}\n", MAX_CMD_TEXT_BATCH_COMMANDS + 1);
1143
1144 let err = decode_cmds_text(&input).expect_err("oversized command count should fail closed");
1145 assert!(err.contains("command count exceeds limit"));
1146 }
1147}