Skip to main content

vexil_codegen_rust/
message.rs

1use std::collections::HashSet;
2
3use vexil_lang::ast::{PrimitiveType, SemanticType};
4use vexil_lang::ir::{
5    Encoding, FieldEncoding, MessageDef, ResolvedType, TypeDef, TypeId, TypeRegistry,
6};
7
8use crate::annotations::{emit_field_annotations, emit_tombstones, emit_type_annotations};
9use crate::emit::CodeWriter;
10use crate::types::rust_type;
11
12// ---------------------------------------------------------------------------
13// Byte-alignment helper
14// ---------------------------------------------------------------------------
15
16/// Returns true if the type is byte-aligned (i.e., not sub-byte).
17///
18/// Returns false for: Bool, SubByte, exhaustive enum with wire_bits < 8.
19pub fn is_byte_aligned(ty: &ResolvedType, registry: &TypeRegistry) -> bool {
20    match ty {
21        ResolvedType::Primitive(PrimitiveType::Bool) => false,
22        ResolvedType::SubByte(_) => false,
23        ResolvedType::Named(id) => {
24            // Check if this is an exhaustive enum with small wire_bits
25            if let Some(TypeDef::Enum(e)) = registry.get(*id) {
26                e.wire_bits >= 8
27            } else {
28                true
29            }
30        }
31        ResolvedType::Optional(inner) => is_byte_aligned(inner, registry),
32        _ => true,
33    }
34}
35
36// ---------------------------------------------------------------------------
37// Primitive type bits helper
38// ---------------------------------------------------------------------------
39
40fn primitive_bits(p: &PrimitiveType) -> u8 {
41    match p {
42        PrimitiveType::I8 | PrimitiveType::U8 => 8,
43        PrimitiveType::I16 | PrimitiveType::U16 => 16,
44        PrimitiveType::I32 | PrimitiveType::U32 | PrimitiveType::F32 => 32,
45        PrimitiveType::I64 | PrimitiveType::U64 | PrimitiveType::F64 => 64,
46        _ => 0,
47    }
48}
49
50// ---------------------------------------------------------------------------
51// Copy-type helper
52// ---------------------------------------------------------------------------
53
54/// Returns true if the type is `Copy` in Rust — primitives and sub-byte types.
55/// Used to determine whether array/optional/result items need dereferencing
56/// when accessed through a reference (`*item` vs `item`).
57fn is_copy_type(ty: &ResolvedType) -> bool {
58    matches!(ty, ResolvedType::Primitive(_) | ResolvedType::SubByte(_))
59}
60
61// ---------------------------------------------------------------------------
62// emit_write
63// ---------------------------------------------------------------------------
64
65/// Emit code to write a field to `w: &mut BitWriter`.
66///
67/// `access` is the Rust expression for the value (e.g. `self.name` or `&self.data`).
68/// For `Encoding::Delta`, this function is a no-op (the delta module handles it).
69pub fn emit_write(
70    w: &mut CodeWriter,
71    access: &str,
72    ty: &ResolvedType,
73    enc: &FieldEncoding,
74    registry: &TypeRegistry,
75    field_name: &str,
76) {
77    // Check non-default encoding first
78    match &enc.encoding {
79        Encoding::Varint => {
80            if let Some(limit) = enc.limit {
81                w.line(&format!(
82                    "if ({access} as u64) > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {access} as u64 }}); }}"
83                ));
84            }
85            w.line(&format!("w.write_leb128({access} as u64);"));
86            return;
87        }
88        Encoding::ZigZag => {
89            if let Some(limit) = enc.limit {
90                w.line(&format!(
91                    "if ({access} as i64).unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access} as i64).unsigned_abs() }}); }}"
92                ));
93            }
94            let type_bits = match ty {
95                ResolvedType::Primitive(p) => primitive_bits(p),
96                _ => 64,
97            };
98            w.line(&format!("w.write_zigzag({access} as i64, {type_bits}_u8);"));
99            return;
100        }
101        Encoding::Delta(inner) => {
102            // For standard Pack, write the field using the inner (base) encoding.
103            // The DeltaEncoder handles delta sequences separately.
104            let base_enc = FieldEncoding {
105                encoding: *inner.clone(),
106                limit: enc.limit,
107            };
108            emit_write(w, access, ty, &base_enc, registry, field_name);
109            return;
110        }
111        Encoding::Default => {} // fall through to type dispatch
112        _ => {}                 // non_exhaustive guard
113    }
114
115    // Emit limit check for default encoding on collections/strings
116    if let Some(limit) = enc.limit {
117        match ty {
118            ResolvedType::Array(_)
119            | ResolvedType::Map(_, _)
120            | ResolvedType::Semantic(SemanticType::String)
121            | ResolvedType::Semantic(SemanticType::Bytes) => {
122                w.line(&format!(
123                    "if ({access}).len() as u64 > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access}).len() as u64 }}); }}"
124                ));
125            }
126            _ => {}
127        }
128    }
129
130    emit_write_type(w, access, ty, registry, field_name);
131}
132
133#[allow(clippy::only_used_in_recursion)]
134fn emit_write_type(
135    w: &mut CodeWriter,
136    access: &str,
137    ty: &ResolvedType,
138    registry: &TypeRegistry,
139    field_name: &str,
140) {
141    match ty {
142        ResolvedType::Primitive(p) => match p {
143            PrimitiveType::Bool => w.line(&format!("w.write_bool({access});")),
144            PrimitiveType::U8 => w.line(&format!("w.write_u8({access});")),
145            PrimitiveType::U16 => w.line(&format!("w.write_u16({access});")),
146            PrimitiveType::U32 => w.line(&format!("w.write_u32({access});")),
147            PrimitiveType::U64 => w.line(&format!("w.write_u64({access});")),
148            PrimitiveType::I8 => w.line(&format!("w.write_i8({access});")),
149            PrimitiveType::I16 => w.line(&format!("w.write_i16({access});")),
150            PrimitiveType::I32 => w.line(&format!("w.write_i32({access});")),
151            PrimitiveType::I64 => w.line(&format!("w.write_i64({access});")),
152            PrimitiveType::F32 => w.line(&format!("w.write_f32({access});")),
153            PrimitiveType::F64 => w.line(&format!("w.write_f64({access});")),
154            PrimitiveType::Void => {} // 0 bits — nothing to write
155        },
156        ResolvedType::SubByte(s) => {
157            let bits = s.bits;
158            if s.signed {
159                w.line(&format!("w.write_bits({access} as u8 as u64, {bits}_u8);"));
160            } else {
161                w.line(&format!("w.write_bits({access} as u64, {bits}_u8);"));
162            }
163        }
164        ResolvedType::Semantic(s) => match s {
165            SemanticType::String => w.line(&format!("w.write_string(&{access});")),
166            SemanticType::Bytes => w.line(&format!("w.write_bytes(&{access});")),
167            SemanticType::Rgb => {
168                w.line(&format!("w.write_u8({access}.0);"));
169                w.line(&format!("w.write_u8({access}.1);"));
170                w.line(&format!("w.write_u8({access}.2);"));
171            }
172            SemanticType::Uuid => w.line(&format!("w.write_raw_bytes(&{access});")),
173            SemanticType::Timestamp => w.line(&format!("w.write_i64({access});")),
174            SemanticType::Hash => w.line(&format!("w.write_raw_bytes(&{access});")),
175        },
176        ResolvedType::Named(_) => {
177            w.line("w.enter_recursive()?;");
178            w.line(&format!("{access}.pack(w)?;"));
179            w.line("w.leave_recursive();");
180        }
181        ResolvedType::Optional(inner) => {
182            // Presence bit
183            w.line(&format!("w.write_bool({access}.is_some());"));
184            // If inner is byte-aligned, flush before conditional
185            if is_byte_aligned(inner, registry) {
186                w.line("w.flush_to_byte_boundary();");
187                // Hmm, actually flush is on writer side. We only flush after the presence
188                // bit if the inner type requires byte alignment. The spec says flush the
189                // bit-stream before writing the inner value. Let's keep the flush here
190                // only when needed.
191            }
192            w.open_block(&format!("if let Some(ref inner_val) = {access}"));
193            let inner_access = if is_copy_type(inner) {
194                "*inner_val"
195            } else {
196                "inner_val"
197            };
198            emit_write_type(w, inner_access, inner, registry, field_name);
199            w.close_block();
200        }
201        ResolvedType::Array(inner) => {
202            w.line(&format!("w.write_leb128({access}.len() as u64);"));
203            w.open_block(&format!("for item in &{access}"));
204            let item_access = if is_copy_type(inner) { "*item" } else { "item" };
205            emit_write_type(w, item_access, inner, registry, field_name);
206            w.close_block();
207        }
208        ResolvedType::Map(k, v) => {
209            w.line(&format!("w.write_leb128({access}.len() as u64);"));
210            w.open_block(&format!("for (map_k, map_v) in &{access}"));
211            let k_access = if is_copy_type(k) { "*map_k" } else { "map_k" };
212            let v_access = if is_copy_type(v) { "*map_v" } else { "map_v" };
213            emit_write_type(w, k_access, k, registry, field_name);
214            emit_write_type(w, v_access, v, registry, field_name);
215            w.close_block();
216        }
217        ResolvedType::Result(ok, err) => {
218            w.open_block(&format!("match &{access}"));
219            w.open_block("Ok(ok_val) =>");
220            w.line("w.write_bool(true);");
221            let ok_access = if is_copy_type(ok) {
222                "*ok_val"
223            } else {
224                "ok_val"
225            };
226            emit_write_type(w, ok_access, ok, registry, field_name);
227            w.close_block();
228            w.open_block("Err(err_val) =>");
229            w.line("w.write_bool(false);");
230            let err_access = if is_copy_type(err) {
231                "*err_val"
232            } else {
233                "err_val"
234            };
235            emit_write_type(w, err_access, err, registry, field_name);
236            w.close_block();
237            w.close_block();
238        }
239        _ => {} // non_exhaustive guard
240    }
241}
242
243// ---------------------------------------------------------------------------
244// emit_read
245// ---------------------------------------------------------------------------
246
247/// Emit code to read a field from `r: &mut BitReader<'_>`.
248///
249/// Binds the result to `var_name`.
250pub fn emit_read(
251    w: &mut CodeWriter,
252    var_name: &str,
253    ty: &ResolvedType,
254    enc: &FieldEncoding,
255    registry: &TypeRegistry,
256    field_name: &str,
257) {
258    match &enc.encoding {
259        Encoding::Varint => {
260            // max_bytes: 10 covers u64 LEB128
261            w.line(&format!("let {var_name}_raw = r.read_leb128(10_u8)?;"));
262            if let Some(limit) = enc.limit {
263                w.line(&format!(
264                    "if {var_name}_raw > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw }}); }}"
265                ));
266            }
267            // Cast to the appropriate Rust type
268            let rust_ty = read_cast_for_varint(ty);
269            w.line(&format!(
270                "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
271            ));
272            return;
273        }
274        Encoding::ZigZag => {
275            let type_bits = match ty {
276                ResolvedType::Primitive(p) => primitive_bits(p),
277                _ => 64,
278            };
279            // max_bytes: 10 for i64 zigzag
280            w.line(&format!(
281                "let {var_name}_raw = r.read_zigzag({type_bits}_u8, 10_u8)?;"
282            ));
283            if let Some(limit) = enc.limit {
284                w.line(&format!(
285                    "if {var_name}_raw.unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw.unsigned_abs() }}); }}"
286                ));
287            }
288            let rust_ty = read_cast_for_zigzag(ty);
289            w.line(&format!(
290                "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
291            ));
292            return;
293        }
294        Encoding::Delta(inner) => {
295            // For standard Unpack, read the field using the inner (base) encoding.
296            // The DeltaDecoder handles delta sequences separately.
297            let base_enc = FieldEncoding {
298                encoding: *inner.clone(),
299                limit: enc.limit,
300            };
301            emit_read(w, var_name, ty, &base_enc, registry, field_name);
302            return;
303        }
304        Encoding::Default => {}
305        _ => {} // non_exhaustive guard
306    }
307
308    emit_read_type(w, var_name, ty, registry, field_name, enc.limit);
309}
310
311fn emit_read_type(
312    w: &mut CodeWriter,
313    var_name: &str,
314    ty: &ResolvedType,
315    registry: &TypeRegistry,
316    field_name: &str,
317    limit: Option<u64>,
318) {
319    match ty {
320        ResolvedType::Primitive(p) => match p {
321            PrimitiveType::Bool => w.line(&format!("let {var_name} = r.read_bool()?;")),
322            PrimitiveType::U8 => w.line(&format!("let {var_name} = r.read_u8()?;")),
323            PrimitiveType::U16 => w.line(&format!("let {var_name} = r.read_u16()?;")),
324            PrimitiveType::U32 => w.line(&format!("let {var_name} = r.read_u32()?;")),
325            PrimitiveType::U64 => w.line(&format!("let {var_name} = r.read_u64()?;")),
326            PrimitiveType::I8 => w.line(&format!("let {var_name} = r.read_i8()?;")),
327            PrimitiveType::I16 => w.line(&format!("let {var_name} = r.read_i16()?;")),
328            PrimitiveType::I32 => w.line(&format!("let {var_name} = r.read_i32()?;")),
329            PrimitiveType::I64 => w.line(&format!("let {var_name} = r.read_i64()?;")),
330            PrimitiveType::F32 => w.line(&format!("let {var_name} = r.read_f32()?;")),
331            PrimitiveType::F64 => w.line(&format!("let {var_name} = r.read_f64()?;")),
332            PrimitiveType::Void => w.line(&format!("let {var_name} = ();")),
333        },
334        ResolvedType::SubByte(s) => {
335            let bits = s.bits;
336            if s.signed {
337                w.line(&format!(
338                    "let {var_name} = r.read_bits({bits}_u8)? as u8 as i8;"
339                ));
340            } else {
341                w.line(&format!("let {var_name} = r.read_bits({bits}_u8)? as u8;"));
342            }
343        }
344        ResolvedType::Semantic(s) => match s {
345            SemanticType::String => {
346                w.line(&format!("let {var_name} = r.read_string()?;"));
347                if let Some(lim) = limit {
348                    w.line(&format!(
349                        "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
350                    ));
351                }
352            }
353            SemanticType::Bytes => {
354                w.line(&format!("let {var_name} = r.read_bytes()?;"));
355                if let Some(lim) = limit {
356                    w.line(&format!(
357                        "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
358                    ));
359                }
360            }
361            SemanticType::Rgb => {
362                w.line(&format!("let {var_name}_0 = r.read_u8()?;"));
363                w.line(&format!("let {var_name}_1 = r.read_u8()?;"));
364                w.line(&format!("let {var_name}_2 = r.read_u8()?;"));
365                w.line(&format!(
366                    "let {var_name} = ({var_name}_0, {var_name}_1, {var_name}_2);"
367                ));
368            }
369            SemanticType::Uuid => {
370                w.line(&format!(
371                    "let {var_name}_bytes = r.read_raw_bytes(16_usize)?;"
372                ));
373                w.line(&format!(
374                    "let {var_name}: [u8; 16] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
375                ));
376            }
377            SemanticType::Timestamp => {
378                w.line(&format!("let {var_name} = r.read_i64()?;"));
379            }
380            SemanticType::Hash => {
381                w.line(&format!(
382                    "let {var_name}_bytes = r.read_raw_bytes(32_usize)?;"
383                ));
384                w.line(&format!(
385                    "let {var_name}: [u8; 32] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
386                ));
387            }
388        },
389        ResolvedType::Named(_) => {
390            w.line("r.enter_recursive()?;");
391            w.line(&format!(
392                "let {var_name} = vexil_runtime::Unpack::unpack(r)?;"
393            ));
394            w.line("r.leave_recursive();");
395        }
396        ResolvedType::Optional(inner) => {
397            w.line(&format!("let {var_name}_present = r.read_bool()?;"));
398            if is_byte_aligned(inner, registry) {
399                w.line("r.flush_to_byte_boundary();");
400            }
401            w.open_block(&format!("let {var_name} = if {var_name}_present"));
402            emit_read_type(
403                w,
404                &format!("{var_name}_inner"),
405                inner,
406                registry,
407                field_name,
408                None,
409            );
410            w.line(&format!("Some({var_name}_inner)"));
411            w.close_block();
412            w.open_block("else");
413            w.line("None");
414            w.close_block();
415            w.append(";");
416            w.append("\n");
417        }
418        ResolvedType::Array(inner) => {
419            w.line(&format!(
420                "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
421            ));
422            if let Some(lim) = limit {
423                w.line(&format!(
424                    "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
425                ));
426            }
427            w.line(&format!(
428                "let mut {var_name} = Vec::with_capacity({var_name}_len);"
429            ));
430            w.open_block(&format!("for _ in 0..{var_name}_len"));
431            emit_read_type(
432                w,
433                &format!("{var_name}_item"),
434                inner,
435                registry,
436                field_name,
437                None,
438            );
439            w.line(&format!("{var_name}.push({var_name}_item);"));
440            w.close_block();
441        }
442        ResolvedType::Map(k, v) => {
443            w.line(&format!(
444                "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
445            ));
446            if let Some(lim) = limit {
447                w.line(&format!(
448                    "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
449                ));
450            }
451            w.line(&format!(
452                "let mut {var_name} = std::collections::BTreeMap::new();"
453            ));
454            w.open_block(&format!("for _ in 0..{var_name}_len"));
455            emit_read_type(w, &format!("{var_name}_k"), k, registry, field_name, None);
456            emit_read_type(w, &format!("{var_name}_v"), v, registry, field_name, None);
457            w.line(&format!("{var_name}.insert({var_name}_k, {var_name}_v);"));
458            w.close_block();
459        }
460        ResolvedType::Result(ok, err) => {
461            w.line(&format!("let {var_name}_is_ok = r.read_bool()?;"));
462            w.open_block(&format!("let {var_name} = if {var_name}_is_ok"));
463            emit_read_type(w, &format!("{var_name}_ok"), ok, registry, field_name, None);
464            w.line(&format!("Ok({var_name}_ok)"));
465            w.close_block();
466            w.open_block("else");
467            emit_read_type(
468                w,
469                &format!("{var_name}_err"),
470                err,
471                registry,
472                field_name,
473                None,
474            );
475            w.line(&format!("Err({var_name}_err)"));
476            w.close_block();
477            w.append(";");
478            w.append("\n");
479        }
480        _ => {} // non_exhaustive guard
481    }
482}
483
484fn read_cast_for_varint(ty: &ResolvedType) -> &'static str {
485    match ty {
486        ResolvedType::Primitive(p) => match p {
487            PrimitiveType::U8 => "u8",
488            PrimitiveType::U16 => "u16",
489            PrimitiveType::U32 => "u32",
490            PrimitiveType::U64 => "u64",
491            _ => "u64",
492        },
493        _ => "u64",
494    }
495}
496
497fn read_cast_for_zigzag(ty: &ResolvedType) -> &'static str {
498    match ty {
499        ResolvedType::Primitive(p) => match p {
500            PrimitiveType::I8 => "i8",
501            PrimitiveType::I16 => "i16",
502            PrimitiveType::I32 => "i32",
503            PrimitiveType::I64 => "i64",
504            _ => "i64",
505        },
506        _ => "i64",
507    }
508}
509
510// ---------------------------------------------------------------------------
511// emit_message
512// ---------------------------------------------------------------------------
513
514/// Emit a complete message struct with Pack and Unpack implementations.
515pub fn emit_message(
516    w: &mut CodeWriter,
517    msg: &MessageDef,
518    registry: &TypeRegistry,
519    needs_box: &HashSet<(TypeId, usize)>,
520    type_id: TypeId,
521) {
522    let name = msg.name.as_str();
523
524    // Tombstone comments
525    emit_tombstones(w, name, &msg.tombstones);
526
527    // Type annotations
528    emit_type_annotations(w, &msg.annotations);
529    w.line("#[derive(Debug, Clone, PartialEq)]");
530
531    // Struct definition
532    w.open_block(&format!("pub struct {name}"));
533    for (fi, field) in msg.fields.iter().enumerate() {
534        emit_field_annotations(w, &field.annotations);
535        let field_rust_type = rust_type(
536            &field.resolved_type,
537            registry,
538            needs_box,
539            Some((type_id, fi)),
540        );
541        w.line(&format!("pub {}: {},", field.name, field_rust_type));
542    }
543    w.close_block();
544    w.blank();
545
546    // Pack impl
547    w.open_block(&format!("impl vexil_runtime::Pack for {name}"));
548    w.open_block("fn pack(&self, w: &mut vexil_runtime::BitWriter) -> Result<(), vexil_runtime::EncodeError>");
549    for field in &msg.fields {
550        let access = format!("self.{}", field.name);
551        emit_write(
552            w,
553            &access,
554            &field.resolved_type,
555            &field.encoding,
556            registry,
557            field.name.as_str(),
558        );
559    }
560    w.line("w.flush_to_byte_boundary();");
561    w.line("Ok(())");
562    w.close_block();
563    w.close_block();
564    w.blank();
565
566    // Unpack impl
567    w.open_block(&format!("impl vexil_runtime::Unpack for {name}"));
568    w.open_block("fn unpack(r: &mut vexil_runtime::BitReader<'_>) -> Result<Self, vexil_runtime::DecodeError>");
569    for field in &msg.fields {
570        let var_name = field.name.as_str();
571        emit_read(
572            w,
573            var_name,
574            &field.resolved_type,
575            &field.encoding,
576            registry,
577            var_name,
578        );
579    }
580    w.line("r.flush_to_byte_boundary();");
581    w.open_block("Ok(Self");
582    for field in &msg.fields {
583        w.line(&format!("{},", field.name));
584    }
585    w.dedent();
586    w.line("})");
587
588    w.close_block();
589    w.close_block();
590    w.blank();
591}