1use ironsbe_schema::ir::{
4 ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8pub struct MessageGenerator<'a> {
10 ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14 #[must_use]
16 pub fn new(ir: &'a SchemaIr) -> Self {
17 Self { ir }
18 }
19
20 #[must_use]
22 pub fn generate(&self) -> String {
23 let mut output = String::new();
24
25 for msg in &self.ir.messages {
26 output.push_str(&self.generate_decoder(msg));
27 output.push_str(&self.generate_encoder(msg));
28
29 if !msg.groups.is_empty() {
31 let mod_name = to_snake_case(&msg.name);
32 output.push_str(&format!("/// Types for {} repeating groups.\n", msg.name));
33 output.push_str(&format!("pub mod {} {{\n", mod_name));
34 output.push_str(" use super::*;\n\n");
35 for group in &msg.groups {
36 output.push_str(&self.generate_group_decoder(group));
37 output.push_str(&self.generate_group_encoder(group));
38 }
39 output.push_str("}\n\n");
40 }
41 }
42
43 output
44 }
45
46 fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
48 let mut output = String::new();
49 let decoder_name = msg.decoder_name();
50
51 output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
53 output.push_str("#[derive(Debug, Clone, Copy)]\n");
54 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
55 output.push_str(" buffer: &'a [u8],\n");
56 output.push_str(" offset: usize,\n");
57 output.push_str(" acting_version: u16,\n");
58 output.push_str("}\n\n");
59
60 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
62 output.push_str(&format!(
63 " /// Template ID for this message.\n\
64 pub const TEMPLATE_ID: u16 = {};\n",
65 msg.template_id
66 ));
67 output.push_str(&format!(
68 " /// Block length of the fixed portion.\n\
69 pub const BLOCK_LENGTH: u16 = {};\n\n",
70 msg.block_length
71 ));
72
73 output.push_str(" /// Wraps a buffer for zero-copy decoding.\n");
75 output.push_str(" ///\n");
76 output.push_str(" /// # Arguments\n");
77 output.push_str(" /// * `buffer` - Buffer containing the message\n");
78 output.push_str(
79 " /// * `offset` - Offset to the start of the root block (after header)\n",
80 );
81 output.push_str(" /// * `acting_version` - Schema version for compatibility\n");
82 output.push_str(" #[inline]\n");
83 output.push_str(" #[must_use]\n");
84 output.push_str(
85 " pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
86 );
87 output.push_str(" Self { buffer, offset, acting_version }\n");
88 output.push_str(" }\n\n");
89
90 for field in &msg.fields {
92 output.push_str(&self.generate_field_getter(field));
93 }
94
95 let mut group_offset = msg.block_length as usize;
97 for group in &msg.groups {
98 output.push_str(&self.generate_group_accessor(group, group_offset, &msg.name));
99 group_offset += 4; }
101
102 output.push_str("}\n\n");
103
104 output.push_str(&format!(
106 "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
107 decoder_name
108 ));
109 output.push_str(&format!(
110 " const TEMPLATE_ID: u16 = {};\n",
111 msg.template_id
112 ));
113 output.push_str(" const SCHEMA_ID: u16 = SCHEMA_ID;\n");
114 output.push_str(" const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
115 output.push_str(&format!(
116 " const BLOCK_LENGTH: u16 = {};\n\n",
117 msg.block_length
118 ));
119
120 output.push_str(
121 " fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
122 );
123 output.push_str(" Self::wrap(buffer, offset, acting_version)\n");
124 output.push_str(" }\n\n");
125
126 output.push_str(" fn encoded_length(&self) -> usize {\n");
127 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
128 output.push_str(" }\n");
129 output.push_str("}\n\n");
130
131 output
132 }
133
134 fn generate_field_getter(&self, field: &ResolvedField) -> String {
136 let mut output = String::new();
137
138 output.push_str(&format!(
139 " /// Field: {} (id={}, offset={}).\n",
140 field.name, field.id, field.offset
141 ));
142 output.push_str(" #[inline(always)]\n");
143 output.push_str(" #[must_use]\n");
144
145 if field.is_array {
146 let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
148 let len = field.array_length.unwrap_or(1);
149
150 if elem_type == "u8" {
151 output.push_str(&format!(
153 " pub fn {}(&self) -> &'a [u8] {{\n",
154 field.getter_name
155 ));
156 output.push_str(&format!(
157 " &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
158 field.offset, field.offset, len
159 ));
160 output.push_str(" }\n\n");
161
162 output.push_str(&format!(
164 " /// Field {} as string (trimmed).\n",
165 field.name
166 ));
167 output.push_str(" #[inline]\n");
168 output.push_str(" #[must_use]\n");
169 output.push_str(&format!(
170 " pub fn {}_as_str(&self) -> &'a str {{\n",
171 field.getter_name
172 ));
173 output.push_str(&format!(
174 " let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
175 field.offset, field.offset, len
176 ));
177 output.push_str(
178 " let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
179 );
180 output.push_str(" std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
181 output.push_str(" }\n\n");
182 } else {
183 output.push_str(&format!(
185 " pub fn {}(&self) -> &'a [u8] {{\n",
186 field.getter_name
187 ));
188 output.push_str(&format!(
189 " &self.buffer[self.offset + {}..self.offset + {}]\n",
190 field.offset,
191 field.offset + field.encoded_length
192 ));
193 output.push_str(" }\n\n");
194 }
195 } else {
196 let rust_type = &field.rust_type;
198 let resolved_type = self.ir.get_type(&field.type_name);
199
200 match resolved_type.map(|t| &t.kind) {
201 Some(TypeKind::Enum { encoding, .. }) => {
202 let read_method = get_read_method(Some(*encoding));
204 output.push_str(&format!(
205 " pub fn {}(&self) -> {} {{\n",
206 field.getter_name, rust_type
207 ));
208 output.push_str(&format!(
209 " {}::from(self.buffer.{}(self.offset + {}))\n",
210 rust_type, read_method, field.offset
211 ));
212 output.push_str(" }\n\n");
213 }
214 Some(TypeKind::Set { encoding, .. }) => {
215 let read_method = get_read_method(Some(*encoding));
217 output.push_str(&format!(
218 " pub fn {}(&self) -> {} {{\n",
219 field.getter_name, rust_type
220 ));
221 output.push_str(&format!(
222 " {}::from_raw(self.buffer.{}(self.offset + {}))\n",
223 rust_type, read_method, field.offset
224 ));
225 output.push_str(" }\n\n");
226 }
227 Some(TypeKind::Composite { .. }) => {
228 output.push_str(&format!(
230 " pub fn {}(&self) -> {}<'a> {{\n",
231 field.getter_name, rust_type
232 ));
233 output.push_str(&format!(
234 " {}::wrap(self.buffer, self.offset + {})\n",
235 rust_type, field.offset
236 ));
237 output.push_str(" }\n\n");
238 }
239 _ => {
240 let read_method = get_read_method(field.primitive_type);
242 output.push_str(&format!(
243 " pub fn {}(&self) -> {} {{\n",
244 field.getter_name, rust_type
245 ));
246 output.push_str(&format!(
247 " self.buffer.{}(self.offset + {})\n",
248 read_method, field.offset
249 ));
250 output.push_str(" }\n\n");
251 }
252 }
253 }
254
255 output
256 }
257
258 fn generate_group_accessor(
260 &self,
261 group: &ResolvedGroup,
262 offset: usize,
263 msg_name: &str,
264 ) -> String {
265 let mut output = String::new();
266 let qualified = format!("{}::{}", to_snake_case(msg_name), group.decoder_name());
267
268 output.push_str(&format!(" /// Access {} repeating group.\n", group.name));
269 output.push_str(" #[inline]\n");
270 output.push_str(" #[must_use]\n");
271 output.push_str(&format!(
272 " pub fn {}(&self) -> {}<'a> {{\n",
273 to_snake_case(&group.name),
274 qualified
275 ));
276 output.push_str(&format!(
277 " {}::wrap(self.buffer, self.offset + {})\n",
278 qualified, offset
279 ));
280 output.push_str(" }\n\n");
281
282 output
283 }
284
285 fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
287 let mut output = String::new();
288 let encoder_name = msg.encoder_name();
289
290 output.push_str(&format!("/// {} Encoder.\n", msg.name));
292 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
293 output.push_str(" buffer: &'a mut [u8],\n");
294 output.push_str(" offset: usize,\n");
295 output.push_str("}\n\n");
296
297 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
299 output.push_str(&format!(
300 " /// Template ID for this message.\n\
301 pub const TEMPLATE_ID: u16 = {};\n",
302 msg.template_id
303 ));
304 output.push_str(&format!(
305 " /// Block length of the fixed portion.\n\
306 pub const BLOCK_LENGTH: u16 = {};\n\n",
307 msg.block_length
308 ));
309
310 output.push_str(" /// Wraps a buffer for encoding, writing the header.\n");
312 output.push_str(" #[inline]\n");
313 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
314 output.push_str(" let mut encoder = Self { buffer, offset };\n");
315 output.push_str(" encoder.write_header();\n");
316 output.push_str(" encoder\n");
317 output.push_str(" }\n\n");
318
319 output.push_str(" fn write_header(&mut self) {\n");
321 output.push_str(" let header = MessageHeader {\n");
322 output.push_str(" block_length: Self::BLOCK_LENGTH,\n");
323 output.push_str(" template_id: Self::TEMPLATE_ID,\n");
324 output.push_str(" schema_id: SCHEMA_ID,\n");
325 output.push_str(" version: SCHEMA_VERSION,\n");
326 output.push_str(" };\n");
327 output.push_str(" header.encode(self.buffer, self.offset);\n");
328 output.push_str(" }\n\n");
329
330 output.push_str(" /// Returns the encoded length of the message.\n");
332 output.push_str(" #[must_use]\n");
333 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
334 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
335 output.push_str(" }\n\n");
336
337 for field in &msg.fields {
339 output.push_str(&self.generate_field_setter(field));
340 }
341
342 let mut group_offset = msg.block_length as usize;
344 for group in &msg.groups {
345 output.push_str(&self.generate_group_encoder_accessor(group, group_offset, &msg.name));
346 group_offset += 4; }
348
349 output.push_str("}\n\n");
350
351 output
352 }
353
354 fn generate_field_setter(&self, field: &ResolvedField) -> String {
356 let mut output = String::new();
357 let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
358
359 output.push_str(&format!(
360 " /// Set field: {} (id={}, offset={}).\n",
361 field.name, field.id, field.offset
362 ));
363 output.push_str(" #[inline(always)]\n");
364
365 if field.is_array {
366 let len = field.array_length.unwrap_or(field.encoded_length);
368
369 output.push_str(&format!(
370 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
371 field.setter_name
372 ));
373 output.push_str(&format!(
374 " let copy_len = value.len().min({});\n",
375 len
376 ));
377 output.push_str(&format!(
378 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
379 field_offset, field_offset
380 ));
381 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
382 output.push_str(&format!(" if copy_len < {} {{\n", len));
383 output.push_str(&format!(
384 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
385 field_offset, field_offset, len
386 ));
387 output.push_str(" }\n");
388 output.push_str(" self\n");
389 output.push_str(" }\n\n");
390 } else {
391 let rust_type = &field.rust_type;
393 let resolved_type = self.ir.get_type(&field.type_name);
394
395 match resolved_type.map(|t| &t.kind) {
396 Some(TypeKind::Enum { encoding, .. }) => {
397 let write_method = get_write_method(Some(*encoding));
399 let prim_type = encoding.rust_type();
400 output.push_str(&format!(
401 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
402 field.setter_name, rust_type
403 ));
404 output.push_str(&format!(
405 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
406 write_method, field_offset, prim_type
407 ));
408 output.push_str(" self\n");
409 output.push_str(" }\n\n");
410 }
411 Some(TypeKind::Set { encoding, .. }) => {
412 let write_method = get_write_method(Some(*encoding));
414 output.push_str(&format!(
415 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
416 field.setter_name, rust_type
417 ));
418 output.push_str(&format!(
419 " self.buffer.{}(self.offset + {}, value.raw());\n",
420 write_method, field_offset
421 ));
422 output.push_str(" self\n");
423 output.push_str(" }\n\n");
424 }
425 Some(TypeKind::Composite { .. }) => {
426 output.push_str(&format!(
428 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
429 field.setter_name, rust_type
430 ));
431 output.push_str(&format!(
432 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
433 rust_type, field_offset
434 ));
435 output.push_str(" }\n\n");
436 }
437 _ => {
438 let write_method = get_write_method(field.primitive_type);
440 output.push_str(&format!(
441 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
442 field.setter_name, rust_type
443 ));
444 output.push_str(&format!(
445 " self.buffer.{}(self.offset + {}, value);\n",
446 write_method, field_offset
447 ));
448 output.push_str(" self\n");
449 output.push_str(" }\n\n");
450 }
451 }
452 }
453
454 output
455 }
456
457 fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
459 let mut output = String::new();
460 let decoder_name = group.decoder_name();
461 let entry_name = group.entry_decoder_name();
462
463 output.push_str(&format!("/// {} Group Decoder.\n", group.name));
465 output.push_str("#[derive(Debug, Clone, Copy)]\n");
466 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
467 output.push_str(" buffer: &'a [u8],\n");
468 output.push_str(" block_length: u16,\n");
469 output.push_str(" count: u16,\n");
470 output.push_str(" index: u16,\n");
471 output.push_str(" offset: usize,\n");
472 output.push_str("}\n\n");
473
474 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
476 output.push_str(" /// Wraps a buffer at the group header position.\n");
477 output.push_str(" #[must_use]\n");
478 output.push_str(" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
479 output.push_str(" let header = GroupHeader::wrap(buffer, offset);\n");
480 output.push_str(" Self {\n");
481 output.push_str(" buffer,\n");
482 output.push_str(" block_length: header.block_length,\n");
483 output.push_str(" count: header.num_in_group,\n");
484 output.push_str(" index: 0,\n");
485 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
486 output.push_str(" }\n");
487 output.push_str(" }\n\n");
488
489 output.push_str(" /// Returns the number of entries in the group.\n");
490 output.push_str(" #[must_use]\n");
491 output.push_str(" pub const fn count(&self) -> u16 {\n");
492 output.push_str(" self.count\n");
493 output.push_str(" }\n\n");
494
495 output.push_str(" /// Returns true if the group is empty.\n");
496 output.push_str(" #[must_use]\n");
497 output.push_str(" pub const fn is_empty(&self) -> bool {\n");
498 output.push_str(" self.count == 0\n");
499 output.push_str(" }\n");
500 output.push_str("}\n\n");
501
502 output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
504 output.push_str(&format!(" type Item = {}<'a>;\n\n", entry_name));
505 output.push_str(" fn next(&mut self) -> Option<Self::Item> {\n");
506 output.push_str(" if self.index >= self.count {\n");
507 output.push_str(" return None;\n");
508 output.push_str(" }\n");
509 output.push_str(&format!(
510 " let entry = {}::wrap(self.buffer, self.offset);\n",
511 entry_name
512 ));
513 output.push_str(" self.offset += self.block_length as usize;\n");
514 output.push_str(" self.index += 1;\n");
515 output.push_str(" Some(entry)\n");
516 output.push_str(" }\n\n");
517
518 output.push_str(" fn size_hint(&self) -> (usize, Option<usize>) {\n");
519 output.push_str(" let remaining = (self.count - self.index) as usize;\n");
520 output.push_str(" (remaining, Some(remaining))\n");
521 output.push_str(" }\n");
522 output.push_str("}\n\n");
523
524 output.push_str(&format!(
525 "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
526 decoder_name
527 ));
528
529 output.push_str(&self.generate_entry_decoder(group));
531
532 for nested in &group.nested_groups {
534 output.push_str(&self.generate_group_decoder(nested));
535 }
536
537 output
538 }
539
540 fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
542 let mut output = String::new();
543 let entry_name = group.entry_decoder_name();
544
545 output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
546 output.push_str("#[derive(Debug, Clone, Copy)]\n");
547 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
548 output.push_str(" buffer: &'a [u8],\n");
549 output.push_str(" offset: usize,\n");
550 output.push_str("}\n\n");
551
552 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
553 output.push_str(" fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
554 output.push_str(" Self { buffer, offset }\n");
555 output.push_str(" }\n\n");
556
557 for field in &group.fields {
559 output.push_str(&self.generate_field_getter(field));
560 }
561
562 output.push_str("}\n\n");
563
564 output
565 }
566
567 fn generate_group_encoder(&self, group: &ResolvedGroup) -> String {
569 let mut output = String::new();
570 let encoder_name = group.encoder_name();
571 let entry_name = group.entry_encoder_name();
572
573 let effective_block_length = if group.block_length > 0 {
575 group.block_length
576 } else {
577 group
578 .fields
579 .iter()
580 .map(|f| f.offset + f.encoded_length)
581 .max()
582 .unwrap_or(0) as u16
583 };
584
585 output.push_str(&format!("/// {} Group Encoder.\n", group.name));
587 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
588 output.push_str(" buffer: &'a mut [u8],\n");
589 output.push_str(" count: u16,\n");
590 output.push_str(" index: u16,\n");
591 output.push_str(" offset: usize,\n");
592 output.push_str("}\n\n");
593
594 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
596 output.push_str(&format!(
597 " /// Block length of each entry.\n\
598 pub const BLOCK_LENGTH: u16 = {};\n\n",
599 effective_block_length
600 ));
601
602 output
604 .push_str(" /// Wraps a buffer at the group header position, writing the header.\n");
605 output.push_str(" ///\n");
606 output.push_str(" /// # Arguments\n");
607 output.push_str(" /// * `buffer` - Mutable buffer to write to\n");
608 output.push_str(" /// * `offset` - Offset of the group header\n");
609 output.push_str(" /// * `count` - Number of entries to encode\n");
610 output.push_str(
611 " pub fn wrap(buffer: &'a mut [u8], offset: usize, count: u16) -> Self {\n",
612 );
613 output.push_str(" let header = GroupHeader::new(Self::BLOCK_LENGTH, count);\n");
614 output.push_str(" header.encode(buffer, offset);\n");
615 output.push_str(" Self {\n");
616 output.push_str(" buffer,\n");
617 output.push_str(" count,\n");
618 output.push_str(" index: 0,\n");
619 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
620 output.push_str(" }\n");
621 output.push_str(" }\n\n");
622
623 output.push_str(
625 " /// Returns the next entry encoder, or `None` if all entries are written.\n",
626 );
627 output.push_str(&format!(
628 " pub fn next_entry(&mut self) -> Option<{}<'_>> {{\n",
629 entry_name
630 ));
631 output.push_str(" if self.index >= self.count {\n");
632 output.push_str(" return None;\n");
633 output.push_str(" }\n");
634 output.push_str(" let offset = self.offset;\n");
635 output.push_str(" self.offset += Self::BLOCK_LENGTH as usize;\n");
636 output.push_str(" self.index += 1;\n");
637 output.push_str(&format!(
638 " Some({}::wrap(&mut *self.buffer, offset))\n",
639 entry_name
640 ));
641 output.push_str(" }\n\n");
642
643 output.push_str(
645 " /// Returns the total encoded length of this group (header + all entries).\n",
646 );
647 output.push_str(" #[must_use]\n");
648 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
649 output.push_str(" GroupHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize * self.count as usize\n");
650 output.push_str(" }\n");
651 output.push_str("}\n\n");
652
653 output.push_str(&self.generate_entry_encoder(group));
655
656 for nested in &group.nested_groups {
658 output.push_str(&self.generate_group_encoder(nested));
659 }
660
661 output
662 }
663
664 fn generate_entry_encoder(&self, group: &ResolvedGroup) -> String {
666 let mut output = String::new();
667 let entry_name = group.entry_encoder_name();
668
669 output.push_str(&format!("/// {} Entry Encoder.\n", group.name));
670 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
671 output.push_str(" buffer: &'a mut [u8],\n");
672 output.push_str(" offset: usize,\n");
673 output.push_str("}\n\n");
674
675 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
676 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
677 output.push_str(" Self { buffer, offset }\n");
678 output.push_str(" }\n\n");
679
680 for field in &group.fields {
682 output.push_str(&self.generate_entry_field_setter(field));
683 }
684
685 output.push_str("}\n\n");
686
687 output
688 }
689
690 fn generate_entry_field_setter(&self, field: &ResolvedField) -> String {
696 let mut output = String::new();
697 let field_offset = field.offset;
698
699 output.push_str(&format!(
700 " /// Set field: {} (id={}, offset={}).\n",
701 field.name, field.id, field.offset
702 ));
703 output.push_str(" #[inline(always)]\n");
704
705 if field.is_array {
706 let len = field.array_length.unwrap_or(field.encoded_length);
707
708 output.push_str(&format!(
709 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
710 field.setter_name
711 ));
712 output.push_str(&format!(
713 " let copy_len = value.len().min({});\n",
714 len
715 ));
716 output.push_str(&format!(
717 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
718 field_offset, field_offset
719 ));
720 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
721 output.push_str(&format!(" if copy_len < {} {{\n", len));
722 output.push_str(&format!(
723 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
724 field_offset, field_offset, len
725 ));
726 output.push_str(" }\n");
727 output.push_str(" self\n");
728 output.push_str(" }\n\n");
729 } else {
730 let rust_type = &field.rust_type;
731 let resolved_type = self.ir.get_type(&field.type_name);
732
733 match resolved_type.map(|t| &t.kind) {
734 Some(TypeKind::Enum { encoding, .. }) => {
735 let write_method = get_write_method(Some(*encoding));
736 let prim_type = encoding.rust_type();
737 output.push_str(&format!(
738 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
739 field.setter_name, rust_type
740 ));
741 output.push_str(&format!(
742 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
743 write_method, field_offset, prim_type
744 ));
745 output.push_str(" self\n");
746 output.push_str(" }\n\n");
747 }
748 Some(TypeKind::Set { encoding, .. }) => {
749 let write_method = get_write_method(Some(*encoding));
750 output.push_str(&format!(
751 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
752 field.setter_name, rust_type
753 ));
754 output.push_str(&format!(
755 " self.buffer.{}(self.offset + {}, value.raw());\n",
756 write_method, field_offset
757 ));
758 output.push_str(" self\n");
759 output.push_str(" }\n\n");
760 }
761 Some(TypeKind::Composite { .. }) => {
762 output.push_str(&format!(
763 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
764 field.setter_name, rust_type
765 ));
766 output.push_str(&format!(
767 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
768 rust_type, field_offset
769 ));
770 output.push_str(" }\n\n");
771 }
772 _ => {
773 let write_method = get_write_method(field.primitive_type);
774 output.push_str(&format!(
775 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
776 field.setter_name, rust_type
777 ));
778 output.push_str(&format!(
779 " self.buffer.{}(self.offset + {}, value);\n",
780 write_method, field_offset
781 ));
782 output.push_str(" self\n");
783 output.push_str(" }\n\n");
784 }
785 }
786 }
787
788 output
789 }
790
791 fn generate_group_encoder_accessor(
793 &self,
794 group: &ResolvedGroup,
795 offset: usize,
796 msg_name: &str,
797 ) -> String {
798 let mut output = String::new();
799 let qualified = format!("{}::{}", to_snake_case(msg_name), group.encoder_name());
800
801 output.push_str(&format!(
802 " /// Begin encoding the {} repeating group.\n",
803 group.name
804 ));
805 output.push_str(&format!(
806 " pub fn {}_count(&mut self, count: u16) -> {}<'_> {{\n",
807 to_snake_case(&group.name),
808 qualified
809 ));
810 output.push_str(&format!(
811 " {}::wrap(&mut *self.buffer, self.offset + MessageHeader::ENCODED_LENGTH + {}, count)\n",
812 qualified, offset
813 ));
814 output.push_str(" }\n\n");
815
816 output
817 }
818}
819
820fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
822 match prim {
823 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
824 Some(PrimitiveType::Int8) => "get_i8",
825 Some(PrimitiveType::Uint16) => "get_u16_le",
826 Some(PrimitiveType::Int16) => "get_i16_le",
827 Some(PrimitiveType::Uint32) => "get_u32_le",
828 Some(PrimitiveType::Int32) => "get_i32_le",
829 Some(PrimitiveType::Uint64) => "get_u64_le",
830 Some(PrimitiveType::Int64) => "get_i64_le",
831 Some(PrimitiveType::Float) => "get_f32_le",
832 Some(PrimitiveType::Double) => "get_f64_le",
833 None => "get_u64_le",
834 }
835}
836
837fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
839 match prim {
840 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
841 Some(PrimitiveType::Int8) => "put_i8",
842 Some(PrimitiveType::Uint16) => "put_u16_le",
843 Some(PrimitiveType::Int16) => "put_i16_le",
844 Some(PrimitiveType::Uint32) => "put_u32_le",
845 Some(PrimitiveType::Int32) => "put_i32_le",
846 Some(PrimitiveType::Uint64) => "put_u64_le",
847 Some(PrimitiveType::Int64) => "put_i64_le",
848 Some(PrimitiveType::Float) => "put_f32_le",
849 Some(PrimitiveType::Double) => "put_f64_le",
850 None => "put_u64_le",
851 }
852}
853
854#[cfg(test)]
855mod tests {
856 use super::*;
857 use ironsbe_schema::{SchemaIr, parse_schema};
858
859 fn schema_with_shared_group_name() -> String {
860 r#"<?xml version="1.0" encoding="UTF-8"?>
861<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
862 package="test" id="1" version="1" byteOrder="littleEndian">
863 <types>
864 <type name="uint64" primitiveType="uint64"/>
865 </types>
866 <sbe:message name="CreateRfqResponse" id="21" blockLength="8">
867 <field name="value" id="1" type="uint64" offset="0"/>
868 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
869 <field name="price" id="200" type="uint64" offset="0"/>
870 </group>
871 </sbe:message>
872 <sbe:message name="GetRfqResponse" id="23" blockLength="8">
873 <field name="value" id="1" type="uint64" offset="0"/>
874 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
875 <field name="price" id="200" type="uint64" offset="0"/>
876 </group>
877 </sbe:message>
878</sbe:messageSchema>"#
879 .to_string()
880 }
881
882 fn schema_with_group_no_offsets() -> String {
883 r#"<?xml version="1.0" encoding="UTF-8"?>
884<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
885 package="test" id="1" version="1" byteOrder="littleEndian">
886 <types>
887 <type name="uint64" primitiveType="uint64"/>
888 <type name="uint32" primitiveType="uint32"/>
889 </types>
890 <sbe:message name="ListOrders" id="19" blockLength="0">
891 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
892 <field name="orderId" id="1" type="uint64" offset="0"/>
893 <field name="instrumentId" id="2" type="uint32"/>
894 <field name="quantity" id="3" type="uint64"/>
895 </group>
896 </sbe:message>
897</sbe:messageSchema>"#
898 .to_string()
899 }
900
901 fn schema_with_group_explicit_offsets() -> String {
902 r#"<?xml version="1.0" encoding="UTF-8"?>
903<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
904 package="test" id="1" version="1" byteOrder="littleEndian">
905 <types>
906 <type name="uint64" primitiveType="uint64"/>
907 <type name="uint32" primitiveType="uint32"/>
908 </types>
909 <sbe:message name="ListOrders" id="19" blockLength="0">
910 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="20">
911 <field name="orderId" id="1" type="uint64" offset="0"/>
912 <field name="instrumentId" id="2" type="uint32" offset="8"/>
913 <field name="quantity" id="3" type="uint64" offset="12"/>
914 </group>
915 </sbe:message>
916</sbe:messageSchema>"#
917 .to_string()
918 }
919
920 #[test]
921 fn test_duplicate_group_name_generates_scoped_modules() {
922 let xml = schema_with_shared_group_name();
923 let schema = parse_schema(&xml).expect("Failed to parse schema");
924 let ir = SchemaIr::from_schema(&schema);
925 let msg_gen = MessageGenerator::new(&ir);
926 let code = msg_gen.generate();
927
928 assert!(
929 code.contains("pub mod create_rfq_response {"),
930 "expected module for CreateRfqResponse groups"
931 );
932 assert!(
933 code.contains("pub mod get_rfq_response {"),
934 "expected module for GetRfqResponse groups"
935 );
936
937 let occurrences = code.matches("pub struct QuotesGroupDecoder").count();
938 assert_eq!(
939 occurrences, 2,
940 "expected one QuotesGroupDecoder per message module, got {occurrences}"
941 );
942 }
943
944 #[test]
945 fn test_group_accessor_uses_qualified_path() {
946 let xml = schema_with_shared_group_name();
947 let schema = parse_schema(&xml).expect("Failed to parse schema");
948 let ir = SchemaIr::from_schema(&schema);
949 let msg_gen = MessageGenerator::new(&ir);
950 let code = msg_gen.generate();
951
952 assert!(
953 code.contains("create_rfq_response::QuotesGroupDecoder"),
954 "accessor in CreateRfqResponse must reference module-qualified type"
955 );
956 assert!(
957 code.contains("get_rfq_response::QuotesGroupDecoder"),
958 "accessor in GetRfqResponse must reference module-qualified type"
959 );
960 }
961
962 #[test]
963 fn test_entry_decoder_field_offsets_auto_computed() {
964 let xml = schema_with_group_no_offsets();
965 let schema = parse_schema(&xml).expect("Failed to parse schema");
966 let ir = SchemaIr::from_schema(&schema);
967 let msg_gen = MessageGenerator::new(&ir);
968 let code = msg_gen.generate();
969
970 assert!(
972 code.contains("self.offset + 0)"),
973 "orderId should be at offset 0"
974 );
975 assert!(
977 code.contains("self.offset + 8)"),
978 "instrumentId should be at offset 8, not 0"
979 );
980 assert!(
982 code.contains("self.offset + 12)"),
983 "quantity should be at offset 12, not 0"
984 );
985 }
986
987 #[test]
988 fn test_entry_decoder_field_offsets_explicit() {
989 let xml = schema_with_group_explicit_offsets();
990 let schema = parse_schema(&xml).expect("Failed to parse schema");
991 let ir = SchemaIr::from_schema(&schema);
992 let msg_gen = MessageGenerator::new(&ir);
993 let code = msg_gen.generate();
994
995 assert!(
996 code.contains("self.offset + 8)"),
997 "instrumentId should be at explicit offset 8"
998 );
999 assert!(
1000 code.contains("self.offset + 12)"),
1001 "quantity should be at explicit offset 12"
1002 );
1003 }
1004
1005 #[test]
1006 fn test_group_encoder_emitted() {
1007 let xml = schema_with_group_no_offsets();
1008 let schema = parse_schema(&xml).expect("Failed to parse schema");
1009 let ir = SchemaIr::from_schema(&schema);
1010 let msg_gen = MessageGenerator::new(&ir);
1011 let code = msg_gen.generate();
1012
1013 assert!(
1014 code.contains("pub struct OrdersGroupEncoder"),
1015 "expected OrdersGroupEncoder struct"
1016 );
1017 assert!(
1018 code.contains("pub struct OrdersEntryEncoder"),
1019 "expected OrdersEntryEncoder struct"
1020 );
1021 }
1022
1023 #[test]
1024 fn test_group_encoder_has_next_entry() {
1025 let xml = schema_with_group_no_offsets();
1026 let schema = parse_schema(&xml).expect("Failed to parse schema");
1027 let ir = SchemaIr::from_schema(&schema);
1028 let msg_gen = MessageGenerator::new(&ir);
1029 let code = msg_gen.generate();
1030
1031 assert!(
1032 code.contains("fn next_entry(&mut self)"),
1033 "expected next_entry method on group encoder"
1034 );
1035 }
1036
1037 #[test]
1038 fn test_entry_encoder_has_field_setters() {
1039 let xml = schema_with_group_no_offsets();
1040 let schema = parse_schema(&xml).expect("Failed to parse schema");
1041 let ir = SchemaIr::from_schema(&schema);
1042 let msg_gen = MessageGenerator::new(&ir);
1043 let code = msg_gen.generate();
1044
1045 assert!(
1046 code.contains("fn set_order_id(&mut self, value: u64)"),
1047 "expected set_order_id setter"
1048 );
1049 assert!(
1050 code.contains("fn set_instrument_id(&mut self, value: u32)"),
1051 "expected set_instrument_id setter"
1052 );
1053 assert!(
1054 code.contains("fn set_quantity(&mut self, value: u64)"),
1055 "expected set_quantity setter"
1056 );
1057 }
1058
1059 #[test]
1060 fn test_parent_encoder_has_group_accessor() {
1061 let xml = schema_with_group_no_offsets();
1062 let schema = parse_schema(&xml).expect("Failed to parse schema");
1063 let ir = SchemaIr::from_schema(&schema);
1064 let msg_gen = MessageGenerator::new(&ir);
1065 let code = msg_gen.generate();
1066
1067 assert!(
1068 code.contains("fn orders_count(&mut self, count: u16)"),
1069 "expected orders_count accessor on parent encoder"
1070 );
1071 }
1072
1073 #[test]
1074 fn test_roundtrip_group_codegen_structure() {
1075 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1076<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
1077 package="test" id="1" version="1" byteOrder="littleEndian">
1078 <types>
1079 <type name="uint64" primitiveType="uint64"/>
1080 <type name="uint32" primitiveType="uint32"/>
1081 <type name="uint8" primitiveType="uint8"/>
1082 </types>
1083 <sbe:message name="ListOrders" id="19" blockLength="8">
1084 <field name="requestId" id="1" type="uint64" offset="0"/>
1085 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="29">
1086 <field name="orderId" id="10" type="uint64" offset="0"/>
1087 <field name="instrumentId" id="11" type="uint32"/>
1088 <field name="price" id="12" type="uint64"/>
1089 <field name="quantity" id="13" type="uint64"/>
1090 <field name="side" id="14" type="uint8"/>
1091 </group>
1092 </sbe:message>
1093</sbe:messageSchema>"#;
1094
1095 let schema = parse_schema(xml).expect("Failed to parse schema");
1096 let ir = SchemaIr::from_schema(&schema);
1097 let msg_gen = MessageGenerator::new(&ir);
1098 let code = msg_gen.generate();
1099
1100 let decoder_pos = code
1102 .find("impl<'a> OrdersEntryDecoder<'a>")
1103 .expect("entry decoder impl");
1104 let decoder_section = &code[decoder_pos..];
1105 assert!(decoder_section.contains("self.offset + 0)"));
1107 assert!(decoder_section.contains("self.offset + 8)"));
1108 assert!(decoder_section.contains("self.offset + 12)"));
1109 assert!(decoder_section.contains("self.offset + 20)"));
1110 assert!(decoder_section.contains("self.offset + 28)"));
1111
1112 let encoder_pos = code
1114 .find("impl<'a> OrdersEntryEncoder<'a>")
1115 .expect("entry encoder impl");
1116 let encoder_section = &code[encoder_pos..];
1117 assert!(encoder_section.contains("self.offset + 0,"));
1119 assert!(encoder_section.contains("self.offset + 8,"));
1120 assert!(encoder_section.contains("self.offset + 12,"));
1121 assert!(encoder_section.contains("self.offset + 20,"));
1122 assert!(encoder_section.contains("self.offset + 28,"));
1123
1124 assert!(
1126 code.contains("BLOCK_LENGTH: u16 = 29"),
1127 "group encoder BLOCK_LENGTH"
1128 );
1129 assert!(
1130 code.contains("fn orders_count(&mut self, count: u16)"),
1131 "parent encoder group accessor"
1132 );
1133 assert!(
1134 code.contains("list_orders::OrdersGroupEncoder::wrap(&mut *self.buffer"),
1135 "parent encoder delegates to module-qualified group encoder"
1136 );
1137
1138 assert!(
1140 code.contains("list_orders::OrdersGroupDecoder"),
1141 "parent decoder uses module-qualified group decoder"
1142 );
1143 }
1144
1145 #[test]
1146 fn test_entry_encoder_setter_offsets_correct() {
1147 let xml = schema_with_group_no_offsets();
1148 let schema = parse_schema(&xml).expect("Failed to parse schema");
1149 let ir = SchemaIr::from_schema(&schema);
1150 let msg_gen = MessageGenerator::new(&ir);
1151 let code = msg_gen.generate();
1152
1153 let entry_encoder_start = code
1155 .find("impl<'a> OrdersEntryEncoder<'a>")
1156 .expect("EntryEncoder impl not found");
1157 let entry_code = &code[entry_encoder_start..];
1158
1159 assert!(
1161 entry_code.contains("self.offset + 0,"),
1162 "set_order_id should write at offset 0"
1163 );
1164 assert!(
1166 entry_code.contains("self.offset + 8,"),
1167 "set_instrument_id should write at offset 8"
1168 );
1169 assert!(
1171 entry_code.contains("self.offset + 12,"),
1172 "set_quantity should write at offset 12"
1173 );
1174 }
1175
1176 fn schema_with_group_zero_block_length() -> String {
1177 r#"<?xml version="1.0" encoding="UTF-8"?>
1178<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
1179 package="test" id="1" version="1" byteOrder="littleEndian">
1180 <types>
1181 <type name="uint64" primitiveType="uint64"/>
1182 <type name="uint32" primitiveType="uint32"/>
1183 </types>
1184 <sbe:message name="ListOrders" id="19" blockLength="0">
1185 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="0">
1186 <field name="orderId" id="1" type="uint64" offset="0"/>
1187 <field name="instrumentId" id="2" type="uint32"/>
1188 <field name="quantity" id="3" type="uint64"/>
1189 </group>
1190 </sbe:message>
1191</sbe:messageSchema>"#
1192 .to_string()
1193 }
1194
1195 #[test]
1196 fn test_group_encoder_block_length_from_xml() {
1197 let xml = schema_with_group_no_offsets();
1198 let schema = parse_schema(&xml).expect("Failed to parse schema");
1199 let ir = SchemaIr::from_schema(&schema);
1200 let msg_gen = MessageGenerator::new(&ir);
1201 let code = msg_gen.generate();
1202
1203 assert!(
1204 code.contains("BLOCK_LENGTH: u16 = 20"),
1205 "BLOCK_LENGTH should use the explicit XML blockLength=20"
1206 );
1207 }
1208
1209 #[test]
1210 fn test_group_encoder_block_length_computed() {
1211 let xml = schema_with_group_zero_block_length();
1212 let schema = parse_schema(&xml).expect("Failed to parse schema");
1213 let ir = SchemaIr::from_schema(&schema);
1214 let msg_gen = MessageGenerator::new(&ir);
1215 let code = msg_gen.generate();
1216
1217 assert!(
1219 code.contains("BLOCK_LENGTH: u16 = 20"),
1220 "BLOCK_LENGTH should be auto-computed as 20 when XML blockLength=0"
1221 );
1222 }
1223
1224 #[test]
1225 fn test_entry_encoder_wrap_is_pub() {
1226 let xml = schema_with_group_no_offsets();
1227 let schema = parse_schema(&xml).expect("Failed to parse schema");
1228 let ir = SchemaIr::from_schema(&schema);
1229 let msg_gen = MessageGenerator::new(&ir);
1230 let code = msg_gen.generate();
1231
1232 let entry_pos = code
1233 .find("impl<'a> OrdersEntryEncoder<'a>")
1234 .expect("EntryEncoder impl not found");
1235 let entry_section = &code[entry_pos..];
1236
1237 assert!(
1238 entry_section.contains("pub fn wrap("),
1239 "EntryEncoder::wrap should be pub"
1240 );
1241 }
1242
1243 #[test]
1244 fn test_roundtrip_multi_entry_codegen() {
1245 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
1246<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
1247 package="test" id="1" version="1" byteOrder="littleEndian">
1248 <types>
1249 <type name="uint64" primitiveType="uint64"/>
1250 <type name="uint32" primitiveType="uint32"/>
1251 </types>
1252 <sbe:message name="ListOrders" id="19" blockLength="8">
1253 <field name="requestId" id="1" type="uint64" offset="0"/>
1254 <group name="orders" id="100" dimensionType="groupSizeEncoding" blockLength="0">
1255 <field name="orderId" id="10" type="uint64" offset="0"/>
1256 <field name="instrumentId" id="11" type="uint32"/>
1257 <field name="quantity" id="12" type="uint64"/>
1258 </group>
1259 </sbe:message>
1260</sbe:messageSchema>"#;
1261
1262 let schema = parse_schema(xml).expect("Failed to parse schema");
1263 let ir = SchemaIr::from_schema(&schema);
1264 let msg_gen = MessageGenerator::new(&ir);
1265 let code = msg_gen.generate();
1266
1267 assert!(
1269 code.contains("BLOCK_LENGTH: u16 = 20"),
1270 "group encoder BLOCK_LENGTH should be 20, not 0"
1271 );
1272
1273 assert!(
1275 code.contains("self.offset += Self::BLOCK_LENGTH as usize"),
1276 "next_entry should advance offset by BLOCK_LENGTH"
1277 );
1278
1279 assert!(
1281 code.contains(
1282 "GroupHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize * self.count as usize"
1283 ),
1284 "encoded_length should use BLOCK_LENGTH * count"
1285 );
1286
1287 assert!(
1289 code.contains("GroupHeader::new(Self::BLOCK_LENGTH, count)"),
1290 "group header should be written with BLOCK_LENGTH"
1291 );
1292
1293 assert!(
1295 code.contains("fn orders_count(&mut self, count: u16)"),
1296 "parent encoder should have group accessor"
1297 );
1298
1299 let entry_pos = code
1301 .find("impl<'a> OrdersEntryEncoder<'a>")
1302 .expect("EntryEncoder impl not found");
1303 let entry_section = &code[entry_pos..];
1304 assert!(
1305 entry_section.contains("pub fn wrap("),
1306 "EntryEncoder::wrap should be pub for external consumers"
1307 );
1308 }
1309}