1use ironsbe_schema::ir::{
4 ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, TypeKind, to_snake_case,
5};
6use ironsbe_schema::types::PrimitiveType;
7
8pub struct MessageGenerator<'a> {
10 ir: &'a SchemaIr,
11}
12
13impl<'a> MessageGenerator<'a> {
14 #[must_use]
16 pub fn new(ir: &'a SchemaIr) -> Self {
17 Self { ir }
18 }
19
20 #[must_use]
22 pub fn generate(&self) -> String {
23 let mut output = String::new();
24
25 for msg in &self.ir.messages {
26 output.push_str(&self.generate_decoder(msg));
27 output.push_str(&self.generate_encoder(msg));
28
29 for group in &msg.groups {
31 output.push_str(&self.generate_group_decoder(group));
32 }
33 }
34
35 output
36 }
37
38 fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
40 let mut output = String::new();
41 let decoder_name = msg.decoder_name();
42
43 output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
45 output.push_str("#[derive(Debug, Clone, Copy)]\n");
46 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
47 output.push_str(" buffer: &'a [u8],\n");
48 output.push_str(" offset: usize,\n");
49 output.push_str(" acting_version: u16,\n");
50 output.push_str("}\n\n");
51
52 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
54 output.push_str(&format!(
55 " /// Template ID for this message.\n\
56 pub const TEMPLATE_ID: u16 = {};\n",
57 msg.template_id
58 ));
59 output.push_str(&format!(
60 " /// Block length of the fixed portion.\n\
61 pub const BLOCK_LENGTH: u16 = {};\n\n",
62 msg.block_length
63 ));
64
65 output.push_str(" /// Wraps a buffer for zero-copy decoding.\n");
67 output.push_str(" ///\n");
68 output.push_str(" /// # Arguments\n");
69 output.push_str(" /// * `buffer` - Buffer containing the message\n");
70 output.push_str(
71 " /// * `offset` - Offset to the start of the root block (after header)\n",
72 );
73 output.push_str(" /// * `acting_version` - Schema version for compatibility\n");
74 output.push_str(" #[inline]\n");
75 output.push_str(" #[must_use]\n");
76 output.push_str(
77 " pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
78 );
79 output.push_str(" Self { buffer, offset, acting_version }\n");
80 output.push_str(" }\n\n");
81
82 for field in &msg.fields {
84 output.push_str(&self.generate_field_getter(field));
85 }
86
87 let mut group_offset = msg.block_length as usize;
89 for group in &msg.groups {
90 output.push_str(&self.generate_group_accessor(group, group_offset));
91 group_offset += 4; }
93
94 output.push_str("}\n\n");
95
96 output.push_str(&format!(
98 "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
99 decoder_name
100 ));
101 output.push_str(&format!(
102 " const TEMPLATE_ID: u16 = {};\n",
103 msg.template_id
104 ));
105 output.push_str(" const SCHEMA_ID: u16 = SCHEMA_ID;\n");
106 output.push_str(" const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
107 output.push_str(&format!(
108 " const BLOCK_LENGTH: u16 = {};\n\n",
109 msg.block_length
110 ));
111
112 output.push_str(
113 " fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
114 );
115 output.push_str(" Self::wrap(buffer, offset, acting_version)\n");
116 output.push_str(" }\n\n");
117
118 output.push_str(" fn encoded_length(&self) -> usize {\n");
119 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
120 output.push_str(" }\n");
121 output.push_str("}\n\n");
122
123 output
124 }
125
126 fn generate_field_getter(&self, field: &ResolvedField) -> String {
128 let mut output = String::new();
129
130 output.push_str(&format!(
131 " /// Field: {} (id={}, offset={}).\n",
132 field.name, field.id, field.offset
133 ));
134 output.push_str(" #[inline(always)]\n");
135 output.push_str(" #[must_use]\n");
136
137 if field.is_array {
138 let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
140 let len = field.array_length.unwrap_or(1);
141
142 if elem_type == "u8" {
143 output.push_str(&format!(
145 " pub fn {}(&self) -> &'a [u8] {{\n",
146 field.getter_name
147 ));
148 output.push_str(&format!(
149 " &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
150 field.offset, field.offset, len
151 ));
152 output.push_str(" }\n\n");
153
154 output.push_str(&format!(
156 " /// Field {} as string (trimmed).\n",
157 field.name
158 ));
159 output.push_str(" #[inline]\n");
160 output.push_str(" #[must_use]\n");
161 output.push_str(&format!(
162 " pub fn {}_as_str(&self) -> &'a str {{\n",
163 field.getter_name
164 ));
165 output.push_str(&format!(
166 " let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
167 field.offset, field.offset, len
168 ));
169 output.push_str(
170 " let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
171 );
172 output.push_str(" std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
173 output.push_str(" }\n\n");
174 } else {
175 output.push_str(&format!(
177 " pub fn {}(&self) -> &'a [u8] {{\n",
178 field.getter_name
179 ));
180 output.push_str(&format!(
181 " &self.buffer[self.offset + {}..self.offset + {}]\n",
182 field.offset,
183 field.offset + field.encoded_length
184 ));
185 output.push_str(" }\n\n");
186 }
187 } else {
188 let rust_type = &field.rust_type;
190 let resolved_type = self.ir.get_type(&field.type_name);
191
192 match resolved_type.map(|t| &t.kind) {
193 Some(TypeKind::Enum { encoding, .. }) => {
194 let read_method = get_read_method(Some(*encoding));
196 output.push_str(&format!(
197 " pub fn {}(&self) -> {} {{\n",
198 field.getter_name, rust_type
199 ));
200 output.push_str(&format!(
201 " {}::from(self.buffer.{}(self.offset + {}))\n",
202 rust_type, read_method, field.offset
203 ));
204 output.push_str(" }\n\n");
205 }
206 Some(TypeKind::Set { encoding, .. }) => {
207 let read_method = get_read_method(Some(*encoding));
209 output.push_str(&format!(
210 " pub fn {}(&self) -> {} {{\n",
211 field.getter_name, rust_type
212 ));
213 output.push_str(&format!(
214 " {}::from_raw(self.buffer.{}(self.offset + {}))\n",
215 rust_type, read_method, field.offset
216 ));
217 output.push_str(" }\n\n");
218 }
219 Some(TypeKind::Composite { .. }) => {
220 output.push_str(&format!(
222 " pub fn {}(&self) -> {}<'a> {{\n",
223 field.getter_name, rust_type
224 ));
225 output.push_str(&format!(
226 " {}::wrap(self.buffer, self.offset + {})\n",
227 rust_type, field.offset
228 ));
229 output.push_str(" }\n\n");
230 }
231 _ => {
232 let read_method = get_read_method(field.primitive_type);
234 output.push_str(&format!(
235 " pub fn {}(&self) -> {} {{\n",
236 field.getter_name, rust_type
237 ));
238 output.push_str(&format!(
239 " self.buffer.{}(self.offset + {})\n",
240 read_method, field.offset
241 ));
242 output.push_str(" }\n\n");
243 }
244 }
245 }
246
247 output
248 }
249
250 fn generate_group_accessor(&self, group: &ResolvedGroup, offset: usize) -> String {
252 let mut output = String::new();
253 let group_decoder = group.decoder_name();
254
255 output.push_str(&format!(" /// Access {} repeating group.\n", group.name));
256 output.push_str(" #[inline]\n");
257 output.push_str(" #[must_use]\n");
258 output.push_str(&format!(
259 " pub fn {}(&self) -> {}<'a> {{\n",
260 to_snake_case(&group.name),
261 group_decoder
262 ));
263 output.push_str(&format!(
264 " {}::wrap(self.buffer, self.offset + {})\n",
265 group_decoder, offset
266 ));
267 output.push_str(" }\n\n");
268
269 output
270 }
271
272 fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
274 let mut output = String::new();
275 let encoder_name = msg.encoder_name();
276
277 output.push_str(&format!("/// {} Encoder.\n", msg.name));
279 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
280 output.push_str(" buffer: &'a mut [u8],\n");
281 output.push_str(" offset: usize,\n");
282 output.push_str("}\n\n");
283
284 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
286 output.push_str(&format!(
287 " /// Template ID for this message.\n\
288 pub const TEMPLATE_ID: u16 = {};\n",
289 msg.template_id
290 ));
291 output.push_str(&format!(
292 " /// Block length of the fixed portion.\n\
293 pub const BLOCK_LENGTH: u16 = {};\n\n",
294 msg.block_length
295 ));
296
297 output.push_str(" /// Wraps a buffer for encoding, writing the header.\n");
299 output.push_str(" #[inline]\n");
300 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
301 output.push_str(" let mut encoder = Self { buffer, offset };\n");
302 output.push_str(" encoder.write_header();\n");
303 output.push_str(" encoder\n");
304 output.push_str(" }\n\n");
305
306 output.push_str(" fn write_header(&mut self) {\n");
308 output.push_str(" let header = MessageHeader {\n");
309 output.push_str(" block_length: Self::BLOCK_LENGTH,\n");
310 output.push_str(" template_id: Self::TEMPLATE_ID,\n");
311 output.push_str(" schema_id: SCHEMA_ID,\n");
312 output.push_str(" version: SCHEMA_VERSION,\n");
313 output.push_str(" };\n");
314 output.push_str(" header.encode(self.buffer, self.offset);\n");
315 output.push_str(" }\n\n");
316
317 output.push_str(" /// Returns the encoded length of the message.\n");
319 output.push_str(" #[must_use]\n");
320 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
321 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
322 output.push_str(" }\n\n");
323
324 for field in &msg.fields {
326 output.push_str(&self.generate_field_setter(field));
327 }
328
329 output.push_str("}\n\n");
330
331 output
332 }
333
334 fn generate_field_setter(&self, field: &ResolvedField) -> String {
336 let mut output = String::new();
337 let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
338
339 output.push_str(&format!(
340 " /// Set field: {} (id={}, offset={}).\n",
341 field.name, field.id, field.offset
342 ));
343 output.push_str(" #[inline(always)]\n");
344
345 if field.is_array {
346 let len = field.array_length.unwrap_or(field.encoded_length);
348
349 output.push_str(&format!(
350 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
351 field.setter_name
352 ));
353 output.push_str(&format!(
354 " let copy_len = value.len().min({});\n",
355 len
356 ));
357 output.push_str(&format!(
358 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
359 field_offset, field_offset
360 ));
361 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
362 output.push_str(&format!(" if copy_len < {} {{\n", len));
363 output.push_str(&format!(
364 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
365 field_offset, field_offset, len
366 ));
367 output.push_str(" }\n");
368 output.push_str(" self\n");
369 output.push_str(" }\n\n");
370 } else {
371 let rust_type = &field.rust_type;
373 let resolved_type = self.ir.get_type(&field.type_name);
374
375 match resolved_type.map(|t| &t.kind) {
376 Some(TypeKind::Enum { encoding, .. }) => {
377 let write_method = get_write_method(Some(*encoding));
379 let prim_type = encoding.rust_type();
380 output.push_str(&format!(
381 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
382 field.setter_name, rust_type
383 ));
384 output.push_str(&format!(
385 " self.buffer.{}(self.offset + {}, {}::from(value));\n",
386 write_method, field_offset, prim_type
387 ));
388 output.push_str(" self\n");
389 output.push_str(" }\n\n");
390 }
391 Some(TypeKind::Set { encoding, .. }) => {
392 let write_method = get_write_method(Some(*encoding));
394 output.push_str(&format!(
395 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
396 field.setter_name, rust_type
397 ));
398 output.push_str(&format!(
399 " self.buffer.{}(self.offset + {}, value.raw());\n",
400 write_method, field_offset
401 ));
402 output.push_str(" self\n");
403 output.push_str(" }\n\n");
404 }
405 Some(TypeKind::Composite { .. }) => {
406 output.push_str(&format!(
408 " pub fn {}(&mut self) -> {}Encoder<'_> {{\n",
409 field.setter_name, rust_type
410 ));
411 output.push_str(&format!(
412 " {}Encoder::wrap(self.buffer, self.offset + {})\n",
413 rust_type, field_offset
414 ));
415 output.push_str(" }\n\n");
416 }
417 _ => {
418 let write_method = get_write_method(field.primitive_type);
420 output.push_str(&format!(
421 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
422 field.setter_name, rust_type
423 ));
424 output.push_str(&format!(
425 " self.buffer.{}(self.offset + {}, value);\n",
426 write_method, field_offset
427 ));
428 output.push_str(" self\n");
429 output.push_str(" }\n\n");
430 }
431 }
432 }
433
434 output
435 }
436
437 fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
439 let mut output = String::new();
440 let decoder_name = group.decoder_name();
441 let entry_name = group.entry_decoder_name();
442
443 output.push_str(&format!("/// {} Group Decoder.\n", group.name));
445 output.push_str("#[derive(Debug, Clone, Copy)]\n");
446 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
447 output.push_str(" buffer: &'a [u8],\n");
448 output.push_str(" block_length: u16,\n");
449 output.push_str(" count: u16,\n");
450 output.push_str(" index: u16,\n");
451 output.push_str(" offset: usize,\n");
452 output.push_str("}\n\n");
453
454 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
456 output.push_str(" /// Wraps a buffer at the group header position.\n");
457 output.push_str(" #[must_use]\n");
458 output.push_str(" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
459 output.push_str(" let header = GroupHeader::wrap(buffer, offset);\n");
460 output.push_str(" Self {\n");
461 output.push_str(" buffer,\n");
462 output.push_str(" block_length: header.block_length,\n");
463 output.push_str(" count: header.num_in_group,\n");
464 output.push_str(" index: 0,\n");
465 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
466 output.push_str(" }\n");
467 output.push_str(" }\n\n");
468
469 output.push_str(" /// Returns the number of entries in the group.\n");
470 output.push_str(" #[must_use]\n");
471 output.push_str(" pub const fn count(&self) -> u16 {\n");
472 output.push_str(" self.count\n");
473 output.push_str(" }\n\n");
474
475 output.push_str(" /// Returns true if the group is empty.\n");
476 output.push_str(" #[must_use]\n");
477 output.push_str(" pub const fn is_empty(&self) -> bool {\n");
478 output.push_str(" self.count == 0\n");
479 output.push_str(" }\n");
480 output.push_str("}\n\n");
481
482 output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
484 output.push_str(&format!(" type Item = {}<'a>;\n\n", entry_name));
485 output.push_str(" fn next(&mut self) -> Option<Self::Item> {\n");
486 output.push_str(" if self.index >= self.count {\n");
487 output.push_str(" return None;\n");
488 output.push_str(" }\n");
489 output.push_str(&format!(
490 " let entry = {}::wrap(self.buffer, self.offset);\n",
491 entry_name
492 ));
493 output.push_str(" self.offset += self.block_length as usize;\n");
494 output.push_str(" self.index += 1;\n");
495 output.push_str(" Some(entry)\n");
496 output.push_str(" }\n\n");
497
498 output.push_str(" fn size_hint(&self) -> (usize, Option<usize>) {\n");
499 output.push_str(" let remaining = (self.count - self.index) as usize;\n");
500 output.push_str(" (remaining, Some(remaining))\n");
501 output.push_str(" }\n");
502 output.push_str("}\n\n");
503
504 output.push_str(&format!(
505 "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
506 decoder_name
507 ));
508
509 output.push_str(&self.generate_entry_decoder(group));
511
512 for nested in &group.nested_groups {
514 output.push_str(&self.generate_group_decoder(nested));
515 }
516
517 output
518 }
519
520 fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
522 let mut output = String::new();
523 let entry_name = group.entry_decoder_name();
524
525 output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
526 output.push_str("#[derive(Debug, Clone, Copy)]\n");
527 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
528 output.push_str(" buffer: &'a [u8],\n");
529 output.push_str(" offset: usize,\n");
530 output.push_str("}\n\n");
531
532 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
533 output.push_str(" fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
534 output.push_str(" Self { buffer, offset }\n");
535 output.push_str(" }\n\n");
536
537 for field in &group.fields {
539 output.push_str(&self.generate_field_getter(field));
540 }
541
542 output.push_str("}\n\n");
543
544 output
545 }
546}
547
548fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
550 match prim {
551 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
552 Some(PrimitiveType::Int8) => "get_i8",
553 Some(PrimitiveType::Uint16) => "get_u16_le",
554 Some(PrimitiveType::Int16) => "get_i16_le",
555 Some(PrimitiveType::Uint32) => "get_u32_le",
556 Some(PrimitiveType::Int32) => "get_i32_le",
557 Some(PrimitiveType::Uint64) => "get_u64_le",
558 Some(PrimitiveType::Int64) => "get_i64_le",
559 Some(PrimitiveType::Float) => "get_f32_le",
560 Some(PrimitiveType::Double) => "get_f64_le",
561 None => "get_u64_le",
562 }
563}
564
565fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
567 match prim {
568 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
569 Some(PrimitiveType::Int8) => "put_i8",
570 Some(PrimitiveType::Uint16) => "put_u16_le",
571 Some(PrimitiveType::Int16) => "put_i16_le",
572 Some(PrimitiveType::Uint32) => "put_u32_le",
573 Some(PrimitiveType::Int32) => "put_i32_le",
574 Some(PrimitiveType::Uint64) => "put_u64_le",
575 Some(PrimitiveType::Int64) => "put_i64_le",
576 Some(PrimitiveType::Float) => "put_f32_le",
577 Some(PrimitiveType::Double) => "put_f64_le",
578 None => "put_u64_le",
579 }
580}