Skip to main content

ironsbe_codegen/rust/
messages.rs

1//! Message encoder/decoder code generation.
2
3use ironsbe_schema::ir::{
4    ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8/// Generator for message encoders and decoders.
9pub struct MessageGenerator<'a> {
10    ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14    /// Creates a new message generator.
15    #[must_use]
16    pub fn new(ir: &'a SchemaIr) -> Self {
17        Self { ir }
18    }
19
20    /// Generates all message definitions.
21    #[must_use]
22    pub fn generate(&self) -> String {
23        let mut output = String::new();
24
25        for msg in &self.ir.messages {
26            output.push_str(&self.generate_decoder(msg));
27            output.push_str(&self.generate_encoder(msg));
28
29            // Generate group decoders and encoders in a message-scoped module
30            if !msg.groups.is_empty() {
31                let mod_name = to_snake_case(&msg.name);
32                output.push_str(&format!("/// Types for {} repeating groups.\n", msg.name));
33                output.push_str(&format!("pub mod {} {{\n", mod_name));
34                output.push_str("    use super::*;\n\n");
35                for group in &msg.groups {
36                    output.push_str(&self.generate_group_decoder(group));
37                    output.push_str(&self.generate_group_encoder(group));
38                }
39                output.push_str("}\n\n");
40            }
41        }
42
43        output
44    }
45
46    /// Generates a message decoder.
47    fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
48        let mut output = String::new();
49        let decoder_name = msg.decoder_name();
50
51        // Struct definition
52        output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
53        output.push_str("#[derive(Debug, Clone, Copy)]\n");
54        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
55        output.push_str("    buffer: &'a [u8],\n");
56        output.push_str("    offset: usize,\n");
57        output.push_str("    acting_version: u16,\n");
58        output.push_str("}\n\n");
59
60        // Implementation
61        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
62        output.push_str(&format!(
63            "    /// Template ID for this message.\n\
64             pub const TEMPLATE_ID: u16 = {};\n",
65            msg.template_id
66        ));
67        output.push_str(&format!(
68            "    /// Block length of the fixed portion.\n\
69             pub const BLOCK_LENGTH: u16 = {};\n\n",
70            msg.block_length
71        ));
72
73        // Constructor
74        output.push_str("    /// Wraps a buffer for zero-copy decoding.\n");
75        output.push_str("    ///\n");
76        output.push_str("    /// # Arguments\n");
77        output.push_str("    /// * `buffer` - Buffer containing the message\n");
78        output.push_str(
79            "    /// * `offset` - Offset to the start of the root block (after header)\n",
80        );
81        output.push_str("    /// * `acting_version` - Schema version for compatibility\n");
82        output.push_str("    #[inline]\n");
83        output.push_str("    #[must_use]\n");
84        output.push_str(
85            "    pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
86        );
87        output.push_str("        Self { buffer, offset, acting_version }\n");
88        output.push_str("    }\n\n");
89
90        // Field getters
91        for field in &msg.fields {
92            output.push_str(&self.generate_field_getter(field));
93        }
94
95        // Group accessors
96        let mut group_offset = msg.block_length as usize;
97        for group in &msg.groups {
98            output.push_str(&self.generate_group_accessor(group, group_offset, &msg.name));
99            group_offset += 4; // Group header size
100        }
101
102        output.push_str("}\n\n");
103
104        // SbeDecoder trait implementation
105        output.push_str(&format!(
106            "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
107            decoder_name
108        ));
109        output.push_str(&format!(
110            "    const TEMPLATE_ID: u16 = {};\n",
111            msg.template_id
112        ));
113        output.push_str("    const SCHEMA_ID: u16 = SCHEMA_ID;\n");
114        output.push_str("    const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
115        output.push_str(&format!(
116            "    const BLOCK_LENGTH: u16 = {};\n\n",
117            msg.block_length
118        ));
119
120        output.push_str(
121            "    fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
122        );
123        output.push_str("        Self::wrap(buffer, offset, acting_version)\n");
124        output.push_str("    }\n\n");
125
126        output.push_str("    fn encoded_length(&self) -> usize {\n");
127        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
128        output.push_str("    }\n");
129        output.push_str("}\n\n");
130
131        output
132    }
133
134    /// Generates a field getter method.
135    fn generate_field_getter(&self, field: &ResolvedField) -> String {
136        let mut output = String::new();
137
138        output.push_str(&format!(
139            "    /// Field: {} (id={}, offset={}).\n",
140            field.name, field.id, field.offset
141        ));
142        output.push_str("    #[inline(always)]\n");
143        output.push_str("    #[must_use]\n");
144
145        if field.is_array {
146            // Array field - return slice
147            let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
148            let len = field.array_length.unwrap_or(1);
149
150            if elem_type == "u8" {
151                // Byte array - return &[u8]
152                output.push_str(&format!(
153                    "    pub fn {}(&self) -> &'a [u8] {{\n",
154                    field.getter_name
155                ));
156                output.push_str(&format!(
157                    "        &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
158                    field.offset, field.offset, len
159                ));
160                output.push_str("    }\n\n");
161
162                // Also generate a string accessor for char arrays
163                output.push_str(&format!(
164                    "    /// Field {} as string (trimmed).\n",
165                    field.name
166                ));
167                output.push_str("    #[inline]\n");
168                output.push_str("    #[must_use]\n");
169                output.push_str(&format!(
170                    "    pub fn {}_as_str(&self) -> &'a str {{\n",
171                    field.getter_name
172                ));
173                output.push_str(&format!(
174                    "        let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
175                    field.offset, field.offset, len
176                ));
177                output.push_str(
178                    "        let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
179                );
180                output.push_str("        std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
181                output.push_str("    }\n\n");
182            } else {
183                // Other array types
184                output.push_str(&format!(
185                    "    pub fn {}(&self) -> &'a [u8] {{\n",
186                    field.getter_name
187                ));
188                output.push_str(&format!(
189                    "        &self.buffer[self.offset + {}..self.offset + {}]\n",
190                    field.offset,
191                    field.offset + field.encoded_length
192                ));
193                output.push_str("    }\n\n");
194            }
195        } else {
196            // Scalar field - check if it's an enum/set type
197            let rust_type = &field.rust_type;
198            let resolved_type = self.ir.get_type(&field.type_name);
199
200            match resolved_type.map(|t| &t.kind) {
201                Some(TypeKind::Enum { encoding, .. }) => {
202                    // Enum field - use encoding primitive and wrap with From
203                    let read_method = get_read_method(Some(*encoding));
204                    output.push_str(&format!(
205                        "    pub fn {}(&self) -> {} {{\n",
206                        field.getter_name, rust_type
207                    ));
208                    output.push_str(&format!(
209                        "        {}::from(self.buffer.{}(self.offset + {}))\n",
210                        rust_type, read_method, field.offset
211                    ));
212                    output.push_str("    }\n\n");
213                }
214                Some(TypeKind::Set { encoding, .. }) => {
215                    // Set field - use encoding primitive and wrap with from_raw
216                    let read_method = get_read_method(Some(*encoding));
217                    output.push_str(&format!(
218                        "    pub fn {}(&self) -> {} {{\n",
219                        field.getter_name, rust_type
220                    ));
221                    output.push_str(&format!(
222                        "        {}::from_raw(self.buffer.{}(self.offset + {}))\n",
223                        rust_type, read_method, field.offset
224                    ));
225                    output.push_str("    }\n\n");
226                }
227                Some(TypeKind::Composite { .. }) => {
228                    // Composite field - return wrapper struct
229                    output.push_str(&format!(
230                        "    pub fn {}(&self) -> {}<'a> {{\n",
231                        field.getter_name, rust_type
232                    ));
233                    output.push_str(&format!(
234                        "        {}::wrap(self.buffer, self.offset + {})\n",
235                        rust_type, field.offset
236                    ));
237                    output.push_str("    }\n\n");
238                }
239                _ => {
240                    // Primitive field
241                    let read_method = get_read_method(field.primitive_type);
242                    output.push_str(&format!(
243                        "    pub fn {}(&self) -> {} {{\n",
244                        field.getter_name, rust_type
245                    ));
246                    output.push_str(&format!(
247                        "        self.buffer.{}(self.offset + {})\n",
248                        read_method, field.offset
249                    ));
250                    output.push_str("    }\n\n");
251                }
252            }
253        }
254
255        output
256    }
257
258    /// Generates a group accessor method.
259    fn generate_group_accessor(
260        &self,
261        group: &ResolvedGroup,
262        offset: usize,
263        msg_name: &str,
264    ) -> String {
265        let mut output = String::new();
266        let qualified = format!("{}::{}", to_snake_case(msg_name), group.decoder_name());
267
268        output.push_str(&format!("    /// Access {} repeating group.\n", group.name));
269        output.push_str("    #[inline]\n");
270        output.push_str("    #[must_use]\n");
271        output.push_str(&format!(
272            "    pub fn {}(&self) -> {}<'a> {{\n",
273            to_snake_case(&group.name),
274            qualified
275        ));
276        output.push_str(&format!(
277            "        {}::wrap(self.buffer, self.offset + {})\n",
278            qualified, offset
279        ));
280        output.push_str("    }\n\n");
281
282        output
283    }
284
285    /// Generates a message encoder.
286    fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
287        let mut output = String::new();
288        let encoder_name = msg.encoder_name();
289
290        // Struct definition
291        output.push_str(&format!("/// {} Encoder.\n", msg.name));
292        output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
293        output.push_str("    buffer: &'a mut [u8],\n");
294        output.push_str("    offset: usize,\n");
295        output.push_str("}\n\n");
296
297        // Implementation
298        output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
299        output.push_str(&format!(
300            "    /// Template ID for this message.\n\
301             pub const TEMPLATE_ID: u16 = {};\n",
302            msg.template_id
303        ));
304        output.push_str(&format!(
305            "    /// Block length of the fixed portion.\n\
306             pub const BLOCK_LENGTH: u16 = {};\n\n",
307            msg.block_length
308        ));
309
310        // Constructor
311        output.push_str("    /// Wraps a buffer for encoding, writing the header.\n");
312        output.push_str("    #[inline]\n");
313        output.push_str("    pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
314        output.push_str("        let mut encoder = Self { buffer, offset };\n");
315        output.push_str("        encoder.write_header();\n");
316        output.push_str("        encoder\n");
317        output.push_str("    }\n\n");
318
319        // Write header
320        output.push_str("    fn write_header(&mut self) {\n");
321        output.push_str("        let header = MessageHeader {\n");
322        output.push_str("            block_length: Self::BLOCK_LENGTH,\n");
323        output.push_str("            template_id: Self::TEMPLATE_ID,\n");
324        output.push_str("            schema_id: SCHEMA_ID,\n");
325        output.push_str("            version: SCHEMA_VERSION,\n");
326        output.push_str("        };\n");
327        output.push_str("        header.encode(self.buffer, self.offset);\n");
328        output.push_str("    }\n\n");
329
330        // Encoded length
331        output.push_str("    /// Returns the encoded length of the message.\n");
332        output.push_str("    #[must_use]\n");
333        output.push_str("    pub const fn encoded_length(&self) -> usize {\n");
334        output.push_str("        MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
335        output.push_str("    }\n\n");
336
337        // Field setters
338        for field in &msg.fields {
339            output.push_str(&self.generate_field_setter(field));
340        }
341
342        // Group encoder accessors
343        let mut group_offset = msg.block_length as usize;
344        for group in &msg.groups {
345            output.push_str(&self.generate_group_encoder_accessor(group, group_offset, &msg.name));
346            group_offset += 4; // Group header size
347        }
348
349        output.push_str("}\n\n");
350
351        output
352    }
353
354    /// Generates a field setter method.
355    fn generate_field_setter(&self, field: &ResolvedField) -> String {
356        let mut output = String::new();
357        let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
358
359        output.push_str(&format!(
360            "    /// Set field: {} (id={}, offset={}).\n",
361            field.name, field.id, field.offset
362        ));
363        output.push_str("    #[inline(always)]\n");
364
365        if field.is_array {
366            // Array field - accept slice
367            let len = field.array_length.unwrap_or(field.encoded_length);
368
369            output.push_str(&format!(
370                "    pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
371                field.setter_name
372            ));
373            output.push_str(&format!(
374                "        let copy_len = value.len().min({});\n",
375                len
376            ));
377            output.push_str(&format!(
378                "        self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
379                field_offset, field_offset
380            ));
381            output.push_str("            .copy_from_slice(&value[..copy_len]);\n");
382            output.push_str(&format!("        if copy_len < {} {{\n", len));
383            output.push_str(&format!(
384                "            self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
385                field_offset, field_offset, len
386            ));
387            output.push_str("        }\n");
388            output.push_str("        self\n");
389            output.push_str("    }\n\n");
390        } else {
391            // Scalar field - check if it's an enum/set type
392            let rust_type = &field.rust_type;
393            let resolved_type = self.ir.get_type(&field.type_name);
394
395            match resolved_type.map(|t| &t.kind) {
396                Some(TypeKind::Enum { encoding, .. }) => {
397                    // Enum field - convert enum to primitive before writing
398                    let write_method = get_write_method(Some(*encoding));
399                    let prim_type = encoding.rust_type();
400                    output.push_str(&format!(
401                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
402                        field.setter_name, rust_type
403                    ));
404                    output.push_str(&format!(
405                        "        self.buffer.{}(self.offset + {}, {}::from(value));\n",
406                        write_method, field_offset, prim_type
407                    ));
408                    output.push_str("        self\n");
409                    output.push_str("    }\n\n");
410                }
411                Some(TypeKind::Set { encoding, .. }) => {
412                    // Set field - use raw() to get the primitive value
413                    let write_method = get_write_method(Some(*encoding));
414                    output.push_str(&format!(
415                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
416                        field.setter_name, rust_type
417                    ));
418                    output.push_str(&format!(
419                        "        self.buffer.{}(self.offset + {}, value.raw());\n",
420                        write_method, field_offset
421                    ));
422                    output.push_str("        self\n");
423                    output.push_str("    }\n\n");
424                }
425                Some(TypeKind::Composite { .. }) => {
426                    // Composite field - return encoder for nested writes
427                    output.push_str(&format!(
428                        "    pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
429                        field.setter_name, rust_type
430                    ));
431                    output.push_str(&format!(
432                        "        {}Encoder::wrap(self.buffer, self.offset + {})\n",
433                        rust_type, field_offset
434                    ));
435                    output.push_str("    }\n\n");
436                }
437                _ => {
438                    // Primitive field
439                    let write_method = get_write_method(field.primitive_type);
440                    output.push_str(&format!(
441                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
442                        field.setter_name, rust_type
443                    ));
444                    output.push_str(&format!(
445                        "        self.buffer.{}(self.offset + {}, value);\n",
446                        write_method, field_offset
447                    ));
448                    output.push_str("        self\n");
449                    output.push_str("    }\n\n");
450                }
451            }
452        }
453
454        output
455    }
456
457    /// Generates a group decoder.
458    fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
459        let mut output = String::new();
460        let decoder_name = group.decoder_name();
461        let entry_name = group.entry_decoder_name();
462
463        // Group decoder struct
464        output.push_str(&format!("/// {} Group Decoder.\n", group.name));
465        output.push_str("#[derive(Debug, Clone, Copy)]\n");
466        output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
467        output.push_str("    buffer: &'a [u8],\n");
468        output.push_str("    block_length: u16,\n");
469        output.push_str("    count: u16,\n");
470        output.push_str("    index: u16,\n");
471        output.push_str("    offset: usize,\n");
472        output.push_str("}\n\n");
473
474        // Group decoder implementation
475        output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
476        output.push_str("    /// Wraps a buffer at the group header position.\n");
477        output.push_str("    #[must_use]\n");
478        output.push_str("    pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
479        output.push_str("        let header = GroupHeader::wrap(buffer, offset);\n");
480        output.push_str("        Self {\n");
481        output.push_str("            buffer,\n");
482        output.push_str("            block_length: header.block_length,\n");
483        output.push_str("            count: header.num_in_group,\n");
484        output.push_str("            index: 0,\n");
485        output.push_str("            offset: offset + GroupHeader::ENCODED_LENGTH,\n");
486        output.push_str("        }\n");
487        output.push_str("    }\n\n");
488
489        output.push_str("    /// Returns the number of entries in the group.\n");
490        output.push_str("    #[must_use]\n");
491        output.push_str("    pub const fn count(&self) -> u16 {\n");
492        output.push_str("        self.count\n");
493        output.push_str("    }\n\n");
494
495        output.push_str("    /// Returns true if the group is empty.\n");
496        output.push_str("    #[must_use]\n");
497        output.push_str("    pub const fn is_empty(&self) -> bool {\n");
498        output.push_str("        self.count == 0\n");
499        output.push_str("    }\n");
500        output.push_str("}\n\n");
501
502        // Iterator implementation
503        output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
504        output.push_str(&format!("    type Item = {}<'a>;\n\n", entry_name));
505        output.push_str("    fn next(&mut self) -> Option<Self::Item> {\n");
506        output.push_str("        if self.index >= self.count {\n");
507        output.push_str("            return None;\n");
508        output.push_str("        }\n");
509        output.push_str(&format!(
510            "        let entry = {}::wrap(self.buffer, self.offset);\n",
511            entry_name
512        ));
513        output.push_str("        self.offset += self.block_length as usize;\n");
514        output.push_str("        self.index += 1;\n");
515        output.push_str("        Some(entry)\n");
516        output.push_str("    }\n\n");
517
518        output.push_str("    fn size_hint(&self) -> (usize, Option<usize>) {\n");
519        output.push_str("        let remaining = (self.count - self.index) as usize;\n");
520        output.push_str("        (remaining, Some(remaining))\n");
521        output.push_str("    }\n");
522        output.push_str("}\n\n");
523
524        output.push_str(&format!(
525            "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
526            decoder_name
527        ));
528
529        // Entry decoder
530        output.push_str(&self.generate_entry_decoder(group));
531
532        // Nested groups
533        for nested in &group.nested_groups {
534            output.push_str(&self.generate_group_decoder(nested));
535        }
536
537        output
538    }
539
540    /// Generates a group entry decoder.
541    fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
542        let mut output = String::new();
543        let entry_name = group.entry_decoder_name();
544
545        output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
546        output.push_str("#[derive(Debug, Clone, Copy)]\n");
547        output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
548        output.push_str("    buffer: &'a [u8],\n");
549        output.push_str("    offset: usize,\n");
550        output.push_str("}\n\n");
551
552        output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
553        output.push_str("    fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
554        output.push_str("        Self { buffer, offset }\n");
555        output.push_str("    }\n\n");
556
557        // Field getters
558        for field in &group.fields {
559            output.push_str(&self.generate_field_getter(field));
560        }
561
562        output.push_str("}\n\n");
563
564        output
565    }
566
567    /// Generates a group encoder.
568    fn generate_group_encoder(&self, group: &ResolvedGroup) -> String {
569        let mut output = String::new();
570        let encoder_name = group.encoder_name();
571        let entry_name = group.entry_encoder_name();
572
573        // Group encoder struct
574        output.push_str(&format!("/// {} Group Encoder.\n", group.name));
575        output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
576        output.push_str("    buffer: &'a mut [u8],\n");
577        output.push_str("    count: u16,\n");
578        output.push_str("    index: u16,\n");
579        output.push_str("    offset: usize,\n");
580        output.push_str("}\n\n");
581
582        // Group encoder implementation
583        output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
584        output.push_str(&format!(
585            "    /// Block length of each entry.\n\
586             pub const BLOCK_LENGTH: u16 = {};\n\n",
587            group.block_length
588        ));
589
590        // wrap constructor
591        output
592            .push_str("    /// Wraps a buffer at the group header position, writing the header.\n");
593        output.push_str("    ///\n");
594        output.push_str("    /// # Arguments\n");
595        output.push_str("    /// * `buffer` - Mutable buffer to write to\n");
596        output.push_str("    /// * `offset` - Offset of the group header\n");
597        output.push_str("    /// * `count` - Number of entries to encode\n");
598        output.push_str(
599            "    pub fn wrap(buffer: &'a mut [u8], offset: usize, count: u16) -> Self {\n",
600        );
601        output.push_str("        let header = GroupHeader::new(Self::BLOCK_LENGTH, count);\n");
602        output.push_str("        header.encode(buffer, offset);\n");
603        output.push_str("        Self {\n");
604        output.push_str("            buffer,\n");
605        output.push_str("            count,\n");
606        output.push_str("            index: 0,\n");
607        output.push_str("            offset: offset + GroupHeader::ENCODED_LENGTH,\n");
608        output.push_str("        }\n");
609        output.push_str("    }\n\n");
610
611        // next_entry
612        output.push_str(
613            "    /// Returns the next entry encoder, or `None` if all entries are written.\n",
614        );
615        output.push_str(&format!(
616            "    pub fn next_entry(&mut self) -> Option<{}<'_>> {{\n",
617            entry_name
618        ));
619        output.push_str("        if self.index >= self.count {\n");
620        output.push_str("            return None;\n");
621        output.push_str("        }\n");
622        output.push_str("        let offset = self.offset;\n");
623        output.push_str("        self.offset += Self::BLOCK_LENGTH as usize;\n");
624        output.push_str("        self.index += 1;\n");
625        output.push_str(&format!(
626            "        Some({}::wrap(&mut *self.buffer, offset))\n",
627            entry_name
628        ));
629        output.push_str("    }\n\n");
630
631        // encoded_length
632        output.push_str(
633            "    /// Returns the total encoded length of this group (header + all entries).\n",
634        );
635        output.push_str("    #[must_use]\n");
636        output.push_str("    pub const fn encoded_length(&self) -> usize {\n");
637        output.push_str("        GroupHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize * self.count as usize\n");
638        output.push_str("    }\n");
639        output.push_str("}\n\n");
640
641        // Entry encoder
642        output.push_str(&self.generate_entry_encoder(group));
643
644        // Nested group encoders
645        for nested in &group.nested_groups {
646            output.push_str(&self.generate_group_encoder(nested));
647        }
648
649        output
650    }
651
652    /// Generates a group entry encoder.
653    fn generate_entry_encoder(&self, group: &ResolvedGroup) -> String {
654        let mut output = String::new();
655        let entry_name = group.entry_encoder_name();
656
657        output.push_str(&format!("/// {} Entry Encoder.\n", group.name));
658        output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
659        output.push_str("    buffer: &'a mut [u8],\n");
660        output.push_str("    offset: usize,\n");
661        output.push_str("}\n\n");
662
663        output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
664        output.push_str("    fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
665        output.push_str("        Self { buffer, offset }\n");
666        output.push_str("    }\n\n");
667
668        // Field setters
669        for field in &group.fields {
670            output.push_str(&self.generate_entry_field_setter(field));
671        }
672
673        output.push_str("}\n\n");
674
675        output
676    }
677
678    /// Generates a field setter for a group entry encoder.
679    ///
680    /// Unlike the message-level `generate_field_setter`, this uses the raw field
681    /// offset (relative to the entry start) without a `MessageHeader::ENCODED_LENGTH`
682    /// prefix.
683    fn generate_entry_field_setter(&self, field: &ResolvedField) -> String {
684        let mut output = String::new();
685        let field_offset = field.offset;
686
687        output.push_str(&format!(
688            "    /// Set field: {} (id={}, offset={}).\n",
689            field.name, field.id, field.offset
690        ));
691        output.push_str("    #[inline(always)]\n");
692
693        if field.is_array {
694            let len = field.array_length.unwrap_or(field.encoded_length);
695
696            output.push_str(&format!(
697                "    pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
698                field.setter_name
699            ));
700            output.push_str(&format!(
701                "        let copy_len = value.len().min({});\n",
702                len
703            ));
704            output.push_str(&format!(
705                "        self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
706                field_offset, field_offset
707            ));
708            output.push_str("            .copy_from_slice(&value[..copy_len]);\n");
709            output.push_str(&format!("        if copy_len < {} {{\n", len));
710            output.push_str(&format!(
711                "            self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
712                field_offset, field_offset, len
713            ));
714            output.push_str("        }\n");
715            output.push_str("        self\n");
716            output.push_str("    }\n\n");
717        } else {
718            let rust_type = &field.rust_type;
719            let resolved_type = self.ir.get_type(&field.type_name);
720
721            match resolved_type.map(|t| &t.kind) {
722                Some(TypeKind::Enum { encoding, .. }) => {
723                    let write_method = get_write_method(Some(*encoding));
724                    let prim_type = encoding.rust_type();
725                    output.push_str(&format!(
726                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
727                        field.setter_name, rust_type
728                    ));
729                    output.push_str(&format!(
730                        "        self.buffer.{}(self.offset + {}, {}::from(value));\n",
731                        write_method, field_offset, prim_type
732                    ));
733                    output.push_str("        self\n");
734                    output.push_str("    }\n\n");
735                }
736                Some(TypeKind::Set { encoding, .. }) => {
737                    let write_method = get_write_method(Some(*encoding));
738                    output.push_str(&format!(
739                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
740                        field.setter_name, rust_type
741                    ));
742                    output.push_str(&format!(
743                        "        self.buffer.{}(self.offset + {}, value.raw());\n",
744                        write_method, field_offset
745                    ));
746                    output.push_str("        self\n");
747                    output.push_str("    }\n\n");
748                }
749                Some(TypeKind::Composite { .. }) => {
750                    output.push_str(&format!(
751                        "    pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
752                        field.setter_name, rust_type
753                    ));
754                    output.push_str(&format!(
755                        "        {}Encoder::wrap(self.buffer, self.offset + {})\n",
756                        rust_type, field_offset
757                    ));
758                    output.push_str("    }\n\n");
759                }
760                _ => {
761                    let write_method = get_write_method(field.primitive_type);
762                    output.push_str(&format!(
763                        "    pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
764                        field.setter_name, rust_type
765                    ));
766                    output.push_str(&format!(
767                        "        self.buffer.{}(self.offset + {}, value);\n",
768                        write_method, field_offset
769                    ));
770                    output.push_str("        self\n");
771                    output.push_str("    }\n\n");
772                }
773            }
774        }
775
776        output
777    }
778
779    /// Generates a group encoder accessor on the parent message encoder.
780    fn generate_group_encoder_accessor(
781        &self,
782        group: &ResolvedGroup,
783        offset: usize,
784        msg_name: &str,
785    ) -> String {
786        let mut output = String::new();
787        let qualified = format!("{}::{}", to_snake_case(msg_name), group.encoder_name());
788
789        output.push_str(&format!(
790            "    /// Begin encoding the {} repeating group.\n",
791            group.name
792        ));
793        output.push_str(&format!(
794            "    pub fn {}_count(&mut self, count: u16) -> {}<'_> {{\n",
795            to_snake_case(&group.name),
796            qualified
797        ));
798        output.push_str(&format!(
799            "        {}::wrap(&mut *self.buffer, self.offset + MessageHeader::ENCODED_LENGTH + {}, count)\n",
800            qualified, offset
801        ));
802        output.push_str("    }\n\n");
803
804        output
805    }
806}
807
808/// Gets the read method name for a primitive type.
809fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
810    match prim {
811        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
812        Some(PrimitiveType::Int8) => "get_i8",
813        Some(PrimitiveType::Uint16) => "get_u16_le",
814        Some(PrimitiveType::Int16) => "get_i16_le",
815        Some(PrimitiveType::Uint32) => "get_u32_le",
816        Some(PrimitiveType::Int32) => "get_i32_le",
817        Some(PrimitiveType::Uint64) => "get_u64_le",
818        Some(PrimitiveType::Int64) => "get_i64_le",
819        Some(PrimitiveType::Float) => "get_f32_le",
820        Some(PrimitiveType::Double) => "get_f64_le",
821        None => "get_u64_le",
822    }
823}
824
825/// Gets the write method name for a primitive type.
826fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
827    match prim {
828        Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
829        Some(PrimitiveType::Int8) => "put_i8",
830        Some(PrimitiveType::Uint16) => "put_u16_le",
831        Some(PrimitiveType::Int16) => "put_i16_le",
832        Some(PrimitiveType::Uint32) => "put_u32_le",
833        Some(PrimitiveType::Int32) => "put_i32_le",
834        Some(PrimitiveType::Uint64) => "put_u64_le",
835        Some(PrimitiveType::Int64) => "put_i64_le",
836        Some(PrimitiveType::Float) => "put_f32_le",
837        Some(PrimitiveType::Double) => "put_f64_le",
838        None => "put_u64_le",
839    }
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845    use ironsbe_schema::{SchemaIr, parse_schema};
846
847    fn schema_with_shared_group_name() -> String {
848        r#"<?xml version="1.0" encoding="UTF-8"?>
849<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
850                   package="test" id="1" version="1" byteOrder="littleEndian">
851    <types>
852        <type name="uint64" primitiveType="uint64"/>
853    </types>
854    <sbe:message name="CreateRfqResponse" id="21" blockLength="8">
855        <field name="value" id="1" type="uint64" offset="0"/>
856        <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
857            <field name="price" id="200" type="uint64" offset="0"/>
858        </group>
859    </sbe:message>
860    <sbe:message name="GetRfqResponse" id="23" blockLength="8">
861        <field name="value" id="1" type="uint64" offset="0"/>
862        <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
863            <field name="price" id="200" type="uint64" offset="0"/>
864        </group>
865    </sbe:message>
866</sbe:messageSchema>"#
867            .to_string()
868    }
869
870    fn schema_with_group_no_offsets() -> String {
871        r#"<?xml version="1.0" encoding="UTF-8"?>
872<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
873                   package="test" id="1" version="1" byteOrder="littleEndian">
874    <types>
875        <type name="uint64" primitiveType="uint64"/>
876        <type name="uint32" primitiveType="uint32"/>
877    </types>
878    <sbe:message name="ListOrders" id="19" blockLength="0">
879        <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
880            <field name="orderId" id="1" type="uint64" offset="0"/>
881            <field name="instrumentId" id="2" type="uint32"/>
882            <field name="quantity" id="3" type="uint64"/>
883        </group>
884    </sbe:message>
885</sbe:messageSchema>"#
886            .to_string()
887    }
888
889    fn schema_with_group_explicit_offsets() -> String {
890        r#"<?xml version="1.0" encoding="UTF-8"?>
891<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
892                   package="test" id="1" version="1" byteOrder="littleEndian">
893    <types>
894        <type name="uint64" primitiveType="uint64"/>
895        <type name="uint32" primitiveType="uint32"/>
896    </types>
897    <sbe:message name="ListOrders" id="19" blockLength="0">
898        <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
899            <field name="orderId" id="1" type="uint64" offset="0"/>
900            <field name="instrumentId" id="2" type="uint32" offset="8"/>
901            <field name="quantity" id="3" type="uint64" offset="12"/>
902        </group>
903    </sbe:message>
904</sbe:messageSchema>"#
905            .to_string()
906    }
907
908    #[test]
909    fn test_duplicate_group_name_generates_scoped_modules() {
910        let xml = schema_with_shared_group_name();
911        let schema = parse_schema(&xml).expect("Failed to parse schema");
912        let ir = SchemaIr::from_schema(&schema);
913        let msg_gen = MessageGenerator::new(&ir);
914        let code = msg_gen.generate();
915
916        assert!(
917            code.contains("pub mod create_rfq_response {"),
918            "expected module for CreateRfqResponse groups"
919        );
920        assert!(
921            code.contains("pub mod get_rfq_response {"),
922            "expected module for GetRfqResponse groups"
923        );
924
925        let occurrences = code.matches("pub struct QuotesGroupDecoder").count();
926        assert_eq!(
927            occurrences, 2,
928            "expected one QuotesGroupDecoder per message module, got {occurrences}"
929        );
930    }
931
932    #[test]
933    fn test_group_accessor_uses_qualified_path() {
934        let xml = schema_with_shared_group_name();
935        let schema = parse_schema(&xml).expect("Failed to parse schema");
936        let ir = SchemaIr::from_schema(&schema);
937        let msg_gen = MessageGenerator::new(&ir);
938        let code = msg_gen.generate();
939
940        assert!(
941            code.contains("create_rfq_response::QuotesGroupDecoder"),
942            "accessor in CreateRfqResponse must reference module-qualified type"
943        );
944        assert!(
945            code.contains("get_rfq_response::QuotesGroupDecoder"),
946            "accessor in GetRfqResponse must reference module-qualified type"
947        );
948    }
949
950    #[test]
951    fn test_entry_decoder_field_offsets_auto_computed() {
952        let xml = schema_with_group_no_offsets();
953        let schema = parse_schema(&xml).expect("Failed to parse schema");
954        let ir = SchemaIr::from_schema(&schema);
955        let msg_gen = MessageGenerator::new(&ir);
956        let code = msg_gen.generate();
957
958        // orderId at offset 0
959        assert!(
960            code.contains("self.offset + 0)"),
961            "orderId should be at offset 0"
962        );
963        // instrumentId at offset 8 (after uint64)
964        assert!(
965            code.contains("self.offset + 8)"),
966            "instrumentId should be at offset 8, not 0"
967        );
968        // quantity at offset 12 (after uint64 + uint32)
969        assert!(
970            code.contains("self.offset + 12)"),
971            "quantity should be at offset 12, not 0"
972        );
973    }
974
975    #[test]
976    fn test_entry_decoder_field_offsets_explicit() {
977        let xml = schema_with_group_explicit_offsets();
978        let schema = parse_schema(&xml).expect("Failed to parse schema");
979        let ir = SchemaIr::from_schema(&schema);
980        let msg_gen = MessageGenerator::new(&ir);
981        let code = msg_gen.generate();
982
983        assert!(
984            code.contains("self.offset + 8)"),
985            "instrumentId should be at explicit offset 8"
986        );
987        assert!(
988            code.contains("self.offset + 12)"),
989            "quantity should be at explicit offset 12"
990        );
991    }
992
993    #[test]
994    fn test_group_encoder_emitted() {
995        let xml = schema_with_group_no_offsets();
996        let schema = parse_schema(&xml).expect("Failed to parse schema");
997        let ir = SchemaIr::from_schema(&schema);
998        let msg_gen = MessageGenerator::new(&ir);
999        let code = msg_gen.generate();
1000
1001        assert!(
1002            code.contains("pub struct OrdersGroupEncoder"),
1003            "expected OrdersGroupEncoder struct"
1004        );
1005        assert!(
1006            code.contains("pub struct OrdersEntryEncoder"),
1007            "expected OrdersEntryEncoder struct"
1008        );
1009    }
1010
1011    #[test]
1012    fn test_group_encoder_has_next_entry() {
1013        let xml = schema_with_group_no_offsets();
1014        let schema = parse_schema(&xml).expect("Failed to parse schema");
1015        let ir = SchemaIr::from_schema(&schema);
1016        let msg_gen = MessageGenerator::new(&ir);
1017        let code = msg_gen.generate();
1018
1019        assert!(
1020            code.contains("fn next_entry(&mut self)"),
1021            "expected next_entry method on group encoder"
1022        );
1023    }
1024
1025    #[test]
1026    fn test_entry_encoder_has_field_setters() {
1027        let xml = schema_with_group_no_offsets();
1028        let schema = parse_schema(&xml).expect("Failed to parse schema");
1029        let ir = SchemaIr::from_schema(&schema);
1030        let msg_gen = MessageGenerator::new(&ir);
1031        let code = msg_gen.generate();
1032
1033        assert!(
1034            code.contains("fn set_order_id(&mut self, value: u64)"),
1035            "expected set_order_id setter"
1036        );
1037        assert!(
1038            code.contains("fn set_instrument_id(&mut self, value: u32)"),
1039            "expected set_instrument_id setter"
1040        );
1041        assert!(
1042            code.contains("fn set_quantity(&mut self, value: u64)"),
1043            "expected set_quantity setter"
1044        );
1045    }
1046
1047    #[test]
1048    fn test_parent_encoder_has_group_accessor() {
1049        let xml = schema_with_group_no_offsets();
1050        let schema = parse_schema(&xml).expect("Failed to parse schema");
1051        let ir = SchemaIr::from_schema(&schema);
1052        let msg_gen = MessageGenerator::new(&ir);
1053        let code = msg_gen.generate();
1054
1055        assert!(
1056            code.contains("fn orders_count(&mut self, count: u16)"),
1057            "expected orders_count accessor on parent encoder"
1058        );
1059    }
1060
1061    #[test]
1062    fn test_roundtrip_group_codegen_structure() {
1063        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1064<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
1065                   package="test" id="1" version="1" byteOrder="littleEndian">
1066    <types>
1067        <type name="uint64" primitiveType="uint64"/>
1068        <type name="uint32" primitiveType="uint32"/>
1069        <type name="uint8" primitiveType="uint8"/>
1070    </types>
1071    <sbe:message name="ListOrders" id="19" blockLength="8">
1072        <field name="requestId" id="1" type="uint64" offset="0"/>
1073        <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="29">
1074            <field name="orderId" id="10" type="uint64" offset="0"/>
1075            <field name="instrumentId" id="11" type="uint32"/>
1076            <field name="price" id="12" type="uint64"/>
1077            <field name="quantity" id="13" type="uint64"/>
1078            <field name="side" id="14" type="uint8"/>
1079        </group>
1080    </sbe:message>
1081</sbe:messageSchema>"#;
1082
1083        let schema = parse_schema(xml).expect("Failed to parse schema");
1084        let ir = SchemaIr::from_schema(&schema);
1085        let msg_gen = MessageGenerator::new(&ir);
1086        let code = msg_gen.generate();
1087
1088        // --- Decoder side ---
1089        let decoder_pos = code
1090            .find("impl<'a> OrdersEntryDecoder<'a>")
1091            .expect("entry decoder impl");
1092        let decoder_section = &code[decoder_pos..];
1093        // Verify all five fields have distinct offsets
1094        assert!(decoder_section.contains("self.offset + 0)"));
1095        assert!(decoder_section.contains("self.offset + 8)"));
1096        assert!(decoder_section.contains("self.offset + 12)"));
1097        assert!(decoder_section.contains("self.offset + 20)"));
1098        assert!(decoder_section.contains("self.offset + 28)"));
1099
1100        // --- Encoder side ---
1101        let encoder_pos = code
1102            .find("impl<'a> OrdersEntryEncoder<'a>")
1103            .expect("entry encoder impl");
1104        let encoder_section = &code[encoder_pos..];
1105        // Verify setter offsets match decoder offsets
1106        assert!(encoder_section.contains("self.offset + 0,"));
1107        assert!(encoder_section.contains("self.offset + 8,"));
1108        assert!(encoder_section.contains("self.offset + 12,"));
1109        assert!(encoder_section.contains("self.offset + 20,"));
1110        assert!(encoder_section.contains("self.offset + 28,"));
1111
1112        // --- Group encoder wiring ---
1113        assert!(
1114            code.contains("BLOCK_LENGTH: u16 = 29"),
1115            "group encoder BLOCK_LENGTH"
1116        );
1117        assert!(
1118            code.contains("fn orders_count(&mut self, count: u16)"),
1119            "parent encoder group accessor"
1120        );
1121        assert!(
1122            code.contains("list_orders::OrdersGroupEncoder::wrap(&mut *self.buffer"),
1123            "parent encoder delegates to module-qualified group encoder"
1124        );
1125
1126        // --- Group decoder wiring ---
1127        assert!(
1128            code.contains("list_orders::OrdersGroupDecoder"),
1129            "parent decoder uses module-qualified group decoder"
1130        );
1131    }
1132
1133    #[test]
1134    fn test_entry_encoder_setter_offsets_correct() {
1135        let xml = schema_with_group_no_offsets();
1136        let schema = parse_schema(&xml).expect("Failed to parse schema");
1137        let ir = SchemaIr::from_schema(&schema);
1138        let msg_gen = MessageGenerator::new(&ir);
1139        let code = msg_gen.generate();
1140
1141        // Find the EntryEncoder section and verify offsets in setters
1142        let entry_encoder_start = code
1143            .find("impl<'a> OrdersEntryEncoder<'a>")
1144            .expect("EntryEncoder impl not found");
1145        let entry_code = &code[entry_encoder_start..];
1146
1147        // set_order_id at offset 0
1148        assert!(
1149            entry_code.contains("self.offset + 0,"),
1150            "set_order_id should write at offset 0"
1151        );
1152        // set_instrument_id at offset 8
1153        assert!(
1154            entry_code.contains("self.offset + 8,"),
1155            "set_instrument_id should write at offset 8"
1156        );
1157        // set_quantity at offset 12
1158        assert!(
1159            entry_code.contains("self.offset + 12,"),
1160            "set_quantity should write at offset 12"
1161        );
1162    }
1163}