1use ironsbe_schema::ir::{
4 ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8pub struct MessageGenerator<'a> {
10 ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14 #[must_use]
16 pub fn new(ir: &'a SchemaIr) -> Self {
17 Self { ir }
18 }
19
20 #[must_use]
22 pub fn generate(&self) -> String {
23 let mut output = String::new();
24
25 for msg in &self.ir.messages {
26 output.push_str(&self.generate_decoder(msg));
27 output.push_str(&self.generate_encoder(msg));
28
29 if !msg.groups.is_empty() {
31 let mod_name = to_snake_case(&msg.name);
32 output.push_str(&format!("/// Types for {} repeating groups.\n", msg.name));
33 output.push_str(&format!("pub mod {} {{\n", mod_name));
34 output.push_str(" use super::*;\n\n");
35 for group in &msg.groups {
36 output.push_str(&self.generate_group_decoder(group));
37 }
38 output.push_str("}\n\n");
39 }
40 }
41
42 output
43 }
44
45 fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
47 let mut output = String::new();
48 let decoder_name = msg.decoder_name();
49
50 output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
52 output.push_str("#[derive(Debug, Clone, Copy)]\n");
53 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
54 output.push_str(" buffer: &'a [u8],\n");
55 output.push_str(" offset: usize,\n");
56 output.push_str(" acting_version: u16,\n");
57 output.push_str("}\n\n");
58
59 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
61 output.push_str(&format!(
62 " /// Template ID for this message.\n\
63 pub const TEMPLATE_ID: u16 = {};\n",
64 msg.template_id
65 ));
66 output.push_str(&format!(
67 " /// Block length of the fixed portion.\n\
68 pub const BLOCK_LENGTH: u16 = {};\n\n",
69 msg.block_length
70 ));
71
72 output.push_str(" /// Wraps a buffer for zero-copy decoding.\n");
74 output.push_str(" ///\n");
75 output.push_str(" /// # Arguments\n");
76 output.push_str(" /// * `buffer` - Buffer containing the message\n");
77 output.push_str(
78 " /// * `offset` - Offset to the start of the root block (after header)\n",
79 );
80 output.push_str(" /// * `acting_version` - Schema version for compatibility\n");
81 output.push_str(" #[inline]\n");
82 output.push_str(" #[must_use]\n");
83 output.push_str(
84 " pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
85 );
86 output.push_str(" Self { buffer, offset, acting_version }\n");
87 output.push_str(" }\n\n");
88
89 for field in &msg.fields {
91 output.push_str(&self.generate_field_getter(field));
92 }
93
94 let mut group_offset = msg.block_length as usize;
96 for group in &msg.groups {
97 output.push_str(&self.generate_group_accessor(group, group_offset, &msg.name));
98 group_offset += 4; }
100
101 output.push_str("}\n\n");
102
103 output.push_str(&format!(
105 "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
106 decoder_name
107 ));
108 output.push_str(&format!(
109 " const TEMPLATE_ID: u16 = {};\n",
110 msg.template_id
111 ));
112 output.push_str(" const SCHEMA_ID: u16 = SCHEMA_ID;\n");
113 output.push_str(" const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
114 output.push_str(&format!(
115 " const BLOCK_LENGTH: u16 = {};\n\n",
116 msg.block_length
117 ));
118
119 output.push_str(
120 " fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
121 );
122 output.push_str(" Self::wrap(buffer, offset, acting_version)\n");
123 output.push_str(" }\n\n");
124
125 output.push_str(" fn encoded_length(&self) -> usize {\n");
126 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
127 output.push_str(" }\n");
128 output.push_str("}\n\n");
129
130 output
131 }
132
133 fn generate_field_getter(&self, field: &ResolvedField) -> String {
135 let mut output = String::new();
136
137 output.push_str(&format!(
138 " /// Field: {} (id={}, offset={}).\n",
139 field.name, field.id, field.offset
140 ));
141 output.push_str(" #[inline(always)]\n");
142 output.push_str(" #[must_use]\n");
143
144 if field.is_array {
145 let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
147 let len = field.array_length.unwrap_or(1);
148
149 if elem_type == "u8" {
150 output.push_str(&format!(
152 " pub fn {}(&self) -> &'a [u8] {{\n",
153 field.getter_name
154 ));
155 output.push_str(&format!(
156 " &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
157 field.offset, field.offset, len
158 ));
159 output.push_str(" }\n\n");
160
161 output.push_str(&format!(
163 " /// Field {} as string (trimmed).\n",
164 field.name
165 ));
166 output.push_str(" #[inline]\n");
167 output.push_str(" #[must_use]\n");
168 output.push_str(&format!(
169 " pub fn {}_as_str(&self) -> &'a str {{\n",
170 field.getter_name
171 ));
172 output.push_str(&format!(
173 " let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
174 field.offset, field.offset, len
175 ));
176 output.push_str(
177 " let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
178 );
179 output.push_str(" std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
180 output.push_str(" }\n\n");
181 } else {
182 output.push_str(&format!(
184 " pub fn {}(&self) -> &'a [u8] {{\n",
185 field.getter_name
186 ));
187 output.push_str(&format!(
188 " &self.buffer[self.offset + {}..self.offset + {}]\n",
189 field.offset,
190 field.offset + field.encoded_length
191 ));
192 output.push_str(" }\n\n");
193 }
194 } else {
195 let rust_type = &field.rust_type;
197 let resolved_type = self.ir.get_type(&field.type_name);
198
199 match resolved_type.map(|t| &t.kind) {
200 Some(TypeKind::Enum { encoding, .. }) => {
201 let read_method = get_read_method(Some(*encoding));
203 output.push_str(&format!(
204 " pub fn {}(&self) -> {} {{\n",
205 field.getter_name, rust_type
206 ));
207 output.push_str(&format!(
208 " {}::from(self.buffer.{}(self.offset + {}))\n",
209 rust_type, read_method, field.offset
210 ));
211 output.push_str(" }\n\n");
212 }
213 Some(TypeKind::Set { encoding, .. }) => {
214 let read_method = get_read_method(Some(*encoding));
216 output.push_str(&format!(
217 " pub fn {}(&self) -> {} {{\n",
218 field.getter_name, rust_type
219 ));
220 output.push_str(&format!(
221 " {}::from_raw(self.buffer.{}(self.offset + {}))\n",
222 rust_type, read_method, field.offset
223 ));
224 output.push_str(" }\n\n");
225 }
226 Some(TypeKind::Composite { .. }) => {
227 output.push_str(&format!(
229 " pub fn {}(&self) -> {}<'a> {{\n",
230 field.getter_name, rust_type
231 ));
232 output.push_str(&format!(
233 " {}::wrap(self.buffer, self.offset + {})\n",
234 rust_type, field.offset
235 ));
236 output.push_str(" }\n\n");
237 }
238 _ => {
239 let read_method = get_read_method(field.primitive_type);
241 output.push_str(&format!(
242 " pub fn {}(&self) -> {} {{\n",
243 field.getter_name, rust_type
244 ));
245 output.push_str(&format!(
246 " self.buffer.{}(self.offset + {})\n",
247 read_method, field.offset
248 ));
249 output.push_str(" }\n\n");
250 }
251 }
252 }
253
254 output
255 }
256
257 fn generate_group_accessor(
259 &self,
260 group: &ResolvedGroup,
261 offset: usize,
262 msg_name: &str,
263 ) -> String {
264 let mut output = String::new();
265 let qualified = format!("{}::{}", to_snake_case(msg_name), group.decoder_name());
266
267 output.push_str(&format!(" /// Access {} repeating group.\n", group.name));
268 output.push_str(" #[inline]\n");
269 output.push_str(" #[must_use]\n");
270 output.push_str(&format!(
271 " pub fn {}(&self) -> {}<'a> {{\n",
272 to_snake_case(&group.name),
273 qualified
274 ));
275 output.push_str(&format!(
276 " {}::wrap(self.buffer, self.offset + {})\n",
277 qualified, offset
278 ));
279 output.push_str(" }\n\n");
280
281 output
282 }
283
284 fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
286 let mut output = String::new();
287 let encoder_name = msg.encoder_name();
288
289 output.push_str(&format!("/// {} Encoder.\n", msg.name));
291 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
292 output.push_str(" buffer: &'a mut [u8],\n");
293 output.push_str(" offset: usize,\n");
294 output.push_str("}\n\n");
295
296 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
298 output.push_str(&format!(
299 " /// Template ID for this message.\n\
300 pub const TEMPLATE_ID: u16 = {};\n",
301 msg.template_id
302 ));
303 output.push_str(&format!(
304 " /// Block length of the fixed portion.\n\
305 pub const BLOCK_LENGTH: u16 = {};\n\n",
306 msg.block_length
307 ));
308
309 output.push_str(" /// Wraps a buffer for encoding, writing the header.\n");
311 output.push_str(" #[inline]\n");
312 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
313 output.push_str(" let mut encoder = Self { buffer, offset };\n");
314 output.push_str(" encoder.write_header();\n");
315 output.push_str(" encoder\n");
316 output.push_str(" }\n\n");
317
318 output.push_str(" fn write_header(&mut self) {\n");
320 output.push_str(" let header = MessageHeader {\n");
321 output.push_str(" block_length: Self::BLOCK_LENGTH,\n");
322 output.push_str(" template_id: Self::TEMPLATE_ID,\n");
323 output.push_str(" schema_id: SCHEMA_ID,\n");
324 output.push_str(" version: SCHEMA_VERSION,\n");
325 output.push_str(" };\n");
326 output.push_str(" header.encode(self.buffer, self.offset);\n");
327 output.push_str(" }\n\n");
328
329 output.push_str(" /// Returns the encoded length of the message.\n");
331 output.push_str(" #[must_use]\n");
332 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
333 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
334 output.push_str(" }\n\n");
335
336 for field in &msg.fields {
338 output.push_str(&self.generate_field_setter(field));
339 }
340
341 output.push_str("}\n\n");
342
343 output
344 }
345
346 fn generate_field_setter(&self, field: &ResolvedField) -> String {
348 let mut output = String::new();
349 let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
350
351 output.push_str(&format!(
352 " /// Set field: {} (id={}, offset={}).\n",
353 field.name, field.id, field.offset
354 ));
355 output.push_str(" #[inline(always)]\n");
356
357 if field.is_array {
358 let len = field.array_length.unwrap_or(field.encoded_length);
360
361 output.push_str(&format!(
362 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
363 field.setter_name
364 ));
365 output.push_str(&format!(
366 " let copy_len = value.len().min({});\n",
367 len
368 ));
369 output.push_str(&format!(
370 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
371 field_offset, field_offset
372 ));
373 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
374 output.push_str(&format!(" if copy_len < {} {{\n", len));
375 output.push_str(&format!(
376 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
377 field_offset, field_offset, len
378 ));
379 output.push_str(" }\n");
380 output.push_str(" self\n");
381 output.push_str(" }\n\n");
382 } else {
383 let rust_type = &field.rust_type;
385 let resolved_type = self.ir.get_type(&field.type_name);
386
387 match resolved_type.map(|t| &t.kind) {
388 Some(TypeKind::Enum { encoding, .. }) => {
389 let write_method = get_write_method(Some(*encoding));
391 let prim_type = encoding.rust_type();
392 output.push_str(&format!(
393 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
394 field.setter_name, rust_type
395 ));
396 output.push_str(&format!(
397 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
398 write_method, field_offset, prim_type
399 ));
400 output.push_str(" self\n");
401 output.push_str(" }\n\n");
402 }
403 Some(TypeKind::Set { encoding, .. }) => {
404 let write_method = get_write_method(Some(*encoding));
406 output.push_str(&format!(
407 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
408 field.setter_name, rust_type
409 ));
410 output.push_str(&format!(
411 " self.buffer.{}(self.offset + {}, value.raw());\n",
412 write_method, field_offset
413 ));
414 output.push_str(" self\n");
415 output.push_str(" }\n\n");
416 }
417 Some(TypeKind::Composite { .. }) => {
418 output.push_str(&format!(
420 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
421 field.setter_name, rust_type
422 ));
423 output.push_str(&format!(
424 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
425 rust_type, field_offset
426 ));
427 output.push_str(" }\n\n");
428 }
429 _ => {
430 let write_method = get_write_method(field.primitive_type);
432 output.push_str(&format!(
433 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
434 field.setter_name, rust_type
435 ));
436 output.push_str(&format!(
437 " self.buffer.{}(self.offset + {}, value);\n",
438 write_method, field_offset
439 ));
440 output.push_str(" self\n");
441 output.push_str(" }\n\n");
442 }
443 }
444 }
445
446 output
447 }
448
449 fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
451 let mut output = String::new();
452 let decoder_name = group.decoder_name();
453 let entry_name = group.entry_decoder_name();
454
455 output.push_str(&format!("/// {} Group Decoder.\n", group.name));
457 output.push_str("#[derive(Debug, Clone, Copy)]\n");
458 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
459 output.push_str(" buffer: &'a [u8],\n");
460 output.push_str(" block_length: u16,\n");
461 output.push_str(" count: u16,\n");
462 output.push_str(" index: u16,\n");
463 output.push_str(" offset: usize,\n");
464 output.push_str("}\n\n");
465
466 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
468 output.push_str(" /// Wraps a buffer at the group header position.\n");
469 output.push_str(" #[must_use]\n");
470 output.push_str(" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
471 output.push_str(" let header = GroupHeader::wrap(buffer, offset);\n");
472 output.push_str(" Self {\n");
473 output.push_str(" buffer,\n");
474 output.push_str(" block_length: header.block_length,\n");
475 output.push_str(" count: header.num_in_group,\n");
476 output.push_str(" index: 0,\n");
477 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
478 output.push_str(" }\n");
479 output.push_str(" }\n\n");
480
481 output.push_str(" /// Returns the number of entries in the group.\n");
482 output.push_str(" #[must_use]\n");
483 output.push_str(" pub const fn count(&self) -> u16 {\n");
484 output.push_str(" self.count\n");
485 output.push_str(" }\n\n");
486
487 output.push_str(" /// Returns true if the group is empty.\n");
488 output.push_str(" #[must_use]\n");
489 output.push_str(" pub const fn is_empty(&self) -> bool {\n");
490 output.push_str(" self.count == 0\n");
491 output.push_str(" }\n");
492 output.push_str("}\n\n");
493
494 output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
496 output.push_str(&format!(" type Item = {}<'a>;\n\n", entry_name));
497 output.push_str(" fn next(&mut self) -> Option<Self::Item> {\n");
498 output.push_str(" if self.index >= self.count {\n");
499 output.push_str(" return None;\n");
500 output.push_str(" }\n");
501 output.push_str(&format!(
502 " let entry = {}::wrap(self.buffer, self.offset);\n",
503 entry_name
504 ));
505 output.push_str(" self.offset += self.block_length as usize;\n");
506 output.push_str(" self.index += 1;\n");
507 output.push_str(" Some(entry)\n");
508 output.push_str(" }\n\n");
509
510 output.push_str(" fn size_hint(&self) -> (usize, Option<usize>) {\n");
511 output.push_str(" let remaining = (self.count - self.index) as usize;\n");
512 output.push_str(" (remaining, Some(remaining))\n");
513 output.push_str(" }\n");
514 output.push_str("}\n\n");
515
516 output.push_str(&format!(
517 "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
518 decoder_name
519 ));
520
521 output.push_str(&self.generate_entry_decoder(group));
523
524 for nested in &group.nested_groups {
526 output.push_str(&self.generate_group_decoder(nested));
527 }
528
529 output
530 }
531
532 fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
534 let mut output = String::new();
535 let entry_name = group.entry_decoder_name();
536
537 output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
538 output.push_str("#[derive(Debug, Clone, Copy)]\n");
539 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
540 output.push_str(" buffer: &'a [u8],\n");
541 output.push_str(" offset: usize,\n");
542 output.push_str("}\n\n");
543
544 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
545 output.push_str(" fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
546 output.push_str(" Self { buffer, offset }\n");
547 output.push_str(" }\n\n");
548
549 for field in &group.fields {
551 output.push_str(&self.generate_field_getter(field));
552 }
553
554 output.push_str("}\n\n");
555
556 output
557 }
558}
559
560fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
562 match prim {
563 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
564 Some(PrimitiveType::Int8) => "get_i8",
565 Some(PrimitiveType::Uint16) => "get_u16_le",
566 Some(PrimitiveType::Int16) => "get_i16_le",
567 Some(PrimitiveType::Uint32) => "get_u32_le",
568 Some(PrimitiveType::Int32) => "get_i32_le",
569 Some(PrimitiveType::Uint64) => "get_u64_le",
570 Some(PrimitiveType::Int64) => "get_i64_le",
571 Some(PrimitiveType::Float) => "get_f32_le",
572 Some(PrimitiveType::Double) => "get_f64_le",
573 None => "get_u64_le",
574 }
575}
576
577fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
579 match prim {
580 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
581 Some(PrimitiveType::Int8) => "put_i8",
582 Some(PrimitiveType::Uint16) => "put_u16_le",
583 Some(PrimitiveType::Int16) => "put_i16_le",
584 Some(PrimitiveType::Uint32) => "put_u32_le",
585 Some(PrimitiveType::Int32) => "put_i32_le",
586 Some(PrimitiveType::Uint64) => "put_u64_le",
587 Some(PrimitiveType::Int64) => "put_i64_le",
588 Some(PrimitiveType::Float) => "put_f32_le",
589 Some(PrimitiveType::Double) => "put_f64_le",
590 None => "put_u64_le",
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use ironsbe_schema::{SchemaIr, parse_schema};
598
599 fn schema_with_shared_group_name() -> String {
600 r#"<?xml version="1.0" encoding="UTF-8"?>
601<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
602 package="test" id="1" version="1" byteOrder="littleEndian">
603 <types>
604 <type name="uint64" primitiveType="uint64"/>
605 </types>
606 <sbe:message name="CreateRfqResponse" id="21" blockLength="8">
607 <field name="value" id="1" type="uint64" offset="0"/>
608 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
609 <field name="price" id="200" type="uint64" offset="0"/>
610 </group>
611 </sbe:message>
612 <sbe:message name="GetRfqResponse" id="23" blockLength="8">
613 <field name="value" id="1" type="uint64" offset="0"/>
614 <group name="quotes" id="100" dimensionType="groupSizeEncoding" blockLength="8">
615 <field name="price" id="200" type="uint64" offset="0"/>
616 </group>
617 </sbe:message>
618</sbe:messageSchema>"#
619 .to_string()
620 }
621
622 #[test]
623 fn test_duplicate_group_name_generates_scoped_modules() {
624 let xml = schema_with_shared_group_name();
625 let schema = parse_schema(&xml).expect("Failed to parse schema");
626 let ir = SchemaIr::from_schema(&schema);
627 let msg_gen = MessageGenerator::new(&ir);
628 let code = msg_gen.generate();
629
630 assert!(
631 code.contains("pub mod create_rfq_response {"),
632 "expected module for CreateRfqResponse groups"
633 );
634 assert!(
635 code.contains("pub mod get_rfq_response {"),
636 "expected module for GetRfqResponse groups"
637 );
638
639 let occurrences = code.matches("pub struct QuotesGroupDecoder").count();
640 assert_eq!(
641 occurrences, 2,
642 "expected one QuotesGroupDecoder per message module, got {occurrences}"
643 );
644 }
645
646 #[test]
647 fn test_group_accessor_uses_qualified_path() {
648 let xml = schema_with_shared_group_name();
649 let schema = parse_schema(&xml).expect("Failed to parse schema");
650 let ir = SchemaIr::from_schema(&schema);
651 let msg_gen = MessageGenerator::new(&ir);
652 let code = msg_gen.generate();
653
654 assert!(
655 code.contains("create_rfq_response::QuotesGroupDecoder"),
656 "accessor in CreateRfqResponse must reference module-qualified type"
657 );
658 assert!(
659 code.contains("get_rfq_response::QuotesGroupDecoder"),
660 "accessor in GetRfqResponse must reference module-qualified type"
661 );
662 }
663}