Skip to main content

qail_core/
wire.rs

1//! QAIL wire codecs for command transport.
2//!
3//! - Text codecs (`QAIL-CMD/1`, `QAIL-CMDS/1`) round-trip through canonical text.
4//! - Binary codec (`QWB2`) transports framed AST bytes directly.
5
6use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13/// Maximum allowed QWB2 payload size (bytes).
14pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_CMD_TEXT_BATCH_COMMANDS: usize = MAX_AST_COLLECTION_LEN;
19const MAX_AST_STRING_LEN: usize = 32 * 1024;
20const MAX_AST_VECTOR_LEN: usize = 8_192;
21const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
22
23/// Encode one command into versioned text wire format.
24pub fn encode_cmd_text(cmd: &Qail) -> String {
25    let payload = cmd.to_string();
26    let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
27    out.push_str(CMD_TEXT_MAGIC);
28    out.push('\n');
29    out.push_str(&payload.len().to_string());
30    out.push('\n');
31    out.push_str(&payload);
32    out
33}
34
35/// Decode one command from text wire format.
36///
37/// Also accepts raw QAIL query text as fallback for convenience.
38pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
39    let bytes = input.as_bytes();
40    let mut idx = 0usize;
41
42    let Ok(magic) = read_line(bytes, &mut idx) else {
43        return crate::parse(input).map_err(|e| e.to_string());
44    };
45
46    if magic != CMD_TEXT_MAGIC {
47        return crate::parse(input).map_err(|e| e.to_string());
48    }
49
50    let len_line = read_line(bytes, &mut idx)?;
51    let payload_len = parse_usize("payload length", len_line)?;
52    let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
53    if idx != bytes.len() {
54        return Err("trailing bytes after command payload".to_string());
55    }
56
57    crate::parse(payload).map_err(|e| e.to_string())
58}
59
60/// Encode multiple commands into versioned text wire format.
61pub fn encode_cmds_text(cmds: &[Qail]) -> String {
62    let mut out = String::new();
63    out.push_str(CMDS_TEXT_MAGIC);
64    out.push('\n');
65    out.push_str(&cmds.len().to_string());
66    out.push('\n');
67
68    for cmd in cmds {
69        let payload = cmd.to_string();
70        out.push_str(&payload.len().to_string());
71        out.push('\n');
72        out.push_str(&payload);
73    }
74
75    out
76}
77
78/// Decode multiple commands from text wire format.
79pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
80    let bytes = input.as_bytes();
81    let mut idx = 0usize;
82
83    let magic = read_line(bytes, &mut idx)?;
84    if magic != CMDS_TEXT_MAGIC {
85        return Err(format!(
86            "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
87        ));
88    }
89
90    let count_line = read_line(bytes, &mut idx)?;
91    let count = parse_usize("command count", count_line)?;
92    if count > MAX_CMD_TEXT_BATCH_COMMANDS {
93        return Err(format!(
94            "command count exceeds limit: {count} > {MAX_CMD_TEXT_BATCH_COMMANDS}"
95        ));
96    }
97    let mut out = Vec::with_capacity(count);
98
99    for _ in 0..count {
100        let len_line = read_line(bytes, &mut idx)?;
101        let payload_len = parse_usize("payload length", len_line)?;
102        let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
103        let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
104        out.push(cmd);
105    }
106
107    if idx != bytes.len() {
108        return Err("trailing bytes after batch payload".to_string());
109    }
110
111    Ok(out)
112}
113
114/// Encode one command into compact binary wire format (QWB2 AST binary).
115pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
116    validate_binary_ast_limits(cmd)?;
117    crate::sanitize::validate_ast(cmd).map_err(|e| e.to_string())?;
118
119    let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
120    if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
121        return Err(format!(
122            "binary AST payload too large: {} bytes (max {})",
123            payload.len(),
124            MAX_CMD_BINARY_PAYLOAD_BYTES
125        ));
126    }
127
128    let payload_len = u32::try_from(payload.len())
129        .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
130    let mut out = Vec::with_capacity(8 + payload.len());
131    out.extend_from_slice(&CMD_BIN_MAGIC);
132    out.extend_from_slice(&payload_len.to_be_bytes());
133    out.extend_from_slice(&payload);
134    Ok(out)
135}
136
137/// Decode one command from strict QWB2 AST-binary wire format.
138///
139/// This path rejects legacy QWB1/raw-text payloads.
140pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
141    let payload = decode_cmd_binary_payload(input)?;
142    let mut deserializer = serde_json::Deserializer::from_slice(payload);
143    let cmd = serde::Deserialize::deserialize(&mut deserializer)
144        .map_err(|e| format!("binary AST decode failed: {e}"))?;
145    deserializer
146        .end()
147        .map_err(|_| "trailing bytes after AST payload".to_string())?;
148    validate_binary_ast_limits(&cmd)?;
149    crate::sanitize::validate_ast(&cmd).map_err(|e| e.to_string())?;
150    Ok(cmd)
151}
152
153/// Decode and validate strict QWB2-framed payload bytes.
154///
155/// This validates framing and payload-size limits only.
156pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
157    if input.len() < 8 {
158        return Err("invalid wire header".to_string());
159    }
160    if input[0..4] != CMD_BIN_MAGIC {
161        if input[0..4] == CMD_BIN_LEGACY_MAGIC {
162            return Err(
163                "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
164            );
165        }
166        return Err("invalid wire header".to_string());
167    }
168
169    let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
170    if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
171        return Err(format!(
172            "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
173        ));
174    }
175    if input.len() != 8 + len {
176        return Err(format!(
177            "invalid payload length: header={len}, actual={}",
178            input.len().saturating_sub(8)
179        ));
180    }
181    Ok(&input[8..])
182}
183
184#[derive(Default)]
185struct AstLimitState {
186    nodes: usize,
187}
188
189impl AstLimitState {
190    fn bump(&mut self, kind: &str) -> Result<(), String> {
191        self.nodes = self
192            .nodes
193            .checked_add(1)
194            .ok_or_else(|| "AST node counter overflow".to_string())?;
195        if self.nodes > MAX_AST_NODES {
196            return Err(format!(
197                "AST node limit exceeded while walking {kind}: {} > {}",
198                self.nodes, MAX_AST_NODES
199            ));
200        }
201        Ok(())
202    }
203}
204
205fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
206    if depth > MAX_AST_DEPTH {
207        return Err(format!(
208            "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
209        ));
210    }
211    Ok(())
212}
213
214fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
215    if len > max {
216        return Err(format!("{kind} exceeds limit: {len} > {max}"));
217    }
218    Ok(())
219}
220
221fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
222    ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
223}
224
225fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
226    let mut state = AstLimitState::default();
227    validate_qail_limits(cmd, 0, &mut state)
228}
229
230fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
231    use crate::ast::GroupByMode;
232
233    ensure_depth(depth, "Qail")?;
234    state.bump("Qail")?;
235
236    ensure_str("qail.table", &cmd.table)?;
237    ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
238    for expr in &cmd.columns {
239        validate_expr_limits(expr, depth + 1, state)?;
240    }
241
242    ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
243    for join in &cmd.joins {
244        validate_join_limits(join, depth + 1, state)?;
245    }
246
247    ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
248    for cage in &cmd.cages {
249        validate_cage_limits(cage, depth + 1, state)?;
250    }
251
252    if let Some(index_def) = &cmd.index_def {
253        validate_index_def_limits(index_def)?;
254    }
255
256    ensure_len(
257        "qail.table_constraints",
258        cmd.table_constraints.len(),
259        MAX_AST_COLLECTION_LEN,
260    )?;
261    for constraint in &cmd.table_constraints {
262        match constraint {
263            crate::ast::TableConstraint::Unique(cols)
264            | crate::ast::TableConstraint::PrimaryKey(cols) => {
265                ensure_len(
266                    "qail.table_constraint.columns",
267                    cols.len(),
268                    MAX_AST_COLLECTION_LEN,
269                )?;
270                for col in cols {
271                    ensure_str("qail.table_constraint.column", col)?;
272                }
273            }
274            crate::ast::TableConstraint::ForeignKey {
275                name,
276                columns,
277                ref_table,
278                ref_columns,
279                on_delete,
280                on_update,
281                deferrable,
282            } => {
283                if let Some(name) = name {
284                    ensure_str("qail.table_constraint.name", name)?;
285                }
286                ensure_len(
287                    "qail.table_constraint.columns",
288                    columns.len(),
289                    MAX_AST_COLLECTION_LEN,
290                )?;
291                for col in columns {
292                    ensure_str("qail.table_constraint.column", col)?;
293                }
294                ensure_str("qail.table_constraint.ref_table", ref_table)?;
295                ensure_len(
296                    "qail.table_constraint.ref_columns",
297                    ref_columns.len(),
298                    MAX_AST_COLLECTION_LEN,
299                )?;
300                for col in ref_columns {
301                    ensure_str("qail.table_constraint.ref_column", col)?;
302                }
303                if let Some(action) = on_delete {
304                    ensure_str("qail.table_constraint.on_delete", action)?;
305                }
306                if let Some(action) = on_update {
307                    ensure_str("qail.table_constraint.on_update", action)?;
308                }
309                if let Some(clause) = deferrable {
310                    ensure_str("qail.table_constraint.deferrable", clause)?;
311                }
312            }
313        }
314    }
315
316    ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
317    for (_, rhs) in &cmd.set_ops {
318        validate_qail_limits(rhs, depth + 1, state)?;
319    }
320
321    ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
322    for cond in &cmd.having {
323        validate_condition_limits(cond, depth + 1, state)?;
324    }
325
326    if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
327        ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
328        for group in groups {
329            ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
330            for col in group {
331                ensure_str("qail.grouping_set.column", col)?;
332            }
333        }
334    }
335
336    ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
337    for cte in &cmd.ctes {
338        ensure_str("qail.cte.name", &cte.name)?;
339        ensure_len(
340            "qail.cte.columns",
341            cte.columns.len(),
342            MAX_AST_COLLECTION_LEN,
343        )?;
344        for col in &cte.columns {
345            ensure_str("qail.cte.column", col)?;
346        }
347        validate_qail_limits(&cte.base_query, depth + 1, state)?;
348        if let Some(recursive) = &cte.recursive_query {
349            validate_qail_limits(recursive, depth + 1, state)?;
350        }
351        if let Some(source_table) = &cte.source_table {
352            ensure_str("qail.cte.source_table", source_table)?;
353        }
354    }
355
356    ensure_len(
357        "qail.distinct_on",
358        cmd.distinct_on.len(),
359        MAX_AST_COLLECTION_LEN,
360    )?;
361    for expr in &cmd.distinct_on {
362        validate_expr_limits(expr, depth + 1, state)?;
363    }
364
365    if let Some(returning) = &cmd.returning {
366        ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
367        for expr in returning {
368            validate_expr_limits(expr, depth + 1, state)?;
369        }
370    }
371
372    if let Some(on_conflict) = &cmd.on_conflict {
373        ensure_len(
374            "qail.on_conflict.columns",
375            on_conflict.columns.len(),
376            MAX_AST_COLLECTION_LEN,
377        )?;
378        for col in &on_conflict.columns {
379            ensure_str("qail.on_conflict.column", col)?;
380        }
381        if let Some(assignments) = on_conflict.action.update_assignments() {
382            ensure_len(
383                "qail.on_conflict.assignments",
384                assignments.len(),
385                MAX_AST_COLLECTION_LEN,
386            )?;
387            for (col, expr) in assignments {
388                ensure_str("qail.on_conflict.assignment.column", col)?;
389                validate_expr_limits(expr, depth + 1, state)?;
390            }
391        }
392    }
393
394    if let Some(merge) = &cmd.merge {
395        if let Some(alias) = &merge.target_alias {
396            ensure_str("qail.merge.target_alias", alias)?;
397        }
398        match &merge.source {
399            crate::ast::MergeSource::Table { name, alias } => {
400                ensure_str("qail.merge.source.table", name)?;
401                if let Some(alias) = alias {
402                    ensure_str("qail.merge.source.alias", alias)?;
403                }
404            }
405            crate::ast::MergeSource::Query { query, alias } => {
406                validate_qail_limits(query, depth + 1, state)?;
407                if let Some(alias) = alias {
408                    ensure_str("qail.merge.source.alias", alias)?;
409                }
410            }
411        }
412        ensure_len("qail.merge.on", merge.on.len(), MAX_AST_COLLECTION_LEN)?;
413        for condition in &merge.on {
414            validate_condition_limits(condition, depth + 1, state)?;
415        }
416        ensure_len(
417            "qail.merge.clauses",
418            merge.clauses.len(),
419            MAX_AST_COLLECTION_LEN,
420        )?;
421        for clause in &merge.clauses {
422            ensure_len(
423                "qail.merge.clause.condition",
424                clause.condition.len(),
425                MAX_AST_COLLECTION_LEN,
426            )?;
427            for condition in &clause.condition {
428                validate_condition_limits(condition, depth + 1, state)?;
429            }
430            match &clause.action {
431                crate::ast::MergeAction::Update { assignments } => {
432                    ensure_len(
433                        "qail.merge.update.assignments",
434                        assignments.len(),
435                        MAX_AST_COLLECTION_LEN,
436                    )?;
437                    for (col, expr) in assignments {
438                        ensure_str("qail.merge.update.column", col)?;
439                        validate_expr_limits(expr, depth + 1, state)?;
440                    }
441                }
442                crate::ast::MergeAction::Insert { columns, values } => {
443                    ensure_len(
444                        "qail.merge.insert.columns",
445                        columns.len(),
446                        MAX_AST_COLLECTION_LEN,
447                    )?;
448                    for col in columns {
449                        ensure_str("qail.merge.insert.column", col)?;
450                    }
451                    ensure_len(
452                        "qail.merge.insert.values",
453                        values.len(),
454                        MAX_AST_COLLECTION_LEN,
455                    )?;
456                    for expr in values {
457                        validate_expr_limits(expr, depth + 1, state)?;
458                    }
459                }
460                crate::ast::MergeAction::Delete | crate::ast::MergeAction::DoNothing => {}
461            }
462        }
463    }
464
465    if let Some(source_query) = &cmd.source_query {
466        validate_qail_limits(source_query, depth + 1, state)?;
467    }
468
469    if let Some(channel) = &cmd.channel {
470        ensure_str("qail.channel", channel)?;
471    }
472    if let Some(payload) = &cmd.payload {
473        ensure_str("qail.payload", payload)?;
474    }
475    if let Some(savepoint_name) = &cmd.savepoint_name {
476        ensure_str("qail.savepoint_name", savepoint_name)?;
477    }
478
479    ensure_len(
480        "qail.from_tables",
481        cmd.from_tables.len(),
482        MAX_AST_COLLECTION_LEN,
483    )?;
484    for table in &cmd.from_tables {
485        ensure_str("qail.from_table", table)?;
486    }
487
488    ensure_len(
489        "qail.using_tables",
490        cmd.using_tables.len(),
491        MAX_AST_COLLECTION_LEN,
492    )?;
493    for table in &cmd.using_tables {
494        ensure_str("qail.using_table", table)?;
495    }
496
497    if let Some((_, percent, _seed)) = cmd.sample
498        && !percent.is_finite()
499    {
500        return Err("qail.sample.percent must be finite".to_string());
501    }
502
503    if let Some(vector) = &cmd.vector {
504        ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
505    }
506    if let Some(vector_name) = &cmd.vector_name {
507        ensure_str("qail.vector_name", vector_name)?;
508    }
509    if let Some(function_def) = &cmd.function_def {
510        validate_function_def_limits(function_def)?;
511    }
512    if let Some(trigger_def) = &cmd.trigger_def {
513        validate_trigger_def_limits(trigger_def)?;
514    }
515    if let Some(policy_def) = &cmd.policy_def {
516        validate_policy_def_limits(policy_def, depth + 1, state)?;
517    }
518
519    Ok(())
520}
521
522fn validate_join_limits(
523    join: &crate::ast::Join,
524    depth: usize,
525    state: &mut AstLimitState,
526) -> Result<(), String> {
527    ensure_depth(depth, "Join")?;
528    state.bump("Join")?;
529    ensure_str("join.table", &join.table)?;
530    if let Some(on) = &join.on {
531        ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
532        for cond in on {
533            validate_condition_limits(cond, depth + 1, state)?;
534        }
535    }
536    Ok(())
537}
538
539fn validate_cage_limits(
540    cage: &crate::ast::Cage,
541    depth: usize,
542    state: &mut AstLimitState,
543) -> Result<(), String> {
544    use crate::ast::CageKind;
545
546    ensure_depth(depth, "Cage")?;
547    state.bump("Cage")?;
548    ensure_len(
549        "cage.conditions",
550        cage.conditions.len(),
551        MAX_AST_COLLECTION_LEN,
552    )?;
553    for cond in &cage.conditions {
554        validate_condition_limits(cond, depth + 1, state)?;
555    }
556    match cage.kind {
557        CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
558            ensure_len("cage.numeric", v, usize::MAX)?;
559        }
560        _ => {}
561    }
562    Ok(())
563}
564
565fn validate_condition_limits(
566    cond: &crate::ast::Condition,
567    depth: usize,
568    state: &mut AstLimitState,
569) -> Result<(), String> {
570    ensure_depth(depth, "Condition")?;
571    state.bump("Condition")?;
572    validate_expr_limits(&cond.left, depth + 1, state)?;
573    validate_value_limits(&cond.value, depth + 1, state)
574}
575
576fn validate_expr_limits(
577    expr: &crate::ast::Expr,
578    depth: usize,
579    state: &mut AstLimitState,
580) -> Result<(), String> {
581    use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
582
583    ensure_depth(depth, "Expr")?;
584    state.bump("Expr")?;
585
586    match expr {
587        Expr::Star => {}
588        Expr::Named(name) => ensure_str("expr.named", name)?,
589        Expr::Aliased { name, alias } => {
590            ensure_str("expr.aliased.name", name)?;
591            ensure_str("expr.aliased.alias", alias)?;
592        }
593        Expr::Aggregate {
594            col, filter, alias, ..
595        } => {
596            ensure_str("expr.aggregate.col", col)?;
597            if let Some(filters) = filter {
598                ensure_len(
599                    "expr.aggregate.filter",
600                    filters.len(),
601                    MAX_AST_COLLECTION_LEN,
602                )?;
603                for cond in filters {
604                    validate_condition_limits(cond, depth + 1, state)?;
605                }
606            }
607            if let Some(alias) = alias {
608                ensure_str("expr.aggregate.alias", alias)?;
609            }
610        }
611        Expr::Cast {
612            expr,
613            target_type,
614            alias,
615        } => {
616            validate_expr_limits(expr, depth + 1, state)?;
617            ensure_str("expr.cast.target_type", target_type)?;
618            if let Some(alias) = alias {
619                ensure_str("expr.cast.alias", alias)?;
620            }
621        }
622        Expr::Def {
623            name,
624            data_type,
625            constraints,
626        } => {
627            ensure_str("expr.def.name", name)?;
628            ensure_str("expr.def.data_type", data_type)?;
629            ensure_len(
630                "expr.def.constraints",
631                constraints.len(),
632                MAX_AST_COLLECTION_LEN,
633            )?;
634            for constraint in constraints {
635                match constraint {
636                    Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
637                    Constraint::Default(v) => ensure_str("expr.def.default", v)?,
638                    Constraint::Check(values) => {
639                        ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
640                        for value in values {
641                            ensure_str("expr.def.check.value", value)?;
642                        }
643                    }
644                    Constraint::Comment(v) | Constraint::References(v) => {
645                        ensure_str("expr.def.constraint", v)?;
646                    }
647                    Constraint::Generated(ColumnGeneration::Stored(v))
648                    | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
649                        ensure_str("expr.def.generated", v)?;
650                    }
651                }
652            }
653        }
654        Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
655        Expr::Window {
656            name,
657            func,
658            params,
659            partition,
660            order,
661            frame,
662        } => {
663            ensure_str("expr.window.name", name)?;
664            ensure_str("expr.window.func", func)?;
665            ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
666            for param in params {
667                validate_expr_limits(param, depth + 1, state)?;
668            }
669            ensure_len(
670                "expr.window.partition",
671                partition.len(),
672                MAX_AST_COLLECTION_LEN,
673            )?;
674            for col in partition {
675                ensure_str("expr.window.partition.column", col)?;
676            }
677            ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
678            for cage in order {
679                validate_cage_limits(cage, depth + 1, state)?;
680            }
681            if let Some(frame) = frame {
682                match frame {
683                    WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
684                }
685            }
686        }
687        Expr::Case {
688            when_clauses,
689            else_value,
690            alias,
691        } => {
692            ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
693            for (cond, then_expr) in when_clauses {
694                validate_condition_limits(cond, depth + 1, state)?;
695                validate_expr_limits(then_expr, depth + 1, state)?;
696            }
697            if let Some(else_expr) = else_value {
698                validate_expr_limits(else_expr, depth + 1, state)?;
699            }
700            if let Some(alias) = alias {
701                ensure_str("expr.case.alias", alias)?;
702            }
703        }
704        Expr::JsonAccess {
705            column,
706            path_segments,
707            alias,
708        } => {
709            ensure_str("expr.json_access.column", column)?;
710            ensure_len(
711                "expr.json_access.path_segments",
712                path_segments.len(),
713                MAX_AST_COLLECTION_LEN,
714            )?;
715            for (segment, _) in path_segments {
716                ensure_str("expr.json_access.segment", segment)?;
717            }
718            if let Some(alias) = alias {
719                ensure_str("expr.json_access.alias", alias)?;
720            }
721        }
722        Expr::FunctionCall { name, args, alias } => {
723            ensure_str("expr.function_call.name", name)?;
724            ensure_len(
725                "expr.function_call.args",
726                args.len(),
727                MAX_AST_COLLECTION_LEN,
728            )?;
729            for arg in args {
730                validate_expr_limits(arg, depth + 1, state)?;
731            }
732            if let Some(alias) = alias {
733                ensure_str("expr.function_call.alias", alias)?;
734            }
735        }
736        Expr::SpecialFunction { name, args, alias } => {
737            ensure_str("expr.special_function.name", name)?;
738            ensure_len(
739                "expr.special_function.args",
740                args.len(),
741                MAX_AST_COLLECTION_LEN,
742            )?;
743            for (keyword, arg) in args {
744                if let Some(keyword) = keyword {
745                    ensure_str("expr.special_function.keyword", keyword)?;
746                }
747                validate_expr_limits(arg, depth + 1, state)?;
748            }
749            if let Some(alias) = alias {
750                ensure_str("expr.special_function.alias", alias)?;
751            }
752        }
753        Expr::Binary {
754            left, right, alias, ..
755        } => {
756            validate_expr_limits(left, depth + 1, state)?;
757            validate_expr_limits(right, depth + 1, state)?;
758            if let Some(alias) = alias {
759                ensure_str("expr.binary.alias", alias)?;
760            }
761        }
762        Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
763        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
764            ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
765            for el in elements {
766                validate_expr_limits(el, depth + 1, state)?;
767            }
768            if let Some(alias) = alias {
769                ensure_str("expr.elements.alias", alias)?;
770            }
771        }
772        Expr::Subscript { expr, index, alias } => {
773            validate_expr_limits(expr, depth + 1, state)?;
774            validate_expr_limits(index, depth + 1, state)?;
775            if let Some(alias) = alias {
776                ensure_str("expr.subscript.alias", alias)?;
777            }
778        }
779        Expr::Collate {
780            expr,
781            collation,
782            alias,
783        } => {
784            validate_expr_limits(expr, depth + 1, state)?;
785            ensure_str("expr.collate.collation", collation)?;
786            if let Some(alias) = alias {
787                ensure_str("expr.collate.alias", alias)?;
788            }
789        }
790        Expr::FieldAccess { expr, field, alias } => {
791            validate_expr_limits(expr, depth + 1, state)?;
792            ensure_str("expr.field_access.field", field)?;
793            if let Some(alias) = alias {
794                ensure_str("expr.field_access.alias", alias)?;
795            }
796        }
797        Expr::Subquery { query, alias } => {
798            validate_qail_limits(query, depth + 1, state)?;
799            if let Some(alias) = alias {
800                ensure_str("expr.subquery.alias", alias)?;
801            }
802        }
803        Expr::Exists { query, alias, .. } => {
804            validate_qail_limits(query, depth + 1, state)?;
805            if let Some(alias) = alias {
806                ensure_str("expr.exists.alias", alias)?;
807            }
808        }
809    }
810
811    Ok(())
812}
813
814fn validate_value_limits(
815    value: &crate::ast::Value,
816    depth: usize,
817    state: &mut AstLimitState,
818) -> Result<(), String> {
819    use crate::ast::Value;
820
821    ensure_depth(depth, "Value")?;
822    state.bump("Value")?;
823
824    match value {
825        Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
826        Value::String(v)
827        | Value::NamedParam(v)
828        | Value::Function(v)
829        | Value::Column(v)
830        | Value::Timestamp(v)
831        | Value::Json(v) => ensure_str("value.string", v)?,
832        Value::Array(values) => {
833            ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
834            for v in values {
835                validate_value_limits(v, depth + 1, state)?;
836            }
837        }
838        Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
839        Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
840        Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
841        Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
842        Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
843    }
844
845    Ok(())
846}
847
848fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
849    ensure_str("index_def.name", &index_def.name)?;
850    ensure_str("index_def.table", &index_def.table)?;
851    ensure_len(
852        "index_def.columns",
853        index_def.columns.len(),
854        MAX_AST_COLLECTION_LEN,
855    )?;
856    for col in &index_def.columns {
857        ensure_str("index_def.column", col)?;
858    }
859    if let Some(index_type) = &index_def.index_type {
860        ensure_str("index_def.index_type", index_type)?;
861    }
862    if let Some(where_clause) = &index_def.where_clause {
863        ensure_str("index_def.where_clause", where_clause)?;
864    }
865    Ok(())
866}
867
868fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
869    ensure_str("function_def.name", &function_def.name)?;
870    ensure_len(
871        "function_def.args",
872        function_def.args.len(),
873        MAX_AST_COLLECTION_LEN,
874    )?;
875    for arg in &function_def.args {
876        ensure_str("function_def.arg", arg)?;
877    }
878    ensure_str("function_def.returns", &function_def.returns)?;
879    ensure_str("function_def.body", &function_def.body)?;
880    if let Some(language) = &function_def.language {
881        ensure_str("function_def.language", language)?;
882    }
883    if let Some(volatility) = &function_def.volatility {
884        ensure_str("function_def.volatility", volatility)?;
885    }
886    Ok(())
887}
888
889fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
890    ensure_str("trigger_def.name", &trigger_def.name)?;
891    ensure_str("trigger_def.table", &trigger_def.table)?;
892    ensure_len(
893        "trigger_def.events",
894        trigger_def.events.len(),
895        MAX_AST_COLLECTION_LEN,
896    )?;
897    ensure_len(
898        "trigger_def.update_columns",
899        trigger_def.update_columns.len(),
900        MAX_AST_COLLECTION_LEN,
901    )?;
902    for col in &trigger_def.update_columns {
903        ensure_str("trigger_def.update_column", col)?;
904    }
905    ensure_str(
906        "trigger_def.execute_function",
907        &trigger_def.execute_function,
908    )?;
909    Ok(())
910}
911
912fn validate_policy_def_limits(
913    policy_def: &crate::migrate::policy::RlsPolicy,
914    depth: usize,
915    state: &mut AstLimitState,
916) -> Result<(), String> {
917    ensure_str("policy_def.name", &policy_def.name)?;
918    ensure_str("policy_def.table", &policy_def.table)?;
919    if let Some(using_expr) = &policy_def.using {
920        validate_expr_limits(using_expr, depth + 1, state)?;
921    }
922    if let Some(with_check_expr) = &policy_def.with_check {
923        validate_expr_limits(with_check_expr, depth + 1, state)?;
924    }
925    if let Some(role) = &policy_def.role {
926        ensure_str("policy_def.role", role)?;
927    }
928    Ok(())
929}
930
931fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
932    if *idx >= bytes.len() {
933        return Err("unexpected EOF".to_string());
934    }
935
936    let start = *idx;
937    while *idx < bytes.len() && bytes[*idx] != b'\n' {
938        *idx += 1;
939    }
940
941    if *idx >= bytes.len() {
942        return Err("unterminated header line".to_string());
943    }
944
945    let line =
946        std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
947    *idx += 1; // consume '\n'
948    Ok(line)
949}
950
951fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
952    line.parse::<usize>()
953        .map_err(|_| format!("invalid {field}: {line}"))
954}
955
956fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
957    let end = (*idx)
958        .checked_add(len)
959        .ok_or_else(|| "payload length overflow".to_string())?;
960    if end > bytes.len() {
961        return Err("payload truncated".to_string());
962    }
963    let start = *idx;
964    *idx = end;
965    std::str::from_utf8(&bytes[start..end]).map_err(|_| "payload is not UTF-8".to_string())
966}
967
968#[cfg(test)]
969mod tests {
970    use super::*;
971    use proptest::prelude::*;
972
973    fn encode_cmd_binary_unchecked_for_test(cmd: &Qail) -> Vec<u8> {
974        let payload = serde_json::to_vec(cmd).expect("test binary AST encode");
975        let mut out = Vec::with_capacity(8 + payload.len());
976        out.extend_from_slice(&CMD_BIN_MAGIC);
977        out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
978        out.extend_from_slice(&payload);
979        out
980    }
981
982    #[test]
983    fn cmd_text_roundtrip() {
984        let cmd = crate::ast::Qail::get("users")
985            .columns(["id", "email"])
986            .where_eq("active", true)
987            .limit(10);
988
989        let encoded = encode_cmd_text(&cmd);
990        let decoded = decode_cmd_text(&encoded).unwrap();
991        assert_eq!(decoded.to_string(), cmd.to_string());
992    }
993
994    #[test]
995    fn cmd_binary_roundtrip() {
996        let cmd = crate::ast::Qail::set("users")
997            .set_value("active", true)
998            .where_eq("id", 7);
999
1000        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1001        let decoded = decode_cmd_binary(&encoded).unwrap();
1002        assert_eq!(decoded.to_string(), cmd.to_string());
1003    }
1004
1005    #[test]
1006    fn cmd_binary_roundtrip_preserves_merge_ast() {
1007        let cmd = crate::ast::Qail::merge_into("users")
1008            .target_alias("u")
1009            .using_table_as("staging_users", "s")
1010            .merge_on_column("u.id", crate::ast::Operator::Eq, "s.id")
1011            .when_matched_update(&[("name", crate::ast::Expr::Named("s.name".to_string()))])
1012            .when_not_matched_insert(
1013                &["id", "name"],
1014                &[
1015                    crate::ast::Expr::Named("s.id".to_string()),
1016                    crate::ast::Expr::Named("s.name".to_string()),
1017                ],
1018            );
1019
1020        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1021        let decoded = decode_cmd_binary(&encoded).unwrap();
1022        assert_eq!(decoded, cmd);
1023    }
1024
1025    #[test]
1026    fn cmd_binary_payload_rejects_oversized_header() {
1027        let mut payload = Vec::new();
1028        payload.extend_from_slice(&CMD_BIN_MAGIC);
1029        payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
1030        payload.extend_from_slice(&[]);
1031
1032        let err = decode_cmd_binary_payload(&payload).unwrap_err();
1033        assert!(err.contains("binary AST payload too large"));
1034    }
1035
1036    #[test]
1037    fn cmd_binary_payload_roundtrip() {
1038        let cmd = crate::ast::Qail::get("users").limit(3);
1039        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
1040        let payload = decode_cmd_binary_payload(&encoded).unwrap();
1041        let mut deserializer = serde_json::Deserializer::from_slice(payload);
1042        let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
1043        deserializer.end().unwrap();
1044        assert_eq!(decoded.to_string(), cmd.to_string());
1045    }
1046
1047    #[test]
1048    fn cmd_binary_payload_rejects_legacy_qwb1() {
1049        let legacy_text = b"get users limit 1";
1050        let mut payload = Vec::new();
1051        payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
1052        payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
1053        payload.extend_from_slice(legacy_text);
1054
1055        let err = decode_cmd_binary_payload(&payload).unwrap_err();
1056        assert!(err.contains("legacy QWB1"));
1057    }
1058
1059    #[test]
1060    fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
1061        let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
1062        assert!(err.contains("invalid wire header"));
1063    }
1064
1065    #[test]
1066    fn cmd_binary_decode_rejects_trailing_bytes() {
1067        let cmd = crate::ast::Qail::get("users").limit(1);
1068        let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
1069        encoded.extend_from_slice(&[0xAA, 0xBB]);
1070        let err = decode_cmd_binary(&encoded).unwrap_err();
1071        assert!(err.contains("invalid payload length"));
1072    }
1073
1074    #[test]
1075    fn cmd_binary_decode_rejects_unsafe_identifiers() {
1076        let cmd = crate::ast::Qail::get("users; DROP TABLE users; --").limit(1);
1077        let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1078
1079        let err = decode_cmd_binary(&encoded).unwrap_err();
1080
1081        assert!(err.contains("AST validation failed"));
1082        assert!(err.contains("table"));
1083    }
1084
1085    #[test]
1086    fn cmd_binary_encode_rejects_unsafe_identifiers() {
1087        let cmd = crate::ast::Qail::get("users; DROP TABLE users; --").limit(1);
1088
1089        let err = encode_cmd_binary(&cmd).unwrap_err();
1090
1091        assert!(err.contains("AST validation failed"));
1092        assert!(err.contains("table"));
1093    }
1094
1095    #[test]
1096    fn cmd_binary_decode_rejects_procedural_actions() {
1097        let cmd = crate::ast::Qail::call("refresh_materialized_views()");
1098        let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1099
1100        let err = decode_cmd_binary(&encoded).unwrap_err();
1101
1102        assert!(err.contains("AST validation failed"));
1103        assert!(err.contains("procedural/session actions"));
1104    }
1105
1106    #[test]
1107    fn cmd_binary_encode_rejects_procedural_actions() {
1108        let cmd = crate::ast::Qail::call("refresh_materialized_views()");
1109
1110        let err = encode_cmd_binary(&cmd).unwrap_err();
1111
1112        assert!(err.contains("AST validation failed"));
1113        assert!(err.contains("procedural/session actions"));
1114    }
1115
1116    #[test]
1117    fn cmd_binary_decode_rejects_unsafe_raw_function_values() {
1118        let cmd = crate::ast::Qail::get("users").filter(
1119            "updated_at",
1120            crate::ast::Operator::Lt,
1121            crate::ast::Value::Function("NOW(); DROP TABLE users; --".to_string()),
1122        );
1123        let encoded = encode_cmd_binary_unchecked_for_test(&cmd);
1124
1125        let err = decode_cmd_binary(&encoded).unwrap_err();
1126
1127        assert!(err.contains("AST validation failed"));
1128        assert!(err.contains("raw function values"));
1129    }
1130
1131    #[test]
1132    fn cmd_binary_encode_rejects_unsafe_raw_function_values() {
1133        let cmd = crate::ast::Qail::get("users").filter(
1134            "updated_at",
1135            crate::ast::Operator::Lt,
1136            crate::ast::Value::Function("NOW(); DROP TABLE users; --".to_string()),
1137        );
1138
1139        let err = encode_cmd_binary(&cmd).unwrap_err();
1140
1141        assert!(err.contains("AST validation failed"));
1142        assert!(err.contains("raw function values"));
1143    }
1144
1145    #[test]
1146    fn cmd_binary_encode_rejects_unsafe_check_constraint_payload() {
1147        let cmd = crate::ast::Qail {
1148            action: crate::ast::Action::AlterAddConstraint,
1149            table: "users".to_string(),
1150            channel: Some("users_active_check".to_string()),
1151            payload: Some("active); DROP TABLE users; --".to_string()),
1152            ..Default::default()
1153        };
1154
1155        let err = encode_cmd_binary(&cmd).unwrap_err();
1156
1157        assert!(err.contains("AST validation failed"));
1158        assert!(err.contains("SQL expression fragments"));
1159    }
1160
1161    #[test]
1162    fn cmd_binary_decode_enforces_depth_limits() {
1163        let mut nested = crate::ast::Qail::get("users").limit(1);
1164        for _ in 0..(MAX_AST_DEPTH + 2) {
1165            nested = crate::ast::Qail {
1166                action: crate::ast::Action::Get,
1167                table: "users".to_string(),
1168                columns: vec![crate::ast::Expr::Subquery {
1169                    query: Box::new(nested),
1170                    alias: None,
1171                }],
1172                ..crate::ast::Qail::default()
1173            };
1174        }
1175
1176        let encoded = encode_cmd_binary_unchecked_for_test(&nested);
1177        let err = decode_cmd_binary(&encoded).unwrap_err();
1178        assert!(
1179            err.contains("AST depth limit exceeded")
1180                || err.contains("binary AST decode failed")
1181                || err.contains("recursion limit exceeded")
1182        );
1183    }
1184
1185    #[test]
1186    fn cmd_binary_decode_bitflip_corpus_no_panic() {
1187        let seeds = vec![
1188            encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
1189            encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
1190                .expect("binary encode"),
1191            vec![],
1192            b"QWB2garbage".to_vec(),
1193            vec![0u8; 32],
1194        ];
1195
1196        for seed in seeds {
1197            for i in 0..seed.len().min(128) {
1198                for bit in 0..8u8 {
1199                    let mut mutated = seed.clone();
1200                    mutated[i] ^= 1 << bit;
1201                    let _ = decode_cmd_binary(&mutated);
1202                }
1203            }
1204            let _ = decode_cmd_binary(&seed);
1205        }
1206    }
1207
1208    proptest! {
1209        #[test]
1210        fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
1211            let _ = decode_cmd_binary(&data);
1212        }
1213    }
1214
1215    #[test]
1216    fn cmds_text_roundtrip() {
1217        let cmds = vec![
1218            crate::ast::Qail::get("users").columns(["id", "email"]),
1219            crate::ast::Qail::get("users").limit(1),
1220            crate::ast::Qail::del("users").where_eq("id", 99),
1221        ];
1222
1223        let encoded = encode_cmds_text(&cmds);
1224        let decoded = decode_cmds_text(&encoded).unwrap();
1225        assert_eq!(decoded.len(), cmds.len());
1226        for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
1227            assert_eq!(lhs.to_string(), rhs.to_string());
1228        }
1229    }
1230
1231    #[test]
1232    fn decode_cmd_text_falls_back_to_raw_qail() {
1233        let decoded = decode_cmd_text("get users limit 1").unwrap();
1234        assert_eq!(decoded.action, crate::ast::Action::Get);
1235        assert_eq!(decoded.table, "users");
1236        assert!(
1237            decoded
1238                .cages
1239                .iter()
1240                .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1241        );
1242    }
1243
1244    #[test]
1245    fn cmd_text_rejects_overflowing_payload_length_without_panic() {
1246        let input = format!("{CMD_TEXT_MAGIC}\n{}\n", usize::MAX);
1247
1248        let err = decode_cmd_text(&input).expect_err("oversized length should fail closed");
1249        assert!(err.contains("payload length overflow") || err.contains("payload truncated"));
1250    }
1251
1252    #[test]
1253    fn cmds_text_rejects_oversized_command_count_without_allocation() {
1254        let input = format!("{CMDS_TEXT_MAGIC}\n{}\n", MAX_CMD_TEXT_BATCH_COMMANDS + 1);
1255
1256        let err = decode_cmds_text(&input).expect_err("oversized command count should fail closed");
1257        assert!(err.contains("command count exceeds limit"));
1258    }
1259}