Skip to main content

qail_core/
wire.rs

1//! QAIL wire codecs for command transport.
2//!
3//! - Text codecs (`QAIL-CMD/1`, `QAIL-CMDS/1`) round-trip through canonical text.
4//! - Binary codec (`QWB2`) transports framed AST bytes directly.
5
6use crate::ast::Qail;
7
8const CMD_TEXT_MAGIC: &str = "QAIL-CMD/1";
9const CMDS_TEXT_MAGIC: &str = "QAIL-CMDS/1";
10const CMD_BIN_MAGIC: [u8; 4] = *b"QWB2";
11const CMD_BIN_LEGACY_MAGIC: [u8; 4] = *b"QWB1";
12
13/// Maximum allowed QWB2 payload size (bytes).
14pub const MAX_CMD_BINARY_PAYLOAD_BYTES: usize = 64 * 1024;
15const MAX_AST_DEPTH: usize = 64;
16const MAX_AST_NODES: usize = 16_384;
17const MAX_AST_COLLECTION_LEN: usize = 2_048;
18const MAX_AST_STRING_LEN: usize = 32 * 1024;
19const MAX_AST_VECTOR_LEN: usize = 8_192;
20const MAX_AST_BINARY_VALUE_LEN: usize = 32 * 1024;
21
22/// Encode one command into versioned text wire format.
23pub fn encode_cmd_text(cmd: &Qail) -> String {
24    let payload = cmd.to_string();
25    let mut out = String::with_capacity(CMD_TEXT_MAGIC.len() + payload.len() + 32);
26    out.push_str(CMD_TEXT_MAGIC);
27    out.push('\n');
28    out.push_str(&payload.len().to_string());
29    out.push('\n');
30    out.push_str(&payload);
31    out
32}
33
34/// Decode one command from text wire format.
35///
36/// Also accepts raw QAIL query text as fallback for convenience.
37pub fn decode_cmd_text(input: &str) -> Result<Qail, String> {
38    let bytes = input.as_bytes();
39    let mut idx = 0usize;
40
41    let Ok(magic) = read_line(bytes, &mut idx) else {
42        return crate::parse(input).map_err(|e| e.to_string());
43    };
44
45    if magic != CMD_TEXT_MAGIC {
46        return crate::parse(input).map_err(|e| e.to_string());
47    }
48
49    let len_line = read_line(bytes, &mut idx)?;
50    let payload_len = parse_usize("payload length", len_line)?;
51    let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
52    if idx != bytes.len() {
53        return Err("trailing bytes after command payload".to_string());
54    }
55
56    crate::parse(payload).map_err(|e| e.to_string())
57}
58
59/// Encode multiple commands into versioned text wire format.
60pub fn encode_cmds_text(cmds: &[Qail]) -> String {
61    let mut out = String::new();
62    out.push_str(CMDS_TEXT_MAGIC);
63    out.push('\n');
64    out.push_str(&cmds.len().to_string());
65    out.push('\n');
66
67    for cmd in cmds {
68        let payload = cmd.to_string();
69        out.push_str(&payload.len().to_string());
70        out.push('\n');
71        out.push_str(&payload);
72    }
73
74    out
75}
76
77/// Decode multiple commands from text wire format.
78pub fn decode_cmds_text(input: &str) -> Result<Vec<Qail>, String> {
79    let bytes = input.as_bytes();
80    let mut idx = 0usize;
81
82    let magic = read_line(bytes, &mut idx)?;
83    if magic != CMDS_TEXT_MAGIC {
84        return Err(format!(
85            "invalid wire magic: expected {CMDS_TEXT_MAGIC}, got {magic}"
86        ));
87    }
88
89    let count_line = read_line(bytes, &mut idx)?;
90    let count = parse_usize("command count", count_line)?;
91    let mut out = Vec::with_capacity(count);
92
93    for _ in 0..count {
94        let len_line = read_line(bytes, &mut idx)?;
95        let payload_len = parse_usize("payload length", len_line)?;
96        let payload = read_exact_utf8(bytes, &mut idx, payload_len)?;
97        let cmd = crate::parse(payload).map_err(|e| e.to_string())?;
98        out.push(cmd);
99    }
100
101    if idx != bytes.len() {
102        return Err("trailing bytes after batch payload".to_string());
103    }
104
105    Ok(out)
106}
107
108/// Encode one command into compact binary wire format (QWB2 AST binary).
109pub fn encode_cmd_binary(cmd: &Qail) -> Result<Vec<u8>, String> {
110    let payload = serde_json::to_vec(cmd).map_err(|e| format!("binary AST encode failed: {e}"))?;
111    if payload.len() > MAX_CMD_BINARY_PAYLOAD_BYTES {
112        return Err(format!(
113            "binary AST payload too large: {} bytes (max {})",
114            payload.len(),
115            MAX_CMD_BINARY_PAYLOAD_BYTES
116        ));
117    }
118
119    let payload_len = u32::try_from(payload.len())
120        .map_err(|_| format!("binary AST payload exceeds u32 length: {}", payload.len()))?;
121    let mut out = Vec::with_capacity(8 + payload.len());
122    out.extend_from_slice(&CMD_BIN_MAGIC);
123    out.extend_from_slice(&payload_len.to_be_bytes());
124    out.extend_from_slice(&payload);
125    Ok(out)
126}
127
128/// Decode one command from strict QWB2 AST-binary wire format.
129///
130/// This path rejects legacy QWB1/raw-text payloads.
131pub fn decode_cmd_binary(input: &[u8]) -> Result<Qail, String> {
132    let payload = decode_cmd_binary_payload(input)?;
133    let mut deserializer = serde_json::Deserializer::from_slice(payload);
134    let cmd = serde::Deserialize::deserialize(&mut deserializer)
135        .map_err(|e| format!("binary AST decode failed: {e}"))?;
136    deserializer
137        .end()
138        .map_err(|_| "trailing bytes after AST payload".to_string())?;
139    validate_binary_ast_limits(&cmd)?;
140    Ok(cmd)
141}
142
143/// Decode and validate strict QWB2-framed payload bytes.
144///
145/// This validates framing and payload-size limits only.
146pub fn decode_cmd_binary_payload(input: &[u8]) -> Result<&[u8], String> {
147    if input.len() < 8 {
148        return Err("invalid wire header".to_string());
149    }
150    if input[0..4] != CMD_BIN_MAGIC {
151        if input[0..4] == CMD_BIN_LEGACY_MAGIC {
152            return Err(
153                "legacy QWB1 text payload is not supported on parse-free binary path".to_string(),
154            );
155        }
156        return Err("invalid wire header".to_string());
157    }
158
159    let len = u32::from_be_bytes([input[4], input[5], input[6], input[7]]) as usize;
160    if len > MAX_CMD_BINARY_PAYLOAD_BYTES {
161        return Err(format!(
162            "binary AST payload too large: header={len}, max={MAX_CMD_BINARY_PAYLOAD_BYTES}"
163        ));
164    }
165    if input.len() != 8 + len {
166        return Err(format!(
167            "invalid payload length: header={len}, actual={}",
168            input.len().saturating_sub(8)
169        ));
170    }
171    Ok(&input[8..])
172}
173
174#[derive(Default)]
175struct AstLimitState {
176    nodes: usize,
177}
178
179impl AstLimitState {
180    fn bump(&mut self, kind: &str) -> Result<(), String> {
181        self.nodes = self
182            .nodes
183            .checked_add(1)
184            .ok_or_else(|| "AST node counter overflow".to_string())?;
185        if self.nodes > MAX_AST_NODES {
186            return Err(format!(
187                "AST node limit exceeded while walking {kind}: {} > {}",
188                self.nodes, MAX_AST_NODES
189            ));
190        }
191        Ok(())
192    }
193}
194
195fn ensure_depth(depth: usize, kind: &str) -> Result<(), String> {
196    if depth > MAX_AST_DEPTH {
197        return Err(format!(
198            "AST depth limit exceeded while walking {kind}: {depth} > {MAX_AST_DEPTH}"
199        ));
200    }
201    Ok(())
202}
203
204fn ensure_len(kind: &str, len: usize, max: usize) -> Result<(), String> {
205    if len > max {
206        return Err(format!("{kind} exceeds limit: {len} > {max}"));
207    }
208    Ok(())
209}
210
211fn ensure_str(kind: &str, value: &str) -> Result<(), String> {
212    ensure_len(kind, value.len(), MAX_AST_STRING_LEN)
213}
214
215fn validate_binary_ast_limits(cmd: &Qail) -> Result<(), String> {
216    let mut state = AstLimitState::default();
217    validate_qail_limits(cmd, 0, &mut state)
218}
219
220fn validate_qail_limits(cmd: &Qail, depth: usize, state: &mut AstLimitState) -> Result<(), String> {
221    use crate::ast::GroupByMode;
222
223    ensure_depth(depth, "Qail")?;
224    state.bump("Qail")?;
225
226    ensure_str("qail.table", &cmd.table)?;
227    ensure_len("qail.columns", cmd.columns.len(), MAX_AST_COLLECTION_LEN)?;
228    for expr in &cmd.columns {
229        validate_expr_limits(expr, depth + 1, state)?;
230    }
231
232    ensure_len("qail.joins", cmd.joins.len(), MAX_AST_COLLECTION_LEN)?;
233    for join in &cmd.joins {
234        validate_join_limits(join, depth + 1, state)?;
235    }
236
237    ensure_len("qail.cages", cmd.cages.len(), MAX_AST_COLLECTION_LEN)?;
238    for cage in &cmd.cages {
239        validate_cage_limits(cage, depth + 1, state)?;
240    }
241
242    if let Some(index_def) = &cmd.index_def {
243        validate_index_def_limits(index_def)?;
244    }
245
246    ensure_len(
247        "qail.table_constraints",
248        cmd.table_constraints.len(),
249        MAX_AST_COLLECTION_LEN,
250    )?;
251    for constraint in &cmd.table_constraints {
252        match constraint {
253            crate::ast::TableConstraint::Unique(cols)
254            | crate::ast::TableConstraint::PrimaryKey(cols) => {
255                ensure_len(
256                    "qail.table_constraint.columns",
257                    cols.len(),
258                    MAX_AST_COLLECTION_LEN,
259                )?;
260                for col in cols {
261                    ensure_str("qail.table_constraint.column", col)?;
262                }
263            }
264        }
265    }
266
267    ensure_len("qail.set_ops", cmd.set_ops.len(), MAX_AST_COLLECTION_LEN)?;
268    for (_, rhs) in &cmd.set_ops {
269        validate_qail_limits(rhs, depth + 1, state)?;
270    }
271
272    ensure_len("qail.having", cmd.having.len(), MAX_AST_COLLECTION_LEN)?;
273    for cond in &cmd.having {
274        validate_condition_limits(cond, depth + 1, state)?;
275    }
276
277    if let GroupByMode::GroupingSets(groups) = &cmd.group_by_mode {
278        ensure_len("qail.grouping_sets", groups.len(), MAX_AST_COLLECTION_LEN)?;
279        for group in groups {
280            ensure_len("qail.grouping_set", group.len(), MAX_AST_COLLECTION_LEN)?;
281            for col in group {
282                ensure_str("qail.grouping_set.column", col)?;
283            }
284        }
285    }
286
287    ensure_len("qail.ctes", cmd.ctes.len(), MAX_AST_COLLECTION_LEN)?;
288    for cte in &cmd.ctes {
289        ensure_str("qail.cte.name", &cte.name)?;
290        ensure_len(
291            "qail.cte.columns",
292            cte.columns.len(),
293            MAX_AST_COLLECTION_LEN,
294        )?;
295        for col in &cte.columns {
296            ensure_str("qail.cte.column", col)?;
297        }
298        validate_qail_limits(&cte.base_query, depth + 1, state)?;
299        if let Some(recursive) = &cte.recursive_query {
300            validate_qail_limits(recursive, depth + 1, state)?;
301        }
302        if let Some(source_table) = &cte.source_table {
303            ensure_str("qail.cte.source_table", source_table)?;
304        }
305    }
306
307    ensure_len(
308        "qail.distinct_on",
309        cmd.distinct_on.len(),
310        MAX_AST_COLLECTION_LEN,
311    )?;
312    for expr in &cmd.distinct_on {
313        validate_expr_limits(expr, depth + 1, state)?;
314    }
315
316    if let Some(returning) = &cmd.returning {
317        ensure_len("qail.returning", returning.len(), MAX_AST_COLLECTION_LEN)?;
318        for expr in returning {
319            validate_expr_limits(expr, depth + 1, state)?;
320        }
321    }
322
323    if let Some(on_conflict) = &cmd.on_conflict {
324        ensure_len(
325            "qail.on_conflict.columns",
326            on_conflict.columns.len(),
327            MAX_AST_COLLECTION_LEN,
328        )?;
329        for col in &on_conflict.columns {
330            ensure_str("qail.on_conflict.column", col)?;
331        }
332        if let Some(assignments) = on_conflict.action.update_assignments() {
333            ensure_len(
334                "qail.on_conflict.assignments",
335                assignments.len(),
336                MAX_AST_COLLECTION_LEN,
337            )?;
338            for (col, expr) in assignments {
339                ensure_str("qail.on_conflict.assignment.column", col)?;
340                validate_expr_limits(expr, depth + 1, state)?;
341            }
342        }
343    }
344
345    if let Some(source_query) = &cmd.source_query {
346        validate_qail_limits(source_query, depth + 1, state)?;
347    }
348
349    if let Some(channel) = &cmd.channel {
350        ensure_str("qail.channel", channel)?;
351    }
352    if let Some(payload) = &cmd.payload {
353        ensure_str("qail.payload", payload)?;
354    }
355    if let Some(savepoint_name) = &cmd.savepoint_name {
356        ensure_str("qail.savepoint_name", savepoint_name)?;
357    }
358
359    ensure_len(
360        "qail.from_tables",
361        cmd.from_tables.len(),
362        MAX_AST_COLLECTION_LEN,
363    )?;
364    for table in &cmd.from_tables {
365        ensure_str("qail.from_table", table)?;
366    }
367
368    ensure_len(
369        "qail.using_tables",
370        cmd.using_tables.len(),
371        MAX_AST_COLLECTION_LEN,
372    )?;
373    for table in &cmd.using_tables {
374        ensure_str("qail.using_table", table)?;
375    }
376
377    if let Some((_, percent, _seed)) = cmd.sample
378        && !percent.is_finite()
379    {
380        return Err("qail.sample.percent must be finite".to_string());
381    }
382
383    if let Some(vector) = &cmd.vector {
384        ensure_len("qail.vector", vector.len(), MAX_AST_VECTOR_LEN)?;
385    }
386    if let Some(vector_name) = &cmd.vector_name {
387        ensure_str("qail.vector_name", vector_name)?;
388    }
389    if let Some(function_def) = &cmd.function_def {
390        validate_function_def_limits(function_def)?;
391    }
392    if let Some(trigger_def) = &cmd.trigger_def {
393        validate_trigger_def_limits(trigger_def)?;
394    }
395    if let Some(policy_def) = &cmd.policy_def {
396        validate_policy_def_limits(policy_def, depth + 1, state)?;
397    }
398
399    Ok(())
400}
401
402fn validate_join_limits(
403    join: &crate::ast::Join,
404    depth: usize,
405    state: &mut AstLimitState,
406) -> Result<(), String> {
407    ensure_depth(depth, "Join")?;
408    state.bump("Join")?;
409    ensure_str("join.table", &join.table)?;
410    if let Some(on) = &join.on {
411        ensure_len("join.on", on.len(), MAX_AST_COLLECTION_LEN)?;
412        for cond in on {
413            validate_condition_limits(cond, depth + 1, state)?;
414        }
415    }
416    Ok(())
417}
418
419fn validate_cage_limits(
420    cage: &crate::ast::Cage,
421    depth: usize,
422    state: &mut AstLimitState,
423) -> Result<(), String> {
424    use crate::ast::CageKind;
425
426    ensure_depth(depth, "Cage")?;
427    state.bump("Cage")?;
428    ensure_len(
429        "cage.conditions",
430        cage.conditions.len(),
431        MAX_AST_COLLECTION_LEN,
432    )?;
433    for cond in &cage.conditions {
434        validate_condition_limits(cond, depth + 1, state)?;
435    }
436    match cage.kind {
437        CageKind::Limit(v) | CageKind::Offset(v) | CageKind::Sample(v) => {
438            ensure_len("cage.numeric", v, usize::MAX)?;
439        }
440        _ => {}
441    }
442    Ok(())
443}
444
445fn validate_condition_limits(
446    cond: &crate::ast::Condition,
447    depth: usize,
448    state: &mut AstLimitState,
449) -> Result<(), String> {
450    ensure_depth(depth, "Condition")?;
451    state.bump("Condition")?;
452    validate_expr_limits(&cond.left, depth + 1, state)?;
453    validate_value_limits(&cond.value, depth + 1, state)
454}
455
456fn validate_expr_limits(
457    expr: &crate::ast::Expr,
458    depth: usize,
459    state: &mut AstLimitState,
460) -> Result<(), String> {
461    use crate::ast::{ColumnGeneration, Constraint, Expr, WindowFrame};
462
463    ensure_depth(depth, "Expr")?;
464    state.bump("Expr")?;
465
466    match expr {
467        Expr::Star => {}
468        Expr::Named(name) => ensure_str("expr.named", name)?,
469        Expr::Aliased { name, alias } => {
470            ensure_str("expr.aliased.name", name)?;
471            ensure_str("expr.aliased.alias", alias)?;
472        }
473        Expr::Aggregate {
474            col, filter, alias, ..
475        } => {
476            ensure_str("expr.aggregate.col", col)?;
477            if let Some(filters) = filter {
478                ensure_len(
479                    "expr.aggregate.filter",
480                    filters.len(),
481                    MAX_AST_COLLECTION_LEN,
482                )?;
483                for cond in filters {
484                    validate_condition_limits(cond, depth + 1, state)?;
485                }
486            }
487            if let Some(alias) = alias {
488                ensure_str("expr.aggregate.alias", alias)?;
489            }
490        }
491        Expr::Cast {
492            expr,
493            target_type,
494            alias,
495        } => {
496            validate_expr_limits(expr, depth + 1, state)?;
497            ensure_str("expr.cast.target_type", target_type)?;
498            if let Some(alias) = alias {
499                ensure_str("expr.cast.alias", alias)?;
500            }
501        }
502        Expr::Def {
503            name,
504            data_type,
505            constraints,
506        } => {
507            ensure_str("expr.def.name", name)?;
508            ensure_str("expr.def.data_type", data_type)?;
509            ensure_len(
510                "expr.def.constraints",
511                constraints.len(),
512                MAX_AST_COLLECTION_LEN,
513            )?;
514            for constraint in constraints {
515                match constraint {
516                    Constraint::PrimaryKey | Constraint::Unique | Constraint::Nullable => {}
517                    Constraint::Default(v) => ensure_str("expr.def.default", v)?,
518                    Constraint::Check(values) => {
519                        ensure_len("expr.def.check", values.len(), MAX_AST_COLLECTION_LEN)?;
520                        for value in values {
521                            ensure_str("expr.def.check.value", value)?;
522                        }
523                    }
524                    Constraint::Comment(v) | Constraint::References(v) => {
525                        ensure_str("expr.def.constraint", v)?;
526                    }
527                    Constraint::Generated(ColumnGeneration::Stored(v))
528                    | Constraint::Generated(ColumnGeneration::Virtual(v)) => {
529                        ensure_str("expr.def.generated", v)?;
530                    }
531                }
532            }
533        }
534        Expr::Mod { col, .. } => validate_expr_limits(col, depth + 1, state)?,
535        Expr::Window {
536            name,
537            func,
538            params,
539            partition,
540            order,
541            frame,
542        } => {
543            ensure_str("expr.window.name", name)?;
544            ensure_str("expr.window.func", func)?;
545            ensure_len("expr.window.params", params.len(), MAX_AST_COLLECTION_LEN)?;
546            for param in params {
547                validate_expr_limits(param, depth + 1, state)?;
548            }
549            ensure_len(
550                "expr.window.partition",
551                partition.len(),
552                MAX_AST_COLLECTION_LEN,
553            )?;
554            for col in partition {
555                ensure_str("expr.window.partition.column", col)?;
556            }
557            ensure_len("expr.window.order", order.len(), MAX_AST_COLLECTION_LEN)?;
558            for cage in order {
559                validate_cage_limits(cage, depth + 1, state)?;
560            }
561            if let Some(frame) = frame {
562                match frame {
563                    WindowFrame::Rows { .. } | WindowFrame::Range { .. } => {}
564                }
565            }
566        }
567        Expr::Case {
568            when_clauses,
569            else_value,
570            alias,
571        } => {
572            ensure_len("expr.case.when", when_clauses.len(), MAX_AST_COLLECTION_LEN)?;
573            for (cond, then_expr) in when_clauses {
574                validate_condition_limits(cond, depth + 1, state)?;
575                validate_expr_limits(then_expr, depth + 1, state)?;
576            }
577            if let Some(else_expr) = else_value {
578                validate_expr_limits(else_expr, depth + 1, state)?;
579            }
580            if let Some(alias) = alias {
581                ensure_str("expr.case.alias", alias)?;
582            }
583        }
584        Expr::JsonAccess {
585            column,
586            path_segments,
587            alias,
588        } => {
589            ensure_str("expr.json_access.column", column)?;
590            ensure_len(
591                "expr.json_access.path_segments",
592                path_segments.len(),
593                MAX_AST_COLLECTION_LEN,
594            )?;
595            for (segment, _) in path_segments {
596                ensure_str("expr.json_access.segment", segment)?;
597            }
598            if let Some(alias) = alias {
599                ensure_str("expr.json_access.alias", alias)?;
600            }
601        }
602        Expr::FunctionCall { name, args, alias } => {
603            ensure_str("expr.function_call.name", name)?;
604            ensure_len(
605                "expr.function_call.args",
606                args.len(),
607                MAX_AST_COLLECTION_LEN,
608            )?;
609            for arg in args {
610                validate_expr_limits(arg, depth + 1, state)?;
611            }
612            if let Some(alias) = alias {
613                ensure_str("expr.function_call.alias", alias)?;
614            }
615        }
616        Expr::SpecialFunction { name, args, alias } => {
617            ensure_str("expr.special_function.name", name)?;
618            ensure_len(
619                "expr.special_function.args",
620                args.len(),
621                MAX_AST_COLLECTION_LEN,
622            )?;
623            for (keyword, arg) in args {
624                if let Some(keyword) = keyword {
625                    ensure_str("expr.special_function.keyword", keyword)?;
626                }
627                validate_expr_limits(arg, depth + 1, state)?;
628            }
629            if let Some(alias) = alias {
630                ensure_str("expr.special_function.alias", alias)?;
631            }
632        }
633        Expr::Binary {
634            left, right, alias, ..
635        } => {
636            validate_expr_limits(left, depth + 1, state)?;
637            validate_expr_limits(right, depth + 1, state)?;
638            if let Some(alias) = alias {
639                ensure_str("expr.binary.alias", alias)?;
640            }
641        }
642        Expr::Literal(v) => validate_value_limits(v, depth + 1, state)?,
643        Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
644            ensure_len("expr.elements", elements.len(), MAX_AST_COLLECTION_LEN)?;
645            for el in elements {
646                validate_expr_limits(el, depth + 1, state)?;
647            }
648            if let Some(alias) = alias {
649                ensure_str("expr.elements.alias", alias)?;
650            }
651        }
652        Expr::Subscript { expr, index, alias } => {
653            validate_expr_limits(expr, depth + 1, state)?;
654            validate_expr_limits(index, depth + 1, state)?;
655            if let Some(alias) = alias {
656                ensure_str("expr.subscript.alias", alias)?;
657            }
658        }
659        Expr::Collate {
660            expr,
661            collation,
662            alias,
663        } => {
664            validate_expr_limits(expr, depth + 1, state)?;
665            ensure_str("expr.collate.collation", collation)?;
666            if let Some(alias) = alias {
667                ensure_str("expr.collate.alias", alias)?;
668            }
669        }
670        Expr::FieldAccess { expr, field, alias } => {
671            validate_expr_limits(expr, depth + 1, state)?;
672            ensure_str("expr.field_access.field", field)?;
673            if let Some(alias) = alias {
674                ensure_str("expr.field_access.alias", alias)?;
675            }
676        }
677        Expr::Subquery { query, alias } => {
678            validate_qail_limits(query, depth + 1, state)?;
679            if let Some(alias) = alias {
680                ensure_str("expr.subquery.alias", alias)?;
681            }
682        }
683        Expr::Exists { query, alias, .. } => {
684            validate_qail_limits(query, depth + 1, state)?;
685            if let Some(alias) = alias {
686                ensure_str("expr.exists.alias", alias)?;
687            }
688        }
689    }
690
691    Ok(())
692}
693
694fn validate_value_limits(
695    value: &crate::ast::Value,
696    depth: usize,
697    state: &mut AstLimitState,
698) -> Result<(), String> {
699    use crate::ast::Value;
700
701    ensure_depth(depth, "Value")?;
702    state.bump("Value")?;
703
704    match value {
705        Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::Param(_) => {}
706        Value::String(v)
707        | Value::NamedParam(v)
708        | Value::Function(v)
709        | Value::Column(v)
710        | Value::Timestamp(v)
711        | Value::Json(v) => ensure_str("value.string", v)?,
712        Value::Array(values) => {
713            ensure_len("value.array", values.len(), MAX_AST_COLLECTION_LEN)?;
714            for v in values {
715                validate_value_limits(v, depth + 1, state)?;
716            }
717        }
718        Value::Subquery(q) => validate_qail_limits(q, depth + 1, state)?,
719        Value::Uuid(_) | Value::NullUuid | Value::Interval { .. } => {}
720        Value::Bytes(bytes) => ensure_len("value.bytes", bytes.len(), MAX_AST_BINARY_VALUE_LEN)?,
721        Value::Expr(expr) => validate_expr_limits(expr, depth + 1, state)?,
722        Value::Vector(values) => ensure_len("value.vector", values.len(), MAX_AST_VECTOR_LEN)?,
723    }
724
725    Ok(())
726}
727
728fn validate_index_def_limits(index_def: &crate::ast::IndexDef) -> Result<(), String> {
729    ensure_str("index_def.name", &index_def.name)?;
730    ensure_str("index_def.table", &index_def.table)?;
731    ensure_len(
732        "index_def.columns",
733        index_def.columns.len(),
734        MAX_AST_COLLECTION_LEN,
735    )?;
736    for col in &index_def.columns {
737        ensure_str("index_def.column", col)?;
738    }
739    if let Some(index_type) = &index_def.index_type {
740        ensure_str("index_def.index_type", index_type)?;
741    }
742    if let Some(where_clause) = &index_def.where_clause {
743        ensure_str("index_def.where_clause", where_clause)?;
744    }
745    Ok(())
746}
747
748fn validate_function_def_limits(function_def: &crate::ast::FunctionDef) -> Result<(), String> {
749    ensure_str("function_def.name", &function_def.name)?;
750    ensure_len(
751        "function_def.args",
752        function_def.args.len(),
753        MAX_AST_COLLECTION_LEN,
754    )?;
755    for arg in &function_def.args {
756        ensure_str("function_def.arg", arg)?;
757    }
758    ensure_str("function_def.returns", &function_def.returns)?;
759    ensure_str("function_def.body", &function_def.body)?;
760    if let Some(language) = &function_def.language {
761        ensure_str("function_def.language", language)?;
762    }
763    if let Some(volatility) = &function_def.volatility {
764        ensure_str("function_def.volatility", volatility)?;
765    }
766    Ok(())
767}
768
769fn validate_trigger_def_limits(trigger_def: &crate::ast::TriggerDef) -> Result<(), String> {
770    ensure_str("trigger_def.name", &trigger_def.name)?;
771    ensure_str("trigger_def.table", &trigger_def.table)?;
772    ensure_len(
773        "trigger_def.events",
774        trigger_def.events.len(),
775        MAX_AST_COLLECTION_LEN,
776    )?;
777    ensure_len(
778        "trigger_def.update_columns",
779        trigger_def.update_columns.len(),
780        MAX_AST_COLLECTION_LEN,
781    )?;
782    for col in &trigger_def.update_columns {
783        ensure_str("trigger_def.update_column", col)?;
784    }
785    ensure_str(
786        "trigger_def.execute_function",
787        &trigger_def.execute_function,
788    )?;
789    Ok(())
790}
791
792fn validate_policy_def_limits(
793    policy_def: &crate::migrate::policy::RlsPolicy,
794    depth: usize,
795    state: &mut AstLimitState,
796) -> Result<(), String> {
797    ensure_str("policy_def.name", &policy_def.name)?;
798    ensure_str("policy_def.table", &policy_def.table)?;
799    if let Some(using_expr) = &policy_def.using {
800        validate_expr_limits(using_expr, depth + 1, state)?;
801    }
802    if let Some(with_check_expr) = &policy_def.with_check {
803        validate_expr_limits(with_check_expr, depth + 1, state)?;
804    }
805    if let Some(role) = &policy_def.role {
806        ensure_str("policy_def.role", role)?;
807    }
808    Ok(())
809}
810
811fn read_line<'a>(bytes: &'a [u8], idx: &mut usize) -> Result<&'a str, String> {
812    if *idx >= bytes.len() {
813        return Err("unexpected EOF".to_string());
814    }
815
816    let start = *idx;
817    while *idx < bytes.len() && bytes[*idx] != b'\n' {
818        *idx += 1;
819    }
820
821    if *idx >= bytes.len() {
822        return Err("unterminated header line".to_string());
823    }
824
825    let line =
826        std::str::from_utf8(&bytes[start..*idx]).map_err(|_| "header is not UTF-8".to_string())?;
827    *idx += 1; // consume '\n'
828    Ok(line)
829}
830
831fn parse_usize(field: &str, line: &str) -> Result<usize, String> {
832    line.parse::<usize>()
833        .map_err(|_| format!("invalid {field}: {line}"))
834}
835
836fn read_exact_utf8<'a>(bytes: &'a [u8], idx: &mut usize, len: usize) -> Result<&'a str, String> {
837    if *idx + len > bytes.len() {
838        return Err("payload truncated".to_string());
839    }
840    let start = *idx;
841    *idx += len;
842    std::str::from_utf8(&bytes[start..start + len]).map_err(|_| "payload is not UTF-8".to_string())
843}
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use proptest::prelude::*;
849
850    #[test]
851    fn cmd_text_roundtrip() {
852        let cmd = crate::ast::Qail::get("users")
853            .columns(["id", "email"])
854            .where_eq("active", true)
855            .limit(10);
856
857        let encoded = encode_cmd_text(&cmd);
858        let decoded = decode_cmd_text(&encoded).unwrap();
859        assert_eq!(decoded.to_string(), cmd.to_string());
860    }
861
862    #[test]
863    fn cmd_binary_roundtrip() {
864        let cmd = crate::ast::Qail::set("users")
865            .set_value("active", true)
866            .where_eq("id", 7);
867
868        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
869        let decoded = decode_cmd_binary(&encoded).unwrap();
870        assert_eq!(decoded.to_string(), cmd.to_string());
871    }
872
873    #[test]
874    fn cmd_binary_payload_rejects_oversized_header() {
875        let mut payload = Vec::new();
876        payload.extend_from_slice(&CMD_BIN_MAGIC);
877        payload.extend_from_slice(&((MAX_CMD_BINARY_PAYLOAD_BYTES + 1) as u32).to_be_bytes());
878        payload.extend_from_slice(&[]);
879
880        let err = decode_cmd_binary_payload(&payload).unwrap_err();
881        assert!(err.contains("binary AST payload too large"));
882    }
883
884    #[test]
885    fn cmd_binary_payload_roundtrip() {
886        let cmd = crate::ast::Qail::get("users").limit(3);
887        let encoded = encode_cmd_binary(&cmd).expect("binary encode");
888        let payload = decode_cmd_binary_payload(&encoded).unwrap();
889        let mut deserializer = serde_json::Deserializer::from_slice(payload);
890        let decoded: crate::ast::Qail = serde::Deserialize::deserialize(&mut deserializer).unwrap();
891        deserializer.end().unwrap();
892        assert_eq!(decoded.to_string(), cmd.to_string());
893    }
894
895    #[test]
896    fn cmd_binary_payload_rejects_legacy_qwb1() {
897        let legacy_text = b"get users limit 1";
898        let mut payload = Vec::new();
899        payload.extend_from_slice(&CMD_BIN_LEGACY_MAGIC);
900        payload.extend_from_slice(&(legacy_text.len() as u32).to_be_bytes());
901        payload.extend_from_slice(legacy_text);
902
903        let err = decode_cmd_binary_payload(&payload).unwrap_err();
904        assert!(err.contains("legacy QWB1"));
905    }
906
907    #[test]
908    fn cmd_binary_decode_rejects_raw_text_without_qwb2_header() {
909        let err = decode_cmd_binary(b"get users limit 1").unwrap_err();
910        assert!(err.contains("invalid wire header"));
911    }
912
913    #[test]
914    fn cmd_binary_decode_rejects_trailing_bytes() {
915        let cmd = crate::ast::Qail::get("users").limit(1);
916        let mut encoded = encode_cmd_binary(&cmd).expect("binary encode");
917        encoded.extend_from_slice(&[0xAA, 0xBB]);
918        let err = decode_cmd_binary(&encoded).unwrap_err();
919        assert!(err.contains("invalid payload length"));
920    }
921
922    #[test]
923    fn cmd_binary_decode_enforces_depth_limits() {
924        let mut nested = crate::ast::Qail::get("users").limit(1);
925        for _ in 0..(MAX_AST_DEPTH + 2) {
926            nested = crate::ast::Qail {
927                action: crate::ast::Action::Get,
928                table: "users".to_string(),
929                columns: vec![crate::ast::Expr::Subquery {
930                    query: Box::new(nested),
931                    alias: None,
932                }],
933                ..crate::ast::Qail::default()
934            };
935        }
936
937        let encoded = encode_cmd_binary(&nested).expect("binary encode");
938        let err = decode_cmd_binary(&encoded).unwrap_err();
939        assert!(
940            err.contains("AST depth limit exceeded")
941                || err.contains("binary AST decode failed")
942                || err.contains("recursion limit exceeded")
943        );
944    }
945
946    #[test]
947    fn cmd_binary_decode_bitflip_corpus_no_panic() {
948        let seeds = vec![
949            encode_cmd_binary(&crate::ast::Qail::get("users").limit(1)).expect("binary encode"),
950            encode_cmd_binary(&crate::ast::Qail::set("users").set_value("active", true))
951                .expect("binary encode"),
952            vec![],
953            b"QWB2garbage".to_vec(),
954            vec![0u8; 32],
955        ];
956
957        for seed in seeds {
958            for i in 0..seed.len().min(128) {
959                for bit in 0..8u8 {
960                    let mut mutated = seed.clone();
961                    mutated[i] ^= 1 << bit;
962                    let _ = decode_cmd_binary(&mutated);
963                }
964            }
965            let _ = decode_cmd_binary(&seed);
966        }
967    }
968
969    proptest! {
970        #[test]
971        fn cmd_binary_decode_fuzz_never_panics(data in proptest::collection::vec(any::<u8>(), 0..4096)) {
972            let _ = decode_cmd_binary(&data);
973        }
974    }
975
976    #[test]
977    fn cmds_text_roundtrip() {
978        let cmds = vec![
979            crate::ast::Qail::get("users").columns(["id", "email"]),
980            crate::ast::Qail::get("users").limit(1),
981            crate::ast::Qail::del("users").where_eq("id", 99),
982        ];
983
984        let encoded = encode_cmds_text(&cmds);
985        let decoded = decode_cmds_text(&encoded).unwrap();
986        assert_eq!(decoded.len(), cmds.len());
987        for (lhs, rhs) in decoded.iter().zip(cmds.iter()) {
988            assert_eq!(lhs.to_string(), rhs.to_string());
989        }
990    }
991
992    #[test]
993    fn decode_cmd_text_falls_back_to_raw_qail() {
994        let decoded = decode_cmd_text("get users limit 1").unwrap();
995        assert_eq!(decoded.action, crate::ast::Action::Get);
996        assert_eq!(decoded.table, "users");
997        assert!(
998            decoded
999                .cages
1000                .iter()
1001                .any(|c| matches!(c.kind, crate::ast::CageKind::Limit(1)))
1002        );
1003    }
1004}