Skip to main content

ironsbe_codegen/rust/
messages.rs

1//! Message encoder/decoder code generation.
2
3use ironsbe_schema::ir::{ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, to_snake_case};
4use ironsbe_schema::types::PrimitiveType;
5
6/// Generator for message encoders and decoders.
7pub struct MessageGenerator<'a> {
8    ir: &'a SchemaIr,
9}
10
11impl<'a> MessageGenerator<'a> {
12    /// Creates a new message generator.
13    #[must_use]
14    pub fn new(ir: &'a SchemaIr) -> Self {
15        Self { ir }
16    }
17
18    /// Generates all message definitions.
19    #[must_use]
20    pub fn generate(&self) -> String {
21        let mut output = String::new();
22
23        for msg in &self.ir.messages {
24            output.push_str(&self.generate_decoder(msg));
25            output.push_str(&self.generate_encoder(msg));
26
27            // Generate group decoders/encoders
28            for group in &msg.groups {
29                output.push_str(&self.generate_group_decoder(group));
30            }
31        }
32
33        output
34    }
35
36    /// Generates a message decoder.
37    fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
38        let mut output = String::new();
39        let decoder_name = msg.decoder_name();
40
41        // Struct definition
42        output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
43        output.push_str("#[derive(Debug, Clone, Copy)]\n");
44        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
45        output.push_str("    buffer: &'a [u8],\n");
46        output.push_str("    offset: usize,\n");
47        output.push_str("    acting_version: u16,\n");
48        output.push_str("}\n\n");
49
50        // Implementation
51        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
52        output.push_str(&format!(
53            "    /// Template ID for this message.\n\
54             pub const TEMPLATE_ID: u16 = {};\n",
55            msg.template_id
56        ));
57        output.push_str(&format!(
58            "    /// Block length of the fixed portion.\n\
59             pub const BLOCK_LENGTH: u16 = {};\n\n",
60            msg.block_length
61        ));
62
63        // Constructor
64        output.push_str("    /// Wraps a buffer for zero-copy decoding.\n");
65        output.push_str("    ///\n");
66        output.push_str("    /// # Arguments\n");
67        output.push_str("    /// * `buffer` - Buffer containing the message\n");
68        output.push_str(
69            "    /// * `offset` - Offset to the start of the root block (after header)\n",
70        );
71        output.push_str("    /// * `acting_version` - Schema version for compatibility\n");
72        output.push_str("    #[inline]\n");
73        output.push_str("    #[must_use]\n");
74        output.push_str(
75            "    pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
76        );
77        output.push_str("        Self { buffer, offset, acting_version }\n");
78        output.push_str("    }\n\n");
79
80        // Field getters
81        for field in &msg.fields {
82            output.push_str(&self.generate_field_getter(field));
83        }
84
85        // Group accessors
86        let mut group_offset = msg.block_length as usize;
87        for group in &msg.groups {
88            output.push_str(&self.generate_group_accessor(group, group_offset));
89            group_offset += 4; // Group header size
90        }
91
92        output.push_str("}\n\n");
93
94        // SbeDecoder trait implementation
95        output.push_str(&format!(
96            "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
97            decoder_name
98        ));
99        output.push_str(&format!(
100            "    const TEMPLATE_ID: u16 = {};\n",
101            msg.template_id
102        ));
103        output.push_str("    const SCHEMA_ID: u16 = SCHEMA_ID;\n");
104        output.push_str("    const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
105        output.push_str(&format!(
106            "    const BLOCK_LENGTH: u16 = {};\n\n",
107            msg.block_length
108        ));
109
110        output.push_str(
111            "    fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
112        );
113        output.push_str("        Self::wrap(buffer, offset, acting_version)\n");
114        output.push_str("    }\n\n");
115
116        output.push_str("    fn encoded_length(&self) -> usize {\n");
117        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
118        output.push_str("    }\n");
119        output.push_str("}\n\n");
120
121        output
122    }
123
124    /// Generates a field getter method.
125    fn generate_field_getter(&self, field: &ResolvedField) -> String {
126        let mut output = String::new();
127
128        output.push_str(&format!(
129            "    /// Field: {} (id={}, offset={}).\n",
130            field.name, field.id, field.offset
131        ));
132        output.push_str("    #[inline(always)]\n");
133        output.push_str("    #[must_use]\n");
134
135        if field.is_array {
136            // Array field - return slice
137            let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
138            let len = field.array_length.unwrap_or(1);
139
140            if elem_type == "u8" {
141                // Byte array - return &[u8]
142                output.push_str(&format!(
143                    "    pub fn {}(&self) -> &'a [u8] {{\n",
144                    field.getter_name
145                ));
146                output.push_str(&format!(
147                    "        &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
148                    field.offset, field.offset, len
149                ));
150                output.push_str("    }\n\n");
151
152                // Also generate a string accessor for char arrays
153                output.push_str(&format!(
154                    "    /// Field {} as string (trimmed).\n",
155                    field.name
156                ));
157                output.push_str("    #[inline]\n");
158                output.push_str("    #[must_use]\n");
159                output.push_str(&format!(
160                    "    pub fn {}_as_str(&self) -> &'a str {{\n",
161                    field.getter_name
162                ));
163                output.push_str(&format!(
164                    "        let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
165                    field.offset, field.offset, len
166                ));
167                output.push_str(
168                    "        let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
169                );
170                output.push_str("        std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
171                output.push_str("    }\n\n");
172            } else {
173                // Other array types
174                output.push_str(&format!(
175                    "    pub fn {}(&self) -> &'a [u8] {{\n",
176                    field.getter_name
177                ));
178                output.push_str(&format!(
179                    "        &self.buffer[self.offset + {}..self.offset + {}]\n",
180                    field.offset,
181                    field.offset + field.encoded_length
182                ));
183                output.push_str("    }\n\n");
184            }
185        } else {
186            // Scalar field
187            let rust_type = &field.rust_type;
188            let read_method = get_read_method(field.primitive_type);
189
190            output.push_str(&format!(
191                "    pub fn {}(&self) -> {} {{\n",
192                field.getter_name, rust_type
193            ));
194            output.push_str(&format!(
195                "        self.buffer.{}(self.offset + {})\n",
196                read_method, field.offset
197            ));
198            output.push_str("    }\n\n");
199        }
200
201        output
202    }
203
204    /// Generates a group accessor method.
205    fn generate_group_accessor(&self, group: &ResolvedGroup, offset: usize) -> String {
206        let mut output = String::new();
207        let group_decoder = group.decoder_name();
208
209        output.push_str(&format!("    /// Access {} repeating group.\n", group.name));
210        output.push_str("    #[inline]\n");
211        output.push_str("    #[must_use]\n");
212        output.push_str(&format!(
213            "    pub fn {}(&self) -> {}<'a> {{\n",
214            to_snake_case(&group.name),
215            group_decoder
216        ));
217        output.push_str(&format!(
218            "        {}::wrap(self.buffer, self.offset + {})\n",
219            group_decoder, offset
220        ));
221        output.push_str("    }\n\n");
222
223        output
224    }
225
226    /// Generates a message encoder.
227    fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
228        let mut output = String::new();
229        let encoder_name = msg.encoder_name();
230
231        // Struct definition
232        output.push_str(&format!("/// {} Encoder.\n", msg.name));
233        output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
234        output.push_str("    buffer: &'a mut [u8],\n");
235        output.push_str("    offset: usize,\n");
236        output.push_str("}\n\n");
237
238        // Implementation
239        output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
240        output.push_str(&format!(
241            "    /// Template ID for this message.\n\
242             pub const TEMPLATE_ID: u16 = {};\n",
243            msg.template_id
244        ));
245        output.push_str(&format!(
246            "    /// Block length of the fixed portion.\n\
247             pub const BLOCK_LENGTH: u16 = {};\n\n",
248            msg.block_length
249        ));
250
251        // Constructor
252        output.push_str("    /// Wraps a buffer for encoding, writing the header.\n");
253        output.push_str("    #[inline]\n");
254        output.push_str("    pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
255        output.push_str("        let mut encoder = Self { buffer, offset };\n");
256        output.push_str("        encoder.write_header();\n");
257        output.push_str("        encoder\n");
258        output.push_str("    }\n\n");
259
260        // Write header
261        output.push_str("    fn write_header(&mut self) {\n");
262        output.push_str("        let header = MessageHeader {\n");
263        output.push_str("            block_length: Self::BLOCK_LENGTH,\n");
264        output.push_str("            template_id: Self::TEMPLATE_ID,\n");
265        output.push_str("            schema_id: SCHEMA_ID,\n");
266        output.push_str("            version: SCHEMA_VERSION,\n");
267        output.push_str("        };\n");
268        output.push_str("        header.encode(self.buffer, self.offset);\n");
269        output.push_str("    }\n\n");
270
271        // Encoded length
272        output.push_str("    /// Returns the encoded length of the message.\n");
273        output.push_str("    #[must_use]\n");
274        output.push_str("    pub const fn encoded_length(&self) -> usize {\n");
275        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
276        output.push_str("    }\n\n");
277
278        // Field setters
279        for field in &msg.fields {
280            output.push_str(&self.generate_field_setter(field));
281        }
282
283        output.push_str("}\n\n");
284
285        output
286    }
287
288    /// Generates a field setter method.
289    fn generate_field_setter(&self, field: &ResolvedField) -> String {
290        let mut output = String::new();
291        let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
292
293        output.push_str(&format!(
294            "    /// Set field: {} (id={}, offset={}).\n",
295            field.name, field.id, field.offset
296        ));
297        output.push_str("    #[inline(always)]\n");
298
299        if field.is_array {
300            // Array field - accept slice
301            let len = field.array_length.unwrap_or(field.encoded_length);
302
303            output.push_str(&format!(
304                "    pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
305                field.setter_name
306            ));
307            output.push_str(&format!(
308                "        let copy_len = value.len().min({});\n",
309                len
310            ));
311            output.push_str(&format!(
312                "        self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
313                field_offset, field_offset
314            ));
315            output.push_str("            .copy_from_slice(&value[..copy_len]);\n");
316            output.push_str(&format!("        if copy_len < {} {{\n", len));
317            output.push_str(&format!(
318                "            self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
319                field_offset, field_offset, len
320            ));
321            output.push_str("        }\n");
322            output.push_str("        self\n");
323            output.push_str("    }\n\n");
324        } else {
325            // Scalar field
326            let rust_type = &field.rust_type;
327            let write_method = get_write_method(field.primitive_type);
328
329            output.push_str(&format!(
330                "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
331                field.setter_name, rust_type
332            ));
333            output.push_str(&format!(
334                "        self.buffer.{}(self.offset + {}, value);\n",
335                write_method, field_offset
336            ));
337            output.push_str("        self\n");
338            output.push_str("    }\n\n");
339        }
340
341        output
342    }
343
344    /// Generates a group decoder.
345    fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
346        let mut output = String::new();
347        let decoder_name = group.decoder_name();
348        let entry_name = group.entry_decoder_name();
349
350        // Group decoder struct
351        output.push_str(&format!("/// {} Group Decoder.\n", group.name));
352        output.push_str("#[derive(Debug, Clone, Copy)]\n");
353        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
354        output.push_str("    buffer: &'a [u8],\n");
355        output.push_str("    block_length: u16,\n");
356        output.push_str("    count: u16,\n");
357        output.push_str("    index: u16,\n");
358        output.push_str("    offset: usize,\n");
359        output.push_str("}\n\n");
360
361        // Group decoder implementation
362        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
363        output.push_str("    /// Wraps a buffer at the group header position.\n");
364        output.push_str("    #[must_use]\n");
365        output.push_str("    pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
366        output.push_str("        let header = GroupHeader::wrap(buffer, offset);\n");
367        output.push_str("        Self {\n");
368        output.push_str("            buffer,\n");
369        output.push_str("            block_length: header.block_length,\n");
370        output.push_str("            count: header.num_in_group,\n");
371        output.push_str("            index: 0,\n");
372        output.push_str("            offset: offset + GroupHeader::ENCODED_LENGTH,\n");
373        output.push_str("        }\n");
374        output.push_str("    }\n\n");
375
376        output.push_str("    /// Returns the number of entries in the group.\n");
377        output.push_str("    #[must_use]\n");
378        output.push_str("    pub const fn count(&self) -> u16 {\n");
379        output.push_str("        self.count\n");
380        output.push_str("    }\n\n");
381
382        output.push_str("    /// Returns true if the group is empty.\n");
383        output.push_str("    #[must_use]\n");
384        output.push_str("    pub const fn is_empty(&self) -> bool {\n");
385        output.push_str("        self.count == 0\n");
386        output.push_str("    }\n");
387        output.push_str("}\n\n");
388
389        // Iterator implementation
390        output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
391        output.push_str(&format!("    type Item = {}<'a>;\n\n", entry_name));
392        output.push_str("    fn next(&mut self) -> Option<Self::Item> {\n");
393        output.push_str("        if self.index >= self.count {\n");
394        output.push_str("            return None;\n");
395        output.push_str("        }\n");
396        output.push_str(&format!(
397            "        let entry = {}::wrap(self.buffer, self.offset);\n",
398            entry_name
399        ));
400        output.push_str("        self.offset += self.block_length as usize;\n");
401        output.push_str("        self.index += 1;\n");
402        output.push_str("        Some(entry)\n");
403        output.push_str("    }\n\n");
404
405        output.push_str("    fn size_hint(&self) -> (usize, Option<usize>) {\n");
406        output.push_str("        let remaining = (self.count - self.index) as usize;\n");
407        output.push_str("        (remaining, Some(remaining))\n");
408        output.push_str("    }\n");
409        output.push_str("}\n\n");
410
411        output.push_str(&format!(
412            "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
413            decoder_name
414        ));
415
416        // Entry decoder
417        output.push_str(&self.generate_entry_decoder(group));
418
419        // Nested groups
420        for nested in &group.nested_groups {
421            output.push_str(&self.generate_group_decoder(nested));
422        }
423
424        output
425    }
426
427    /// Generates a group entry decoder.
428    fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
429        let mut output = String::new();
430        let entry_name = group.entry_decoder_name();
431
432        output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
433        output.push_str("#[derive(Debug, Clone, Copy)]\n");
434        output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
435        output.push_str("    buffer: &'a [u8],\n");
436        output.push_str("    offset: usize,\n");
437        output.push_str("}\n\n");
438
439        output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
440        output.push_str("    fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
441        output.push_str("        Self { buffer, offset }\n");
442        output.push_str("    }\n\n");
443
444        // Field getters
445        for field in &group.fields {
446            output.push_str(&self.generate_field_getter(field));
447        }
448
449        output.push_str("}\n\n");
450
451        output
452    }
453}
454
455/// Gets the read method name for a primitive type.
456fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
457    match prim {
458        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
459        Some(PrimitiveType::Int8) => "get_i8",
460        Some(PrimitiveType::Uint16) => "get_u16_le",
461        Some(PrimitiveType::Int16) => "get_i16_le",
462        Some(PrimitiveType::Uint32) => "get_u32_le",
463        Some(PrimitiveType::Int32) => "get_i32_le",
464        Some(PrimitiveType::Uint64) => "get_u64_le",
465        Some(PrimitiveType::Int64) => "get_i64_le",
466        Some(PrimitiveType::Float) => "get_f32_le",
467        Some(PrimitiveType::Double) => "get_f64_le",
468        None => "get_u64_le",
469    }
470}
471
472/// Gets the write method name for a primitive type.
473fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
474    match prim {
475        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
476        Some(PrimitiveType::Int8) => "put_i8",
477        Some(PrimitiveType::Uint16) => "put_u16_le",
478        Some(PrimitiveType::Int16) => "put_i16_le",
479        Some(PrimitiveType::Uint32) => "put_u32_le",
480        Some(PrimitiveType::Int32) => "put_i32_le",
481        Some(PrimitiveType::Uint64) => "put_u64_le",
482        Some(PrimitiveType::Int64) => "put_i64_le",
483        Some(PrimitiveType::Float) => "put_f32_le",
484        Some(PrimitiveType::Double) => "put_f64_le",
485        None => "put_u64_le",
486    }
487}