Skip to main content

ironsbe_codegen/rust/
types.rs

1//! Type code generation.
2
3use ironsbe_schema::ir::{CompositeFieldInfo, SchemaIr, TypeKind, to_pascal_case, to_snake_case};
4use ironsbe_schema::types::PrimitiveType;
5
6/// Generator for type definitions.
7pub struct TypeGenerator<'a> {
8    ir: &'a SchemaIr,
9}
10
11impl<'a> TypeGenerator<'a> {
12    /// Creates a new type generator.
13    #[must_use]
14    pub fn new(ir: &'a SchemaIr) -> Self {
15        Self { ir }
16    }
17
18    /// Generates all type definitions.
19    #[must_use]
20    pub fn generate(&self) -> String {
21        let mut output = String::new();
22
23        for resolved_type in self.ir.types.values() {
24            if let TypeKind::Composite { fields } = &resolved_type.kind {
25                // Skip messageHeader - it's provided by ironsbe_core::header::MessageHeader
26                if resolved_type.name.eq_ignore_ascii_case("messageHeader") {
27                    continue;
28                }
29                output.push_str(&self.generate_composite(
30                    &resolved_type.name,
31                    fields,
32                    resolved_type.encoded_length,
33                ));
34            }
35        }
36
37        output
38    }
39
40    /// Generates a composite type struct with zero-copy decoder and encoder.
41    fn generate_composite(
42        &self,
43        name: &str,
44        fields: &[CompositeFieldInfo],
45        encoded_length: usize,
46    ) -> String {
47        let mut output = String::new();
48        let struct_name = to_pascal_case(name);
49
50        // Generate decoder struct
51        output.push_str(&format!("/// {} Decoder (zero-copy).\n", struct_name));
52        output.push_str("#[derive(Debug, Clone, Copy)]\n");
53        output.push_str(&format!("pub struct {}<'a> {{\n", struct_name));
54        output.push_str("    buffer: &'a [u8],\n");
55        output.push_str("    offset: usize,\n");
56        output.push_str("}\n\n");
57
58        output.push_str(&format!("impl<'a> {}<'a> {{\n", struct_name));
59        output.push_str(&format!(
60            "    /// Encoded length of {} in bytes.\n",
61            struct_name
62        ));
63        output.push_str(&format!(
64            "    pub const ENCODED_LENGTH: usize = {};\n\n",
65            encoded_length
66        ));
67
68        // Constructor
69        output.push_str("    /// Wraps a buffer for zero-copy decoding.\n");
70        output.push_str("    #[inline]\n");
71        output.push_str("    #[must_use]\n");
72        output.push_str("    pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {\n");
73        output.push_str("        Self { buffer, offset }\n");
74        output.push_str("    }\n\n");
75
76        // Field getters
77        for field in fields {
78            let field_name = to_snake_case(&field.name);
79            let rust_type = field.primitive_type.rust_type();
80            let read_method = get_read_method(field.primitive_type);
81
82            output.push_str(&format!("    /// Gets the {} field.\n", field.name));
83            output.push_str("    #[inline(always)]\n");
84            output.push_str("    #[must_use]\n");
85            output.push_str(&format!(
86                "    pub fn {}(&self) -> {} {{\n",
87                field_name, rust_type
88            ));
89            output.push_str(&format!(
90                "        self.buffer.{}(self.offset + {})\n",
91                read_method, field.offset
92            ));
93            output.push_str("    }\n\n");
94        }
95
96        output.push_str("}\n\n");
97
98        // Generate encoder struct
99        output.push_str(&format!("/// {} Encoder.\n", struct_name));
100        output.push_str(&format!("pub struct {}Encoder<'a> {{\n", struct_name));
101        output.push_str("    buffer: &'a mut [u8],\n");
102        output.push_str("    offset: usize,\n");
103        output.push_str("}\n\n");
104
105        output.push_str(&format!("impl<'a> {}Encoder<'a> {{\n", struct_name));
106        output.push_str(&format!(
107            "    /// Encoded length of {} in bytes.\n",
108            struct_name
109        ));
110        output.push_str(&format!(
111            "    pub const ENCODED_LENGTH: usize = {};\n\n",
112            encoded_length
113        ));
114
115        // Constructor
116        output.push_str("    /// Wraps a buffer for encoding.\n");
117        output.push_str("    #[inline]\n");
118        output.push_str("    pub fn wrap(buffer: &'a mut [u8], offset: usize) -> Self {\n");
119        output.push_str("        Self { buffer, offset }\n");
120        output.push_str("    }\n\n");
121
122        // Field setters
123        for field in fields {
124            let field_name = to_snake_case(&field.name);
125            let rust_type = field.primitive_type.rust_type();
126            let write_method = get_write_method(field.primitive_type);
127
128            output.push_str(&format!("    /// Sets the {} field.\n", field.name));
129            output.push_str("    #[inline(always)]\n");
130            output.push_str(&format!(
131                "    pub fn set_{}(&mut self, value: {}) -> &mut Self {{\n",
132                field_name, rust_type
133            ));
134            output.push_str(&format!(
135                "        self.buffer.{}(self.offset + {}, value);\n",
136                write_method, field.offset
137            ));
138            output.push_str("        self\n");
139            output.push_str("    }\n\n");
140        }
141
142        output.push_str("}\n\n");
143
144        output
145    }
146}
147
148/// Gets the read method name for a primitive type.
149fn get_read_method(prim: PrimitiveType) -> &'static str {
150    match prim {
151        PrimitiveType::Char | PrimitiveType::Uint8 => "get_u8",
152        PrimitiveType::Int8 => "get_i8",
153        PrimitiveType::Uint16 => "get_u16_le",
154        PrimitiveType::Int16 => "get_i16_le",
155        PrimitiveType::Uint32 => "get_u32_le",
156        PrimitiveType::Int32 => "get_i32_le",
157        PrimitiveType::Uint64 => "get_u64_le",
158        PrimitiveType::Int64 => "get_i64_le",
159        PrimitiveType::Float => "get_f32_le",
160        PrimitiveType::Double => "get_f64_le",
161    }
162}
163
164/// Gets the write method name for a primitive type.
165fn get_write_method(prim: PrimitiveType) -> &'static str {
166    match prim {
167        PrimitiveType::Char | PrimitiveType::Uint8 => "put_u8",
168        PrimitiveType::Int8 => "put_i8",
169        PrimitiveType::Uint16 => "put_u16_le",
170        PrimitiveType::Int16 => "put_i16_le",
171        PrimitiveType::Uint32 => "put_u32_le",
172        PrimitiveType::Int32 => "put_i32_le",
173        PrimitiveType::Uint64 => "put_u64_le",
174        PrimitiveType::Int64 => "put_i64_le",
175        PrimitiveType::Float => "put_f32_le",
176        PrimitiveType::Double => "put_f64_le",
177    }
178}