Skip to main content

qail_core/
wire.rs

1//! QAIL wire codecs for command transport.
2//!
3//! - Text codecs (`QAIL-CMD/1`, `QAIL-CMDS/1`) round-trip through canonical text.
4//! - Binary codec (`QWB2`) transports framed AST bytes directly.
5
6use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13/// Maximum allowed QWB2 payload size (bytes).
14pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_CMD_TEXT_BATCH_COMMANDS: usize = MAX_AST_COLLECTION_LEN;
19const MAX_AST_STRING_LEN: usize = 32 * 1024;
20const MAX_AST_VECTOR_LEN: usize = 8_192;
21const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
22
23/// Encode one command into versioned text wire format.
24pub fn encode_cmd_text(cmd: &Qail) -> String {
25    let payload = cmd.to_string();
26    let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
27    out.push_str(CMD_TEXT_MAGIC);
28    out.push('\n');
29    out.push_str(&payload.len().to_string());
30    out.push('\n');
31    out.push_str(&payload);
32    out
33}
34
35/// Decode one command from text wire format.
36///
37/// Also accepts raw QAIL query text as fallback for convenience.
38pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
39    let bytes = input.as_bytes();
40    let mut idx = 0usize;
41
42    let Ok(magic) = read_line(bytes, &mut idx) else {
43        return crate::parse(input).map_err(|e| e.to_string());
44    };
45
46    if magic != CMD_TEXT_MAGIC {
47        return crate::parse(input).map_err(|e| e.to_string());
48    }
49
50    let len_line = read_line(bytes, &mut idx)?;
51    let payload_len = parse_usize("payload length", len_line)?;
52    let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
53    if idx != bytes.len() {
54        return Err("trailing bytes after command payload".to_string());
55    }
56
57    crate::parse(payload).map_err(|e| e.to_string())
58}
59
60/// Encode multiple commands into versioned text wire format.
61pub fn encode_cmds_text(cmds: &[Qail]) -> String {
62    let mut out = String::new();
63    out.push_str(CMDS_TEXT_MAGIC);
64    out.push('\n');
65    out.push_str(&cmds.len().to_string());
66    out.push('\n');
67
68    for cmd in cmds {
69        let payload = cmd.to_string();
70        out.push_str(&payload.len().to_string());
71        out.push('\n');
72        out.push_str(&payload);
73    }
74
75    out
76}
77
78/// Decode multiple commands from text wire format.
79pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
80    let bytes = input.as_bytes();
81    let mut idx = 0usize;
82
83    let magic = read_line(bytes, &mut idx)?;
84    if magic != CMDS_TEXT_MAGIC {
85        return Err(format!(
86            "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
87        ));
88    }
89
90    let count_line = read_line(bytes, &mut idx)?;
91    let count = parse_usize("command count", count_line)?;
92    if count > MAX_CMD_TEXT_BATCH_COMMANDS {
93        return Err(format!(
94            "command count exceeds limit: {count} > {MAX_CMD_TEXT_BATCH_COMMANDS}"
95        ));
96    }
97    let mut out = Vec::with_capacity(count);
98
99    for _ in 0..count {
100        let len_line = read_line(bytes, &mut idx)?;
101        let payload_len = parse_usize("payload length", len_line)?;
102        let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
103        let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
104        out.push(cmd);
105    }
106
107    if idx != bytes.len() {
108        return Err("trailing bytes after batch payload".to_string());
109    }
110
111    Ok(out)
112}
113
114/// Encode one command into compact binary wire format (QWB2 AST binary).
115pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
116    let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
117    if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
118        return Err(format!(
119            "binary AST payload too large: {} bytes (max {})",
120            payload.len(),
121            MAX_CMD_BINARY_PAYLOAD_BYTES
122        ));
123    }
124
125    let payload_len = u32::try_from(payload.len())
126        .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
127    let mut out = Vec::with_capacity(8 + payload.len());
128    out.extend_from_slice(&CMD_BIN_MAGIC);
129    out.extend_from_slice(&payload_len.to_be_bytes());
130    out.extend_from_slice(&payload);
131    Ok(out)
132}
133
134/// Decode one command from strict QWB2 AST-binary wire format.
135///
136/// This path rejects legacy QWB1/raw-text payloads.
137pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
138    let payload = decode_cmd_binary_payload(input)?;
139    let mut deserializer = serde_json::Deserializer::from_slice(payload);
140    let cmd = serde::Deserialize::deserialize(&mut deserializer)
141        .map_err(|e| format!("binary AST decode failed: {e}"))?;
142    deserializer
143        .end()
144        .map_err(|_| "trailing bytes after AST payload".to_string())?;
145    validate_binary_ast_limits(&cmd)?;
146    Ok(cmd)
147}
148
149/// Decode and validate strict QWB2-framed payload bytes.
150///
151/// This validates framing and payload-size limits only.
152pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
153    if input.len() < 8 {
154        return Err("invalid wire header".to_string());
155    }
156    if input[0..4] != CMD_BIN_MAGIC {
157        if input[0..4] == CMD_BIN_LEGACY_MAGIC {
158            return Err(
159                "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
160            );
161        }
162        return Err("invalid wire header".to_string());
163    }
164
165    let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
166    if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
167        return Err(format!(
168            "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
169        ));
170    }
171    if input.len() != 8 + len {
172        return Err(format!(
173            "invalid payload length: header={len}, actual={}",
174            input.len().saturating_sub(8)
175        ));
176    }
177    Ok(&input[8..])
178}
179
180#[derive(Default)]
181struct AstLimitState {
182    nodes: usize,
183}
184
185impl AstLimitState {
186    fn bump(&mut self, kind: &str) -> Result<(), String> {
187        self.nodes = self
188            .nodes
189            .checked_add(1)
190            .ok_or_else(|| "AST node counter overflow".to_string())?;
191        if self.nodes > MAX_AST_NODES {
192            return Err(format!(
193                "AST node limit exceeded while walking {kind}: {} > {}",
194                self.nodes, MAX_AST_NODES
195            ));
196        }
197        Ok(())
198    }
199}
200
201fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
202    if depth > MAX_AST_DEPTH {
203        return Err(format!(
204            "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
205        ));
206    }
207    Ok(())
208}
209
210fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
211    if len > max {
212        return Err(format!("{kind} exceeds limit: {len} > {max}"));
213    }
214    Ok(())
215}
216
217fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
218    ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
219}
220
221fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
222    let mut state = AstLimitState::default();
223    validate_qail_limits(cmd, 0, &mut state)
224}
225
226fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
227    use crate::ast::GroupByMode;
228
229    ensure_depth(depth, "Qail")?;
230    state.bump("Qail")?;
231
232    ensure_str("qail.table", &cmd.table)?;
233    ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
234    for expr in &cmd.columns {
235        validate_expr_limits(expr, depth + 1, state)?;
236    }
237
238    ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
239    for join in &cmd.joins {
240        validate_join_limits(join, depth + 1, state)?;
241    }
242
243    ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
244    for cage in &cmd.cages {
245        validate_cage_limits(cage, depth + 1, state)?;
246    }
247
248    if let Some(index_def) = &cmd.index_def {
249        validate_index_def_limits(index_def)?;
250    }
251
252    ensure_len(
253        "qail.table_constraints",
254        cmd.table_constraints.len(),
255        MAX_AST_COLLECTION_LEN,
256    )?;
257    for constraint in &cmd.table_constraints {
258        match constraint {
259            crate::ast::TableConstraint::Unique(cols)
260            | crate::ast::TableConstraint::PrimaryKey(cols) => {
261                ensure_len(
262                    "qail.table_constraint.columns",
263                    cols.len(),
264                    MAX_AST_COLLECTION_LEN,
265                )?;
266                for col in cols {
267                    ensure_str("qail.table_constraint.column", col)?;
268                }
269            }
270            crate::ast::TableConstraint::ForeignKey {
271                name,
272                columns,
273                ref_table,
274                ref_columns,
275                on_delete,
276                on_update,
277                deferrable,
278            } => {
279                if let Some(name) = name {
280                    ensure_str("qail.table_constraint.name", name)?;
281                }
282                ensure_len(
283                    "qail.table_constraint.columns",
284                    columns.len(),
285                    MAX_AST_COLLECTION_LEN,
286                )?;
287                for col in columns {
288                    ensure_str("qail.table_constraint.column", col)?;
289                }
290                ensure_str("qail.table_constraint.ref_table", ref_table)?;
291                ensure_len(
292                    "qail.table_constraint.ref_columns",
293                    ref_columns.len(),
294                    MAX_AST_COLLECTION_LEN,
295                )?;
296                for col in ref_columns {
297                    ensure_str("qail.table_constraint.ref_column", col)?;
298                }
299                if let Some(action) = on_delete {
300                    ensure_str("qail.table_constraint.on_delete", action)?;
301                }
302                if let Some(action) = on_update {
303                    ensure_str("qail.table_constraint.on_update", action)?;
304                }
305                if let Some(clause) = deferrable {
306                    ensure_str("qail.table_constraint.deferrable", clause)?;
307                }
308            }
309        }
310    }
311
312    ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
313    for (_, rhs) in &cmd.set_ops {
314        validate_qail_limits(rhs, depth + 1, state)?;
315    }
316
317    ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
318    for cond in &cmd.having {
319        validate_condition_limits(cond, depth + 1, state)?;
320    }
321
322    if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
323        ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
324        for group in groups {
325            ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
326            for col in group {
327                ensure_str("qail.grouping_set.column", col)?;
328            }
329        }
330    }
331
332    ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
333    for cte in &cmd.ctes {
334        ensure_str("qail.cte.name", &cte.name)?;
335        ensure_len(
336            "qail.cte.columns",
337            cte.columns.len(),
338            MAX_AST_COLLECTION_LEN,
339        )?;
340        for col in &cte.columns {
341            ensure_str("qail.cte.column", col)?;
342        }
343        validate_qail_limits(&cte.base_query, depth + 1, state)?;
344        if let Some(recursive) = &cte.recursive_query {
345            validate_qail_limits(recursive, depth + 1, state)?;
346        }
347        if let Some(source_table) = &cte.source_table {
348            ensure_str("qail.cte.source_table", source_table)?;
349        }
350    }
351
352    ensure_len(
353        "qail.distinct_on",
354        cmd.distinct_on.len(),
355        MAX_AST_COLLECTION_LEN,
356    )?;
357    for expr in &cmd.distinct_on {
358        validate_expr_limits(expr, depth + 1, state)?;
359    }
360
361    if let Some(returning) = &cmd.returning {
362        ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
363        for expr in returning {
364            validate_expr_limits(expr, depth + 1, state)?;
365        }
366    }
367
368    if let Some(on_conflict) = &cmd.on_conflict {
369        ensure_len(
370            "qail.on_conflict.columns",
371            on_conflict.columns.len(),
372            MAX_AST_COLLECTION_LEN,
373        )?;
374        for col in &on_conflict.columns {
375            ensure_str("qail.on_conflict.column", col)?;
376        }
377        if let Some(assignments) = on_conflict.action.update_assignments() {
378            ensure_len(
379                "qail.on_conflict.assignments",
380                assignments.len(),
381                MAX_AST_COLLECTION_LEN,
382            )?;
383            for (col, expr) in assignments {
384                ensure_str("qail.on_conflict.assignment.column", col)?;
385                validate_expr_limits(expr, depth + 1, state)?;
386            }
387        }
388    }
389
390    if let Some(merge) = &cmd.merge {
391        if let Some(alias) = &merge.target_alias {
392            ensure_str("qail.merge.target_alias", alias)?;
393        }
394        match &merge.source {
395            crate::ast::MergeSource::Table { name, alias } => {
396                ensure_str("qail.merge.source.table", name)?;
397                if let Some(alias) = alias {
398                    ensure_str("qail.merge.source.alias", alias)?;
399                }
400            }
401            crate::ast::MergeSource::Query { query, alias } => {
402                validate_qail_limits(query, depth + 1, state)?;
403                if let Some(alias) = alias {
404                    ensure_str("qail.merge.source.alias", alias)?;
405                }
406            }
407        }
408        ensure_len("qail.merge.on", merge.on.len(), MAX_AST_COLLECTION_LEN)?;
409        for condition in &merge.on {
410            validate_condition_limits(condition, depth + 1, state)?;
411        }
412        ensure_len(
413            "qail.merge.clauses",
414            merge.clauses.len(),
415            MAX_AST_COLLECTION_LEN,
416        )?;
417        for clause in &merge.clauses {
418            ensure_len(
419                "qail.merge.clause.condition",
420                clause.condition.len(),
421                MAX_AST_COLLECTION_LEN,
422            )?;
423            for condition in &clause.condition {
424                validate_condition_limits(condition, depth + 1, state)?;
425            }
426            match &clause.action {
427                crate::ast::MergeAction::Update { assignments } => {
428                    ensure_len(
429                        "qail.merge.update.assignments",
430                        assignments.len(),
431                        MAX_AST_COLLECTION_LEN,
432                    )?;
433                    for (col, expr) in assignments {
434                        ensure_str("qail.merge.update.column", col)?;
435                        validate_expr_limits(expr, depth + 1, state)?;
436                    }
437                }
438                crate::ast::MergeAction::Insert { columns, values } => {
439                    ensure_len(
440                        "qail.merge.insert.columns",
441                        columns.len(),
442                        MAX_AST_COLLECTION_LEN,
443                    )?;
444                    for col in columns {
445                        ensure_str("qail.merge.insert.column", col)?;
446                    }
447                    ensure_len(
448                        "qail.merge.insert.values",
449                        values.len(),
450                        MAX_AST_COLLECTION_LEN,
451                    )?;
452                    for expr in values {
453                        validate_expr_limits(expr, depth + 1, state)?;
454                    }
455                }
456                crate::ast::MergeAction::Delete | crate::ast::MergeAction::DoNothing => {}
457            }
458        }
459    }
460
461    if let Some(source_query) = &cmd.source_query {
462        validate_qail_limits(source_query, depth + 1, state)?;
463    }
464
465    if let Some(channel) = &cmd.channel {
466        ensure_str("qail.channel", channel)?;
467    }
468    if let Some(payload) = &cmd.payload {
469        ensure_str("qail.payload", payload)?;
470    }
471    if let Some(savepoint_name) = &cmd.savepoint_name {
472        ensure_str("qail.savepoint_name", savepoint_name)?;
473    }
474
475    ensure_len(
476        "qail.from_tables",
477        cmd.from_tables.len(),
478        MAX_AST_COLLECTION_LEN,
479    )?;
480    for table in &cmd.from_tables {
481        ensure_str("qail.from_table", table)?;
482    }
483
484    ensure_len(
485        "qail.using_tables",
486        cmd.using_tables.len(),
487        MAX_AST_COLLECTION_LEN,
488    )?;
489    for table in &cmd.using_tables {
490        ensure_str("qail.using_table", table)?;
491    }
492
493    if let Some((_, percent, _seed)) = cmd.sample
494        && !percent.is_finite()
495    {
496        return Err("qail.sample.percent must be finite".to_string());
497    }
498
499    if let Some(vector) = &cmd.vector {
500        ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
501    }
502    if let Some(vector_name) = &cmd.vector_name {
503        ensure_str("qail.vector_name", vector_name)?;
504    }
505    if let Some(function_def) = &cmd.function_def {
506        validate_function_def_limits(function_def)?;
507    }
508    if let Some(trigger_def) = &cmd.trigger_def {
509        validate_trigger_def_limits(trigger_def)?;
510    }
511    if let Some(policy_def) = &cmd.policy_def {
512        validate_policy_def_limits(policy_def, depth + 1, state)?;
513    }
514
515    Ok(())
516}
517
518fn validate_join_limits(
519    join: &crate::ast::Join,
520    depth: usize,
521    state: &mut AstLimitState,
522) -> Result<(), String> {
523    ensure_depth(depth, "Join")?;
524    state.bump("Join")?;
525    ensure_str("join.table", &join.table)?;
526    if let Some(on) = &join.on {
527        ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
528        for cond in on {
529            validate_condition_limits(cond, depth + 1, state)?;
530        }
531    }
532    Ok(())
533}
534
535fn validate_cage_limits(
536    cage: &crate::ast::Cage,
537    depth: usize,
538    state: &mut AstLimitState,
539) -> Result<(), String> {
540    use crate::ast::CageKind;
541
542    ensure_depth(depth, "Cage")?;
543    state.bump("Cage")?;
544    ensure_len(
545        "cage.conditions",
546        cage.conditions.len(),
547        MAX_AST_COLLECTION_LEN,
548    )?;
549    for cond in &cage.conditions {
550        validate_condition_limits(cond, depth + 1, state)?;
551    }
552    match cage.kind {
553        CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
554            ensure_len("cage.numeric", v, usize::MAX)?;
555        }
556        _ => {}
557    }
558    Ok(())
559}
560
561fn validate_condition_limits(
562    cond: &crate::ast::Condition,
563    depth: usize,
564    state: &mut AstLimitState,
565) -> Result<(), String> {
566    ensure_depth(depth, "Condition")?;
567    state.bump("Condition")?;
568    validate_expr_limits(&cond.left, depth + 1, state)?;
569    validate_value_limits(&cond.value, depth + 1, state)
570}
571
572fn validate_expr_limits(
573    expr: &crate::ast::Expr,
574    depth: usize,
575    state: &mut AstLimitState,
576) -> Result<(), String> {
577    use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
578
579    ensure_depth(depth, "Expr")?;
580    state.bump("Expr")?;
581
582    match expr {
583        Expr::Star => {}
584        Expr::Named(name) => ensure_str("expr.named", name)?,
585        Expr::Aliased { name, alias } => {
586            ensure_str("expr.aliased.name", name)?;
587            ensure_str("expr.aliased.alias", alias)?;
588        }
589        Expr::Aggregate {
590            col, filter, alias, ..
591        } => {
592            ensure_str("expr.aggregate.col", col)?;
593            if let Some(filters) = filter {
594                ensure_len(
595                    "expr.aggregate.filter",
596                    filters.len(),
597                    MAX_AST_COLLECTION_LEN,
598                )?;
599                for cond in filters {
600                    validate_condition_limits(cond, depth + 1, state)?;
601                }
602            }
603            if let Some(alias) = alias {
604                ensure_str("expr.aggregate.alias", alias)?;
605            }
606        }
607        Expr::Cast {
608            expr,
609            target_type,
610            alias,
611        } => {
612            validate_expr_limits(expr, depth + 1, state)?;
613            ensure_str("expr.cast.target_type", target_type)?;
614            if let Some(alias) = alias {
615                ensure_str("expr.cast.alias", alias)?;
616            }
617        }
618        Expr::Def {
619            name,
620            data_type,
621            constraints,
622        } => {
623            ensure_str("expr.def.name", name)?;
624            ensure_str("expr.def.data_type", data_type)?;
625            ensure_len(
626                "expr.def.constraints",
627                constraints.len(),
628                MAX_AST_COLLECTION_LEN,
629            )?;
630            for constraint in constraints {
631                match constraint {
632                    Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
633                    Constraint::Default(v) => ensure_str("expr.def.default", v)?,
634                    Constraint::Check(values) => {
635                        ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
636                        for value in values {
637                            ensure_str("expr.def.check.value", value)?;
638                        }
639                    }
640                    Constraint::Comment(v) | Constraint::References(v) => {
641                        ensure_str("expr.def.constraint", v)?;
642                    }
643                    Constraint::Generated(ColumnGeneration::Stored(v))
644                    | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
645                        ensure_str("expr.def.generated", v)?;
646                    }
647                }
648            }
649        }
650        Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
651        Expr::Window {
652            name,
653            func,
654            params,
655            partition,
656            order,
657            frame,
658        } => {
659            ensure_str("expr.window.name", name)?;
660            ensure_str("expr.window.func", func)?;
661            ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
662            for param in params {
663                validate_expr_limits(param, depth + 1, state)?;
664            }
665            ensure_len(
666                "expr.window.partition",
667                partition.len(),
668                MAX_AST_COLLECTION_LEN,
669            )?;
670            for col in partition {
671                ensure_str("expr.window.partition.column", col)?;
672            }
673            ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
674            for cage in order {
675                validate_cage_limits(cage, depth + 1, state)?;
676            }
677            if let Some(frame) = frame {
678                match frame {
679                    WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
680                }
681            }
682        }
683        Expr::Case {
684            when_clauses,
685            else_value,
686            alias,
687        } => {
688            ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
689            for (cond, then_expr) in when_clauses {
690                validate_condition_limits(cond, depth + 1, state)?;
691                validate_expr_limits(then_expr, depth + 1, state)?;
692            }
693            if let Some(else_expr) = else_value {
694                validate_expr_limits(else_expr, depth + 1, state)?;
695            }
696            if let Some(alias) = alias {
697                ensure_str("expr.case.alias", alias)?;
698            }
699        }
700        Expr::JsonAccess {
701            column,
702            path_segments,
703            alias,
704        } => {
705            ensure_str("expr.json_access.column", column)?;
706            ensure_len(
707                "expr.json_access.path_segments",
708                path_segments.len(),
709                MAX_AST_COLLECTION_LEN,
710            )?;
711            for (segment, _) in path_segments {
712                ensure_str("expr.json_access.segment", segment)?;
713            }
714            if let Some(alias) = alias {
715                ensure_str("expr.json_access.alias", alias)?;
716            }
717        }
718        Expr::FunctionCall { name, args, alias } => {
719            ensure_str("expr.function_call.name", name)?;
720            ensure_len(
721                "expr.function_call.args",
722                args.len(),
723                MAX_AST_COLLECTION_LEN,
724            )?;
725            for arg in args {
726                validate_expr_limits(arg, depth + 1, state)?;
727            }
728            if let Some(alias) = alias {
729                ensure_str("expr.function_call.alias", alias)?;
730            }
731        }
732        Expr::SpecialFunction { name, args, alias } => {
733            ensure_str("expr.special_function.name", name)?;
734            ensure_len(
735                "expr.special_function.args",
736                args.len(),
737                MAX_AST_COLLECTION_LEN,
738            )?;
739            for (keyword, arg) in args {
740                if let Some(keyword) = keyword {
741                    ensure_str("expr.special_function.keyword", keyword)?;
742                }
743                validate_expr_limits(arg, depth + 1, state)?;
744            }
745            if let Some(alias) = alias {
746                ensure_str("expr.special_function.alias", alias)?;
747            }
748        }
749        Expr::Binary {
750            left, right, alias, ..
751        } => {
752            validate_expr_limits(left, depth + 1, state)?;
753            validate_expr_limits(right, depth + 1, state)?;
754            if let Some(alias) = alias {
755                ensure_str("expr.binary.alias", alias)?;
756            }
757        }
758        Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
759        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
760            ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
761            for el in elements {
762                validate_expr_limits(el, depth + 1, state)?;
763            }
764            if let Some(alias) = alias {
765                ensure_str("expr.elements.alias", alias)?;
766            }
767        }
768        Expr::Subscript { expr, index, alias } => {
769            validate_expr_limits(expr, depth + 1, state)?;
770            validate_expr_limits(index, depth + 1, state)?;
771            if let Some(alias) = alias {
772                ensure_str("expr.subscript.alias", alias)?;
773            }
774        }
775        Expr::Collate {
776            expr,
777            collation,
778            alias,
779        } => {
780            validate_expr_limits(expr, depth + 1, state)?;
781            ensure_str("expr.collate.collation", collation)?;
782            if let Some(alias) = alias {
783                ensure_str("expr.collate.alias", alias)?;
784            }
785        }
786        Expr::FieldAccess { expr, field, alias } => {
787            validate_expr_limits(expr, depth + 1, state)?;
788            ensure_str("expr.field_access.field", field)?;
789            if let Some(alias) = alias {
790                ensure_str("expr.field_access.alias", alias)?;
791            }
792        }
793        Expr::Subquery { query, alias } => {
794            validate_qail_limits(query, depth + 1, state)?;
795            if let Some(alias) = alias {
796                ensure_str("expr.subquery.alias", alias)?;
797            }
798        }
799        Expr::Exists { query, alias, .. } => {
800            validate_qail_limits(query, depth + 1, state)?;
801            if let Some(alias) = alias {
802                ensure_str("expr.exists.alias", alias)?;
803            }
804        }
805    }
806
807    Ok(())
808}
809
810fn validate_value_limits(
811    value: &crate::ast::Value,
812    depth: usize,
813    state: &mut AstLimitState,
814) -> Result<(), String> {
815    use crate::ast::Value;
816
817    ensure_depth(depth, "Value")?;
818    state.bump("Value")?;
819
820    match value {
821        Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
822        Value::String(v)
823        | Value::NamedParam(v)
824        | Value::Function(v)
825        | Value::Column(v)
826        | Value::Timestamp(v)
827        | Value::Json(v) => ensure_str("value.string", v)?,
828        Value::Array(values) => {
829            ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
830            for v in values {
831                validate_value_limits(v, depth + 1, state)?;
832            }
833        }
834        Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
835        Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
836        Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
837        Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
838        Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
839    }
840
841    Ok(())
842}
843
844fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
845    ensure_str("index_def.name", &index_def.name)?;
846    ensure_str("index_def.table", &index_def.table)?;
847    ensure_len(
848        "index_def.columns",
849        index_def.columns.len(),
850        MAX_AST_COLLECTION_LEN,
851    )?;
852    for col in &index_def.columns {
853        ensure_str("index_def.column", col)?;
854    }
855    if let Some(index_type) = &index_def.index_type {
856        ensure_str("index_def.index_type", index_type)?;
857    }
858    if let Some(where_clause) = &index_def.where_clause {
859        ensure_str("index_def.where_clause", where_clause)?;
860    }
861    Ok(())
862}
863
864fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
865    ensure_str("function_def.name", &function_def.name)?;
866    ensure_len(
867        "function_def.args",
868        function_def.args.len(),
869        MAX_AST_COLLECTION_LEN,
870    )?;
871    for arg in &function_def.args {
872        ensure_str("function_def.arg", arg)?;
873    }
874    ensure_str("function_def.returns", &function_def.returns)?;
875    ensure_str("function_def.body", &function_def.body)?;
876    if let Some(language) = &function_def.language {
877        ensure_str("function_def.language", language)?;
878    }
879    if let Some(volatility) = &function_def.volatility {
880        ensure_str("function_def.volatility", volatility)?;
881    }
882    Ok(())
883}
884
885fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
886    ensure_str("trigger_def.name", &trigger_def.name)?;
887    ensure_str("trigger_def.table", &trigger_def.table)?;
888    ensure_len(
889        "trigger_def.events",
890        trigger_def.events.len(),
891        MAX_AST_COLLECTION_LEN,
892    )?;
893    ensure_len(
894        "trigger_def.update_columns",
895        trigger_def.update_columns.len(),
896        MAX_AST_COLLECTION_LEN,
897    )?;
898    for col in &trigger_def.update_columns {
899        ensure_str("trigger_def.update_column", col)?;
900    }
901    ensure_str(
902        "trigger_def.execute_function",
903        &trigger_def.execute_function,
904    )?;
905    Ok(())
906}
907
908fn validate_policy_def_limits(
909    policy_def: &crate::migrate::policy::RlsPolicy,
910    depth: usize,
911    state: &mut AstLimitState,
912) -> Result<(), String> {
913    ensure_str("policy_def.name", &policy_def.name)?;
914    ensure_str("policy_def.table", &policy_def.table)?;
915    if let Some(using_expr) = &policy_def.using {
916        validate_expr_limits(using_expr, depth + 1, state)?;
917    }
918    if let Some(with_check_expr) = &policy_def.with_check {
919        validate_expr_limits(with_check_expr, depth + 1, state)?;
920    }
921    if let Some(role) = &policy_def.role {
922        ensure_str("policy_def.role", role)?;
923    }
924    Ok(())
925}
926
927fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
928    if *idx >= bytes.len() {
929        return Err("unexpected EOF".to_string());
930    }
931
932    let start = *idx;
933    while *idx < bytes.len() && bytes[*idx] != b'\n' {
934        *idx += 1;
935    }
936
937    if *idx >= bytes.len() {
938        return Err("unterminated header line".to_string());
939    }
940
941    let line =
942        std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
943    *idx += 1; // consume '\n'
944    Ok(line)
945}
946
947fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
948    line.parse::<usize>()
949        .map_err(|_| format!("invalid {field}: {line}"))
950}
951
952fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
953    let end = (*idx)
954        .checked_add(len)
955        .ok_or_else(|| "payload length overflow".to_string())?;
956    if end > bytes.len() {
957        return Err("payload truncated".to_string());
958    }
959    let start = *idx;
960    *idx = end;
961    std::str::from_utf8(&bytes[start..end]).map_err(|_| "payload is not UTF-8".to_string())
962}
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967    use proptest::prelude::*;
968
969    #[test]
970    fn cmd_text_roundtrip() {
971        let cmd = crate::ast::Qail::get("users")
972            .columns(["id", "email"])
973            .where_eq("active", true)
974            .limit(10);
975
976        let encoded = encode_cmd_text(&cmd);
977        let decoded = decode_cmd_text(&encoded).unwrap();
978        assert_eq!(decoded.to_string(), cmd.to_string());
979    }
980
981    #[test]
982    fn cmd_binary_roundtrip() {
983        let cmd = crate::ast::Qail::set("users")
984            .set_value("active", true)
985            .where_eq("id", 7);
986
987        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
988        let decoded = decode_cmd_binary(&encoded).unwrap();
989        assert_eq!(decoded.to_string(), cmd.to_string());
990    }
991
992    #[test]
993    fn cmd_binary_roundtrip_preserves_merge_ast() {
994        let cmd = crate::ast::Qail::merge_into("users")
995            .target_alias("u")
996            .using_table_as("staging_users", "s")
997            .merge_on_column("u.id", crate::ast::Operator::Eq, "s.id")
998            .when_matched_update(&[("name", crate::ast::Expr::Named("s.name".to_string()))])
999            .when_not_matched_insert(
1000                &["id", "name"],
1001                &[
1002                    crate::ast::Expr::Named("s.id".to_string()),
1003                    crate::ast::Expr::Named("s.name".to_string()),
1004                ],
1005            );
1006
1007        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1008        let decoded = decode_cmd_binary(&encoded).unwrap();
1009        assert_eq!(decoded, cmd);
1010    }
1011
1012    #[test]
1013    fn cmd_binary_payload_rejects_oversized_header() {
1014        let mut payload = Vec::new();
1015        payload.extend_from_slice(&CMD_BIN_MAGIC);
1016        payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
1017        payload.extend_from_slice(&[]);
1018
1019        let err = decode_cmd_binary_payload(&payload).unwrap_err();
1020        assert!(err.contains("binary AST payload too large"));
1021    }
1022
1023    #[test]
1024    fn cmd_binary_payload_roundtrip() {
1025        let cmd = crate::ast::Qail::get("users").limit(3);
1026        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1027        let payload = decode_cmd_binary_payload(&encoded).unwrap();
1028        let mut deserializer = serde_json::Deserializer::from_slice(payload);
1029        let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
1030        deserializer.end().unwrap();
1031        assert_eq!(decoded.to_string(), cmd.to_string());
1032    }
1033
1034    #[test]
1035    fn cmd_binary_payload_rejects_legacy_qwb1() {
1036        let legacy_text = b"get users limit 1";
1037        let mut payload = Vec::new();
1038        payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
1039        payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
1040        payload.extend_from_slice(legacy_text);
1041
1042        let err = decode_cmd_binary_payload(&payload).unwrap_err();
1043        assert!(err.contains("legacy QWB1"));
1044    }
1045
1046    #[test]
1047    fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
1048        let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
1049        assert!(err.contains("invalid wire header"));
1050    }
1051
1052    #[test]
1053    fn cmd_binary_decode_rejects_trailing_bytes() {
1054        let cmd = crate::ast::Qail::get("users").limit(1);
1055        let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
1056        encoded.extend_from_slice(&[0xAA, 0xBB]);
1057        let err = decode_cmd_binary(&encoded).unwrap_err();
1058        assert!(err.contains("invalid payload length"));
1059    }
1060
1061    #[test]
1062    fn cmd_binary_decode_enforces_depth_limits() {
1063        let mut nested = crate::ast::Qail::get("users").limit(1);
1064        for _ in 0..(MAX_AST_DEPTH + 2) {
1065            nested = crate::ast::Qail {
1066                action: crate::ast::Action::Get,
1067                table: "users".to_string(),
1068                columns: vec![crate::ast::Expr::Subquery {
1069                    query: Box::new(nested),
1070                    alias: None,
1071                }],
1072                ..crate::ast::Qail::default()
1073            };
1074        }
1075
1076        let encoded = encode_cmd_binary(&nested).expect("binary encode");
1077        let err = decode_cmd_binary(&encoded).unwrap_err();
1078        assert!(
1079            err.contains("AST depth limit exceeded")
1080                || err.contains("binary AST decode failed")
1081                || err.contains("recursion limit exceeded")
1082        );
1083    }
1084
1085    #[test]
1086    fn cmd_binary_decode_bitflip_corpus_no_panic() {
1087        let seeds = vec![
1088            encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
1089            encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
1090                .expect("binary encode"),
1091            vec![],
1092            b"QWB2garbage".to_vec(),
1093            vec![0u8; 32],
1094        ];
1095
1096        for seed in seeds {
1097            for i in 0..seed.len().min(128) {
1098                for bit in 0..8u8 {
1099                    let mut mutated = seed.clone();
1100                    mutated[i] ^= 1 << bit;
1101                    let _ = decode_cmd_binary(&mutated);
1102                }
1103            }
1104            let _ = decode_cmd_binary(&seed);
1105        }
1106    }
1107
1108    proptest! {
1109        #[test]
1110        fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
1111            let _ = decode_cmd_binary(&data);
1112        }
1113    }
1114
1115    #[test]
1116    fn cmds_text_roundtrip() {
1117        let cmds = vec![
1118            crate::ast::Qail::get("users").columns(["id", "email"]),
1119            crate::ast::Qail::get("users").limit(1),
1120            crate::ast::Qail::del("users").where_eq("id", 99),
1121        ];
1122
1123        let encoded = encode_cmds_text(&cmds);
1124        let decoded = decode_cmds_text(&encoded).unwrap();
1125        assert_eq!(decoded.len(), cmds.len());
1126        for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
1127            assert_eq!(lhs.to_string(), rhs.to_string());
1128        }
1129    }
1130
1131    #[test]
1132    fn decode_cmd_text_falls_back_to_raw_qail() {
1133        let decoded = decode_cmd_text("get users limit 1").unwrap();
1134        assert_eq!(decoded.action, crate::ast::Action::Get);
1135        assert_eq!(decoded.table, "users");
1136        assert!(
1137            decoded
1138                .cages
1139                .iter()
1140                .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1141        );
1142    }
1143
1144    #[test]
1145    fn cmd_text_rejects_overflowing_payload_length_without_panic() {
1146        let input = format!("{CMD_TEXT_MAGIC}\n{}\n", usize::MAX);
1147
1148        let err = decode_cmd_text(&input).expect_err("oversized length should fail closed");
1149        assert!(err.contains("payload length overflow") || err.contains("payload truncated"));
1150    }
1151
1152    #[test]
1153    fn cmds_text_rejects_oversized_command_count_without_allocation() {
1154        let input = format!("{CMDS_TEXT_MAGIC}\n{}\n", MAX_CMD_TEXT_BATCH_COMMANDS + 1);
1155
1156        let err = decode_cmds_text(&input).expect_err("oversized command count should fail closed");
1157        assert!(err.contains("command count exceeds limit"));
1158    }
1159}