1use ironsbe_schema::ir::{ResolvedField, ResolvedGroup, ResolvedMessage, SchemaIr, to_snake_case};
4use ironsbe_schema::types::PrimitiveType;
5
6pub struct MessageGenerator<'a> {
8 ir: &'a SchemaIr,
9}
10
11impl<'a> MessageGenerator<'a> {
12 #[must_use]
14 pub fn new(ir: &'a SchemaIr) -> Self {
15 Self { ir }
16 }
17
18 #[must_use]
20 pub fn generate(&self) -> String {
21 let mut output = String::new();
22
23 for msg in &self.ir.messages {
24 output.push_str(&self.generate_decoder(msg));
25 output.push_str(&self.generate_encoder(msg));
26
27 for group in &msg.groups {
29 output.push_str(&self.generate_group_decoder(group));
30 }
31 }
32
33 output
34 }
35
36 fn generate_decoder(&self, msg: &ResolvedMessage) -> String {
38 let mut output = String::new();
39 let decoder_name = msg.decoder_name();
40
41 output.push_str(&format!("/// {} Decoder (zero-copy).\n", msg.name));
43 output.push_str("#[derive(Debug, Clone, Copy)]\n");
44 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
45 output.push_str(" buffer: &'a [u8],\n");
46 output.push_str(" offset: usize,\n");
47 output.push_str(" acting_version: u16,\n");
48 output.push_str("}\n\n");
49
50 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
52 output.push_str(&format!(
53 " /// Template ID for this message.\n\
54 pub const TEMPLATE_ID: u16 = {};\n",
55 msg.template_id
56 ));
57 output.push_str(&format!(
58 " /// Block length of the fixed portion.\n\
59 pub const BLOCK_LENGTH: u16 = {};\n\n",
60 msg.block_length
61 ));
62
63 output.push_str(" /// Wraps a buffer for zero-copy decoding.\n");
65 output.push_str(" ///\n");
66 output.push_str(" /// # Arguments\n");
67 output.push_str(" /// * `buffer` - Buffer containing the message\n");
68 output.push_str(
69 " /// * `offset` - Offset to the start of the root block (after header)\n",
70 );
71 output.push_str(" /// * `acting_version` - Schema version for compatibility\n");
72 output.push_str(" #[inline]\n");
73 output.push_str(" #[must_use]\n");
74 output.push_str(
75 " pub fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
76 );
77 output.push_str(" Self { buffer, offset, acting_version }\n");
78 output.push_str(" }\n\n");
79
80 for field in &msg.fields {
82 output.push_str(&self.generate_field_getter(field));
83 }
84
85 let mut group_offset = msg.block_length as usize;
87 for group in &msg.groups {
88 output.push_str(&self.generate_group_accessor(group, group_offset));
89 group_offset += 4; }
91
92 output.push_str("}\n\n");
93
94 output.push_str(&format!(
96 "impl<'a> SbeDecoder<'a> for {}<'a> {{\n",
97 decoder_name
98 ));
99 output.push_str(&format!(
100 " const TEMPLATE_ID: u16 = {};\n",
101 msg.template_id
102 ));
103 output.push_str(" const SCHEMA_ID: u16 = SCHEMA_ID;\n");
104 output.push_str(" const SCHEMA_VERSION: u16 = SCHEMA_VERSION;\n");
105 output.push_str(&format!(
106 " const BLOCK_LENGTH: u16 = {};\n\n",
107 msg.block_length
108 ));
109
110 output.push_str(
111 " fn wrap(buffer: &'a [u8], offset: usize, acting_version: u16) -> Self {\n",
112 );
113 output.push_str(" Self::wrap(buffer, offset, acting_version)\n");
114 output.push_str(" }\n\n");
115
116 output.push_str(" fn encoded_length(&self) -> usize {\n");
117 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
118 output.push_str(" }\n");
119 output.push_str("}\n\n");
120
121 output
122 }
123
124 fn generate_field_getter(&self, field: &ResolvedField) -> String {
126 let mut output = String::new();
127
128 output.push_str(&format!(
129 " /// Field: {} (id={}, offset={}).\n",
130 field.name, field.id, field.offset
131 ));
132 output.push_str(" #[inline(always)]\n");
133 output.push_str(" #[must_use]\n");
134
135 if field.is_array {
136 let elem_type = field.primitive_type.map(|p| p.rust_type()).unwrap_or("u8");
138 let len = field.array_length.unwrap_or(1);
139
140 if elem_type == "u8" {
141 output.push_str(&format!(
143 " pub fn {}(&self) -> &'a [u8] {{\n",
144 field.getter_name
145 ));
146 output.push_str(&format!(
147 " &self.buffer[self.offset + {}..self.offset + {} + {}]\n",
148 field.offset, field.offset, len
149 ));
150 output.push_str(" }\n\n");
151
152 output.push_str(&format!(
154 " /// Field {} as string (trimmed).\n",
155 field.name
156 ));
157 output.push_str(" #[inline]\n");
158 output.push_str(" #[must_use]\n");
159 output.push_str(&format!(
160 " pub fn {}_as_str(&self) -> &'a str {{\n",
161 field.getter_name
162 ));
163 output.push_str(&format!(
164 " let bytes = &self.buffer[self.offset + {}..self.offset + {} + {}];\n",
165 field.offset, field.offset, len
166 ));
167 output.push_str(
168 " let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());\n",
169 );
170 output.push_str(" std::str::from_utf8(&bytes[..end]).unwrap_or(\"\")\n");
171 output.push_str(" }\n\n");
172 } else {
173 output.push_str(&format!(
175 " pub fn {}(&self) -> &'a [u8] {{\n",
176 field.getter_name
177 ));
178 output.push_str(&format!(
179 " &self.buffer[self.offset + {}..self.offset + {}]\n",
180 field.offset,
181 field.offset + field.encoded_length
182 ));
183 output.push_str(" }\n\n");
184 }
185 } else {
186 let rust_type = &field.rust_type;
188 let read_method = get_read_method(field.primitive_type);
189
190 output.push_str(&format!(
191 " pub fn {}(&self) -> {} {{\n",
192 field.getter_name, rust_type
193 ));
194 output.push_str(&format!(
195 " self.buffer.{}(self.offset + {})\n",
196 read_method, field.offset
197 ));
198 output.push_str(" }\n\n");
199 }
200
201 output
202 }
203
204 fn generate_group_accessor(&self, group: &ResolvedGroup, offset: usize) -> String {
206 let mut output = String::new();
207 let group_decoder = group.decoder_name();
208
209 output.push_str(&format!(" /// Access {} repeating group.\n", group.name));
210 output.push_str(" #[inline]\n");
211 output.push_str(" #[must_use]\n");
212 output.push_str(&format!(
213 " pub fn {}(&self) -> {}<'a> {{\n",
214 to_snake_case(&group.name),
215 group_decoder
216 ));
217 output.push_str(&format!(
218 " {}::wrap(self.buffer, self.offset + {})\n",
219 group_decoder, offset
220 ));
221 output.push_str(" }\n\n");
222
223 output
224 }
225
226 fn generate_encoder(&self, msg: &ResolvedMessage) -> String {
228 let mut output = String::new();
229 let encoder_name = msg.encoder_name();
230
231 output.push_str(&format!("/// {} Encoder.\n", msg.name));
233 output.push_str(&format!("pub struct {}<'a> {{\n", encoder_name));
234 output.push_str(" buffer: &'a mut [u8],\n");
235 output.push_str(" offset: usize,\n");
236 output.push_str("}\n\n");
237
238 output.push_str(&format!("impl<'a> {}<'a> {{\n", encoder_name));
240 output.push_str(&format!(
241 " /// Template ID for this message.\n\
242 pub const TEMPLATE_ID: u16 = {};\n",
243 msg.template_id
244 ));
245 output.push_str(&format!(
246 " /// Block length of the fixed portion.\n\
247 pub const BLOCK_LENGTH: u16 = {};\n\n",
248 msg.block_length
249 ));
250
251 output.push_str(" /// Wraps a buffer for encoding, writing the header.\n");
253 output.push_str(" #[inline]\n");
254 output.push_str(" pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
255 output.push_str(" let mut encoder = Self { buffer, offset };\n");
256 output.push_str(" encoder.write_header();\n");
257 output.push_str(" encoder\n");
258 output.push_str(" }\n\n");
259
260 output.push_str(" fn write_header(&mut self) {\n");
262 output.push_str(" let header = MessageHeader {\n");
263 output.push_str(" block_length: Self::BLOCK_LENGTH,\n");
264 output.push_str(" template_id: Self::TEMPLATE_ID,\n");
265 output.push_str(" schema_id: SCHEMA_ID,\n");
266 output.push_str(" version: SCHEMA_VERSION,\n");
267 output.push_str(" };\n");
268 output.push_str(" header.encode(self.buffer, self.offset);\n");
269 output.push_str(" }\n\n");
270
271 output.push_str(" /// Returns the encoded length of the message.\n");
273 output.push_str(" #[must_use]\n");
274 output.push_str(" pub const fn encoded_length(&self) -> usize {\n");
275 output.push_str(" MessageHeader::ENCODED_LENGTH + Self::BLOCK_LENGTH as usize\n");
276 output.push_str(" }\n\n");
277
278 for field in &msg.fields {
280 output.push_str(&self.generate_field_setter(field));
281 }
282
283 output.push_str("}\n\n");
284
285 output
286 }
287
288 fn generate_field_setter(&self, field: &ResolvedField) -> String {
290 let mut output = String::new();
291 let field_offset = format!("MessageHeader::ENCODED_LENGTH + {}", field.offset);
292
293 output.push_str(&format!(
294 " /// Set field: {} (id={}, offset={}).\n",
295 field.name, field.id, field.offset
296 ));
297 output.push_str(" #[inline(always)]\n");
298
299 if field.is_array {
300 let len = field.array_length.unwrap_or(field.encoded_length);
302
303 output.push_str(&format!(
304 " pub fn {}(&mut self, value: &[u8]) -> &mut Self {{\n",
305 field.setter_name
306 ));
307 output.push_str(&format!(
308 " let copy_len = value.len().min({});\n",
309 len
310 ));
311 output.push_str(&format!(
312 " self.buffer[self.offset + {}..self.offset + {} + copy_len]\n",
313 field_offset, field_offset
314 ));
315 output.push_str(" .copy_from_slice(&value[..copy_len]);\n");
316 output.push_str(&format!(" if copy_len < {} {{\n", len));
317 output.push_str(&format!(
318 " self.buffer[self.offset + {} + copy_len..self.offset + {} + {}].fill(0);\n",
319 field_offset, field_offset, len
320 ));
321 output.push_str(" }\n");
322 output.push_str(" self\n");
323 output.push_str(" }\n\n");
324 } else {
325 let rust_type = &field.rust_type;
327 let write_method = get_write_method(field.primitive_type);
328
329 output.push_str(&format!(
330 " pub fn {}(&mut self, value: {}) -> &mut Self {{\n",
331 field.setter_name, rust_type
332 ));
333 output.push_str(&format!(
334 " self.buffer.{}(self.offset + {}, value);\n",
335 write_method, field_offset
336 ));
337 output.push_str(" self\n");
338 output.push_str(" }\n\n");
339 }
340
341 output
342 }
343
344 fn generate_group_decoder(&self, group: &ResolvedGroup) -> String {
346 let mut output = String::new();
347 let decoder_name = group.decoder_name();
348 let entry_name = group.entry_decoder_name();
349
350 output.push_str(&format!("/// {} Group Decoder.\n", group.name));
352 output.push_str("#[derive(Debug, Clone, Copy)]\n");
353 output.push_str(&format!("pub struct {}<'a> {{\n", decoder_name));
354 output.push_str(" buffer: &'a [u8],\n");
355 output.push_str(" block_length: u16,\n");
356 output.push_str(" count: u16,\n");
357 output.push_str(" index: u16,\n");
358 output.push_str(" offset: usize,\n");
359 output.push_str("}\n\n");
360
361 output.push_str(&format!("impl<'a> {}<'a> {{\n", decoder_name));
363 output.push_str(" /// Wraps a buffer at the group header position.\n");
364 output.push_str(" #[must_use]\n");
365 output.push_str(" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
366 output.push_str(" let header = GroupHeader::wrap(buffer, offset);\n");
367 output.push_str(" Self {\n");
368 output.push_str(" buffer,\n");
369 output.push_str(" block_length: header.block_length,\n");
370 output.push_str(" count: header.num_in_group,\n");
371 output.push_str(" index: 0,\n");
372 output.push_str(" offset: offset + GroupHeader::ENCODED_LENGTH,\n");
373 output.push_str(" }\n");
374 output.push_str(" }\n\n");
375
376 output.push_str(" /// Returns the number of entries in the group.\n");
377 output.push_str(" #[must_use]\n");
378 output.push_str(" pub const fn count(&self) -> u16 {\n");
379 output.push_str(" self.count\n");
380 output.push_str(" }\n\n");
381
382 output.push_str(" /// Returns true if the group is empty.\n");
383 output.push_str(" #[must_use]\n");
384 output.push_str(" pub const fn is_empty(&self) -> bool {\n");
385 output.push_str(" self.count == 0\n");
386 output.push_str(" }\n");
387 output.push_str("}\n\n");
388
389 output.push_str(&format!("impl<'a> Iterator for {}<'a> {{\n", decoder_name));
391 output.push_str(&format!(" type Item = {}<'a>;\n\n", entry_name));
392 output.push_str(" fn next(&mut self) -> Option<Self::Item> {\n");
393 output.push_str(" if self.index >= self.count {\n");
394 output.push_str(" return None;\n");
395 output.push_str(" }\n");
396 output.push_str(&format!(
397 " let entry = {}::wrap(self.buffer, self.offset);\n",
398 entry_name
399 ));
400 output.push_str(" self.offset += self.block_length as usize;\n");
401 output.push_str(" self.index += 1;\n");
402 output.push_str(" Some(entry)\n");
403 output.push_str(" }\n\n");
404
405 output.push_str(" fn size_hint(&self) -> (usize, Option<usize>) {\n");
406 output.push_str(" let remaining = (self.count - self.index) as usize;\n");
407 output.push_str(" (remaining, Some(remaining))\n");
408 output.push_str(" }\n");
409 output.push_str("}\n\n");
410
411 output.push_str(&format!(
412 "impl<'a> ExactSizeIterator for {}<'a> {{}}\n\n",
413 decoder_name
414 ));
415
416 output.push_str(&self.generate_entry_decoder(group));
418
419 for nested in &group.nested_groups {
421 output.push_str(&self.generate_group_decoder(nested));
422 }
423
424 output
425 }
426
427 fn generate_entry_decoder(&self, group: &ResolvedGroup) -> String {
429 let mut output = String::new();
430 let entry_name = group.entry_decoder_name();
431
432 output.push_str(&format!("/// {} Entry Decoder.\n", group.name));
433 output.push_str("#[derive(Debug, Clone, Copy)]\n");
434 output.push_str(&format!("pub struct {}<'a> {{\n", entry_name));
435 output.push_str(" buffer: &'a [u8],\n");
436 output.push_str(" offset: usize,\n");
437 output.push_str("}\n\n");
438
439 output.push_str(&format!("impl<'a> {}<'a> {{\n", entry_name));
440 output.push_str(" fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
441 output.push_str(" Self { buffer, offset }\n");
442 output.push_str(" }\n\n");
443
444 for field in &group.fields {
446 output.push_str(&self.generate_field_getter(field));
447 }
448
449 output.push_str("}\n\n");
450
451 output
452 }
453}
454
455fn get_read_method(prim: Option<PrimitiveType>) -> &'static str {
457 match prim {
458 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "get_u8",
459 Some(PrimitiveType::Int8) => "get_i8",
460 Some(PrimitiveType::Uint16) => "get_u16_le",
461 Some(PrimitiveType::Int16) => "get_i16_le",
462 Some(PrimitiveType::Uint32) => "get_u32_le",
463 Some(PrimitiveType::Int32) => "get_i32_le",
464 Some(PrimitiveType::Uint64) => "get_u64_le",
465 Some(PrimitiveType::Int64) => "get_i64_le",
466 Some(PrimitiveType::Float) => "get_f32_le",
467 Some(PrimitiveType::Double) => "get_f64_le",
468 None => "get_u64_le",
469 }
470}
471
472fn get_write_method(prim: Option<PrimitiveType>) -> &'static str {
474 match prim {
475 Some(PrimitiveType::Char) | Some(PrimitiveType::Uint8) => "put_u8",
476 Some(PrimitiveType::Int8) => "put_i8",
477 Some(PrimitiveType::Uint16) => "put_u16_le",
478 Some(PrimitiveType::Int16) => "put_i16_le",
479 Some(PrimitiveType::Uint32) => "put_u32_le",
480 Some(PrimitiveType::Int32) => "put_i32_le",
481 Some(PrimitiveType::Uint64) => "put_u64_le",
482 Some(PrimitiveType::Int64) => "put_i64_le",
483 Some(PrimitiveType::Float) => "put_f32_le",
484 Some(PrimitiveType::Double) => "put_f64_le",
485 None => "put_u64_le",
486 }
487}