Skip to main content

ironfix_codegen/
generator.rs

1/******************************************************************************
2   Author: Joaquín Béjar García
3   Email: jb@taunais.com
4   Date: 27/1/26
5******************************************************************************/
6
7//! Code generator for FIX dictionaries.
8//!
9//! Generates Rust source code from FIX dictionary definitions.
10
11use ironfix_dictionary::schema::{Dictionary, FieldDef, FieldType, MessageDef};
12use std::fmt::Write;
13
14/// Configuration for code generation.
15#[derive(Debug, Clone)]
16pub struct GeneratorConfig {
17    /// Whether to generate field constants.
18    pub generate_fields: bool,
19    /// Whether to generate message structs.
20    pub generate_messages: bool,
21    /// Whether to generate component traits.
22    pub generate_components: bool,
23    /// Module visibility (e.g., "pub", "pub(crate)").
24    pub visibility: String,
25}
26
27impl Default for GeneratorConfig {
28    fn default() -> Self {
29        Self {
30            generate_fields: true,
31            generate_messages: true,
32            generate_components: true,
33            visibility: "pub".to_string(),
34        }
35    }
36}
37
38/// Code generator for FIX dictionaries.
39#[derive(Debug)]
40pub struct CodeGenerator {
41    config: GeneratorConfig,
42}
43
44impl CodeGenerator {
45    /// Creates a new code generator with default configuration.
46    #[must_use]
47    pub fn new() -> Self {
48        Self {
49            config: GeneratorConfig::default(),
50        }
51    }
52
53    /// Creates a new code generator with the specified configuration.
54    #[must_use]
55    pub fn with_config(config: GeneratorConfig) -> Self {
56        Self { config }
57    }
58
59    /// Generates Rust source code from a dictionary.
60    ///
61    /// # Arguments
62    /// * `dict` - The FIX dictionary to generate code from
63    ///
64    /// # Returns
65    /// The generated Rust source code as a string.
66    #[must_use]
67    pub fn generate(&self, dict: &Dictionary) -> String {
68        let mut code = String::new();
69
70        // File header
71        writeln!(code, "//! Generated FIX {} definitions.", dict.version).unwrap();
72        writeln!(code, "//!").unwrap();
73        writeln!(
74            code,
75            "//! This file was automatically generated. Do not edit."
76        )
77        .unwrap();
78        writeln!(code).unwrap();
79
80        if self.config.generate_fields {
81            self.generate_fields_module(&mut code, dict);
82        }
83
84        if self.config.generate_messages {
85            self.generate_messages_module(&mut code, dict);
86        }
87
88        code
89    }
90
91    /// Generates the fields module with tag constants.
92    fn generate_fields_module(&self, code: &mut String, dict: &Dictionary) {
93        writeln!(code, "/// Field tag constants.").unwrap();
94        writeln!(code, "{} mod fields {{", self.config.visibility).unwrap();
95
96        let mut fields: Vec<_> = dict.fields().collect();
97        fields.sort_by_key(|f| f.tag);
98
99        for field in fields {
100            self.generate_field_constant(code, field);
101        }
102
103        writeln!(code, "}}").unwrap();
104        writeln!(code).unwrap();
105    }
106
107    /// Generates a single field constant.
108    fn generate_field_constant(&self, code: &mut String, field: &FieldDef) {
109        let const_name = to_screaming_snake_case(&field.name);
110
111        if let Some(ref desc) = field.description {
112            writeln!(code, "    /// {}", desc).unwrap();
113        }
114        writeln!(code, "    pub const {}: u32 = {};", const_name, field.tag).unwrap();
115    }
116
117    /// Generates the messages module with message structs.
118    fn generate_messages_module(&self, code: &mut String, dict: &Dictionary) {
119        writeln!(code, "/// Message type definitions.").unwrap();
120        writeln!(code, "{} mod messages {{", self.config.visibility).unwrap();
121        writeln!(code, "    use super::fields;").unwrap();
122        writeln!(code).unwrap();
123
124        let mut messages: Vec<_> = dict.messages().collect();
125        messages.sort_by(|a, b| a.msg_type.cmp(&b.msg_type));
126
127        for msg in messages {
128            self.generate_message_struct(code, msg, dict);
129        }
130
131        writeln!(code, "}}").unwrap();
132    }
133
134    /// Generates a message struct.
135    fn generate_message_struct(&self, code: &mut String, msg: &MessageDef, dict: &Dictionary) {
136        let struct_name = to_pascal_case(&msg.name);
137
138        writeln!(
139            code,
140            "    /// {} message (MsgType={}).",
141            msg.name, msg.msg_type
142        )
143        .unwrap();
144        writeln!(code, "    #[derive(Debug, Clone)]").unwrap();
145        writeln!(code, "    pub struct {} {{", struct_name).unwrap();
146
147        for field_ref in &msg.fields {
148            if let Some(field_def) = dict.get_field(field_ref.tag) {
149                let field_name = to_snake_case(&field_ref.name);
150                let rust_type = field_type_to_rust(&field_def.field_type);
151
152                if field_ref.required {
153                    writeln!(code, "        pub {}: {},", field_name, rust_type).unwrap();
154                } else {
155                    writeln!(code, "        pub {}: Option<{}>,", field_name, rust_type).unwrap();
156                }
157            }
158        }
159
160        writeln!(code, "    }}").unwrap();
161        writeln!(code).unwrap();
162    }
163}
164
165impl Default for CodeGenerator {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171/// Converts a string to SCREAMING_SNAKE_CASE.
172fn to_screaming_snake_case(s: &str) -> String {
173    let mut result = String::new();
174    let mut prev_lower = false;
175
176    for c in s.chars() {
177        if c.is_uppercase() && prev_lower {
178            result.push('_');
179        }
180        result.push(c.to_ascii_uppercase());
181        prev_lower = c.is_lowercase();
182    }
183
184    result
185}
186
187/// Converts a string to snake_case.
188fn to_snake_case(s: &str) -> String {
189    let mut result = String::new();
190    let mut prev_lower = false;
191
192    for c in s.chars() {
193        if c.is_uppercase() && prev_lower {
194            result.push('_');
195        }
196        result.push(c.to_ascii_lowercase());
197        prev_lower = c.is_lowercase();
198    }
199
200    result
201}
202
203/// Converts a string to PascalCase.
204fn to_pascal_case(s: &str) -> String {
205    let mut result = String::new();
206    let mut capitalize_next = true;
207
208    for c in s.chars() {
209        if c == '_' || c == ' ' {
210            capitalize_next = true;
211        } else if capitalize_next {
212            result.push(c.to_ascii_uppercase());
213            capitalize_next = false;
214        } else {
215            result.push(c);
216        }
217    }
218
219    result
220}
221
222/// Maps FIX field types to Rust types.
223fn field_type_to_rust(field_type: &FieldType) -> &'static str {
224    match field_type {
225        FieldType::Int
226        | FieldType::Length
227        | FieldType::SeqNum
228        | FieldType::NumInGroup
229        | FieldType::TagNum
230        | FieldType::DayOfMonth => "i64",
231        FieldType::Float
232        | FieldType::Qty
233        | FieldType::Price
234        | FieldType::PriceOffset
235        | FieldType::Amt
236        | FieldType::Percentage => "rust_decimal::Decimal",
237        FieldType::Char => "char",
238        FieldType::Boolean => "bool",
239        FieldType::String
240        | FieldType::MultipleCharValue
241        | FieldType::MultipleStringValue
242        | FieldType::Country
243        | FieldType::Currency
244        | FieldType::Exchange
245        | FieldType::Language
246        | FieldType::Pattern
247        | FieldType::Tenor => "String",
248        FieldType::MonthYear
249        | FieldType::UtcTimestamp
250        | FieldType::UtcTimeOnly
251        | FieldType::UtcDateOnly
252        | FieldType::LocalMktDate
253        | FieldType::LocalMktTime
254        | FieldType::TzTimeOnly
255        | FieldType::TzTimestamp => "String",
256        FieldType::Data | FieldType::XmlData => "Vec<u8>",
257        FieldType::Reserved => "String",
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_to_screaming_snake_case() {
267        assert_eq!(to_screaming_snake_case("MsgType"), "MSG_TYPE");
268        assert_eq!(to_screaming_snake_case("ClOrdID"), "CL_ORD_ID");
269        assert_eq!(to_screaming_snake_case("BeginString"), "BEGIN_STRING");
270    }
271
272    #[test]
273    fn test_to_snake_case() {
274        assert_eq!(to_snake_case("MsgType"), "msg_type");
275        assert_eq!(to_snake_case("ClOrdID"), "cl_ord_id");
276    }
277
278    #[test]
279    fn test_to_pascal_case() {
280        assert_eq!(to_pascal_case("new_order_single"), "NewOrderSingle");
281        assert_eq!(to_pascal_case("execution_report"), "ExecutionReport");
282    }
283
284    #[test]
285    fn test_generator_new() {
286        let generator = CodeGenerator::new();
287        assert!(generator.config.generate_fields);
288        assert!(generator.config.generate_messages);
289    }
290}