Skip to main content

ironsbe_schema/
ir.rs

1//! Intermediate representation for code generation.
2//!
3//! This module provides a flattened, resolved representation of the schema
4//! that is easier to use for code generation.
5
6use crate::types::{PrimitiveType, Schema, TypeDef};
7use std::collections::HashMap;
8
9/// Intermediate representation of a schema for code generation.
10#[derive(Debug, Clone)]
11pub struct SchemaIr {
12    /// Package name.
13    pub package: String,
14    /// Schema ID.
15    pub schema_id: u16,
16    /// Schema version.
17    pub schema_version: u16,
18    /// Resolved types with their full information.
19    pub types: HashMap<String, ResolvedType>,
20    /// Messages with resolved field types.
21    pub messages: Vec<ResolvedMessage>,
22}
23
24impl SchemaIr {
25    /// Creates an intermediate representation from a parsed schema.
26    #[must_use]
27    pub fn from_schema(schema: &Schema) -> Self {
28        let mut ir = Self {
29            package: schema.package.clone(),
30            schema_id: schema.id,
31            schema_version: schema.version,
32            types: HashMap::new(),
33            messages: Vec::new(),
34        };
35
36        // Resolve types
37        for type_def in &schema.types {
38            let resolved = ResolvedType::from_type_def(type_def);
39            ir.types.insert(resolved.name.clone(), resolved);
40        }
41
42        // Resolve messages
43        for msg in &schema.messages {
44            ir.messages
45                .push(ResolvedMessage::from_message_def(msg, &ir.types));
46        }
47
48        ir
49    }
50
51    /// Gets a resolved type by name.
52    #[must_use]
53    pub fn get_type(&self, name: &str) -> Option<&ResolvedType> {
54        self.types.get(name)
55    }
56}
57
58/// Resolved type information.
59#[derive(Debug, Clone)]
60pub struct ResolvedType {
61    /// Type name.
62    pub name: String,
63    /// Type kind.
64    pub kind: TypeKind,
65    /// Encoded length in bytes.
66    pub encoded_length: usize,
67    /// Rust type representation.
68    pub rust_type: String,
69    /// Whether this is an array type.
70    pub is_array: bool,
71    /// Array length (if array).
72    pub array_length: Option<usize>,
73}
74
75impl ResolvedType {
76    /// Creates a resolved type from a type definition.
77    #[must_use]
78    pub fn from_type_def(type_def: &TypeDef) -> Self {
79        match type_def {
80            TypeDef::Primitive(p) => Self {
81                name: p.name.clone(),
82                kind: TypeKind::Primitive(p.primitive_type),
83                encoded_length: p.encoded_length(),
84                rust_type: if p.is_array() {
85                    format!(
86                        "[{}; {}]",
87                        p.primitive_type.rust_type(),
88                        p.length.unwrap_or(1)
89                    )
90                } else {
91                    p.primitive_type.rust_type().to_string()
92                },
93                is_array: p.is_array(),
94                array_length: p.length,
95            },
96            TypeDef::Composite(c) => Self {
97                name: c.name.clone(),
98                kind: TypeKind::Composite,
99                encoded_length: c.encoded_length(),
100                rust_type: to_pascal_case(&c.name),
101                is_array: false,
102                array_length: None,
103            },
104            TypeDef::Enum(e) => Self {
105                name: e.name.clone(),
106                kind: TypeKind::Enum(e.encoding_type),
107                encoded_length: e.encoding_type.size(),
108                rust_type: to_pascal_case(&e.name),
109                is_array: false,
110                array_length: None,
111            },
112            TypeDef::Set(s) => Self {
113                name: s.name.clone(),
114                kind: TypeKind::Set(s.encoding_type),
115                encoded_length: s.encoding_type.size(),
116                rust_type: to_pascal_case(&s.name),
117                is_array: false,
118                array_length: None,
119            },
120        }
121    }
122
123    /// Creates a resolved type for a built-in primitive.
124    #[must_use]
125    pub fn from_primitive(prim: PrimitiveType) -> Self {
126        Self {
127            name: prim.sbe_name().to_string(),
128            kind: TypeKind::Primitive(prim),
129            encoded_length: prim.size(),
130            rust_type: prim.rust_type().to_string(),
131            is_array: false,
132            array_length: None,
133        }
134    }
135}
136
137/// Type kind enumeration.
138#[derive(Debug, Clone, Copy)]
139pub enum TypeKind {
140    /// Primitive type.
141    Primitive(PrimitiveType),
142    /// Composite type.
143    Composite,
144    /// Enum type with encoding.
145    Enum(PrimitiveType),
146    /// Set (bitfield) type with encoding.
147    Set(PrimitiveType),
148}
149
150/// Resolved message information.
151#[derive(Debug, Clone)]
152pub struct ResolvedMessage {
153    /// Message name.
154    pub name: String,
155    /// Template ID.
156    pub template_id: u16,
157    /// Block length.
158    pub block_length: u16,
159    /// Resolved fields.
160    pub fields: Vec<ResolvedField>,
161    /// Resolved groups.
162    pub groups: Vec<ResolvedGroup>,
163    /// Variable data fields.
164    pub var_data: Vec<ResolvedVarData>,
165}
166
167impl ResolvedMessage {
168    /// Creates a resolved message from a message definition.
169    #[must_use]
170    pub fn from_message_def(
171        msg: &crate::messages::MessageDef,
172        types: &HashMap<String, ResolvedType>,
173    ) -> Self {
174        let fields = msg
175            .fields
176            .iter()
177            .map(|f| ResolvedField::from_field_def(f, types))
178            .collect();
179
180        let groups = msg
181            .groups
182            .iter()
183            .map(|g| ResolvedGroup::from_group_def(g, types))
184            .collect();
185
186        let var_data = msg
187            .data_fields
188            .iter()
189            .map(|d| ResolvedVarData {
190                name: d.name.clone(),
191                id: d.id,
192                type_name: d.type_name.clone(),
193            })
194            .collect();
195
196        Self {
197            name: msg.name.clone(),
198            template_id: msg.id,
199            block_length: msg.block_length,
200            fields,
201            groups,
202            var_data,
203        }
204    }
205
206    /// Returns the decoder struct name.
207    #[must_use]
208    pub fn decoder_name(&self) -> String {
209        format!("{}Decoder", self.name)
210    }
211
212    /// Returns the encoder struct name.
213    #[must_use]
214    pub fn encoder_name(&self) -> String {
215        format!("{}Encoder", self.name)
216    }
217}
218
219/// Resolved field information.
220#[derive(Debug, Clone)]
221pub struct ResolvedField {
222    /// Field name.
223    pub name: String,
224    /// Field ID.
225    pub id: u16,
226    /// Type name.
227    pub type_name: String,
228    /// Offset in bytes.
229    pub offset: usize,
230    /// Encoded length in bytes.
231    pub encoded_length: usize,
232    /// Rust type.
233    pub rust_type: String,
234    /// Getter method name.
235    pub getter_name: String,
236    /// Setter method name.
237    pub setter_name: String,
238    /// Whether the field is optional.
239    pub is_optional: bool,
240    /// Whether the field is an array.
241    pub is_array: bool,
242    /// Array length (if array).
243    pub array_length: Option<usize>,
244    /// Primitive type (if applicable).
245    pub primitive_type: Option<PrimitiveType>,
246}
247
248impl ResolvedField {
249    /// Creates a resolved field from a field definition.
250    #[must_use]
251    pub fn from_field_def(
252        field: &crate::messages::FieldDef,
253        types: &HashMap<String, ResolvedType>,
254    ) -> Self {
255        let resolved_type = types.get(&field.type_name).cloned().or_else(|| {
256            PrimitiveType::from_sbe_name(&field.type_name).map(ResolvedType::from_primitive)
257        });
258
259        let (encoded_length, rust_type, is_array, array_length, primitive_type) =
260            if let Some(rt) = &resolved_type {
261                (
262                    rt.encoded_length,
263                    rt.rust_type.clone(),
264                    rt.is_array,
265                    rt.array_length,
266                    match rt.kind {
267                        TypeKind::Primitive(p) => Some(p),
268                        _ => None,
269                    },
270                )
271            } else {
272                (field.encoded_length, "u64".to_string(), false, None, None)
273            };
274
275        Self {
276            name: field.name.clone(),
277            id: field.id,
278            type_name: field.type_name.clone(),
279            offset: field.offset,
280            encoded_length,
281            rust_type,
282            getter_name: to_snake_case(&field.name),
283            setter_name: format!("set_{}", to_snake_case(&field.name)),
284            is_optional: field.is_optional(),
285            is_array,
286            array_length,
287            primitive_type,
288        }
289    }
290}
291
292/// Resolved group information.
293#[derive(Debug, Clone)]
294pub struct ResolvedGroup {
295    /// Group name.
296    pub name: String,
297    /// Group ID.
298    pub id: u16,
299    /// Block length per entry.
300    pub block_length: u16,
301    /// Resolved fields.
302    pub fields: Vec<ResolvedField>,
303    /// Nested groups.
304    pub nested_groups: Vec<ResolvedGroup>,
305    /// Variable data fields.
306    pub var_data: Vec<ResolvedVarData>,
307}
308
309impl ResolvedGroup {
310    /// Creates a resolved group from a group definition.
311    #[must_use]
312    pub fn from_group_def(
313        group: &crate::messages::GroupDef,
314        types: &HashMap<String, ResolvedType>,
315    ) -> Self {
316        let fields = group
317            .fields
318            .iter()
319            .map(|f| ResolvedField::from_field_def(f, types))
320            .collect();
321
322        let nested_groups = group
323            .nested_groups
324            .iter()
325            .map(|g| ResolvedGroup::from_group_def(g, types))
326            .collect();
327
328        let var_data = group
329            .data_fields
330            .iter()
331            .map(|d| ResolvedVarData {
332                name: d.name.clone(),
333                id: d.id,
334                type_name: d.type_name.clone(),
335            })
336            .collect();
337
338        Self {
339            name: group.name.clone(),
340            id: group.id,
341            block_length: group.block_length,
342            fields,
343            nested_groups,
344            var_data,
345        }
346    }
347
348    /// Returns the decoder struct name.
349    #[must_use]
350    pub fn decoder_name(&self) -> String {
351        format!("{}GroupDecoder", to_pascal_case(&self.name))
352    }
353
354    /// Returns the entry decoder struct name.
355    #[must_use]
356    pub fn entry_decoder_name(&self) -> String {
357        format!("{}EntryDecoder", to_pascal_case(&self.name))
358    }
359}
360
361/// Resolved variable data field.
362#[derive(Debug, Clone)]
363pub struct ResolvedVarData {
364    /// Field name.
365    pub name: String,
366    /// Field ID.
367    pub id: u16,
368    /// Type name.
369    pub type_name: String,
370}
371
372/// Converts a string to snake_case.
373#[must_use]
374pub fn to_snake_case(s: &str) -> String {
375    let mut result = String::with_capacity(s.len() + 4);
376    for (i, c) in s.chars().enumerate() {
377        if c.is_uppercase() && i > 0 {
378            result.push('_');
379        }
380        result.push(c.to_ascii_lowercase());
381    }
382    result
383}
384
385/// Converts a string to PascalCase.
386#[must_use]
387pub fn to_pascal_case(s: &str) -> String {
388    let mut result = String::with_capacity(s.len());
389    let mut capitalize_next = true;
390
391    for c in s.chars() {
392        if c == '_' || c == '-' {
393            capitalize_next = true;
394        } else if capitalize_next {
395            result.push(c.to_ascii_uppercase());
396            capitalize_next = false;
397        } else {
398            result.push(c);
399        }
400    }
401
402    result
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::parser::parse_schema;
409
410    #[test]
411    fn test_to_snake_case() {
412        assert_eq!(to_snake_case("clOrdId"), "cl_ord_id");
413        assert_eq!(to_snake_case("symbol"), "symbol");
414        assert_eq!(to_snake_case("MDEntryPx"), "m_d_entry_px");
415    }
416
417    #[test]
418    fn test_to_pascal_case() {
419        assert_eq!(to_pascal_case("message_header"), "MessageHeader");
420        assert_eq!(to_pascal_case("side"), "Side");
421        assert_eq!(to_pascal_case("order-type"), "OrderType");
422    }
423
424    #[test]
425    fn test_schema_ir_from_schema() {
426        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
427<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
428                   package="test" id="1" version="1" byteOrder="littleEndian">
429    <types>
430        <type name="uint64" primitiveType="uint64"/>
431    </types>
432    <sbe:message name="Test" id="1" blockLength="8">
433        <field name="value" id="1" type="uint64" offset="0"/>
434    </sbe:message>
435</sbe:messageSchema>"#;
436
437        let schema = parse_schema(xml).expect("Failed to parse");
438        let ir = SchemaIr::from_schema(&schema);
439
440        assert_eq!(ir.package, "test");
441        assert_eq!(ir.schema_id, 1);
442        assert_eq!(ir.schema_version, 1);
443        assert!(!ir.messages.is_empty());
444    }
445
446    #[test]
447    fn test_resolved_type_from_primitive() {
448        let resolved = ResolvedType::from_primitive(PrimitiveType::Uint64);
449        assert_eq!(resolved.name, "uint64");
450        assert_eq!(resolved.encoded_length, 8);
451        assert_eq!(resolved.rust_type, "u64");
452        assert!(!resolved.is_array);
453    }
454
455    #[test]
456    fn test_type_kind_debug() {
457        let kind = TypeKind::Primitive(PrimitiveType::Int32);
458        let debug_str = format!("{:?}", kind);
459        assert!(debug_str.contains("Primitive"));
460
461        let kind = TypeKind::Composite;
462        let debug_str = format!("{:?}", kind);
463        assert!(debug_str.contains("Composite"));
464
465        let kind = TypeKind::Enum(PrimitiveType::Uint8);
466        let debug_str = format!("{:?}", kind);
467        assert!(debug_str.contains("Enum"));
468
469        let kind = TypeKind::Set(PrimitiveType::Uint16);
470        let debug_str = format!("{:?}", kind);
471        assert!(debug_str.contains("Set"));
472    }
473
474    #[test]
475    fn test_resolved_type_clone() {
476        let resolved = ResolvedType::from_primitive(PrimitiveType::Float);
477        let cloned = resolved.clone();
478        assert_eq!(resolved.name, cloned.name);
479        assert_eq!(resolved.encoded_length, cloned.encoded_length);
480    }
481
482    #[test]
483    fn test_schema_ir_with_enum() {
484        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
485<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
486                   package="test" id="1" version="1" byteOrder="littleEndian">
487    <types>
488        <enum name="Side" encodingType="uint8">
489            <validValue name="Buy">1</validValue>
490            <validValue name="Sell">2</validValue>
491        </enum>
492    </types>
493    <sbe:message name="Test" id="1" blockLength="1">
494        <field name="side" id="1" type="Side" offset="0"/>
495    </sbe:message>
496</sbe:messageSchema>"#;
497
498        let schema = parse_schema(xml).expect("Failed to parse");
499        let ir = SchemaIr::from_schema(&schema);
500
501        assert!(ir.types.contains_key("Side"));
502    }
503
504    #[test]
505    fn test_schema_ir_with_composite() {
506        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
507<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
508                   package="test" id="1" version="1" byteOrder="littleEndian">
509    <types>
510        <composite name="Decimal">
511            <type name="mantissa" primitiveType="int64"/>
512            <type name="exponent" primitiveType="int8"/>
513        </composite>
514    </types>
515    <sbe:message name="Test" id="1" blockLength="9">
516        <field name="price" id="1" type="Decimal" offset="0"/>
517    </sbe:message>
518</sbe:messageSchema>"#;
519
520        let schema = parse_schema(xml).expect("Failed to parse");
521        let ir = SchemaIr::from_schema(&schema);
522
523        assert!(ir.types.contains_key("Decimal"));
524    }
525}