Skip to main content

ironsbe_codegen/rust/
messages.rs

1//! Message encoder/decoder code generation.
2
3use ironsbe_schema::ir::{
4    ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8/// Generator for message encoders and decoders.
9pub struct MessageGenerator<'a> {
10    ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14    /// Creates a new message generator.
15    #[must_use]
16    pub fn new(ir: &'a SchemaIr) -> Self {
17        Self { ir }
18    }
19
20    /// Generates all message definitions.
21    #[must_use]
22    pub fn generate(&self) -> String {
23        let mut output = String::new();
24
25        for msg in &self.ir.messages {
26            output.push_str(&self.generate_decoder(msg));
27            output.push_str(&self.generate_encoder(msg));
28
29            // Generate group decoders/encoders
30            for group in &msg.groups {
31                output.push_str(&self.generate_group_decoder(group));
32            }
33        }
34
35        output
36    }
37
38    /// Generates a message decoder.
39    fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
40        let mut output = String::new();
41        let decoder_name = msg.decoder_name();
42
43        // Struct definition
44        output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
45        output.push_str("#[derive(Debug, Clone, Copy)]\n");
46        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
47        output.push_str("    buffer: &'a [u8],\n");
48        output.push_str("    offset: usize,\n");
49        output.push_str("    acting_version: u16,\n");
50        output.push_str("}\n\n");
51
52        // Implementation
53        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
54        output.push_str(&format!(
55            "    /// Template ID for this message.\n\
56             pub const TEMPLATE_ID: u16 = {};\n",
57            msg.template_id
58        ));
59        output.push_str(&format!(
60            "    /// Block length of the fixed portion.\n\
61             pub const BLOCK_LENGTH: u16 = {};\n\n",
62            msg.block_length
63        ));
64
65        // Constructor
66        output.push_str("    /// Wraps a buffer for zero-copy decoding.\n");
67        output.push_str("    ///\n");
68        output.push_str("    /// # Arguments\n");
69        output.push_str("    /// * `buffer` - Buffer containing the message\n");
70        output.push_str(
71            "    /// * `offset` - Offset to the start of the root block (after header)\n",
72        );
73        output.push_str("    /// * `acting_version` - Schema version for compatibility\n");
74        output.push_str("    #[inline]\n");
75        output.push_str("    #[must_use]\n");
76        output.push_str(
77            "    pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
78        );
79        output.push_str("        Self { buffer, offset, acting_version }\n");
80        output.push_str("    }\n\n");
81
82        // Field getters
83        for field in &msg.fields {
84            output.push_str(&self.generate_field_getter(field));
85        }
86
87        // Group accessors
88        let mut group_offset = msg.block_length as usize;
89        for group in &msg.groups {
90            output.push_str(&self.generate_group_accessor(group, group_offset));
91            group_offset += 4; // Group header size
92        }
93
94        output.push_str("}\n\n");
95
96        // SbeDecoder trait implementation
97        output.push_str(&format!(
98            "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
99            decoder_name
100        ));
101        output.push_str(&format!(
102            "    const TEMPLATE_ID: u16 = {};\n",
103            msg.template_id
104        ));
105        output.push_str("    const SCHEMA_ID: u16 = SCHEMA_ID;\n");
106        output.push_str("    const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
107        output.push_str(&format!(
108            "    const BLOCK_LENGTH: u16 = {};\n\n",
109            msg.block_length
110        ));
111
112        output.push_str(
113            "    fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
114        );
115        output.push_str("        Self::wrap(buffer, offset, acting_version)\n");
116        output.push_str("    }\n\n");
117
118        output.push_str("    fn encoded_length(&self) -> usize {\n");
119        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
120        output.push_str("    }\n");
121        output.push_str("}\n\n");
122
123        output
124    }
125
126    /// Generates a field getter method.
127    fn generate_field_getter(&self, field: &ResolvedField) -> String {
128        let mut output = String::new();
129
130        output.push_str(&format!(
131            "    /// Field: {} (id={}, offset={}).\n",
132            field.name, field.id, field.offset
133        ));
134        output.push_str("    #[inline(always)]\n");
135        output.push_str("    #[must_use]\n");
136
137        if field.is_array {
138            // Array field - return slice
139            let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
140            let len = field.array_length.unwrap_or(1);
141
142            if elem_type == "u8" {
143                // Byte array - return &[u8]
144                output.push_str(&format!(
145                    "    pub fn {}(&self) -> &'a [u8] {{\n",
146                    field.getter_name
147                ));
148                output.push_str(&format!(
149                    "        &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
150                    field.offset, field.offset, len
151                ));
152                output.push_str("    }\n\n");
153
154                // Also generate a string accessor for char arrays
155                output.push_str(&format!(
156                    "    /// Field {} as string (trimmed).\n",
157                    field.name
158                ));
159                output.push_str("    #[inline]\n");
160                output.push_str("    #[must_use]\n");
161                output.push_str(&format!(
162                    "    pub fn {}_as_str(&self) -> &'a str {{\n",
163                    field.getter_name
164                ));
165                output.push_str(&format!(
166                    "        let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
167                    field.offset, field.offset, len
168                ));
169                output.push_str(
170                    "        let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
171                );
172                output.push_str("        std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
173                output.push_str("    }\n\n");
174            } else {
175                // Other array types
176                output.push_str(&format!(
177                    "    pub fn {}(&self) -> &'a [u8] {{\n",
178                    field.getter_name
179                ));
180                output.push_str(&format!(
181                    "        &self.buffer[self.offset + {}..self.offset + {}]\n",
182                    field.offset,
183                    field.offset + field.encoded_length
184                ));
185                output.push_str("    }\n\n");
186            }
187        } else {
188            // Scalar field - check if it's an enum/set type
189            let rust_type = &field.rust_type;
190            let resolved_type = self.ir.get_type(&field.type_name);
191
192            match resolved_type.map(|t| &t.kind) {
193                Some(TypeKind::Enum { encoding, .. }) => {
194                    // Enum field - use encoding primitive and wrap with From
195                    let read_method = get_read_method(Some(*encoding));
196                    output.push_str(&format!(
197                        "    pub fn {}(&self) -> {} {{\n",
198                        field.getter_name, rust_type
199                    ));
200                    output.push_str(&format!(
201                        "        {}::from(self.buffer.{}(self.offset + {}))\n",
202                        rust_type, read_method, field.offset
203                    ));
204                    output.push_str("    }\n\n");
205                }
206                Some(TypeKind::Set { encoding, .. }) => {
207                    // Set field - use encoding primitive and wrap with from_raw
208                    let read_method = get_read_method(Some(*encoding));
209                    output.push_str(&format!(
210                        "    pub fn {}(&self) -> {} {{\n",
211                        field.getter_name, rust_type
212                    ));
213                    output.push_str(&format!(
214                        "        {}::from_raw(self.buffer.{}(self.offset + {}))\n",
215                        rust_type, read_method, field.offset
216                    ));
217                    output.push_str("    }\n\n");
218                }
219                Some(TypeKind::Composite { .. }) => {
220                    // Composite field - return wrapper struct
221                    output.push_str(&format!(
222                        "    pub fn {}(&self) -> {}<'a> {{\n",
223                        field.getter_name, rust_type
224                    ));
225                    output.push_str(&format!(
226                        "        {}::wrap(self.buffer, self.offset + {})\n",
227                        rust_type, field.offset
228                    ));
229                    output.push_str("    }\n\n");
230                }
231                _ => {
232                    // Primitive field
233                    let read_method = get_read_method(field.primitive_type);
234                    output.push_str(&format!(
235                        "    pub fn {}(&self) -> {} {{\n",
236                        field.getter_name, rust_type
237                    ));
238                    output.push_str(&format!(
239                        "        self.buffer.{}(self.offset + {})\n",
240                        read_method, field.offset
241                    ));
242                    output.push_str("    }\n\n");
243                }
244            }
245        }
246
247        output
248    }
249
250    /// Generates a group accessor method.
251    fn generate_group_accessor(&self, group: &ResolvedGroup, offset: usize) -> String {
252        let mut output = String::new();
253        let group_decoder = group.decoder_name();
254
255        output.push_str(&format!("    /// Access {} repeating group.\n", group.name));
256        output.push_str("    #[inline]\n");
257        output.push_str("    #[must_use]\n");
258        output.push_str(&format!(
259            "    pub fn {}(&self) -> {}<'a> {{\n",
260            to_snake_case(&group.name),
261            group_decoder
262        ));
263        output.push_str(&format!(
264            "        {}::wrap(self.buffer, self.offset + {})\n",
265            group_decoder, offset
266        ));
267        output.push_str("    }\n\n");
268
269        output
270    }
271
272    /// Generates a message encoder.
273    fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
274        let mut output = String::new();
275        let encoder_name = msg.encoder_name();
276
277        // Struct definition
278        output.push_str(&format!("/// {} Encoder.\n", msg.name));
279        output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
280        output.push_str("    buffer: &'a mut [u8],\n");
281        output.push_str("    offset: usize,\n");
282        output.push_str("}\n\n");
283
284        // Implementation
285        output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
286        output.push_str(&format!(
287            "    /// Template ID for this message.\n\
288             pub const TEMPLATE_ID: u16 = {};\n",
289            msg.template_id
290        ));
291        output.push_str(&format!(
292            "    /// Block length of the fixed portion.\n\
293             pub const BLOCK_LENGTH: u16 = {};\n\n",
294            msg.block_length
295        ));
296
297        // Constructor
298        output.push_str("    /// Wraps a buffer for encoding, writing the header.\n");
299        output.push_str("    #[inline]\n");
300        output.push_str("    pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
301        output.push_str("        let mut encoder = Self { buffer, offset };\n");
302        output.push_str("        encoder.write_header();\n");
303        output.push_str("        encoder\n");
304        output.push_str("    }\n\n");
305
306        // Write header
307        output.push_str("    fn write_header(&mut self) {\n");
308        output.push_str("        let header = MessageHeader {\n");
309        output.push_str("            block_length: Self::BLOCK_LENGTH,\n");
310        output.push_str("            template_id: Self::TEMPLATE_ID,\n");
311        output.push_str("            schema_id: SCHEMA_ID,\n");
312        output.push_str("            version: SCHEMA_VERSION,\n");
313        output.push_str("        };\n");
314        output.push_str("        header.encode(self.buffer, self.offset);\n");
315        output.push_str("    }\n\n");
316
317        // Encoded length
318        output.push_str("    /// Returns the encoded length of the message.\n");
319        output.push_str("    #[must_use]\n");
320        output.push_str("    pub const fn encoded_length(&self) -> usize {\n");
321        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
322        output.push_str("    }\n\n");
323
324        // Field setters
325        for field in &msg.fields {
326            output.push_str(&self.generate_field_setter(field));
327        }
328
329        output.push_str("}\n\n");
330
331        output
332    }
333
334    /// Generates a field setter method.
335    fn generate_field_setter(&self, field: &ResolvedField) -> String {
336        let mut output = String::new();
337        let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
338
339        output.push_str(&format!(
340            "    /// Set field: {} (id={}, offset={}).\n",
341            field.name, field.id, field.offset
342        ));
343        output.push_str("    #[inline(always)]\n");
344
345        if field.is_array {
346            // Array field - accept slice
347            let len = field.array_length.unwrap_or(field.encoded_length);
348
349            output.push_str(&format!(
350                "    pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
351                field.setter_name
352            ));
353            output.push_str(&format!(
354                "        let copy_len = value.len().min({});\n",
355                len
356            ));
357            output.push_str(&format!(
358                "        self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
359                field_offset, field_offset
360            ));
361            output.push_str("            .copy_from_slice(&value[..copy_len]);\n");
362            output.push_str(&format!("        if copy_len < {} {{\n", len));
363            output.push_str(&format!(
364                "            self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
365                field_offset, field_offset, len
366            ));
367            output.push_str("        }\n");
368            output.push_str("        self\n");
369            output.push_str("    }\n\n");
370        } else {
371            // Scalar field - check if it's an enum/set type
372            let rust_type = &field.rust_type;
373            let resolved_type = self.ir.get_type(&field.type_name);
374
375            match resolved_type.map(|t| &t.kind) {
376                Some(TypeKind::Enum { encoding, .. }) => {
377                    // Enum field - convert enum to primitive before writing
378                    let write_method = get_write_method(Some(*encoding));
379                    let prim_type = encoding.rust_type();
380                    output.push_str(&format!(
381                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
382                        field.setter_name, rust_type
383                    ));
384                    output.push_str(&format!(
385                        "        self.buffer.{}(self.offset + {}, {}::from(value));\n",
386                        write_method, field_offset, prim_type
387                    ));
388                    output.push_str("        self\n");
389                    output.push_str("    }\n\n");
390                }
391                Some(TypeKind::Set { encoding, .. }) => {
392                    // Set field - use raw() to get the primitive value
393                    let write_method = get_write_method(Some(*encoding));
394                    output.push_str(&format!(
395                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
396                        field.setter_name, rust_type
397                    ));
398                    output.push_str(&format!(
399                        "        self.buffer.{}(self.offset + {}, value.raw());\n",
400                        write_method, field_offset
401                    ));
402                    output.push_str("        self\n");
403                    output.push_str("    }\n\n");
404                }
405                Some(TypeKind::Composite { .. }) => {
406                    // Composite field - return encoder for nested writes
407                    output.push_str(&format!(
408                        "    pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
409                        field.setter_name, rust_type
410                    ));
411                    output.push_str(&format!(
412                        "        {}Encoder::wrap(self.buffer, self.offset + {})\n",
413                        rust_type, field_offset
414                    ));
415                    output.push_str("    }\n\n");
416                }
417                _ => {
418                    // Primitive field
419                    let write_method = get_write_method(field.primitive_type);
420                    output.push_str(&format!(
421                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
422                        field.setter_name, rust_type
423                    ));
424                    output.push_str(&format!(
425                        "        self.buffer.{}(self.offset + {}, value);\n",
426                        write_method, field_offset
427                    ));
428                    output.push_str("        self\n");
429                    output.push_str("    }\n\n");
430                }
431            }
432        }
433
434        output
435    }
436
437    /// Generates a group decoder.
438    fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
439        let mut output = String::new();
440        let decoder_name = group.decoder_name();
441        let entry_name = group.entry_decoder_name();
442
443        // Group decoder struct
444        output.push_str(&format!("/// {} Group Decoder.\n", group.name));
445        output.push_str("#[derive(Debug, Clone, Copy)]\n");
446        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
447        output.push_str("    buffer: &'a [u8],\n");
448        output.push_str("    block_length: u16,\n");
449        output.push_str("    count: u16,\n");
450        output.push_str("    index: u16,\n");
451        output.push_str("    offset: usize,\n");
452        output.push_str("}\n\n");
453
454        // Group decoder implementation
455        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
456        output.push_str("    /// Wraps a buffer at the group header position.\n");
457        output.push_str("    #[must_use]\n");
458        output.push_str("    pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
459        output.push_str("        let header = GroupHeader::wrap(buffer, offset);\n");
460        output.push_str("        Self {\n");
461        output.push_str("            buffer,\n");
462        output.push_str("            block_length: header.block_length,\n");
463        output.push_str("            count: header.num_in_group,\n");
464        output.push_str("            index: 0,\n");
465        output.push_str("            offset: offset + GroupHeader::ENCODED_LENGTH,\n");
466        output.push_str("        }\n");
467        output.push_str("    }\n\n");
468
469        output.push_str("    /// Returns the number of entries in the group.\n");
470        output.push_str("    #[must_use]\n");
471        output.push_str("    pub const fn count(&self) -> u16 {\n");
472        output.push_str("        self.count\n");
473        output.push_str("    }\n\n");
474
475        output.push_str("    /// Returns true if the group is empty.\n");
476        output.push_str("    #[must_use]\n");
477        output.push_str("    pub const fn is_empty(&self) -> bool {\n");
478        output.push_str("        self.count == 0\n");
479        output.push_str("    }\n");
480        output.push_str("}\n\n");
481
482        // Iterator implementation
483        output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
484        output.push_str(&format!("    type Item = {}<'a>;\n\n", entry_name));
485        output.push_str("    fn next(&mut self) -> Option<Self::Item> {\n");
486        output.push_str("        if self.index >= self.count {\n");
487        output.push_str("            return None;\n");
488        output.push_str("        }\n");
489        output.push_str(&format!(
490            "        let entry = {}::wrap(self.buffer, self.offset);\n",
491            entry_name
492        ));
493        output.push_str("        self.offset += self.block_length as usize;\n");
494        output.push_str("        self.index += 1;\n");
495        output.push_str("        Some(entry)\n");
496        output.push_str("    }\n\n");
497
498        output.push_str("    fn size_hint(&self) -> (usize, Option<usize>) {\n");
499        output.push_str("        let remaining = (self.count - self.index) as usize;\n");
500        output.push_str("        (remaining, Some(remaining))\n");
501        output.push_str("    }\n");
502        output.push_str("}\n\n");
503
504        output.push_str(&format!(
505            "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
506            decoder_name
507        ));
508
509        // Entry decoder
510        output.push_str(&self.generate_entry_decoder(group));
511
512        // Nested groups
513        for nested in &group.nested_groups {
514            output.push_str(&self.generate_group_decoder(nested));
515        }
516
517        output
518    }
519
520    /// Generates a group entry decoder.
521    fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
522        let mut output = String::new();
523        let entry_name = group.entry_decoder_name();
524
525        output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
526        output.push_str("#[derive(Debug, Clone, Copy)]\n");
527        output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
528        output.push_str("    buffer: &'a [u8],\n");
529        output.push_str("    offset: usize,\n");
530        output.push_str("}\n\n");
531
532        output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
533        output.push_str("    fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
534        output.push_str("        Self { buffer, offset }\n");
535        output.push_str("    }\n\n");
536
537        // Field getters
538        for field in &group.fields {
539            output.push_str(&self.generate_field_getter(field));
540        }
541
542        output.push_str("}\n\n");
543
544        output
545    }
546}
547
548/// Gets the read method name for a primitive type.
549fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
550    match prim {
551        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
552        Some(PrimitiveType::Int8) => "get_i8",
553        Some(PrimitiveType::Uint16) => "get_u16_le",
554        Some(PrimitiveType::Int16) => "get_i16_le",
555        Some(PrimitiveType::Uint32) => "get_u32_le",
556        Some(PrimitiveType::Int32) => "get_i32_le",
557        Some(PrimitiveType::Uint64) => "get_u64_le",
558        Some(PrimitiveType::Int64) => "get_i64_le",
559        Some(PrimitiveType::Float) => "get_f32_le",
560        Some(PrimitiveType::Double) => "get_f64_le",
561        None => "get_u64_le",
562    }
563}
564
565/// Gets the write method name for a primitive type.
566fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
567    match prim {
568        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
569        Some(PrimitiveType::Int8) => "put_i8",
570        Some(PrimitiveType::Uint16) => "put_u16_le",
571        Some(PrimitiveType::Int16) => "put_i16_le",
572        Some(PrimitiveType::Uint32) => "put_u32_le",
573        Some(PrimitiveType::Int32) => "put_i32_le",
574        Some(PrimitiveType::Uint64) => "put_u64_le",
575        Some(PrimitiveType::Int64) => "put_i64_le",
576        Some(PrimitiveType::Float) => "put_f32_le",
577        Some(PrimitiveType::Double) => "put_f64_le",
578        None => "put_u64_le",
579    }
580}