ironfix_codegen/
generator.rs1use ironfix_dictionary::schema::{Dictionary, FieldDef, FieldType, MessageDef};
12use std::fmt::Write;
13
14#[derive(Debug, Clone)]
16pub struct GeneratorConfig {
17 pub generate_fields: bool,
19 pub generate_messages: bool,
21 pub generate_components: bool,
23 pub visibility: String,
25}
26
27impl Default for GeneratorConfig {
28 fn default() -> Self {
29 Self {
30 generate_fields: true,
31 generate_messages: true,
32 generate_components: true,
33 visibility: "pub".to_string(),
34 }
35 }
36}
37
38#[derive(Debug)]
40pub struct CodeGenerator {
41 config: GeneratorConfig,
42}
43
44impl CodeGenerator {
45 #[must_use]
47 pub fn new() -> Self {
48 Self {
49 config: GeneratorConfig::default(),
50 }
51 }
52
53 #[must_use]
55 pub fn with_config(config: GeneratorConfig) -> Self {
56 Self { config }
57 }
58
59 #[must_use]
67 pub fn generate(&self, dict: &Dictionary) -> String {
68 let mut code = String::new();
69
70 writeln!(code, "//! Generated FIX {} definitions.", dict.version).unwrap();
72 writeln!(code, "//!").unwrap();
73 writeln!(
74 code,
75 "//! This file was automatically generated. Do not edit."
76 )
77 .unwrap();
78 writeln!(code).unwrap();
79
80 if self.config.generate_fields {
81 self.generate_fields_module(&mut code, dict);
82 }
83
84 if self.config.generate_messages {
85 self.generate_messages_module(&mut code, dict);
86 }
87
88 code
89 }
90
91 fn generate_fields_module(&self, code: &mut String, dict: &Dictionary) {
93 writeln!(code, "/// Field tag constants.").unwrap();
94 writeln!(code, "{} mod fields {{", self.config.visibility).unwrap();
95
96 let mut fields: Vec<_> = dict.fields().collect();
97 fields.sort_by_key(|f| f.tag);
98
99 for field in fields {
100 self.generate_field_constant(code, field);
101 }
102
103 writeln!(code, "}}").unwrap();
104 writeln!(code).unwrap();
105 }
106
107 fn generate_field_constant(&self, code: &mut String, field: &FieldDef) {
109 let const_name = to_screaming_snake_case(&field.name);
110
111 if let Some(ref desc) = field.description {
112 writeln!(code, " /// {}", desc).unwrap();
113 }
114 writeln!(code, " pub const {}: u32 = {};", const_name, field.tag).unwrap();
115 }
116
117 fn generate_messages_module(&self, code: &mut String, dict: &Dictionary) {
119 writeln!(code, "/// Message type definitions.").unwrap();
120 writeln!(code, "{} mod messages {{", self.config.visibility).unwrap();
121 writeln!(code, " use super::fields;").unwrap();
122 writeln!(code).unwrap();
123
124 let mut messages: Vec<_> = dict.messages().collect();
125 messages.sort_by(|a, b| a.msg_type.cmp(&b.msg_type));
126
127 for msg in messages {
128 self.generate_message_struct(code, msg, dict);
129 }
130
131 writeln!(code, "}}").unwrap();
132 }
133
134 fn generate_message_struct(&self, code: &mut String, msg: &MessageDef, dict: &Dictionary) {
136 let struct_name = to_pascal_case(&msg.name);
137
138 writeln!(
139 code,
140 " /// {} message (MsgType={}).",
141 msg.name, msg.msg_type
142 )
143 .unwrap();
144 writeln!(code, " #[derive(Debug, Clone)]").unwrap();
145 writeln!(code, " pub struct {} {{", struct_name).unwrap();
146
147 for field_ref in &msg.fields {
148 if let Some(field_def) = dict.get_field(field_ref.tag) {
149 let field_name = to_snake_case(&field_ref.name);
150 let rust_type = field_type_to_rust(&field_def.field_type);
151
152 if field_ref.required {
153 writeln!(code, " pub {}: {},", field_name, rust_type).unwrap();
154 } else {
155 writeln!(code, " pub {}: Option<{}>,", field_name, rust_type).unwrap();
156 }
157 }
158 }
159
160 writeln!(code, " }}").unwrap();
161 writeln!(code).unwrap();
162 }
163}
164
165impl Default for CodeGenerator {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171fn to_screaming_snake_case(s: &str) -> String {
173 let mut result = String::new();
174 let mut prev_lower = false;
175
176 for c in s.chars() {
177 if c.is_uppercase() && prev_lower {
178 result.push('_');
179 }
180 result.push(c.to_ascii_uppercase());
181 prev_lower = c.is_lowercase();
182 }
183
184 result
185}
186
187fn to_snake_case(s: &str) -> String {
189 let mut result = String::new();
190 let mut prev_lower = false;
191
192 for c in s.chars() {
193 if c.is_uppercase() && prev_lower {
194 result.push('_');
195 }
196 result.push(c.to_ascii_lowercase());
197 prev_lower = c.is_lowercase();
198 }
199
200 result
201}
202
203fn to_pascal_case(s: &str) -> String {
205 let mut result = String::new();
206 let mut capitalize_next = true;
207
208 for c in s.chars() {
209 if c == '_' || c == ' ' {
210 capitalize_next = true;
211 } else if capitalize_next {
212 result.push(c.to_ascii_uppercase());
213 capitalize_next = false;
214 } else {
215 result.push(c);
216 }
217 }
218
219 result
220}
221
222fn field_type_to_rust(field_type: &FieldType) -> &'static str {
224 match field_type {
225 FieldType::Int
226 | FieldType::Length
227 | FieldType::SeqNum
228 | FieldType::NumInGroup
229 | FieldType::TagNum
230 | FieldType::DayOfMonth => "i64",
231 FieldType::Float
232 | FieldType::Qty
233 | FieldType::Price
234 | FieldType::PriceOffset
235 | FieldType::Amt
236 | FieldType::Percentage => "rust_decimal::Decimal",
237 FieldType::Char => "char",
238 FieldType::Boolean => "bool",
239 FieldType::String
240 | FieldType::MultipleCharValue
241 | FieldType::MultipleStringValue
242 | FieldType::Country
243 | FieldType::Currency
244 | FieldType::Exchange
245 | FieldType::Language
246 | FieldType::Pattern
247 | FieldType::Tenor => "String",
248 FieldType::MonthYear
249 | FieldType::UtcTimestamp
250 | FieldType::UtcTimeOnly
251 | FieldType::UtcDateOnly
252 | FieldType::LocalMktDate
253 | FieldType::LocalMktTime
254 | FieldType::TzTimeOnly
255 | FieldType::TzTimestamp => "String",
256 FieldType::Data | FieldType::XmlData => "Vec<u8>",
257 FieldType::Reserved => "String",
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_to_screaming_snake_case() {
267 assert_eq!(to_screaming_snake_case("MsgType"), "MSG_TYPE");
268 assert_eq!(to_screaming_snake_case("ClOrdID"), "CL_ORD_ID");
269 assert_eq!(to_screaming_snake_case("BeginString"), "BEGIN_STRING");
270 }
271
272 #[test]
273 fn test_to_snake_case() {
274 assert_eq!(to_snake_case("MsgType"), "msg_type");
275 assert_eq!(to_snake_case("ClOrdID"), "cl_ord_id");
276 }
277
278 #[test]
279 fn test_to_pascal_case() {
280 assert_eq!(to_pascal_case("new_order_single"), "NewOrderSingle");
281 assert_eq!(to_pascal_case("execution_report"), "ExecutionReport");
282 }
283
284 #[test]
285 fn test_generator_new() {
286 let generator = CodeGenerator::new();
287 assert!(generator.config.generate_fields);
288 assert!(generator.config.generate_messages);
289 }
290}