Skip to main content

ironsbe_codegen/rust/
enums.rs

1//! Enum and set code generation.
2
3use ironsbe_schema::ir::{
4    EnumVariant, SchemaIr, SetVariant, TypeKind, to_pascal_case, to_screaming_snake_case,
5    to_snake_case,
6};
7use ironsbe_schema::types::PrimitiveType;
8
9/// Generator for enum and set definitions.
10pub struct EnumGenerator<'a> {
11    ir: &'a SchemaIr,
12}
13
14impl<'a> EnumGenerator<'a> {
15    /// Creates a new enum generator.
16    #[must_use]
17    pub fn new(ir: &'a SchemaIr) -> Self {
18        Self { ir }
19    }
20
21    /// Generates all enum and set definitions.
22    #[must_use]
23    pub fn generate(&self) -> String {
24        let mut output = String::new();
25
26        for resolved_type in self.ir.types.values() {
27            match &resolved_type.kind {
28                TypeKind::Enum { encoding, variants } => {
29                    output.push_str(&self.generate_enum(&resolved_type.name, *encoding, variants));
30                }
31                TypeKind::Set { encoding, choices } => {
32                    output.push_str(&self.generate_set(&resolved_type.name, *encoding, choices));
33                }
34                _ => {}
35            }
36        }
37
38        output
39    }
40
41    /// Generates an enum definition.
42    fn generate_enum(
43        &self,
44        name: &str,
45        encoding: PrimitiveType,
46        variants: &[EnumVariant],
47    ) -> String {
48        let mut output = String::new();
49        let rust_name = to_pascal_case(name);
50        let rust_type = encoding.rust_type();
51
52        output.push_str(&format!("/// {} enum.\n", rust_name));
53        output.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]\n");
54        output.push_str(&format!("#[repr({})]\n", rust_type));
55        output.push_str(&format!("pub enum {} {{\n", rust_name));
56
57        // Generate enum variants from schema
58        for variant in variants {
59            let variant_name = to_pascal_case(&variant.name);
60            output.push_str(&format!("    /// {} variant.\n", variant_name));
61            output.push_str(&format!("    {} = {},\n", variant_name, variant.value));
62        }
63
64        output.push_str("}\n\n");
65
66        // Implement From<primitive> -> Enum (safe match, no transmute)
67        // Only generate if variants exist to avoid empty match
68        if !variants.is_empty() {
69            output.push_str(&format!("impl From<{}> for {} {{\n", rust_type, rust_name));
70            output.push_str(&format!("    fn from(value: {}) -> Self {{\n", rust_type));
71            output.push_str("        match value {\n");
72            for variant in variants {
73                let variant_name = to_pascal_case(&variant.name);
74                output.push_str(&format!(
75                    "            {} => Self::{},\n",
76                    variant.value, variant_name
77                ));
78            }
79            // Default to first variant for unknown values
80            let first_name = to_pascal_case(&variants[0].name);
81            output.push_str(&format!("            _ => Self::{},\n", first_name));
82            output.push_str("        }\n");
83            output.push_str("    }\n");
84            output.push_str("}\n\n");
85        }
86
87        // Implement From<Enum> -> primitive
88        if !variants.is_empty() {
89            output.push_str(&format!("impl From<{}> for {} {{\n", rust_name, rust_type));
90            output.push_str(&format!("    fn from(value: {}) -> Self {{\n", rust_name));
91            output.push_str("        value as Self\n");
92            output.push_str("    }\n");
93            output.push_str("}\n\n");
94        }
95
96        output
97    }
98
99    /// Generates a set (bitfield) definition.
100    fn generate_set(&self, name: &str, encoding: PrimitiveType, choices: &[SetVariant]) -> String {
101        let mut output = String::new();
102        let rust_name = to_pascal_case(name);
103        let rust_type = encoding.rust_type();
104
105        output.push_str(&format!("/// {} bitfield set.\n", rust_name));
106        output.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]\n");
107        output.push_str(&format!("pub struct {}({});\n\n", rust_name, rust_type));
108
109        output.push_str(&format!("impl {} {{\n", rust_name));
110
111        // Generate bit position constants for each choice
112        for choice in choices {
113            let const_name = to_screaming_snake_case(&choice.name);
114            output.push_str(&format!(
115                "    /// Bit position for {} choice.\n",
116                choice.name
117            ));
118            output.push_str(&format!(
119                "    pub const {}: u8 = {};\n",
120                const_name, choice.bit_position
121            ));
122        }
123        if !choices.is_empty() {
124            output.push('\n');
125        }
126
127        output.push_str(&format!("    /// Creates a new empty {}.\n", rust_name));
128        output.push_str("    #[must_use]\n");
129        output.push_str("    pub const fn new() -> Self {\n");
130        output.push_str("        Self(0)\n");
131        output.push_str("    }\n\n");
132
133        output.push_str(&format!("    /// Creates from raw {} value.\n", rust_type));
134        output.push_str("    #[must_use]\n");
135        output.push_str(&format!(
136            "    pub const fn from_raw(value: {}) -> Self {{\n",
137            rust_type
138        ));
139        output.push_str("        Self(value)\n");
140        output.push_str("    }\n\n");
141
142        output.push_str("    /// Returns the raw value.\n");
143        output.push_str("    #[must_use]\n");
144        output.push_str(&format!(
145            "    pub const fn raw(&self) -> {} {{\n",
146            rust_type
147        ));
148        output.push_str("        self.0\n");
149        output.push_str("    }\n\n");
150
151        output.push_str("    /// Checks if a bit is set.\n");
152        output.push_str("    #[must_use]\n");
153        output.push_str("    pub const fn is_set(&self, bit: u8) -> bool {\n");
154        output.push_str("        (self.0 >> bit) & 1 != 0\n");
155        output.push_str("    }\n\n");
156
157        output.push_str("    /// Sets a bit.\n");
158        output.push_str("    pub fn set(&mut self, bit: u8) {\n");
159        output.push_str("        self.0 |= 1 << bit;\n");
160        output.push_str("    }\n\n");
161
162        output.push_str("    /// Clears a bit.\n");
163        output.push_str("    pub fn clear(&mut self, bit: u8) {\n");
164        output.push_str("        self.0 &= !(1 << bit);\n");
165        output.push_str("    }\n");
166
167        // Generate named methods for each choice
168        for choice in choices {
169            let method_name = to_snake_case(&choice.name);
170            output.push_str(&format!("\n    /// Checks if {} is set.\n", choice.name));
171            output.push_str("    #[must_use]\n");
172            output.push_str(&format!(
173                "    pub const fn is_{}(&self) -> bool {{\n",
174                method_name
175            ));
176            output.push_str(&format!("        self.is_set({})\n", choice.bit_position));
177            output.push_str("    }\n");
178
179            output.push_str(&format!("\n    /// Sets {}.\n", choice.name));
180            output.push_str(&format!("    pub fn set_{}(&mut self) {{\n", method_name));
181            output.push_str(&format!("        self.set({});\n", choice.bit_position));
182            output.push_str("    }\n");
183
184            output.push_str(&format!("\n    /// Clears {}.\n", choice.name));
185            output.push_str(&format!("    pub fn clear_{}(&mut self) {{\n", method_name));
186            output.push_str(&format!("        self.clear({});\n", choice.bit_position));
187            output.push_str("    }\n");
188        }
189
190        output.push_str("}\n\n");
191
192        output
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use ironsbe_schema::ir::SchemaIr;
200    use ironsbe_schema::parser::parse_schema;
201
202    fn create_test_ir_with_enum() -> SchemaIr {
203        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
204<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
205                   package="test" id="1" version="1" byteOrder="littleEndian">
206    <types>
207        <enum name="Side" encodingType="uint8">
208            <validValue name="Buy">1</validValue>
209            <validValue name="Sell">2</validValue>
210        </enum>
211    </types>
212    <sbe:message name="Test" id="1" blockLength="1">
213        <field name="side" id="1" type="Side" offset="0"/>
214    </sbe:message>
215</sbe:messageSchema>"#;
216
217        let schema = parse_schema(xml).expect("Failed to parse");
218        SchemaIr::from_schema(&schema)
219    }
220
221    fn create_test_ir_with_set() -> SchemaIr {
222        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
223<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
224                   package="test" id="1" version="1" byteOrder="littleEndian">
225    <types>
226        <set name="Flags" encodingType="uint8">
227            <choice name="Active">0</choice>
228            <choice name="Visible">1</choice>
229        </set>
230    </types>
231    <sbe:message name="Test" id="1" blockLength="1">
232        <field name="flags" id="1" type="Flags" offset="0"/>
233    </sbe:message>
234</sbe:messageSchema>"#;
235
236        let schema = parse_schema(xml).expect("Failed to parse");
237        SchemaIr::from_schema(&schema)
238    }
239
240    #[test]
241    fn test_enum_generator_new() {
242        let ir = create_test_ir_with_enum();
243        let generator = EnumGenerator::new(&ir);
244        assert!(!generator.ir.types.is_empty());
245    }
246
247    #[test]
248    fn test_generate_enum() {
249        let ir = create_test_ir_with_enum();
250        let generator = EnumGenerator::new(&ir);
251        let output = generator.generate();
252
253        // Check enum structure
254        assert!(output.contains("pub enum Side"));
255        assert!(output.contains("Buy = 1"));
256        assert!(output.contains("Sell = 2"));
257
258        // Check From impl with match arms
259        assert!(output.contains("impl From<u8> for Side"));
260        assert!(output.contains("1 => Self::Buy"));
261        assert!(output.contains("2 => Self::Sell"));
262        assert!(output.contains("_ => Self::Buy")); // fallback to first variant
263    }
264
265    #[test]
266    fn test_generate_set() {
267        let ir = create_test_ir_with_set();
268        let generator = EnumGenerator::new(&ir);
269        let output = generator.generate();
270
271        // Check struct and basic methods
272        assert!(output.contains("struct"));
273        assert!(output.contains("is_set"));
274        assert!(output.contains("set"));
275        assert!(output.contains("clear"));
276
277        // Check specific bit position constants are generated
278        assert!(
279            output.contains("pub const ACTIVE: u8 = 0"),
280            "Should generate ACTIVE constant with bit position 0"
281        );
282        assert!(
283            output.contains("pub const VISIBLE: u8 = 1"),
284            "Should generate VISIBLE constant with bit position 1"
285        );
286
287        // Check specific named methods are generated
288        assert!(
289            output.contains("fn is_active(&self)"),
290            "Should generate is_active method"
291        );
292        assert!(
293            output.contains("fn set_active(&mut self)"),
294            "Should generate set_active method"
295        );
296        assert!(
297            output.contains("fn clear_active(&mut self)"),
298            "Should generate clear_active method"
299        );
300    }
301
302    #[test]
303    fn test_generate_empty_ir() {
304        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
305<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
306                   package="test" id="1" version="1" byteOrder="littleEndian">
307    <types>
308        <type name="uint64" primitiveType="uint64"/>
309    </types>
310    <sbe:message name="Test" id="1" blockLength="8">
311        <field name="value" id="1" type="uint64" offset="0"/>
312    </sbe:message>
313</sbe:messageSchema>"#;
314
315        let schema = parse_schema(xml).expect("Failed to parse");
316        let ir = SchemaIr::from_schema(&schema);
317        let generator = EnumGenerator::new(&ir);
318        let output = generator.generate();
319
320        // No enums or sets, should be empty or minimal
321        assert!(!output.contains("enum Side"));
322    }
323}