1use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_AST_STRING_LEN: usize = 32 * 1024;
19const MAX_AST_VECTOR_LEN: usize = 8_192;
20const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
21
22pub fn encode_cmd_text(cmd: &Qail) -> String {
24 let payload = cmd.to_string();
25 let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
26 out.push_str(CMD_TEXT_MAGIC);
27 out.push('\n');
28 out.push_str(&payload.len().to_string());
29 out.push('\n');
30 out.push_str(&payload);
31 out
32}
33
34pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
38 let bytes = input.as_bytes();
39 let mut idx = 0usize;
40
41 let Ok(magic) = read_line(bytes, &mut idx) else {
42 return crate::parse(input).map_err(|e| e.to_string());
43 };
44
45 if magic != CMD_TEXT_MAGIC {
46 return crate::parse(input).map_err(|e| e.to_string());
47 }
48
49 let len_line = read_line(bytes, &mut idx)?;
50 let payload_len = parse_usize("payload length", len_line)?;
51 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
52 if idx != bytes.len() {
53 return Err("trailing bytes after command payload".to_string());
54 }
55
56 crate::parse(payload).map_err(|e| e.to_string())
57}
58
59pub fn encode_cmds_text(cmds: &[Qail]) -> String {
61 let mut out = String::new();
62 out.push_str(CMDS_TEXT_MAGIC);
63 out.push('\n');
64 out.push_str(&cmds.len().to_string());
65 out.push('\n');
66
67 for cmd in cmds {
68 let payload = cmd.to_string();
69 out.push_str(&payload.len().to_string());
70 out.push('\n');
71 out.push_str(&payload);
72 }
73
74 out
75}
76
77pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
79 let bytes = input.as_bytes();
80 let mut idx = 0usize;
81
82 let magic = read_line(bytes, &mut idx)?;
83 if magic != CMDS_TEXT_MAGIC {
84 return Err(format!(
85 "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
86 ));
87 }
88
89 let count_line = read_line(bytes, &mut idx)?;
90 let count = parse_usize("command count", count_line)?;
91 let mut out = Vec::with_capacity(count);
92
93 for _ in 0..count {
94 let len_line = read_line(bytes, &mut idx)?;
95 let payload_len = parse_usize("payload length", len_line)?;
96 let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
97 let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
98 out.push(cmd);
99 }
100
101 if idx != bytes.len() {
102 return Err("trailing bytes after batch payload".to_string());
103 }
104
105 Ok(out)
106}
107
108pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
110 let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
111 if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
112 return Err(format!(
113 "binary AST payload too large: {} bytes (max {})",
114 payload.len(),
115 MAX_CMD_BINARY_PAYLOAD_BYTES
116 ));
117 }
118
119 let payload_len = u32::try_from(payload.len())
120 .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
121 let mut out = Vec::with_capacity(8 + payload.len());
122 out.extend_from_slice(&CMD_BIN_MAGIC);
123 out.extend_from_slice(&payload_len.to_be_bytes());
124 out.extend_from_slice(&payload);
125 Ok(out)
126}
127
128pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
132 let payload = decode_cmd_binary_payload(input)?;
133 let mut deserializer = serde_json::Deserializer::from_slice(payload);
134 let cmd = serde::Deserialize::deserialize(&mut deserializer)
135 .map_err(|e| format!("binary AST decode failed: {e}"))?;
136 deserializer
137 .end()
138 .map_err(|_| "trailing bytes after AST payload".to_string())?;
139 validate_binary_ast_limits(&cmd)?;
140 Ok(cmd)
141}
142
143pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
147 if input.len() < 8 {
148 return Err("invalid wire header".to_string());
149 }
150 if input[0..4] != CMD_BIN_MAGIC {
151 if input[0..4] == CMD_BIN_LEGACY_MAGIC {
152 return Err(
153 "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
154 );
155 }
156 return Err("invalid wire header".to_string());
157 }
158
159 let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
160 if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
161 return Err(format!(
162 "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
163 ));
164 }
165 if input.len() != 8 + len {
166 return Err(format!(
167 "invalid payload length: header={len}, actual={}",
168 input.len().saturating_sub(8)
169 ));
170 }
171 Ok(&input[8..])
172}
173
174#[derive(Default)]
175struct AstLimitState {
176 nodes: usize,
177}
178
179impl AstLimitState {
180 fn bump(&mut self, kind: &str) -> Result<(), String> {
181 self.nodes = self
182 .nodes
183 .checked_add(1)
184 .ok_or_else(|| "AST node counter overflow".to_string())?;
185 if self.nodes > MAX_AST_NODES {
186 return Err(format!(
187 "AST node limit exceeded while walking {kind}: {} > {}",
188 self.nodes, MAX_AST_NODES
189 ));
190 }
191 Ok(())
192 }
193}
194
195fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
196 if depth > MAX_AST_DEPTH {
197 return Err(format!(
198 "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
199 ));
200 }
201 Ok(())
202}
203
204fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
205 if len > max {
206 return Err(format!("{kind} exceeds limit: {len} > {max}"));
207 }
208 Ok(())
209}
210
211fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
212 ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
213}
214
215fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
216 let mut state = AstLimitState::default();
217 validate_qail_limits(cmd, 0, &mut state)
218}
219
220fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
221 use crate::ast::GroupByMode;
222
223 ensure_depth(depth, "Qail")?;
224 state.bump("Qail")?;
225
226 ensure_str("qail.table", &cmd.table)?;
227 ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
228 for expr in &cmd.columns {
229 validate_expr_limits(expr, depth + 1, state)?;
230 }
231
232 ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
233 for join in &cmd.joins {
234 validate_join_limits(join, depth + 1, state)?;
235 }
236
237 ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
238 for cage in &cmd.cages {
239 validate_cage_limits(cage, depth + 1, state)?;
240 }
241
242 if let Some(index_def) = &cmd.index_def {
243 validate_index_def_limits(index_def)?;
244 }
245
246 ensure_len(
247 "qail.table_constraints",
248 cmd.table_constraints.len(),
249 MAX_AST_COLLECTION_LEN,
250 )?;
251 for constraint in &cmd.table_constraints {
252 match constraint {
253 crate::ast::TableConstraint::Unique(cols)
254 | crate::ast::TableConstraint::PrimaryKey(cols) => {
255 ensure_len(
256 "qail.table_constraint.columns",
257 cols.len(),
258 MAX_AST_COLLECTION_LEN,
259 )?;
260 for col in cols {
261 ensure_str("qail.table_constraint.column", col)?;
262 }
263 }
264 }
265 }
266
267 ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
268 for (_, rhs) in &cmd.set_ops {
269 validate_qail_limits(rhs, depth + 1, state)?;
270 }
271
272 ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
273 for cond in &cmd.having {
274 validate_condition_limits(cond, depth + 1, state)?;
275 }
276
277 if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
278 ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
279 for group in groups {
280 ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
281 for col in group {
282 ensure_str("qail.grouping_set.column", col)?;
283 }
284 }
285 }
286
287 ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
288 for cte in &cmd.ctes {
289 ensure_str("qail.cte.name", &cte.name)?;
290 ensure_len(
291 "qail.cte.columns",
292 cte.columns.len(),
293 MAX_AST_COLLECTION_LEN,
294 )?;
295 for col in &cte.columns {
296 ensure_str("qail.cte.column", col)?;
297 }
298 validate_qail_limits(&cte.base_query, depth + 1, state)?;
299 if let Some(recursive) = &cte.recursive_query {
300 validate_qail_limits(recursive, depth + 1, state)?;
301 }
302 if let Some(source_table) = &cte.source_table {
303 ensure_str("qail.cte.source_table", source_table)?;
304 }
305 }
306
307 ensure_len(
308 "qail.distinct_on",
309 cmd.distinct_on.len(),
310 MAX_AST_COLLECTION_LEN,
311 )?;
312 for expr in &cmd.distinct_on {
313 validate_expr_limits(expr, depth + 1, state)?;
314 }
315
316 if let Some(returning) = &cmd.returning {
317 ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
318 for expr in returning {
319 validate_expr_limits(expr, depth + 1, state)?;
320 }
321 }
322
323 if let Some(on_conflict) = &cmd.on_conflict {
324 ensure_len(
325 "qail.on_conflict.columns",
326 on_conflict.columns.len(),
327 MAX_AST_COLLECTION_LEN,
328 )?;
329 for col in &on_conflict.columns {
330 ensure_str("qail.on_conflict.column", col)?;
331 }
332 if let Some(assignments) = on_conflict.action.update_assignments() {
333 ensure_len(
334 "qail.on_conflict.assignments",
335 assignments.len(),
336 MAX_AST_COLLECTION_LEN,
337 )?;
338 for (col, expr) in assignments {
339 ensure_str("qail.on_conflict.assignment.column", col)?;
340 validate_expr_limits(expr, depth + 1, state)?;
341 }
342 }
343 }
344
345 if let Some(source_query) = &cmd.source_query {
346 validate_qail_limits(source_query, depth + 1, state)?;
347 }
348
349 if let Some(channel) = &cmd.channel {
350 ensure_str("qail.channel", channel)?;
351 }
352 if let Some(payload) = &cmd.payload {
353 ensure_str("qail.payload", payload)?;
354 }
355 if let Some(savepoint_name) = &cmd.savepoint_name {
356 ensure_str("qail.savepoint_name", savepoint_name)?;
357 }
358
359 ensure_len(
360 "qail.from_tables",
361 cmd.from_tables.len(),
362 MAX_AST_COLLECTION_LEN,
363 )?;
364 for table in &cmd.from_tables {
365 ensure_str("qail.from_table", table)?;
366 }
367
368 ensure_len(
369 "qail.using_tables",
370 cmd.using_tables.len(),
371 MAX_AST_COLLECTION_LEN,
372 )?;
373 for table in &cmd.using_tables {
374 ensure_str("qail.using_table", table)?;
375 }
376
377 if let Some((_, percent, _seed)) = cmd.sample
378 && !percent.is_finite()
379 {
380 return Err("qail.sample.percent must be finite".to_string());
381 }
382
383 if let Some(vector) = &cmd.vector {
384 ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
385 }
386 if let Some(vector_name) = &cmd.vector_name {
387 ensure_str("qail.vector_name", vector_name)?;
388 }
389 if let Some(function_def) = &cmd.function_def {
390 validate_function_def_limits(function_def)?;
391 }
392 if let Some(trigger_def) = &cmd.trigger_def {
393 validate_trigger_def_limits(trigger_def)?;
394 }
395 if let Some(policy_def) = &cmd.policy_def {
396 validate_policy_def_limits(policy_def, depth + 1, state)?;
397 }
398
399 Ok(())
400}
401
402fn validate_join_limits(
403 join: &crate::ast::Join,
404 depth: usize,
405 state: &mut AstLimitState,
406) -> Result<(), String> {
407 ensure_depth(depth, "Join")?;
408 state.bump("Join")?;
409 ensure_str("join.table", &join.table)?;
410 if let Some(on) = &join.on {
411 ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
412 for cond in on {
413 validate_condition_limits(cond, depth + 1, state)?;
414 }
415 }
416 Ok(())
417}
418
419fn validate_cage_limits(
420 cage: &crate::ast::Cage,
421 depth: usize,
422 state: &mut AstLimitState,
423) -> Result<(), String> {
424 use crate::ast::CageKind;
425
426 ensure_depth(depth, "Cage")?;
427 state.bump("Cage")?;
428 ensure_len(
429 "cage.conditions",
430 cage.conditions.len(),
431 MAX_AST_COLLECTION_LEN,
432 )?;
433 for cond in &cage.conditions {
434 validate_condition_limits(cond, depth + 1, state)?;
435 }
436 match cage.kind {
437 CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
438 ensure_len("cage.numeric", v, usize::MAX)?;
439 }
440 _ => {}
441 }
442 Ok(())
443}
444
445fn validate_condition_limits(
446 cond: &crate::ast::Condition,
447 depth: usize,
448 state: &mut AstLimitState,
449) -> Result<(), String> {
450 ensure_depth(depth, "Condition")?;
451 state.bump("Condition")?;
452 validate_expr_limits(&cond.left, depth + 1, state)?;
453 validate_value_limits(&cond.value, depth + 1, state)
454}
455
456fn validate_expr_limits(
457 expr: &crate::ast::Expr,
458 depth: usize,
459 state: &mut AstLimitState,
460) -> Result<(), String> {
461 use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
462
463 ensure_depth(depth, "Expr")?;
464 state.bump("Expr")?;
465
466 match expr {
467 Expr::Star => {}
468 Expr::Named(name) => ensure_str("expr.named", name)?,
469 Expr::Aliased { name, alias } => {
470 ensure_str("expr.aliased.name", name)?;
471 ensure_str("expr.aliased.alias", alias)?;
472 }
473 Expr::Aggregate {
474 col, filter, alias, ..
475 } => {
476 ensure_str("expr.aggregate.col", col)?;
477 if let Some(filters) = filter {
478 ensure_len(
479 "expr.aggregate.filter",
480 filters.len(),
481 MAX_AST_COLLECTION_LEN,
482 )?;
483 for cond in filters {
484 validate_condition_limits(cond, depth + 1, state)?;
485 }
486 }
487 if let Some(alias) = alias {
488 ensure_str("expr.aggregate.alias", alias)?;
489 }
490 }
491 Expr::Cast {
492 expr,
493 target_type,
494 alias,
495 } => {
496 validate_expr_limits(expr, depth + 1, state)?;
497 ensure_str("expr.cast.target_type", target_type)?;
498 if let Some(alias) = alias {
499 ensure_str("expr.cast.alias", alias)?;
500 }
501 }
502 Expr::Def {
503 name,
504 data_type,
505 constraints,
506 } => {
507 ensure_str("expr.def.name", name)?;
508 ensure_str("expr.def.data_type", data_type)?;
509 ensure_len(
510 "expr.def.constraints",
511 constraints.len(),
512 MAX_AST_COLLECTION_LEN,
513 )?;
514 for constraint in constraints {
515 match constraint {
516 Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
517 Constraint::Default(v) => ensure_str("expr.def.default", v)?,
518 Constraint::Check(values) => {
519 ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
520 for value in values {
521 ensure_str("expr.def.check.value", value)?;
522 }
523 }
524 Constraint::Comment(v) | Constraint::References(v) => {
525 ensure_str("expr.def.constraint", v)?;
526 }
527 Constraint::Generated(ColumnGeneration::Stored(v))
528 | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
529 ensure_str("expr.def.generated", v)?;
530 }
531 }
532 }
533 }
534 Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
535 Expr::Window {
536 name,
537 func,
538 params,
539 partition,
540 order,
541 frame,
542 } => {
543 ensure_str("expr.window.name", name)?;
544 ensure_str("expr.window.func", func)?;
545 ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
546 for param in params {
547 validate_expr_limits(param, depth + 1, state)?;
548 }
549 ensure_len(
550 "expr.window.partition",
551 partition.len(),
552 MAX_AST_COLLECTION_LEN,
553 )?;
554 for col in partition {
555 ensure_str("expr.window.partition.column", col)?;
556 }
557 ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
558 for cage in order {
559 validate_cage_limits(cage, depth + 1, state)?;
560 }
561 if let Some(frame) = frame {
562 match frame {
563 WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
564 }
565 }
566 }
567 Expr::Case {
568 when_clauses,
569 else_value,
570 alias,
571 } => {
572 ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
573 for (cond, then_expr) in when_clauses {
574 validate_condition_limits(cond, depth + 1, state)?;
575 validate_expr_limits(then_expr, depth + 1, state)?;
576 }
577 if let Some(else_expr) = else_value {
578 validate_expr_limits(else_expr, depth + 1, state)?;
579 }
580 if let Some(alias) = alias {
581 ensure_str("expr.case.alias", alias)?;
582 }
583 }
584 Expr::JsonAccess {
585 column,
586 path_segments,
587 alias,
588 } => {
589 ensure_str("expr.json_access.column", column)?;
590 ensure_len(
591 "expr.json_access.path_segments",
592 path_segments.len(),
593 MAX_AST_COLLECTION_LEN,
594 )?;
595 for (segment, _) in path_segments {
596 ensure_str("expr.json_access.segment", segment)?;
597 }
598 if let Some(alias) = alias {
599 ensure_str("expr.json_access.alias", alias)?;
600 }
601 }
602 Expr::FunctionCall { name, args, alias } => {
603 ensure_str("expr.function_call.name", name)?;
604 ensure_len(
605 "expr.function_call.args",
606 args.len(),
607 MAX_AST_COLLECTION_LEN,
608 )?;
609 for arg in args {
610 validate_expr_limits(arg, depth + 1, state)?;
611 }
612 if let Some(alias) = alias {
613 ensure_str("expr.function_call.alias", alias)?;
614 }
615 }
616 Expr::SpecialFunction { name, args, alias } => {
617 ensure_str("expr.special_function.name", name)?;
618 ensure_len(
619 "expr.special_function.args",
620 args.len(),
621 MAX_AST_COLLECTION_LEN,
622 )?;
623 for (keyword, arg) in args {
624 if let Some(keyword) = keyword {
625 ensure_str("expr.special_function.keyword", keyword)?;
626 }
627 validate_expr_limits(arg, depth + 1, state)?;
628 }
629 if let Some(alias) = alias {
630 ensure_str("expr.special_function.alias", alias)?;
631 }
632 }
633 Expr::Binary {
634 left, right, alias, ..
635 } => {
636 validate_expr_limits(left, depth + 1, state)?;
637 validate_expr_limits(right, depth + 1, state)?;
638 if let Some(alias) = alias {
639 ensure_str("expr.binary.alias", alias)?;
640 }
641 }
642 Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
643 Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
644 ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
645 for el in elements {
646 validate_expr_limits(el, depth + 1, state)?;
647 }
648 if let Some(alias) = alias {
649 ensure_str("expr.elements.alias", alias)?;
650 }
651 }
652 Expr::Subscript { expr, index, alias } => {
653 validate_expr_limits(expr, depth + 1, state)?;
654 validate_expr_limits(index, depth + 1, state)?;
655 if let Some(alias) = alias {
656 ensure_str("expr.subscript.alias", alias)?;
657 }
658 }
659 Expr::Collate {
660 expr,
661 collation,
662 alias,
663 } => {
664 validate_expr_limits(expr, depth + 1, state)?;
665 ensure_str("expr.collate.collation", collation)?;
666 if let Some(alias) = alias {
667 ensure_str("expr.collate.alias", alias)?;
668 }
669 }
670 Expr::FieldAccess { expr, field, alias } => {
671 validate_expr_limits(expr, depth + 1, state)?;
672 ensure_str("expr.field_access.field", field)?;
673 if let Some(alias) = alias {
674 ensure_str("expr.field_access.alias", alias)?;
675 }
676 }
677 Expr::Subquery { query, alias } => {
678 validate_qail_limits(query, depth + 1, state)?;
679 if let Some(alias) = alias {
680 ensure_str("expr.subquery.alias", alias)?;
681 }
682 }
683 Expr::Exists { query, alias, .. } => {
684 validate_qail_limits(query, depth + 1, state)?;
685 if let Some(alias) = alias {
686 ensure_str("expr.exists.alias", alias)?;
687 }
688 }
689 }
690
691 Ok(())
692}
693
694fn validate_value_limits(
695 value: &crate::ast::Value,
696 depth: usize,
697 state: &mut AstLimitState,
698) -> Result<(), String> {
699 use crate::ast::Value;
700
701 ensure_depth(depth, "Value")?;
702 state.bump("Value")?;
703
704 match value {
705 Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
706 Value::String(v)
707 | Value::NamedParam(v)
708 | Value::Function(v)
709 | Value::Column(v)
710 | Value::Timestamp(v)
711 | Value::Json(v) => ensure_str("value.string", v)?,
712 Value::Array(values) => {
713 ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
714 for v in values {
715 validate_value_limits(v, depth + 1, state)?;
716 }
717 }
718 Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
719 Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
720 Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
721 Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
722 Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
723 }
724
725 Ok(())
726}
727
728fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
729 ensure_str("index_def.name", &index_def.name)?;
730 ensure_str("index_def.table", &index_def.table)?;
731 ensure_len(
732 "index_def.columns",
733 index_def.columns.len(),
734 MAX_AST_COLLECTION_LEN,
735 )?;
736 for col in &index_def.columns {
737 ensure_str("index_def.column", col)?;
738 }
739 if let Some(index_type) = &index_def.index_type {
740 ensure_str("index_def.index_type", index_type)?;
741 }
742 if let Some(where_clause) = &index_def.where_clause {
743 ensure_str("index_def.where_clause", where_clause)?;
744 }
745 Ok(())
746}
747
748fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
749 ensure_str("function_def.name", &function_def.name)?;
750 ensure_len(
751 "function_def.args",
752 function_def.args.len(),
753 MAX_AST_COLLECTION_LEN,
754 )?;
755 for arg in &function_def.args {
756 ensure_str("function_def.arg", arg)?;
757 }
758 ensure_str("function_def.returns", &function_def.returns)?;
759 ensure_str("function_def.body", &function_def.body)?;
760 if let Some(language) = &function_def.language {
761 ensure_str("function_def.language", language)?;
762 }
763 if let Some(volatility) = &function_def.volatility {
764 ensure_str("function_def.volatility", volatility)?;
765 }
766 Ok(())
767}
768
769fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
770 ensure_str("trigger_def.name", &trigger_def.name)?;
771 ensure_str("trigger_def.table", &trigger_def.table)?;
772 ensure_len(
773 "trigger_def.events",
774 trigger_def.events.len(),
775 MAX_AST_COLLECTION_LEN,
776 )?;
777 ensure_len(
778 "trigger_def.update_columns",
779 trigger_def.update_columns.len(),
780 MAX_AST_COLLECTION_LEN,
781 )?;
782 for col in &trigger_def.update_columns {
783 ensure_str("trigger_def.update_column", col)?;
784 }
785 ensure_str(
786 "trigger_def.execute_function",
787 &trigger_def.execute_function,
788 )?;
789 Ok(())
790}
791
792fn validate_policy_def_limits(
793 policy_def: &crate::migrate::policy::RlsPolicy,
794 depth: usize,
795 state: &mut AstLimitState,
796) -> Result<(), String> {
797 ensure_str("policy_def.name", &policy_def.name)?;
798 ensure_str("policy_def.table", &policy_def.table)?;
799 if let Some(using_expr) = &policy_def.using {
800 validate_expr_limits(using_expr, depth + 1, state)?;
801 }
802 if let Some(with_check_expr) = &policy_def.with_check {
803 validate_expr_limits(with_check_expr, depth + 1, state)?;
804 }
805 if let Some(role) = &policy_def.role {
806 ensure_str("policy_def.role", role)?;
807 }
808 Ok(())
809}
810
811fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
812 if *idx >= bytes.len() {
813 return Err("unexpected EOF".to_string());
814 }
815
816 let start = *idx;
817 while *idx < bytes.len() && bytes[*idx] != b'\n' {
818 *idx += 1;
819 }
820
821 if *idx >= bytes.len() {
822 return Err("unterminated header line".to_string());
823 }
824
825 let line =
826 std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
827 *idx += 1; Ok(line)
829}
830
831fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
832 line.parse::<usize>()
833 .map_err(|_| format!("invalid {field}: {line}"))
834}
835
836fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
837 if *idx + len > bytes.len() {
838 return Err("payload truncated".to_string());
839 }
840 let start = *idx;
841 *idx += len;
842 std::str::from_utf8(&bytes[start..start + len]).map_err(|_| "payload is not UTF-8".to_string())
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use proptest::prelude::*;
849
850 #[test]
851 fn cmd_text_roundtrip() {
852 let cmd = crate::ast::Qail::get("users")
853 .columns(["id", "email"])
854 .where_eq("active", true)
855 .limit(10);
856
857 let encoded = encode_cmd_text(&cmd);
858 let decoded = decode_cmd_text(&encoded).unwrap();
859 assert_eq!(decoded.to_string(), cmd.to_string());
860 }
861
862 #[test]
863 fn cmd_binary_roundtrip() {
864 let cmd = crate::ast::Qail::set("users")
865 .set_value("active", true)
866 .where_eq("id", 7);
867
868 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
869 let decoded = decode_cmd_binary(&encoded).unwrap();
870 assert_eq!(decoded.to_string(), cmd.to_string());
871 }
872
873 #[test]
874 fn cmd_binary_payload_rejects_oversized_header() {
875 let mut payload = Vec::new();
876 payload.extend_from_slice(&CMD_BIN_MAGIC);
877 payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
878 payload.extend_from_slice(&[]);
879
880 let err = decode_cmd_binary_payload(&payload).unwrap_err();
881 assert!(err.contains("binary AST payload too large"));
882 }
883
884 #[test]
885 fn cmd_binary_payload_roundtrip() {
886 let cmd = crate::ast::Qail::get("users").limit(3);
887 let encoded = encode_cmd_binary(&cmd).expect("binary encode");
888 let payload = decode_cmd_binary_payload(&encoded).unwrap();
889 let mut deserializer = serde_json::Deserializer::from_slice(payload);
890 let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
891 deserializer.end().unwrap();
892 assert_eq!(decoded.to_string(), cmd.to_string());
893 }
894
895 #[test]
896 fn cmd_binary_payload_rejects_legacy_qwb1() {
897 let legacy_text = b"get users limit 1";
898 let mut payload = Vec::new();
899 payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
900 payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
901 payload.extend_from_slice(legacy_text);
902
903 let err = decode_cmd_binary_payload(&payload).unwrap_err();
904 assert!(err.contains("legacy QWB1"));
905 }
906
907 #[test]
908 fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
909 let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
910 assert!(err.contains("invalid wire header"));
911 }
912
913 #[test]
914 fn cmd_binary_decode_rejects_trailing_bytes() {
915 let cmd = crate::ast::Qail::get("users").limit(1);
916 let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
917 encoded.extend_from_slice(&[0xAA, 0xBB]);
918 let err = decode_cmd_binary(&encoded).unwrap_err();
919 assert!(err.contains("invalid payload length"));
920 }
921
922 #[test]
923 fn cmd_binary_decode_enforces_depth_limits() {
924 let mut nested = crate::ast::Qail::get("users").limit(1);
925 for _ in 0..(MAX_AST_DEPTH + 2) {
926 nested = crate::ast::Qail {
927 action: crate::ast::Action::Get,
928 table: "users".to_string(),
929 columns: vec![crate::ast::Expr::Subquery {
930 query: Box::new(nested),
931 alias: None,
932 }],
933 ..crate::ast::Qail::default()
934 };
935 }
936
937 let encoded = encode_cmd_binary(&nested).expect("binary encode");
938 let err = decode_cmd_binary(&encoded).unwrap_err();
939 assert!(
940 err.contains("AST depth limit exceeded")
941 || err.contains("binary AST decode failed")
942 || err.contains("recursion limit exceeded")
943 );
944 }
945
946 #[test]
947 fn cmd_binary_decode_bitflip_corpus_no_panic() {
948 let seeds = vec![
949 encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
950 encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
951 .expect("binary encode"),
952 vec![],
953 b"QWB2garbage".to_vec(),
954 vec![0u8; 32],
955 ];
956
957 for seed in seeds {
958 for i in 0..seed.len().min(128) {
959 for bit in 0..8u8 {
960 let mut mutated = seed.clone();
961 mutated[i] ^= 1 << bit;
962 let _ = decode_cmd_binary(&mutated);
963 }
964 }
965 let _ = decode_cmd_binary(&seed);
966 }
967 }
968
969 proptest! {
970 #[test]
971 fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
972 let _ = decode_cmd_binary(&data);
973 }
974 }
975
976 #[test]
977 fn cmds_text_roundtrip() {
978 let cmds = vec![
979 crate::ast::Qail::get("users").columns(["id", "email"]),
980 crate::ast::Qail::get("users").limit(1),
981 crate::ast::Qail::del("users").where_eq("id", 99),
982 ];
983
984 let encoded = encode_cmds_text(&cmds);
985 let decoded = decode_cmds_text(&encoded).unwrap();
986 assert_eq!(decoded.len(), cmds.len());
987 for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
988 assert_eq!(lhs.to_string(), rhs.to_string());
989 }
990 }
991
992 #[test]
993 fn decode_cmd_text_falls_back_to_raw_qail() {
994 let decoded = decode_cmd_text("get users limit 1").unwrap();
995 assert_eq!(decoded.action, crate::ast::Action::Get);
996 assert_eq!(decoded.table, "users");
997 assert!(
998 decoded
999 .cages
1000 .iter()
1001 .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1002 );
1003 }
1004}