1use ironsbe_schema::ir::{
4 ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8pub struct MessageGenerator<'a> {
10 ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14 #[must_use]
16 pub fn new(ir: &'a SchemaIr) -> Self {
17 Self { ir }
18 }
19
20 #[must_use]
22 pub fn generate(&self) -> String {
23 let mut output = String::new();
24
25 for msg in &self.ir.messages {
26 output.push_str(&self.generate_decoder(msg));
27 output.push_str(&self.generate_encoder(msg));
28
29 if !msg.groups.is_empty() {
31 let mod_name = to_snake_case(&msg.name);
32 output.push_str(&format!("/// Types for {} repeating groups.\n", msg.name));
33 output.push_str(&format!("pub mod {} {{\n", mod_name));
34 output.push_str(" use super::*;\n\n");
35 for group in &msg.groups {
36 output.push_str(&self.generate_group_decoder(group));
37 output.push_str(&self.generate_group_encoder(group));
38 }
39 output.push_str("}\n\n");
40 }
41 }
42
43 output
44 }
45
46 fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
48 let mut output = String::new();
49 let decoder_name = msg.decoder_name();
50
51 output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
53 output.push_str("#[derive(Debug, Clone, Copy)]\n");
54 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
55 output.push_str(" buffer: &'a [u8],\n");
56 output.push_str(" offset: usize,\n");
57 output.push_str(" acting_version: u16,\n");
58 output.push_str("}\n\n");
59
60 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
62 output.push_str(&format!(
63 " /// Template ID for this message.\n\
64 pub const TEMPLATE_ID: u16 = {};\n",
65 msg.template_id
66 ));
67 output.push_str(&format!(
68 " /// Block length of the fixed portion.\n\
69 pub const BLOCK_LENGTH: u16 = {};\n\n",
70 msg.block_length
71 ));
72
73 output.push_str(" /// Wraps a buffer for zero-copy decoding.\n");
75 output.push_str(" ///\n");
76 output.push_str(" /// # Arguments\n");
77 output.push_str(" /// * `buffer` - Buffer containing the message\n");
78 output.push_str(
79 " /// * `offset` - Offset to the start of the root block (after header)\n",
80 );
81 output.push_str(" /// * `acting_version` - Schema version for compatibility\n");
82 output.push_str(" #[inline]\n");
83 output.push_str(" #[must_use]\n");
84 output.push_str(
85 " pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
86 );
87 output.push_str(" Self { buffer, offset, acting_version }\n");
88 output.push_str(" }\n\n");
89
90 for field in &msg.fields {
92 output.push_str(&self.generate_field_getter(field));
93 }
94
95 let mut group_offset = msg.block_length as usize;
97 for group in &msg.groups {
98 output.push_str(&self.generate_group_accessor(group, group_offset, &msg.name));
99 group_offset += 4; }
101
102 output.push_str("}\n\n");
103
104 output.push_str(&format!(
106 "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
107 decoder_name
108 ));
109 output.push_str(&format!(
110 " const TEMPLATE_ID: u16 = {};\n",
111 msg.template_id
112 ));
113 output.push_str(" const SCHEMA_ID: u16 = SCHEMA_ID;\n");
114 output.push_str(" const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
115 output.push_str(&format!(
116 " const BLOCK_LENGTH: u16 = {};\n\n",
117 msg.block_length
118 ));
119
120 output.push_str(
121 " fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
122 );
123 output.push_str(" Self::wrap(buffer, offset, acting_version)\n");
124 output.push_str(" }\n\n");
125
126 output.push_str(" fn encoded_length(&self) -> usize {\n");
127 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
128 output.push_str(" }\n");
129 output.push_str("}\n\n");
130
131 output
132 }
133
134 fn generate_field_getter(&self, field: &ResolvedField) -> String {
136 let mut output = String::new();
137
138 output.push_str(&format!(
139 " /// Field: {} (id={}, offset={}).\n",
140 field.name, field.id, field.offset
141 ));
142 output.push_str(" #[inline(always)]\n");
143 output.push_str(" #[must_use]\n");
144
145 if field.is_array {
146 let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
148 let len = field.array_length.unwrap_or(1);
149
150 if elem_type == "u8" {
151 output.push_str(&format!(
153 " pub fn {}(&self) -> &'a [u8] {{\n",
154 field.getter_name
155 ));
156 output.push_str(&format!(
157 " &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
158 field.offset, field.offset, len
159 ));
160 output.push_str(" }\n\n");
161
162 output.push_str(&format!(
164 " /// Field {} as string (trimmed).\n",
165 field.name
166 ));
167 output.push_str(" #[inline]\n");
168 output.push_str(" #[must_use]\n");
169 output.push_str(&format!(
170 " pub fn {}_as_str(&self) -> &'a str {{\n",
171 field.getter_name
172 ));
173 output.push_str(&format!(
174 " let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
175 field.offset, field.offset, len
176 ));
177 output.push_str(
178 " let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
179 );
180 output.push_str(" std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
181 output.push_str(" }\n\n");
182 } else {
183 output.push_str(&format!(
185 " pub fn {}(&self) -> &'a [u8] {{\n",
186 field.getter_name
187 ));
188 output.push_str(&format!(
189 " &self.buffer[self.offset + {}..self.offset + {}]\n",
190 field.offset,
191 field.offset + field.encoded_length
192 ));
193 output.push_str(" }\n\n");
194 }
195 } else {
196 let rust_type = &field.rust_type;
198 let resolved_type = self.ir.get_type(&field.type_name);
199
200 match resolved_type.map(|t| &t.kind) {
201 Some(TypeKind::Enum { encoding, .. }) => {
202 let read_method = get_read_method(Some(*encoding));
204 output.push_str(&format!(
205 " pub fn {}(&self) -> {} {{\n",
206 field.getter_name, rust_type
207 ));
208 output.push_str(&format!(
209 " {}::from(self.buffer.{}(self.offset + {}))\n",
210 rust_type, read_method, field.offset
211 ));
212 output.push_str(" }\n\n");
213 }
214 Some(TypeKind::Set { encoding, .. }) => {
215 let read_method = get_read_method(Some(*encoding));
217 output.push_str(&format!(
218 " pub fn {}(&self) -> {} {{\n",
219 field.getter_name, rust_type
220 ));
221 output.push_str(&format!(
222 " {}::from_raw(self.buffer.{}(self.offset + {}))\n",
223 rust_type, read_method, field.offset
224 ));
225 output.push_str(" }\n\n");
226 }
227 Some(TypeKind::Composite { .. }) => {
228 output.push_str(&format!(
230 " pub fn {}(&self) -> {}<'a> {{\n",
231 field.getter_name, rust_type
232 ));
233 output.push_str(&format!(
234 " {}::wrap(self.buffer, self.offset + {})\n",
235 rust_type, field.offset
236 ));
237 output.push_str(" }\n\n");
238 }
239 _ => {
240 let read_method = get_read_method(field.primitive_type);
242 output.push_str(&format!(
243 " pub fn {}(&self) -> {} {{\n",
244 field.getter_name, rust_type
245 ));
246 output.push_str(&format!(
247 " self.buffer.{}(self.offset + {})\n",
248 read_method, field.offset
249 ));
250 output.push_str(" }\n\n");
251 }
252 }
253 }
254
255 output
256 }
257
258 fn generate_group_accessor(
260 &self,
261 group: &ResolvedGroup,
262 offset: usize,
263 msg_name: &str,
264 ) -> String {
265 let mut output = String::new();
266 let qualified = format!("{}::{}", to_snake_case(msg_name), group.decoder_name());
267
268 output.push_str(&format!(" /// Access {} repeating group.\n", group.name));
269 output.push_str(" #[inline]\n");
270 output.push_str(" #[must_use]\n");
271 output.push_str(&format!(
272 " pub fn {}(&self) -> {}<'a> {{\n",
273 to_snake_case(&group.name),
274 qualified
275 ));
276 output.push_str(&format!(
277 " {}::wrap(self.buffer, self.offset + {})\n",
278 qualified, offset
279 ));
280 output.push_str(" }\n\n");
281
282 output
283 }
284
285 fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
287 let mut output = String::new();
288 let encoder_name = msg.encoder_name();
289
290 output.push_str(&format!("/// {} Encoder.\n", msg.name));
292 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
293 output.push_str(" buffer: &'a mut [u8],\n");
294 output.push_str(" offset: usize,\n");
295 output.push_str("}\n\n");
296
297 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
299 output.push_str(&format!(
300 " /// Template ID for this message.\n\
301 pub const TEMPLATE_ID: u16 = {};\n",
302 msg.template_id
303 ));
304 output.push_str(&format!(
305 " /// Block length of the fixed portion.\n\
306 pub const BLOCK_LENGTH: u16 = {};\n\n",
307 msg.block_length
308 ));
309
310 output.push_str(" /// Wraps a buffer for encoding, writing the header.\n");
312 output.push_str(" #[inline]\n");
313 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
314 output.push_str(" let mut encoder = Self { buffer, offset };\n");
315 output.push_str(" encoder.write_header();\n");
316 output.push_str(" encoder\n");
317 output.push_str(" }\n\n");
318
319 output.push_str(" fn write_header(&mut self) {\n");
321 output.push_str(" let header = MessageHeader {\n");
322 output.push_str(" block_length: Self::BLOCK_LENGTH,\n");
323 output.push_str(" template_id: Self::TEMPLATE_ID,\n");
324 output.push_str(" schema_id: SCHEMA_ID,\n");
325 output.push_str(" version: SCHEMA_VERSION,\n");
326 output.push_str(" };\n");
327 output.push_str(" header.encode(self.buffer, self.offset);\n");
328 output.push_str(" }\n\n");
329
330 output.push_str(" /// Returns the encoded length of the message.\n");
332 output.push_str(" #[must_use]\n");
333 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
334 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
335 output.push_str(" }\n\n");
336
337 for field in &msg.fields {
339 output.push_str(&self.generate_field_setter(field));
340 }
341
342 let mut group_offset = msg.block_length as usize;
344 for group in &msg.groups {
345 output.push_str(&self.generate_group_encoder_accessor(group, group_offset, &msg.name));
346 group_offset += 4; }
348
349 output.push_str("}\n\n");
350
351 output
352 }
353
354 fn generate_field_setter(&self, field: &ResolvedField) -> String {
356 let mut output = String::new();
357 let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
358
359 output.push_str(&format!(
360 " /// Set field: {} (id={}, offset={}).\n",
361 field.name, field.id, field.offset
362 ));
363 output.push_str(" #[inline(always)]\n");
364
365 if field.is_array {
366 let len = field.array_length.unwrap_or(field.encoded_length);
368
369 output.push_str(&format!(
370 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
371 field.setter_name
372 ));
373 output.push_str(&format!(
374 " let copy_len = value.len().min({});\n",
375 len
376 ));
377 output.push_str(&format!(
378 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
379 field_offset, field_offset
380 ));
381 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
382 output.push_str(&format!(" if copy_len < {} {{\n", len));
383 output.push_str(&format!(
384 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
385 field_offset, field_offset, len
386 ));
387 output.push_str(" }\n");
388 output.push_str(" self\n");
389 output.push_str(" }\n\n");
390 } else {
391 let rust_type = &field.rust_type;
393 let resolved_type = self.ir.get_type(&field.type_name);
394
395 match resolved_type.map(|t| &t.kind) {
396 Some(TypeKind::Enum { encoding, .. }) => {
397 let write_method = get_write_method(Some(*encoding));
399 let prim_type = encoding.rust_type();
400 output.push_str(&format!(
401 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
402 field.setter_name, rust_type
403 ));
404 output.push_str(&format!(
405 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
406 write_method, field_offset, prim_type
407 ));
408 output.push_str(" self\n");
409 output.push_str(" }\n\n");
410 }
411 Some(TypeKind::Set { encoding, .. }) => {
412 let write_method = get_write_method(Some(*encoding));
414 output.push_str(&format!(
415 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
416 field.setter_name, rust_type
417 ));
418 output.push_str(&format!(
419 " self.buffer.{}(self.offset + {}, value.raw());\n",
420 write_method, field_offset
421 ));
422 output.push_str(" self\n");
423 output.push_str(" }\n\n");
424 }
425 Some(TypeKind::Composite { .. }) => {
426 output.push_str(&format!(
428 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
429 field.setter_name, rust_type
430 ));
431 output.push_str(&format!(
432 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
433 rust_type, field_offset
434 ));
435 output.push_str(" }\n\n");
436 }
437 _ => {
438 let write_method = get_write_method(field.primitive_type);
440 output.push_str(&format!(
441 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
442 field.setter_name, rust_type
443 ));
444 output.push_str(&format!(
445 " self.buffer.{}(self.offset + {}, value);\n",
446 write_method, field_offset
447 ));
448 output.push_str(" self\n");
449 output.push_str(" }\n\n");
450 }
451 }
452 }
453
454 output
455 }
456
457 fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
459 let mut output = String::new();
460 let decoder_name = group.decoder_name();
461 let entry_name = group.entry_decoder_name();
462
463 output.push_str(&format!("/// {} Group Decoder.\n", group.name));
465 output.push_str("#[derive(Debug, Clone, Copy)]\n");
466 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
467 output.push_str(" buffer: &'a [u8],\n");
468 output.push_str(" block_length: u16,\n");
469 output.push_str(" count: u16,\n");
470 output.push_str(" index: u16,\n");
471 output.push_str(" offset: usize,\n");
472 output.push_str("}\n\n");
473
474 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
476 output.push_str(" /// Wraps a buffer at the group header position.\n");
477 output.push_str(" #[must_use]\n");
478 output.push_str(" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
479 output.push_str(" let header = GroupHeader::wrap(buffer, offset);\n");
480 output.push_str(" Self {\n");
481 output.push_str(" buffer,\n");
482 output.push_str(" block_length: header.block_length,\n");
483 output.push_str(" count: header.num_in_group,\n");
484 output.push_str(" index: 0,\n");
485 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
486 output.push_str(" }\n");
487 output.push_str(" }\n\n");
488
489 output.push_str(" /// Returns the number of entries in the group.\n");
490 output.push_str(" #[must_use]\n");
491 output.push_str(" pub const fn count(&self) -> u16 {\n");
492 output.push_str(" self.count\n");
493 output.push_str(" }\n\n");
494
495 output.push_str(" /// Returns true if the group is empty.\n");
496 output.push_str(" #[must_use]\n");
497 output.push_str(" pub const fn is_empty(&self) -> bool {\n");
498 output.push_str(" self.count == 0\n");
499 output.push_str(" }\n");
500 output.push_str("}\n\n");
501
502 output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
504 output.push_str(&format!(" type Item = {}<'a>;\n\n", entry_name));
505 output.push_str(" fn next(&mut self) -> Option<Self::Item> {\n");
506 output.push_str(" if self.index >= self.count {\n");
507 output.push_str(" return None;\n");
508 output.push_str(" }\n");
509 output.push_str(&format!(
510 " let entry = {}::wrap(self.buffer, self.offset);\n",
511 entry_name
512 ));
513 output.push_str(" self.offset += self.block_length as usize;\n");
514 output.push_str(" self.index += 1;\n");
515 output.push_str(" Some(entry)\n");
516 output.push_str(" }\n\n");
517
518 output.push_str(" fn size_hint(&self) -> (usize, Option<usize>) {\n");
519 output.push_str(" let remaining = (self.count - self.index) as usize;\n");
520 output.push_str(" (remaining, Some(remaining))\n");
521 output.push_str(" }\n");
522 output.push_str("}\n\n");
523
524 output.push_str(&format!(
525 "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
526 decoder_name
527 ));
528
529 output.push_str(&self.generate_entry_decoder(group));
531
532 for nested in &group.nested_groups {
534 output.push_str(&self.generate_group_decoder(nested));
535 }
536
537 output
538 }
539
540 fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
542 let mut output = String::new();
543 let entry_name = group.entry_decoder_name();
544
545 output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
546 output.push_str("#[derive(Debug, Clone, Copy)]\n");
547 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
548 output.push_str(" buffer: &'a [u8],\n");
549 output.push_str(" offset: usize,\n");
550 output.push_str("}\n\n");
551
552 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
553 output.push_str(" fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
554 output.push_str(" Self { buffer, offset }\n");
555 output.push_str(" }\n\n");
556
557 for field in &group.fields {
559 output.push_str(&self.generate_field_getter(field));
560 }
561
562 output.push_str("}\n\n");
563
564 output
565 }
566
567 fn generate_group_encoder(&self, group: &ResolvedGroup) -> String {
569 let mut output = String::new();
570 let encoder_name = group.encoder_name();
571 let entry_name = group.entry_encoder_name();
572
573 output.push_str(&format!("/// {} Group Encoder.\n", group.name));
575 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
576 output.push_str(" buffer: &'a mut [u8],\n");
577 output.push_str(" count: u16,\n");
578 output.push_str(" index: u16,\n");
579 output.push_str(" offset: usize,\n");
580 output.push_str("}\n\n");
581
582 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
584 output.push_str(&format!(
585 " /// Block length of each entry.\n\
586 pub const BLOCK_LENGTH: u16 = {};\n\n",
587 group.block_length
588 ));
589
590 output
592 .push_str(" /// Wraps a buffer at the group header position, writing the header.\n");
593 output.push_str(" ///\n");
594 output.push_str(" /// # Arguments\n");
595 output.push_str(" /// * `buffer` - Mutable buffer to write to\n");
596 output.push_str(" /// * `offset` - Offset of the group header\n");
597 output.push_str(" /// * `count` - Number of entries to encode\n");
598 output.push_str(
599 " pub fn wrap(buffer: &'a mut [u8], offset: usize, count: u16) -> Self {\n",
600 );
601 output.push_str(" let header = GroupHeader::new(Self::BLOCK_LENGTH, count);\n");
602 output.push_str(" header.encode(buffer, offset);\n");
603 output.push_str(" Self {\n");
604 output.push_str(" buffer,\n");
605 output.push_str(" count,\n");
606 output.push_str(" index: 0,\n");
607 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
608 output.push_str(" }\n");
609 output.push_str(" }\n\n");
610
611 output.push_str(
613 " /// Returns the next entry encoder, or `None` if all entries are written.\n",
614 );
615 output.push_str(&format!(
616 " pub fn next_entry(&mut self) -> Option<{}<'_>> {{\n",
617 entry_name
618 ));
619 output.push_str(" if self.index >= self.count {\n");
620 output.push_str(" return None;\n");
621 output.push_str(" }\n");
622 output.push_str(" let offset = self.offset;\n");
623 output.push_str(" self.offset += Self::BLOCK_LENGTH as usize;\n");
624 output.push_str(" self.index += 1;\n");
625 output.push_str(&format!(
626 " Some({}::wrap(&mut *self.buffer, offset))\n",
627 entry_name
628 ));
629 output.push_str(" }\n\n");
630
631 output.push_str(
633 " /// Returns the total encoded length of this group (header + all entries).\n",
634 );
635 output.push_str(" #[must_use]\n");
636 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
637 output.push_str(" GroupHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize * self.count as usize\n");
638 output.push_str(" }\n");
639 output.push_str("}\n\n");
640
641 output.push_str(&self.generate_entry_encoder(group));
643
644 for nested in &group.nested_groups {
646 output.push_str(&self.generate_group_encoder(nested));
647 }
648
649 output
650 }
651
652 fn generate_entry_encoder(&self, group: &ResolvedGroup) -> String {
654 let mut output = String::new();
655 let entry_name = group.entry_encoder_name();
656
657 output.push_str(&format!("/// {} Entry Encoder.\n", group.name));
658 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
659 output.push_str(" buffer: &'a mut [u8],\n");
660 output.push_str(" offset: usize,\n");
661 output.push_str("}\n\n");
662
663 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
664 output.push_str(" fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
665 output.push_str(" Self { buffer, offset }\n");
666 output.push_str(" }\n\n");
667
668 for field in &group.fields {
670 output.push_str(&self.generate_entry_field_setter(field));
671 }
672
673 output.push_str("}\n\n");
674
675 output
676 }
677
678 fn generate_entry_field_setter(&self, field: &ResolvedField) -> String {
684 let mut output = String::new();
685 let field_offset = field.offset;
686
687 output.push_str(&format!(
688 " /// Set field: {} (id={}, offset={}).\n",
689 field.name, field.id, field.offset
690 ));
691 output.push_str(" #[inline(always)]\n");
692
693 if field.is_array {
694 let len = field.array_length.unwrap_or(field.encoded_length);
695
696 output.push_str(&format!(
697 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
698 field.setter_name
699 ));
700 output.push_str(&format!(
701 " let copy_len = value.len().min({});\n",
702 len
703 ));
704 output.push_str(&format!(
705 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
706 field_offset, field_offset
707 ));
708 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
709 output.push_str(&format!(" if copy_len < {} {{\n", len));
710 output.push_str(&format!(
711 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
712 field_offset, field_offset, len
713 ));
714 output.push_str(" }\n");
715 output.push_str(" self\n");
716 output.push_str(" }\n\n");
717 } else {
718 let rust_type = &field.rust_type;
719 let resolved_type = self.ir.get_type(&field.type_name);
720
721 match resolved_type.map(|t| &t.kind) {
722 Some(TypeKind::Enum { encoding, .. }) => {
723 let write_method = get_write_method(Some(*encoding));
724 let prim_type = encoding.rust_type();
725 output.push_str(&format!(
726 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
727 field.setter_name, rust_type
728 ));
729 output.push_str(&format!(
730 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
731 write_method, field_offset, prim_type
732 ));
733 output.push_str(" self\n");
734 output.push_str(" }\n\n");
735 }
736 Some(TypeKind::Set { encoding, .. }) => {
737 let write_method = get_write_method(Some(*encoding));
738 output.push_str(&format!(
739 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
740 field.setter_name, rust_type
741 ));
742 output.push_str(&format!(
743 " self.buffer.{}(self.offset + {}, value.raw());\n",
744 write_method, field_offset
745 ));
746 output.push_str(" self\n");
747 output.push_str(" }\n\n");
748 }
749 Some(TypeKind::Composite { .. }) => {
750 output.push_str(&format!(
751 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
752 field.setter_name, rust_type
753 ));
754 output.push_str(&format!(
755 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
756 rust_type, field_offset
757 ));
758 output.push_str(" }\n\n");
759 }
760 _ => {
761 let write_method = get_write_method(field.primitive_type);
762 output.push_str(&format!(
763 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
764 field.setter_name, rust_type
765 ));
766 output.push_str(&format!(
767 " self.buffer.{}(self.offset + {}, value);\n",
768 write_method, field_offset
769 ));
770 output.push_str(" self\n");
771 output.push_str(" }\n\n");
772 }
773 }
774 }
775
776 output
777 }
778
779 fn generate_group_encoder_accessor(
781 &self,
782 group: &ResolvedGroup,
783 offset: usize,
784 msg_name: &str,
785 ) -> String {
786 let mut output = String::new();
787 let qualified = format!("{}::{}", to_snake_case(msg_name), group.encoder_name());
788
789 output.push_str(&format!(
790 " /// Begin encoding the {} repeating group.\n",
791 group.name
792 ));
793 output.push_str(&format!(
794 " pub fn {}_count(&mut self, count: u16) -> {}<'_> {{\n",
795 to_snake_case(&group.name),
796 qualified
797 ));
798 output.push_str(&format!(
799 " {}::wrap(&mut *self.buffer, self.offset + MessageHeader::ENCODED_LENGTH + {}, count)\n",
800 qualified, offset
801 ));
802 output.push_str(" }\n\n");
803
804 output
805 }
806}
807
808fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
810 match prim {
811 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
812 Some(PrimitiveType::Int8) => "get_i8",
813 Some(PrimitiveType::Uint16) => "get_u16_le",
814 Some(PrimitiveType::Int16) => "get_i16_le",
815 Some(PrimitiveType::Uint32) => "get_u32_le",
816 Some(PrimitiveType::Int32) => "get_i32_le",
817 Some(PrimitiveType::Uint64) => "get_u64_le",
818 Some(PrimitiveType::Int64) => "get_i64_le",
819 Some(PrimitiveType::Float) => "get_f32_le",
820 Some(PrimitiveType::Double) => "get_f64_le",
821 None => "get_u64_le",
822 }
823}
824
825fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
827 match prim {
828 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
829 Some(PrimitiveType::Int8) => "put_i8",
830 Some(PrimitiveType::Uint16) => "put_u16_le",
831 Some(PrimitiveType::Int16) => "put_i16_le",
832 Some(PrimitiveType::Uint32) => "put_u32_le",
833 Some(PrimitiveType::Int32) => "put_i32_le",
834 Some(PrimitiveType::Uint64) => "put_u64_le",
835 Some(PrimitiveType::Int64) => "put_i64_le",
836 Some(PrimitiveType::Float) => "put_f32_le",
837 Some(PrimitiveType::Double) => "put_f64_le",
838 None => "put_u64_le",
839 }
840}
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845 use ironsbe_schema::{SchemaIr, parse_schema};
846
847 fn schema_with_shared_group_name() -> String {
848 r#"<?xml version="1.0" encoding="UTF-8"?>
849<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
850 package="test" id="1" version="1" byteOrder="littleEndian">
851 <types>
852 <type name="uint64" primitiveType="uint64"/>
853 </types>
854 <sbe:message name="CreateRfqResponse" id="21" blockLength="8">
855 <field name="value" id="1" type="uint64" offset="0"/>
856 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
857 <field name="price" id="200" type="uint64" offset="0"/>
858 </group>
859 </sbe:message>
860 <sbe:message name="GetRfqResponse" id="23" blockLength="8">
861 <field name="value" id="1" type="uint64" offset="0"/>
862 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
863 <field name="price" id="200" type="uint64" offset="0"/>
864 </group>
865 </sbe:message>
866</sbe:messageSchema>"#
867 .to_string()
868 }
869
870 fn schema_with_group_no_offsets() -> String {
871 r#"<?xml version="1.0" encoding="UTF-8"?>
872<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
873 package="test" id="1" version="1" byteOrder="littleEndian">
874 <types>
875 <type name="uint64" primitiveType="uint64"/>
876 <type name="uint32" primitiveType="uint32"/>
877 </types>
878 <sbe:message name="ListOrders" id="19" blockLength="0">
879 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
880 <field name="orderId" id="1" type="uint64" offset="0"/>
881 <field name="instrumentId" id="2" type="uint32"/>
882 <field name="quantity" id="3" type="uint64"/>
883 </group>
884 </sbe:message>
885</sbe:messageSchema>"#
886 .to_string()
887 }
888
889 fn schema_with_group_explicit_offsets() -> String {
890 r#"<?xml version="1.0" encoding="UTF-8"?>
891<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
892 package="test" id="1" version="1" byteOrder="littleEndian">
893 <types>
894 <type name="uint64" primitiveType="uint64"/>
895 <type name="uint32" primitiveType="uint32"/>
896 </types>
897 <sbe:message name="ListOrders" id="19" blockLength="0">
898 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
899 <field name="orderId" id="1" type="uint64" offset="0"/>
900 <field name="instrumentId" id="2" type="uint32" offset="8"/>
901 <field name="quantity" id="3" type="uint64" offset="12"/>
902 </group>
903 </sbe:message>
904</sbe:messageSchema>"#
905 .to_string()
906 }
907
908 #[test]
909 fn test_duplicate_group_name_generates_scoped_modules() {
910 let xml = schema_with_shared_group_name();
911 let schema = parse_schema(&xml).expect("Failed to parse schema");
912 let ir = SchemaIr::from_schema(&schema);
913 let msg_gen = MessageGenerator::new(&ir);
914 let code = msg_gen.generate();
915
916 assert!(
917 code.contains("pub mod create_rfq_response {"),
918 "expected module for CreateRfqResponse groups"
919 );
920 assert!(
921 code.contains("pub mod get_rfq_response {"),
922 "expected module for GetRfqResponse groups"
923 );
924
925 let occurrences = code.matches("pub struct QuotesGroupDecoder").count();
926 assert_eq!(
927 occurrences, 2,
928 "expected one QuotesGroupDecoder per message module, got {occurrences}"
929 );
930 }
931
932 #[test]
933 fn test_group_accessor_uses_qualified_path() {
934 let xml = schema_with_shared_group_name();
935 let schema = parse_schema(&xml).expect("Failed to parse schema");
936 let ir = SchemaIr::from_schema(&schema);
937 let msg_gen = MessageGenerator::new(&ir);
938 let code = msg_gen.generate();
939
940 assert!(
941 code.contains("create_rfq_response::QuotesGroupDecoder"),
942 "accessor in CreateRfqResponse must reference module-qualified type"
943 );
944 assert!(
945 code.contains("get_rfq_response::QuotesGroupDecoder"),
946 "accessor in GetRfqResponse must reference module-qualified type"
947 );
948 }
949
950 #[test]
951 fn test_entry_decoder_field_offsets_auto_computed() {
952 let xml = schema_with_group_no_offsets();
953 let schema = parse_schema(&xml).expect("Failed to parse schema");
954 let ir = SchemaIr::from_schema(&schema);
955 let msg_gen = MessageGenerator::new(&ir);
956 let code = msg_gen.generate();
957
958 assert!(
960 code.contains("self.offset + 0)"),
961 "orderId should be at offset 0"
962 );
963 assert!(
965 code.contains("self.offset + 8)"),
966 "instrumentId should be at offset 8, not 0"
967 );
968 assert!(
970 code.contains("self.offset + 12)"),
971 "quantity should be at offset 12, not 0"
972 );
973 }
974
975 #[test]
976 fn test_entry_decoder_field_offsets_explicit() {
977 let xml = schema_with_group_explicit_offsets();
978 let schema = parse_schema(&xml).expect("Failed to parse schema");
979 let ir = SchemaIr::from_schema(&schema);
980 let msg_gen = MessageGenerator::new(&ir);
981 let code = msg_gen.generate();
982
983 assert!(
984 code.contains("self.offset + 8)"),
985 "instrumentId should be at explicit offset 8"
986 );
987 assert!(
988 code.contains("self.offset + 12)"),
989 "quantity should be at explicit offset 12"
990 );
991 }
992
993 #[test]
994 fn test_group_encoder_emitted() {
995 let xml = schema_with_group_no_offsets();
996 let schema = parse_schema(&xml).expect("Failed to parse schema");
997 let ir = SchemaIr::from_schema(&schema);
998 let msg_gen = MessageGenerator::new(&ir);
999 let code = msg_gen.generate();
1000
1001 assert!(
1002 code.contains("pub struct OrdersGroupEncoder"),
1003 "expected OrdersGroupEncoder struct"
1004 );
1005 assert!(
1006 code.contains("pub struct OrdersEntryEncoder"),
1007 "expected OrdersEntryEncoder struct"
1008 );
1009 }
1010
1011 #[test]
1012 fn test_group_encoder_has_next_entry() {
1013 let xml = schema_with_group_no_offsets();
1014 let schema = parse_schema(&xml).expect("Failed to parse schema");
1015 let ir = SchemaIr::from_schema(&schema);
1016 let msg_gen = MessageGenerator::new(&ir);
1017 let code = msg_gen.generate();
1018
1019 assert!(
1020 code.contains("fn next_entry(&mut self)"),
1021 "expected next_entry method on group encoder"
1022 );
1023 }
1024
1025 #[test]
1026 fn test_entry_encoder_has_field_setters() {
1027 let xml = schema_with_group_no_offsets();
1028 let schema = parse_schema(&xml).expect("Failed to parse schema");
1029 let ir = SchemaIr::from_schema(&schema);
1030 let msg_gen = MessageGenerator::new(&ir);
1031 let code = msg_gen.generate();
1032
1033 assert!(
1034 code.contains("fn set_order_id(&mut self, value: u64)"),
1035 "expected set_order_id setter"
1036 );
1037 assert!(
1038 code.contains("fn set_instrument_id(&mut self, value: u32)"),
1039 "expected set_instrument_id setter"
1040 );
1041 assert!(
1042 code.contains("fn set_quantity(&mut self, value: u64)"),
1043 "expected set_quantity setter"
1044 );
1045 }
1046
1047 #[test]
1048 fn test_parent_encoder_has_group_accessor() {
1049 let xml = schema_with_group_no_offsets();
1050 let schema = parse_schema(&xml).expect("Failed to parse schema");
1051 let ir = SchemaIr::from_schema(&schema);
1052 let msg_gen = MessageGenerator::new(&ir);
1053 let code = msg_gen.generate();
1054
1055 assert!(
1056 code.contains("fn orders_count(&mut self, count: u16)"),
1057 "expected orders_count accessor on parent encoder"
1058 );
1059 }
1060
1061 #[test]
1062 fn test_roundtrip_group_codegen_structure() {
1063 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1064<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
1065 package="test" id="1" version="1" byteOrder="littleEndian">
1066 <types>
1067 <type name="uint64" primitiveType="uint64"/>
1068 <type name="uint32" primitiveType="uint32"/>
1069 <type name="uint8" primitiveType="uint8"/>
1070 </types>
1071 <sbe:message name="ListOrders" id="19" blockLength="8">
1072 <field name="requestId" id="1" type="uint64" offset="0"/>
1073 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="29">
1074 <field name="orderId" id="10" type="uint64" offset="0"/>
1075 <field name="instrumentId" id="11" type="uint32"/>
1076 <field name="price" id="12" type="uint64"/>
1077 <field name="quantity" id="13" type="uint64"/>
1078 <field name="side" id="14" type="uint8"/>
1079 </group>
1080 </sbe:message>
1081</sbe:messageSchema>"#;
1082
1083 let schema = parse_schema(xml).expect("Failed to parse schema");
1084 let ir = SchemaIr::from_schema(&schema);
1085 let msg_gen = MessageGenerator::new(&ir);
1086 let code = msg_gen.generate();
1087
1088 let decoder_pos = code
1090 .find("impl<'a> OrdersEntryDecoder<'a>")
1091 .expect("entry decoder impl");
1092 let decoder_section = &code[decoder_pos..];
1093 assert!(decoder_section.contains("self.offset + 0)"));
1095 assert!(decoder_section.contains("self.offset + 8)"));
1096 assert!(decoder_section.contains("self.offset + 12)"));
1097 assert!(decoder_section.contains("self.offset + 20)"));
1098 assert!(decoder_section.contains("self.offset + 28)"));
1099
1100 let encoder_pos = code
1102 .find("impl<'a> OrdersEntryEncoder<'a>")
1103 .expect("entry encoder impl");
1104 let encoder_section = &code[encoder_pos..];
1105 assert!(encoder_section.contains("self.offset + 0,"));
1107 assert!(encoder_section.contains("self.offset + 8,"));
1108 assert!(encoder_section.contains("self.offset + 12,"));
1109 assert!(encoder_section.contains("self.offset + 20,"));
1110 assert!(encoder_section.contains("self.offset + 28,"));
1111
1112 assert!(
1114 code.contains("BLOCK_LENGTH: u16 = 29"),
1115 "group encoder BLOCK_LENGTH"
1116 );
1117 assert!(
1118 code.contains("fn orders_count(&mut self, count: u16)"),
1119 "parent encoder group accessor"
1120 );
1121 assert!(
1122 code.contains("list_orders::OrdersGroupEncoder::wrap(&mut *self.buffer"),
1123 "parent encoder delegates to module-qualified group encoder"
1124 );
1125
1126 assert!(
1128 code.contains("list_orders::OrdersGroupDecoder"),
1129 "parent decoder uses module-qualified group decoder"
1130 );
1131 }
1132
1133 #[test]
1134 fn test_entry_encoder_setter_offsets_correct() {
1135 let xml = schema_with_group_no_offsets();
1136 let schema = parse_schema(&xml).expect("Failed to parse schema");
1137 let ir = SchemaIr::from_schema(&schema);
1138 let msg_gen = MessageGenerator::new(&ir);
1139 let code = msg_gen.generate();
1140
1141 let entry_encoder_start = code
1143 .find("impl<'a> OrdersEntryEncoder<'a>")
1144 .expect("EntryEncoder impl not found");
1145 let entry_code = &code[entry_encoder_start..];
1146
1147 assert!(
1149 entry_code.contains("self.offset + 0,"),
1150 "set_order_id should write at offset 0"
1151 );
1152 assert!(
1154 entry_code.contains("self.offset + 8,"),
1155 "set_instrument_id should write at offset 8"
1156 );
1157 assert!(
1159 entry_code.contains("self.offset + 12,"),
1160 "set_quantity should write at offset 12"
1161 );
1162 }
1163}