Skip to main content

vexil_codegen_rust/
message.rs

1use std::collections::HashSet;
2
3use vexil_lang::ast::{PrimitiveType, SemanticType};
4use vexil_lang::ir::{
5    Encoding, FieldEncoding, MessageDef, ResolvedType, TypeDef, TypeId, TypeRegistry,
6};
7
8use crate::annotations::{emit_field_annotations, emit_tombstones, emit_type_annotations};
9use crate::emit::CodeWriter;
10use crate::types::rust_type;
11
12// ---------------------------------------------------------------------------
13// Byte-alignment helper
14// ---------------------------------------------------------------------------
15
16/// Returns true if the type is byte-aligned (i.e., not sub-byte).
17///
18/// Returns false for: Bool, SubByte, exhaustive enum with wire_bits < 8.
19pub fn is_byte_aligned(ty: &ResolvedType, registry: &TypeRegistry) -> bool {
20    match ty {
21        ResolvedType::Primitive(PrimitiveType::Bool) => false,
22        ResolvedType::SubByte(_) => false,
23        ResolvedType::Named(id) => {
24            // Check if this is an exhaustive enum with small wire_bits
25            if let Some(TypeDef::Enum(e)) = registry.get(*id) {
26                e.wire_bits >= 8
27            } else {
28                true
29            }
30        }
31        ResolvedType::Optional(inner) => is_byte_aligned(inner, registry),
32        _ => true,
33    }
34}
35
36// ---------------------------------------------------------------------------
37// Primitive type bits helper
38// ---------------------------------------------------------------------------
39
40fn primitive_bits(p: &PrimitiveType) -> u8 {
41    match p {
42        PrimitiveType::I8 | PrimitiveType::U8 => 8,
43        PrimitiveType::I16 | PrimitiveType::U16 => 16,
44        PrimitiveType::I32 | PrimitiveType::U32 | PrimitiveType::F32 => 32,
45        PrimitiveType::I64 | PrimitiveType::U64 | PrimitiveType::F64 => 64,
46        _ => 0,
47    }
48}
49
50// ---------------------------------------------------------------------------
51// emit_write
52// ---------------------------------------------------------------------------
53
54/// Emit code to write a field to `w: &mut BitWriter`.
55///
56/// `access` is the Rust expression for the value (e.g. `self.name` or `&self.data`).
57/// For `Encoding::Delta`, this function is a no-op (the delta module handles it).
58pub fn emit_write(
59    w: &mut CodeWriter,
60    access: &str,
61    ty: &ResolvedType,
62    enc: &FieldEncoding,
63    registry: &TypeRegistry,
64    field_name: &str,
65) {
66    // Check non-default encoding first
67    match &enc.encoding {
68        Encoding::Varint => {
69            if let Some(limit) = enc.limit {
70                w.line(&format!(
71                    "if ({access} as u64) > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {access} as u64 }}); }}"
72                ));
73            }
74            w.line(&format!("w.write_leb128({access} as u64);"));
75            return;
76        }
77        Encoding::ZigZag => {
78            if let Some(limit) = enc.limit {
79                w.line(&format!(
80                    "if ({access} as i64).unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access} as i64).unsigned_abs() }}); }}"
81                ));
82            }
83            let type_bits = match ty {
84                ResolvedType::Primitive(p) => primitive_bits(p),
85                _ => 64,
86            };
87            w.line(&format!("w.write_zigzag({access} as i64, {type_bits}_u8);"));
88            return;
89        }
90        Encoding::Delta(inner) => {
91            // For standard Pack, write the field using the inner (base) encoding.
92            // The DeltaEncoder handles delta sequences separately.
93            let base_enc = FieldEncoding {
94                encoding: *inner.clone(),
95                limit: enc.limit,
96            };
97            emit_write(w, access, ty, &base_enc, registry, field_name);
98            return;
99        }
100        Encoding::Default => {} // fall through to type dispatch
101        _ => {}                 // non_exhaustive guard
102    }
103
104    // Emit limit check for default encoding on collections/strings
105    if let Some(limit) = enc.limit {
106        match ty {
107            ResolvedType::Array(_)
108            | ResolvedType::Map(_, _)
109            | ResolvedType::Semantic(SemanticType::String)
110            | ResolvedType::Semantic(SemanticType::Bytes) => {
111                w.line(&format!(
112                    "if ({access}).len() as u64 > {limit}_u64 {{ return Err(vexil_runtime::EncodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: ({access}).len() as u64 }}); }}"
113                ));
114            }
115            _ => {}
116        }
117    }
118
119    emit_write_type(w, access, ty, registry, field_name);
120}
121
122#[allow(clippy::only_used_in_recursion)]
123fn emit_write_type(
124    w: &mut CodeWriter,
125    access: &str,
126    ty: &ResolvedType,
127    registry: &TypeRegistry,
128    field_name: &str,
129) {
130    match ty {
131        ResolvedType::Primitive(p) => match p {
132            PrimitiveType::Bool => w.line(&format!("w.write_bool({access});")),
133            PrimitiveType::U8 => w.line(&format!("w.write_u8({access});")),
134            PrimitiveType::U16 => w.line(&format!("w.write_u16({access});")),
135            PrimitiveType::U32 => w.line(&format!("w.write_u32({access});")),
136            PrimitiveType::U64 => w.line(&format!("w.write_u64({access});")),
137            PrimitiveType::I8 => w.line(&format!("w.write_i8({access});")),
138            PrimitiveType::I16 => w.line(&format!("w.write_i16({access});")),
139            PrimitiveType::I32 => w.line(&format!("w.write_i32({access});")),
140            PrimitiveType::I64 => w.line(&format!("w.write_i64({access});")),
141            PrimitiveType::F32 => w.line(&format!("w.write_f32({access});")),
142            PrimitiveType::F64 => w.line(&format!("w.write_f64({access});")),
143            PrimitiveType::Void => {} // 0 bits — nothing to write
144        },
145        ResolvedType::SubByte(s) => {
146            let bits = s.bits;
147            if s.signed {
148                w.line(&format!("w.write_bits({access} as u8 as u64, {bits}_u8);"));
149            } else {
150                w.line(&format!("w.write_bits({access} as u64, {bits}_u8);"));
151            }
152        }
153        ResolvedType::Semantic(s) => match s {
154            SemanticType::String => w.line(&format!("w.write_string(&{access});")),
155            SemanticType::Bytes => w.line(&format!("w.write_bytes(&{access});")),
156            SemanticType::Rgb => {
157                w.line(&format!("w.write_u8({access}.0);"));
158                w.line(&format!("w.write_u8({access}.1);"));
159                w.line(&format!("w.write_u8({access}.2);"));
160            }
161            SemanticType::Uuid => w.line(&format!("w.write_raw_bytes(&{access});")),
162            SemanticType::Timestamp => w.line(&format!("w.write_i64({access});")),
163            SemanticType::Hash => w.line(&format!("w.write_raw_bytes(&{access});")),
164        },
165        ResolvedType::Named(_) => {
166            w.line(&format!("{access}.pack(w)?;"));
167        }
168        ResolvedType::Optional(inner) => {
169            // Presence bit
170            w.line(&format!("w.write_bool({access}.is_some());"));
171            // If inner is byte-aligned, flush before conditional
172            if is_byte_aligned(inner, registry) {
173                w.line("w.flush_to_byte_boundary();");
174                // Hmm, actually flush is on writer side. We only flush after the presence
175                // bit if the inner type requires byte alignment. The spec says flush the
176                // bit-stream before writing the inner value. Let's keep the flush here
177                // only when needed.
178            }
179            w.open_block(&format!("if let Some(ref inner_val) = {access}"));
180            emit_write_type(w, "inner_val", inner, registry, field_name);
181            w.close_block();
182        }
183        ResolvedType::Array(inner) => {
184            w.line(&format!("w.write_leb128({access}.len() as u64);"));
185            w.open_block(&format!("for item in &{access}"));
186            emit_write_type(w, "item", inner, registry, field_name);
187            w.close_block();
188        }
189        ResolvedType::Map(k, v) => {
190            w.line(&format!("w.write_leb128({access}.len() as u64);"));
191            w.open_block(&format!("for (map_k, map_v) in &{access}"));
192            emit_write_type(w, "map_k", k, registry, field_name);
193            emit_write_type(w, "map_v", v, registry, field_name);
194            w.close_block();
195        }
196        ResolvedType::Result(ok, err) => {
197            w.open_block(&format!("match &{access}"));
198            w.open_block("Ok(ok_val) =>");
199            w.line("w.write_bool(true);");
200            emit_write_type(w, "ok_val", ok, registry, field_name);
201            w.close_block();
202            w.open_block("Err(err_val) =>");
203            w.line("w.write_bool(false);");
204            emit_write_type(w, "err_val", err, registry, field_name);
205            w.close_block();
206            w.close_block();
207        }
208        _ => {} // non_exhaustive guard
209    }
210}
211
212// ---------------------------------------------------------------------------
213// emit_read
214// ---------------------------------------------------------------------------
215
216/// Emit code to read a field from `r: &mut BitReader<'_>`.
217///
218/// Binds the result to `var_name`.
219pub fn emit_read(
220    w: &mut CodeWriter,
221    var_name: &str,
222    ty: &ResolvedType,
223    enc: &FieldEncoding,
224    registry: &TypeRegistry,
225    field_name: &str,
226) {
227    match &enc.encoding {
228        Encoding::Varint => {
229            // max_bytes: 10 covers u64 LEB128
230            w.line(&format!("let {var_name}_raw = r.read_leb128(10_u8)?;"));
231            if let Some(limit) = enc.limit {
232                w.line(&format!(
233                    "if {var_name}_raw > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw }}); }}"
234                ));
235            }
236            // Cast to the appropriate Rust type
237            let rust_ty = read_cast_for_varint(ty);
238            w.line(&format!(
239                "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
240            ));
241            return;
242        }
243        Encoding::ZigZag => {
244            let type_bits = match ty {
245                ResolvedType::Primitive(p) => primitive_bits(p),
246                _ => 64,
247            };
248            // max_bytes: 10 for i64 zigzag
249            w.line(&format!(
250                "let {var_name}_raw = r.read_zigzag({type_bits}_u8, 10_u8)?;"
251            ));
252            if let Some(limit) = enc.limit {
253                w.line(&format!(
254                    "if {var_name}_raw.unsigned_abs() > {limit}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {limit}_u64, actual: {var_name}_raw.unsigned_abs() }}); }}"
255                ));
256            }
257            let rust_ty = read_cast_for_zigzag(ty);
258            w.line(&format!(
259                "let {var_name}: {rust_ty} = {var_name}_raw as {rust_ty};"
260            ));
261            return;
262        }
263        Encoding::Delta(inner) => {
264            // For standard Unpack, read the field using the inner (base) encoding.
265            // The DeltaDecoder handles delta sequences separately.
266            let base_enc = FieldEncoding {
267                encoding: *inner.clone(),
268                limit: enc.limit,
269            };
270            emit_read(w, var_name, ty, &base_enc, registry, field_name);
271            return;
272        }
273        Encoding::Default => {}
274        _ => {} // non_exhaustive guard
275    }
276
277    emit_read_type(w, var_name, ty, registry, field_name, enc.limit);
278}
279
280fn emit_read_type(
281    w: &mut CodeWriter,
282    var_name: &str,
283    ty: &ResolvedType,
284    registry: &TypeRegistry,
285    field_name: &str,
286    limit: Option<u64>,
287) {
288    match ty {
289        ResolvedType::Primitive(p) => match p {
290            PrimitiveType::Bool => w.line(&format!("let {var_name} = r.read_bool()?;")),
291            PrimitiveType::U8 => w.line(&format!("let {var_name} = r.read_u8()?;")),
292            PrimitiveType::U16 => w.line(&format!("let {var_name} = r.read_u16()?;")),
293            PrimitiveType::U32 => w.line(&format!("let {var_name} = r.read_u32()?;")),
294            PrimitiveType::U64 => w.line(&format!("let {var_name} = r.read_u64()?;")),
295            PrimitiveType::I8 => w.line(&format!("let {var_name} = r.read_i8()?;")),
296            PrimitiveType::I16 => w.line(&format!("let {var_name} = r.read_i16()?;")),
297            PrimitiveType::I32 => w.line(&format!("let {var_name} = r.read_i32()?;")),
298            PrimitiveType::I64 => w.line(&format!("let {var_name} = r.read_i64()?;")),
299            PrimitiveType::F32 => w.line(&format!("let {var_name} = r.read_f32()?;")),
300            PrimitiveType::F64 => w.line(&format!("let {var_name} = r.read_f64()?;")),
301            PrimitiveType::Void => w.line(&format!("let {var_name} = ();")),
302        },
303        ResolvedType::SubByte(s) => {
304            let bits = s.bits;
305            if s.signed {
306                w.line(&format!(
307                    "let {var_name} = r.read_bits({bits}_u8)? as u8 as i8;"
308                ));
309            } else {
310                w.line(&format!("let {var_name} = r.read_bits({bits}_u8)? as u8;"));
311            }
312        }
313        ResolvedType::Semantic(s) => match s {
314            SemanticType::String => {
315                w.line(&format!("let {var_name} = r.read_string()?;"));
316                if let Some(lim) = limit {
317                    w.line(&format!(
318                        "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
319                    ));
320                }
321            }
322            SemanticType::Bytes => {
323                w.line(&format!("let {var_name} = r.read_bytes()?;"));
324                if let Some(lim) = limit {
325                    w.line(&format!(
326                        "if {var_name}.len() as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}.len() as u64 }}); }}"
327                    ));
328                }
329            }
330            SemanticType::Rgb => {
331                w.line(&format!("let {var_name}_0 = r.read_u8()?;"));
332                w.line(&format!("let {var_name}_1 = r.read_u8()?;"));
333                w.line(&format!("let {var_name}_2 = r.read_u8()?;"));
334                w.line(&format!(
335                    "let {var_name} = ({var_name}_0, {var_name}_1, {var_name}_2);"
336                ));
337            }
338            SemanticType::Uuid => {
339                w.line(&format!(
340                    "let {var_name}_bytes = r.read_raw_bytes(16_usize)?;"
341                ));
342                w.line(&format!(
343                    "let {var_name}: [u8; 16] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
344                ));
345            }
346            SemanticType::Timestamp => {
347                w.line(&format!("let {var_name} = r.read_i64()?;"));
348            }
349            SemanticType::Hash => {
350                w.line(&format!(
351                    "let {var_name}_bytes = r.read_raw_bytes(32_usize)?;"
352                ));
353                w.line(&format!(
354                    "let {var_name}: [u8; 32] = {var_name}_bytes.try_into().map_err(|_| vexil_runtime::DecodeError::UnexpectedEof)?;"
355                ));
356            }
357        },
358        ResolvedType::Named(_) => {
359            w.line("r.enter_recursive()?;");
360            w.line(&format!(
361                "let {var_name} = vexil_runtime::Unpack::unpack(r)?;"
362            ));
363            w.line("r.leave_recursive();");
364        }
365        ResolvedType::Optional(inner) => {
366            w.line(&format!("let {var_name}_present = r.read_bool()?;"));
367            if is_byte_aligned(inner, registry) {
368                w.line("r.flush_to_byte_boundary();");
369            }
370            w.open_block(&format!("let {var_name} = if {var_name}_present"));
371            emit_read_type(
372                w,
373                &format!("{var_name}_inner"),
374                inner,
375                registry,
376                field_name,
377                None,
378            );
379            w.line(&format!("Some({var_name}_inner)"));
380            w.close_block();
381            w.open_block("else");
382            w.line("None");
383            w.close_block();
384            w.append(";");
385            w.append("\n");
386        }
387        ResolvedType::Array(inner) => {
388            w.line(&format!(
389                "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
390            ));
391            if let Some(lim) = limit {
392                w.line(&format!(
393                    "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
394                ));
395            }
396            w.line(&format!(
397                "let mut {var_name} = Vec::with_capacity({var_name}_len);"
398            ));
399            w.open_block(&format!("for _ in 0..{var_name}_len"));
400            emit_read_type(
401                w,
402                &format!("{var_name}_item"),
403                inner,
404                registry,
405                field_name,
406                None,
407            );
408            w.line(&format!("{var_name}.push({var_name}_item);"));
409            w.close_block();
410        }
411        ResolvedType::Map(k, v) => {
412            w.line(&format!(
413                "let {var_name}_len = r.read_leb128(10_u8)? as usize;"
414            ));
415            if let Some(lim) = limit {
416                w.line(&format!(
417                    "if {var_name}_len as u64 > {lim}_u64 {{ return Err(vexil_runtime::DecodeError::LimitExceeded {{ field: \"{field_name}\", limit: {lim}_u64, actual: {var_name}_len as u64 }}); }}"
418                ));
419            }
420            w.line(&format!(
421                "let mut {var_name} = std::collections::BTreeMap::new();"
422            ));
423            w.open_block(&format!("for _ in 0..{var_name}_len"));
424            emit_read_type(w, &format!("{var_name}_k"), k, registry, field_name, None);
425            emit_read_type(w, &format!("{var_name}_v"), v, registry, field_name, None);
426            w.line(&format!("{var_name}.insert({var_name}_k, {var_name}_v);"));
427            w.close_block();
428        }
429        ResolvedType::Result(ok, err) => {
430            w.line(&format!("let {var_name}_is_ok = r.read_bool()?;"));
431            w.open_block(&format!("let {var_name} = if {var_name}_is_ok"));
432            emit_read_type(w, &format!("{var_name}_ok"), ok, registry, field_name, None);
433            w.line(&format!("Ok({var_name}_ok)"));
434            w.close_block();
435            w.open_block("else");
436            emit_read_type(
437                w,
438                &format!("{var_name}_err"),
439                err,
440                registry,
441                field_name,
442                None,
443            );
444            w.line(&format!("Err({var_name}_err)"));
445            w.close_block();
446            w.append(";");
447            w.append("\n");
448        }
449        _ => {} // non_exhaustive guard
450    }
451}
452
453fn read_cast_for_varint(ty: &ResolvedType) -> &'static str {
454    match ty {
455        ResolvedType::Primitive(p) => match p {
456            PrimitiveType::U8 => "u8",
457            PrimitiveType::U16 => "u16",
458            PrimitiveType::U32 => "u32",
459            PrimitiveType::U64 => "u64",
460            _ => "u64",
461        },
462        _ => "u64",
463    }
464}
465
466fn read_cast_for_zigzag(ty: &ResolvedType) -> &'static str {
467    match ty {
468        ResolvedType::Primitive(p) => match p {
469            PrimitiveType::I8 => "i8",
470            PrimitiveType::I16 => "i16",
471            PrimitiveType::I32 => "i32",
472            PrimitiveType::I64 => "i64",
473            _ => "i64",
474        },
475        _ => "i64",
476    }
477}
478
479// ---------------------------------------------------------------------------
480// emit_message
481// ---------------------------------------------------------------------------
482
483/// Emit a complete message struct with Pack and Unpack implementations.
484pub fn emit_message(
485    w: &mut CodeWriter,
486    msg: &MessageDef,
487    registry: &TypeRegistry,
488    needs_box: &HashSet<(TypeId, usize)>,
489    type_id: TypeId,
490) {
491    let name = msg.name.as_str();
492
493    // Tombstone comments
494    emit_tombstones(w, name, &msg.tombstones);
495
496    // Type annotations
497    emit_type_annotations(w, &msg.annotations);
498    w.line("#[derive(Debug, Clone, PartialEq)]");
499
500    // Struct definition
501    w.open_block(&format!("pub struct {name}"));
502    for (fi, field) in msg.fields.iter().enumerate() {
503        emit_field_annotations(w, &field.annotations);
504        let field_rust_type = rust_type(
505            &field.resolved_type,
506            registry,
507            needs_box,
508            Some((type_id, fi)),
509        );
510        w.line(&format!("pub {}: {},", field.name, field_rust_type));
511    }
512    w.close_block();
513    w.blank();
514
515    // Pack impl
516    w.open_block(&format!("impl vexil_runtime::Pack for {name}"));
517    w.open_block("fn pack(&self, w: &mut vexil_runtime::BitWriter) -> Result<(), vexil_runtime::EncodeError>");
518    for field in &msg.fields {
519        let access = format!("self.{}", field.name);
520        emit_write(
521            w,
522            &access,
523            &field.resolved_type,
524            &field.encoding,
525            registry,
526            field.name.as_str(),
527        );
528    }
529    w.line("w.flush_to_byte_boundary();");
530    w.line("Ok(())");
531    w.close_block();
532    w.close_block();
533    w.blank();
534
535    // Unpack impl
536    w.open_block(&format!("impl vexil_runtime::Unpack for {name}"));
537    w.open_block("fn unpack(r: &mut vexil_runtime::BitReader<'_>) -> Result<Self, vexil_runtime::DecodeError>");
538    for field in &msg.fields {
539        let var_name = field.name.as_str();
540        emit_read(
541            w,
542            var_name,
543            &field.resolved_type,
544            &field.encoding,
545            registry,
546            var_name,
547        );
548    }
549    w.line("r.flush_to_byte_boundary();");
550    w.open_block("Ok(Self");
551    for field in &msg.fields {
552        w.line(&format!("{},", field.name));
553    }
554    w.dedent();
555    w.line("})");
556
557    w.close_block();
558    w.close_block();
559    w.blank();
560}