Skip to main content

ironsbe_schema/
ir.rs

1//! Intermediate representation for code generation.
2//!
3//! This module provides a flattened, resolved representation of the schema
4//! that is easier to use for code generation.
5
6use crate::types::{PrimitiveType, Schema, TypeDef};
7use std::collections::HashMap;
8
9/// Intermediate representation of a schema for code generation.
10#[derive(Debug, Clone)]
11pub struct SchemaIr {
12    /// Package name.
13    pub package: String,
14    /// Schema ID.
15    pub schema_id: u16,
16    /// Schema version.
17    pub schema_version: u16,
18    /// Resolved types with their full information.
19    pub types: HashMap<String, ResolvedType>,
20    /// Messages with resolved field types.
21    pub messages: Vec<ResolvedMessage>,
22}
23
24impl SchemaIr {
25    /// Creates an intermediate representation from a parsed schema.
26    #[must_use]
27    pub fn from_schema(schema: &Schema) -> Self {
28        let mut ir = Self {
29            package: schema.package.clone(),
30            schema_id: schema.id,
31            schema_version: schema.version,
32            types: HashMap::new(),
33            messages: Vec::new(),
34        };
35
36        // Resolve types
37        for type_def in &schema.types {
38            let resolved = ResolvedType::from_type_def(type_def);
39            ir.types.insert(resolved.name.clone(), resolved);
40        }
41
42        // Resolve messages
43        for msg in &schema.messages {
44            ir.messages
45                .push(ResolvedMessage::from_message_def(msg, &ir.types));
46        }
47
48        ir
49    }
50
51    /// Gets a resolved type by name.
52    #[must_use]
53    pub fn get_type(&self, name: &str) -> Option<&ResolvedType> {
54        self.types.get(name)
55    }
56}
57
58/// Resolved type information.
59#[derive(Debug, Clone)]
60pub struct ResolvedType {
61    /// Type name.
62    pub name: String,
63    /// Type kind.
64    pub kind: TypeKind,
65    /// Encoded length in bytes.
66    pub encoded_length: usize,
67    /// Rust type representation.
68    pub rust_type: String,
69    /// Whether this is an array type.
70    pub is_array: bool,
71    /// Array length (if array).
72    pub array_length: Option<usize>,
73}
74
75impl ResolvedType {
76    /// Creates a resolved type from a type definition.
77    #[must_use]
78    pub fn from_type_def(type_def: &TypeDef) -> Self {
79        match type_def {
80            TypeDef::Primitive(p) => Self {
81                name: p.name.clone(),
82                kind: TypeKind::Primitive(p.primitive_type),
83                encoded_length: p.encoded_length(),
84                rust_type: if p.is_array() {
85                    format!(
86                        "[{}; {}]",
87                        p.primitive_type.rust_type(),
88                        p.length.unwrap_or(1)
89                    )
90                } else {
91                    p.primitive_type.rust_type().to_string()
92                },
93                is_array: p.is_array(),
94                array_length: p.length,
95            },
96            TypeDef::Composite(c) => {
97                let mut offset = 0usize;
98                let fields = c
99                    .fields
100                    .iter()
101                    .filter_map(|f| {
102                        let field_offset = offset;
103                        offset += f.encoded_length;
104                        f.primitive_type.map(|prim| CompositeFieldInfo {
105                            name: f.name.clone(),
106                            primitive_type: prim,
107                            offset: field_offset,
108                            encoded_length: f.encoded_length,
109                        })
110                    })
111                    .collect();
112                Self {
113                    name: c.name.clone(),
114                    kind: TypeKind::Composite { fields },
115                    encoded_length: c.encoded_length(),
116                    rust_type: to_pascal_case(&c.name),
117                    is_array: false,
118                    array_length: None,
119                }
120            }
121            TypeDef::Enum(e) => {
122                let variants: Vec<EnumVariant> = e
123                    .valid_values
124                    .iter()
125                    .filter_map(|v| {
126                        // Use signed parsing for signed types, unsigned for others
127                        let value = if e.encoding_type.is_signed() {
128                            v.as_i64()
129                        } else {
130                            v.as_u64().map(|u| u as i64)
131                        };
132                        value.map(|val| EnumVariant {
133                            name: v.name.clone(),
134                            value: val,
135                        })
136                    })
137                    .collect();
138                Self {
139                    name: e.name.clone(),
140                    kind: TypeKind::Enum {
141                        encoding: e.encoding_type,
142                        variants,
143                    },
144                    encoded_length: e.encoding_type.size(),
145                    rust_type: to_pascal_case(&e.name),
146                    is_array: false,
147                    array_length: None,
148                }
149            }
150            TypeDef::Set(s) => {
151                let choices = s
152                    .choices
153                    .iter()
154                    .map(|c| SetVariant {
155                        name: c.name.clone(),
156                        bit_position: c.bit_position,
157                    })
158                    .collect();
159                Self {
160                    name: s.name.clone(),
161                    kind: TypeKind::Set {
162                        encoding: s.encoding_type,
163                        choices,
164                    },
165                    encoded_length: s.encoding_type.size(),
166                    rust_type: to_pascal_case(&s.name),
167                    is_array: false,
168                    array_length: None,
169                }
170            }
171        }
172    }
173
174    /// Creates a resolved type for a built-in primitive.
175    #[must_use]
176    pub fn from_primitive(prim: PrimitiveType) -> Self {
177        Self {
178            name: prim.sbe_name().to_string(),
179            kind: TypeKind::Primitive(prim),
180            encoded_length: prim.size(),
181            rust_type: prim.rust_type().to_string(),
182            is_array: false,
183            array_length: None,
184        }
185    }
186}
187
188/// Enum variant with name and discriminant value.
189#[derive(Debug, Clone)]
190pub struct EnumVariant {
191    /// Variant name (will be converted to PascalCase).
192    pub name: String,
193    /// Discriminant value (i64 to support signed encodings).
194    pub value: i64,
195}
196
197/// Set choice with name and bit position.
198#[derive(Debug, Clone)]
199pub struct SetVariant {
200    /// Choice name (will be converted to PascalCase).
201    pub name: String,
202    /// Bit position (0-based).
203    pub bit_position: u8,
204}
205
206/// Composite field information for code generation.
207#[derive(Debug, Clone)]
208pub struct CompositeFieldInfo {
209    /// Field name.
210    pub name: String,
211    /// Primitive type of this field.
212    pub primitive_type: PrimitiveType,
213    /// Offset within the composite.
214    pub offset: usize,
215    /// Encoded length in bytes.
216    pub encoded_length: usize,
217}
218
219/// Type kind enumeration.
220#[derive(Debug, Clone)]
221pub enum TypeKind {
222    /// Primitive type.
223    Primitive(PrimitiveType),
224    /// Composite type with fields.
225    Composite {
226        /// Fields in the composite.
227        fields: Vec<CompositeFieldInfo>,
228    },
229    /// Enum type with encoding and variants.
230    Enum {
231        /// Underlying encoding type.
232        encoding: PrimitiveType,
233        /// Enum variants with discriminant values.
234        variants: Vec<EnumVariant>,
235    },
236    /// Set (bitfield) type with encoding and choices.
237    Set {
238        /// Underlying encoding type.
239        encoding: PrimitiveType,
240        /// Bit choices.
241        choices: Vec<SetVariant>,
242    },
243}
244
245/// Resolved message information.
246#[derive(Debug, Clone)]
247pub struct ResolvedMessage {
248    /// Message name.
249    pub name: String,
250    /// Template ID.
251    pub template_id: u16,
252    /// Block length.
253    pub block_length: u16,
254    /// Resolved fields.
255    pub fields: Vec<ResolvedField>,
256    /// Resolved groups.
257    pub groups: Vec<ResolvedGroup>,
258    /// Variable data fields.
259    pub var_data: Vec<ResolvedVarData>,
260}
261
262impl ResolvedMessage {
263    /// Creates a resolved message from a message definition.
264    #[must_use]
265    pub fn from_message_def(
266        msg: &crate::messages::MessageDef,
267        types: &HashMap<String, ResolvedType>,
268    ) -> Self {
269        let fields = msg
270            .fields
271            .iter()
272            .map(|f| ResolvedField::from_field_def(f, types))
273            .collect();
274
275        let groups = msg
276            .groups
277            .iter()
278            .map(|g| ResolvedGroup::from_group_def(g, types))
279            .collect();
280
281        let var_data = msg
282            .data_fields
283            .iter()
284            .map(|d| ResolvedVarData {
285                name: d.name.clone(),
286                id: d.id,
287                type_name: d.type_name.clone(),
288            })
289            .collect();
290
291        Self {
292            name: msg.name.clone(),
293            template_id: msg.id,
294            block_length: msg.block_length,
295            fields,
296            groups,
297            var_data,
298        }
299    }
300
301    /// Returns the decoder struct name.
302    #[must_use]
303    pub fn decoder_name(&self) -> String {
304        format!("{}Decoder", self.name)
305    }
306
307    /// Returns the encoder struct name.
308    #[must_use]
309    pub fn encoder_name(&self) -> String {
310        format!("{}Encoder", self.name)
311    }
312}
313
314/// Resolved field information.
315#[derive(Debug, Clone)]
316pub struct ResolvedField {
317    /// Field name.
318    pub name: String,
319    /// Field ID.
320    pub id: u16,
321    /// Type name.
322    pub type_name: String,
323    /// Offset in bytes.
324    pub offset: usize,
325    /// Encoded length in bytes.
326    pub encoded_length: usize,
327    /// Rust type.
328    pub rust_type: String,
329    /// Getter method name.
330    pub getter_name: String,
331    /// Setter method name.
332    pub setter_name: String,
333    /// Whether the field is optional.
334    pub is_optional: bool,
335    /// Whether the field is an array.
336    pub is_array: bool,
337    /// Array length (if array).
338    pub array_length: Option<usize>,
339    /// Primitive type (if applicable).
340    pub primitive_type: Option<PrimitiveType>,
341}
342
343impl ResolvedField {
344    /// Creates a resolved field from a field definition.
345    #[must_use]
346    pub fn from_field_def(
347        field: &crate::messages::FieldDef,
348        types: &HashMap<String, ResolvedType>,
349    ) -> Self {
350        let resolved_type = types.get(&field.type_name).cloned().or_else(|| {
351            PrimitiveType::from_sbe_name(&field.type_name).map(ResolvedType::from_primitive)
352        });
353
354        let (encoded_length, rust_type, is_array, array_length, primitive_type) =
355            if let Some(rt) = &resolved_type {
356                (
357                    rt.encoded_length,
358                    rt.rust_type.clone(),
359                    rt.is_array,
360                    rt.array_length,
361                    match &rt.kind {
362                        TypeKind::Primitive(p) => Some(*p),
363                        _ => None,
364                    },
365                )
366            } else {
367                (field.encoded_length, "u64".to_string(), false, None, None)
368            };
369
370        Self {
371            name: field.name.clone(),
372            id: field.id,
373            type_name: field.type_name.clone(),
374            offset: field.offset,
375            encoded_length,
376            rust_type,
377            getter_name: to_snake_case(&field.name),
378            setter_name: format!("set_{}", to_snake_case(&field.name)),
379            is_optional: field.is_optional(),
380            is_array,
381            array_length,
382            primitive_type,
383        }
384    }
385}
386
387/// Resolved group information.
388#[derive(Debug, Clone)]
389pub struct ResolvedGroup {
390    /// Group name.
391    pub name: String,
392    /// Group ID.
393    pub id: u16,
394    /// Block length per entry.
395    pub block_length: u16,
396    /// Resolved fields.
397    pub fields: Vec<ResolvedField>,
398    /// Nested groups.
399    pub nested_groups: Vec<ResolvedGroup>,
400    /// Variable data fields.
401    pub var_data: Vec<ResolvedVarData>,
402}
403
404impl ResolvedGroup {
405    /// Creates a resolved group from a group definition.
406    #[must_use]
407    pub fn from_group_def(
408        group: &crate::messages::GroupDef,
409        types: &HashMap<String, ResolvedType>,
410    ) -> Self {
411        let fields = group
412            .fields
413            .iter()
414            .map(|f| ResolvedField::from_field_def(f, types))
415            .collect();
416
417        let nested_groups = group
418            .nested_groups
419            .iter()
420            .map(|g| ResolvedGroup::from_group_def(g, types))
421            .collect();
422
423        let var_data = group
424            .data_fields
425            .iter()
426            .map(|d| ResolvedVarData {
427                name: d.name.clone(),
428                id: d.id,
429                type_name: d.type_name.clone(),
430            })
431            .collect();
432
433        Self {
434            name: group.name.clone(),
435            id: group.id,
436            block_length: group.block_length,
437            fields,
438            nested_groups,
439            var_data,
440        }
441    }
442
443    /// Returns the decoder struct name.
444    #[must_use]
445    pub fn decoder_name(&self) -> String {
446        format!("{}GroupDecoder", to_pascal_case(&self.name))
447    }
448
449    /// Returns the entry decoder struct name.
450    #[must_use]
451    pub fn entry_decoder_name(&self) -> String {
452        format!("{}EntryDecoder", to_pascal_case(&self.name))
453    }
454}
455
456/// Resolved variable data field.
457#[derive(Debug, Clone)]
458pub struct ResolvedVarData {
459    /// Field name.
460    pub name: String,
461    /// Field ID.
462    pub id: u16,
463    /// Type name.
464    pub type_name: String,
465}
466
467/// Converts a string to snake_case.
468/// Treats `-` as a word separator and normalizes to `_`.
469#[must_use]
470pub fn to_snake_case(s: &str) -> String {
471    let mut result = String::with_capacity(s.len() + 4);
472    let mut last_was_separator = false;
473    for (i, c) in s.chars().enumerate() {
474        if c == '-' {
475            if !last_was_separator {
476                result.push('_');
477            }
478            last_was_separator = true;
479        } else if c.is_uppercase() && i > 0 && !last_was_separator {
480            result.push('_');
481            result.push(c.to_ascii_lowercase());
482            last_was_separator = false;
483        } else {
484            result.push(c.to_ascii_lowercase());
485            last_was_separator = false;
486        }
487    }
488    result
489}
490
491/// Converts a string to SCREAMING_SNAKE_CASE.
492/// Treats `-` as a word separator and normalizes to `_`.
493#[must_use]
494pub fn to_screaming_snake_case(s: &str) -> String {
495    let mut result = String::with_capacity(s.len() + 4);
496    let mut last_was_separator = false;
497    for (i, c) in s.chars().enumerate() {
498        if c == '-' {
499            if !last_was_separator {
500                result.push('_');
501            }
502            last_was_separator = true;
503        } else if c.is_uppercase() && i > 0 && !last_was_separator {
504            result.push('_');
505            result.push(c.to_ascii_uppercase());
506            last_was_separator = false;
507        } else {
508            result.push(c.to_ascii_uppercase());
509            last_was_separator = false;
510        }
511    }
512    result
513}
514
515/// Converts a string to PascalCase.
516#[must_use]
517pub fn to_pascal_case(s: &str) -> String {
518    let mut result = String::with_capacity(s.len());
519    let mut capitalize_next = true;
520
521    for c in s.chars() {
522        if c == '_' || c == '-' {
523            capitalize_next = true;
524        } else if capitalize_next {
525            result.push(c.to_ascii_uppercase());
526            capitalize_next = false;
527        } else {
528            result.push(c);
529        }
530    }
531
532    result
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use crate::parser::parse_schema;
539
540    #[test]
541    fn test_to_snake_case() {
542        assert_eq!(to_snake_case("clOrdId"), "cl_ord_id");
543        assert_eq!(to_snake_case("symbol"), "symbol");
544        assert_eq!(to_snake_case("MDEntryPx"), "m_d_entry_px");
545        assert_eq!(to_snake_case("some-hyphenated"), "some_hyphenated");
546        // Hyphen followed by uppercase should not produce double underscore
547        assert_eq!(to_snake_case("some-Hyphen"), "some_hyphen");
548    }
549
550    #[test]
551    fn test_to_screaming_snake_case() {
552        assert_eq!(to_screaming_snake_case("clOrdId"), "CL_ORD_ID");
553        assert_eq!(to_screaming_snake_case("symbol"), "SYMBOL");
554        assert_eq!(
555            to_screaming_snake_case("some-hyphenated"),
556            "SOME_HYPHENATED"
557        );
558        // Hyphen followed by uppercase should not produce double underscore
559        assert_eq!(to_screaming_snake_case("some-Hyphen"), "SOME_HYPHEN");
560    }
561
562    #[test]
563    fn test_to_pascal_case() {
564        assert_eq!(to_pascal_case("message_header"), "MessageHeader");
565        assert_eq!(to_pascal_case("side"), "Side");
566        assert_eq!(to_pascal_case("order-type"), "OrderType");
567    }
568
569    #[test]
570    fn test_schema_ir_from_schema() {
571        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
572<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
573                   package="test" id="1" version="1" byteOrder="littleEndian">
574    <types>
575        <type name="uint64" primitiveType="uint64"/>
576    </types>
577    <sbe:message name="Test" id="1" blockLength="8">
578        <field name="value" id="1" type="uint64" offset="0"/>
579    </sbe:message>
580</sbe:messageSchema>"#;
581
582        let schema = parse_schema(xml).expect("Failed to parse");
583        let ir = SchemaIr::from_schema(&schema);
584
585        assert_eq!(ir.package, "test");
586        assert_eq!(ir.schema_id, 1);
587        assert_eq!(ir.schema_version, 1);
588        assert!(!ir.messages.is_empty());
589    }
590
591    #[test]
592    fn test_resolved_type_from_primitive() {
593        let resolved = ResolvedType::from_primitive(PrimitiveType::Uint64);
594        assert_eq!(resolved.name, "uint64");
595        assert_eq!(resolved.encoded_length, 8);
596        assert_eq!(resolved.rust_type, "u64");
597        assert!(!resolved.is_array);
598    }
599
600    #[test]
601    fn test_type_kind_debug() {
602        let kind = TypeKind::Primitive(PrimitiveType::Int32);
603        let debug_str = format!("{:?}", kind);
604        assert!(debug_str.contains("Primitive"));
605
606        let kind = TypeKind::Composite { fields: vec![] };
607        let debug_str = format!("{:?}", kind);
608        assert!(debug_str.contains("Composite"));
609
610        let kind = TypeKind::Enum {
611            encoding: PrimitiveType::Uint8,
612            variants: vec![],
613        };
614        let debug_str = format!("{:?}", kind);
615        assert!(debug_str.contains("Enum"));
616
617        let kind = TypeKind::Set {
618            encoding: PrimitiveType::Uint16,
619            choices: vec![],
620        };
621        let debug_str = format!("{:?}", kind);
622        assert!(debug_str.contains("Set"));
623    }
624
625    #[test]
626    fn test_resolved_type_clone() {
627        let resolved = ResolvedType::from_primitive(PrimitiveType::Float);
628        let cloned = resolved.clone();
629        assert_eq!(resolved.name, cloned.name);
630        assert_eq!(resolved.encoded_length, cloned.encoded_length);
631    }
632
633    #[test]
634    fn test_schema_ir_with_enum() {
635        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
636<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
637                   package="test" id="1" version="1" byteOrder="littleEndian">
638    <types>
639        <enum name="Side" encodingType="uint8">
640            <validValue name="Buy">1</validValue>
641            <validValue name="Sell">2</validValue>
642        </enum>
643    </types>
644    <sbe:message name="Test" id="1" blockLength="1">
645        <field name="side" id="1" type="Side" offset="0"/>
646    </sbe:message>
647</sbe:messageSchema>"#;
648
649        let schema = parse_schema(xml).expect("Failed to parse");
650        let ir = SchemaIr::from_schema(&schema);
651
652        assert!(ir.types.contains_key("Side"));
653    }
654
655    #[test]
656    fn test_schema_ir_with_composite() {
657        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
658<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
659                   package="test" id="1" version="1" byteOrder="littleEndian">
660    <types>
661        <composite name="Decimal">
662            <type name="mantissa" primitiveType="int64"/>
663            <type name="exponent" primitiveType="int8"/>
664        </composite>
665    </types>
666    <sbe:message name="Test" id="1" blockLength="9">
667        <field name="price" id="1" type="Decimal" offset="0"/>
668    </sbe:message>
669</sbe:messageSchema>"#;
670
671        let schema = parse_schema(xml).expect("Failed to parse");
672        let ir = SchemaIr::from_schema(&schema);
673
674        assert!(ir.types.contains_key("Decimal"));
675    }
676}