1use crate::types::{PrimitiveType, Schema, TypeDef};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct SchemaIr {
12 pub package: String,
14 pub schema_id: u16,
16 pub schema_version: u16,
18 pub types: HashMap<String, ResolvedType>,
20 pub messages: Vec<ResolvedMessage>,
22}
23
24impl SchemaIr {
25 #[must_use]
27 pub fn from_schema(schema: &Schema) -> Self {
28 let mut ir = Self {
29 package: schema.package.clone(),
30 schema_id: schema.id,
31 schema_version: schema.version,
32 types: HashMap::new(),
33 messages: Vec::new(),
34 };
35
36 for type_def in &schema.types {
38 let resolved = ResolvedType::from_type_def(type_def);
39 ir.types.insert(resolved.name.clone(), resolved);
40 }
41
42 for msg in &schema.messages {
44 ir.messages
45 .push(ResolvedMessage::from_message_def(msg, &ir.types));
46 }
47
48 ir
49 }
50
51 #[must_use]
53 pub fn get_type(&self, name: &str) -> Option<&ResolvedType> {
54 self.types.get(name)
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct ResolvedType {
61 pub name: String,
63 pub kind: TypeKind,
65 pub encoded_length: usize,
67 pub rust_type: String,
69 pub is_array: bool,
71 pub array_length: Option<usize>,
73}
74
75impl ResolvedType {
76 #[must_use]
78 pub fn from_type_def(type_def: &TypeDef) -> Self {
79 match type_def {
80 TypeDef::Primitive(p) => Self {
81 name: p.name.clone(),
82 kind: TypeKind::Primitive(p.primitive_type),
83 encoded_length: p.encoded_length(),
84 rust_type: if p.is_array() {
85 format!(
86 "[{}; {}]",
87 p.primitive_type.rust_type(),
88 p.length.unwrap_or(1)
89 )
90 } else {
91 p.primitive_type.rust_type().to_string()
92 },
93 is_array: p.is_array(),
94 array_length: p.length,
95 },
96 TypeDef::Composite(c) => {
97 let mut offset = 0usize;
98 let fields = c
99 .fields
100 .iter()
101 .filter_map(|f| {
102 let field_offset = offset;
103 offset += f.encoded_length;
104 f.primitive_type.map(|prim| CompositeFieldInfo {
105 name: f.name.clone(),
106 primitive_type: prim,
107 offset: field_offset,
108 encoded_length: f.encoded_length,
109 })
110 })
111 .collect();
112 Self {
113 name: c.name.clone(),
114 kind: TypeKind::Composite { fields },
115 encoded_length: c.encoded_length(),
116 rust_type: to_pascal_case(&c.name),
117 is_array: false,
118 array_length: None,
119 }
120 }
121 TypeDef::Enum(e) => {
122 let variants: Vec<EnumVariant> = e
123 .valid_values
124 .iter()
125 .filter_map(|v| {
126 let value = if e.encoding_type.is_signed() {
128 v.as_i64()
129 } else {
130 v.as_u64().map(|u| u as i64)
131 };
132 value.map(|val| EnumVariant {
133 name: v.name.clone(),
134 value: val,
135 })
136 })
137 .collect();
138 Self {
139 name: e.name.clone(),
140 kind: TypeKind::Enum {
141 encoding: e.encoding_type,
142 variants,
143 },
144 encoded_length: e.encoding_type.size(),
145 rust_type: to_pascal_case(&e.name),
146 is_array: false,
147 array_length: None,
148 }
149 }
150 TypeDef::Set(s) => {
151 let choices = s
152 .choices
153 .iter()
154 .map(|c| SetVariant {
155 name: c.name.clone(),
156 bit_position: c.bit_position,
157 })
158 .collect();
159 Self {
160 name: s.name.clone(),
161 kind: TypeKind::Set {
162 encoding: s.encoding_type,
163 choices,
164 },
165 encoded_length: s.encoding_type.size(),
166 rust_type: to_pascal_case(&s.name),
167 is_array: false,
168 array_length: None,
169 }
170 }
171 }
172 }
173
174 #[must_use]
176 pub fn from_primitive(prim: PrimitiveType) -> Self {
177 Self {
178 name: prim.sbe_name().to_string(),
179 kind: TypeKind::Primitive(prim),
180 encoded_length: prim.size(),
181 rust_type: prim.rust_type().to_string(),
182 is_array: false,
183 array_length: None,
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct EnumVariant {
191 pub name: String,
193 pub value: i64,
195}
196
197#[derive(Debug, Clone)]
199pub struct SetVariant {
200 pub name: String,
202 pub bit_position: u8,
204}
205
206#[derive(Debug, Clone)]
208pub struct CompositeFieldInfo {
209 pub name: String,
211 pub primitive_type: PrimitiveType,
213 pub offset: usize,
215 pub encoded_length: usize,
217}
218
219#[derive(Debug, Clone)]
221pub enum TypeKind {
222 Primitive(PrimitiveType),
224 Composite {
226 fields: Vec<CompositeFieldInfo>,
228 },
229 Enum {
231 encoding: PrimitiveType,
233 variants: Vec<EnumVariant>,
235 },
236 Set {
238 encoding: PrimitiveType,
240 choices: Vec<SetVariant>,
242 },
243}
244
245#[derive(Debug, Clone)]
247pub struct ResolvedMessage {
248 pub name: String,
250 pub template_id: u16,
252 pub block_length: u16,
254 pub fields: Vec<ResolvedField>,
256 pub groups: Vec<ResolvedGroup>,
258 pub var_data: Vec<ResolvedVarData>,
260}
261
262impl ResolvedMessage {
263 #[must_use]
265 pub fn from_message_def(
266 msg: &crate::messages::MessageDef,
267 types: &HashMap<String, ResolvedType>,
268 ) -> Self {
269 let fields = msg
270 .fields
271 .iter()
272 .map(|f| ResolvedField::from_field_def(f, types))
273 .collect();
274
275 let groups = msg
276 .groups
277 .iter()
278 .map(|g| ResolvedGroup::from_group_def(g, types))
279 .collect();
280
281 let var_data = msg
282 .data_fields
283 .iter()
284 .map(|d| ResolvedVarData {
285 name: d.name.clone(),
286 id: d.id,
287 type_name: d.type_name.clone(),
288 })
289 .collect();
290
291 Self {
292 name: msg.name.clone(),
293 template_id: msg.id,
294 block_length: msg.block_length,
295 fields,
296 groups,
297 var_data,
298 }
299 }
300
301 #[must_use]
303 pub fn decoder_name(&self) -> String {
304 format!("{}Decoder", self.name)
305 }
306
307 #[must_use]
309 pub fn encoder_name(&self) -> String {
310 format!("{}Encoder", self.name)
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct ResolvedField {
317 pub name: String,
319 pub id: u16,
321 pub type_name: String,
323 pub offset: usize,
325 pub encoded_length: usize,
327 pub rust_type: String,
329 pub getter_name: String,
331 pub setter_name: String,
333 pub is_optional: bool,
335 pub is_array: bool,
337 pub array_length: Option<usize>,
339 pub primitive_type: Option<PrimitiveType>,
341}
342
343impl ResolvedField {
344 #[must_use]
346 pub fn from_field_def(
347 field: &crate::messages::FieldDef,
348 types: &HashMap<String, ResolvedType>,
349 ) -> Self {
350 let resolved_type = types.get(&field.type_name).cloned().or_else(|| {
351 PrimitiveType::from_sbe_name(&field.type_name).map(ResolvedType::from_primitive)
352 });
353
354 let (encoded_length, rust_type, is_array, array_length, primitive_type) =
355 if let Some(rt) = &resolved_type {
356 (
357 rt.encoded_length,
358 rt.rust_type.clone(),
359 rt.is_array,
360 rt.array_length,
361 match &rt.kind {
362 TypeKind::Primitive(p) => Some(*p),
363 _ => None,
364 },
365 )
366 } else {
367 (field.encoded_length, "u64".to_string(), false, None, None)
368 };
369
370 Self {
371 name: field.name.clone(),
372 id: field.id,
373 type_name: field.type_name.clone(),
374 offset: field.offset,
375 encoded_length,
376 rust_type,
377 getter_name: to_snake_case(&field.name),
378 setter_name: format!("set_{}", to_snake_case(&field.name)),
379 is_optional: field.is_optional(),
380 is_array,
381 array_length,
382 primitive_type,
383 }
384 }
385}
386
387#[derive(Debug, Clone)]
389pub struct ResolvedGroup {
390 pub name: String,
392 pub id: u16,
394 pub block_length: u16,
396 pub fields: Vec<ResolvedField>,
398 pub nested_groups: Vec<ResolvedGroup>,
400 pub var_data: Vec<ResolvedVarData>,
402}
403
404impl ResolvedGroup {
405 #[must_use]
407 pub fn from_group_def(
408 group: &crate::messages::GroupDef,
409 types: &HashMap<String, ResolvedType>,
410 ) -> Self {
411 let fields = group
412 .fields
413 .iter()
414 .map(|f| ResolvedField::from_field_def(f, types))
415 .collect();
416
417 let nested_groups = group
418 .nested_groups
419 .iter()
420 .map(|g| ResolvedGroup::from_group_def(g, types))
421 .collect();
422
423 let var_data = group
424 .data_fields
425 .iter()
426 .map(|d| ResolvedVarData {
427 name: d.name.clone(),
428 id: d.id,
429 type_name: d.type_name.clone(),
430 })
431 .collect();
432
433 Self {
434 name: group.name.clone(),
435 id: group.id,
436 block_length: group.block_length,
437 fields,
438 nested_groups,
439 var_data,
440 }
441 }
442
443 #[must_use]
445 pub fn decoder_name(&self) -> String {
446 format!("{}GroupDecoder", to_pascal_case(&self.name))
447 }
448
449 #[must_use]
451 pub fn entry_decoder_name(&self) -> String {
452 format!("{}EntryDecoder", to_pascal_case(&self.name))
453 }
454}
455
456#[derive(Debug, Clone)]
458pub struct ResolvedVarData {
459 pub name: String,
461 pub id: u16,
463 pub type_name: String,
465}
466
467#[must_use]
470pub fn to_snake_case(s: &str) -> String {
471 let mut result = String::with_capacity(s.len() + 4);
472 let mut last_was_separator = false;
473 for (i, c) in s.chars().enumerate() {
474 if c == '-' {
475 if !last_was_separator {
476 result.push('_');
477 }
478 last_was_separator = true;
479 } else if c.is_uppercase() && i > 0 && !last_was_separator {
480 result.push('_');
481 result.push(c.to_ascii_lowercase());
482 last_was_separator = false;
483 } else {
484 result.push(c.to_ascii_lowercase());
485 last_was_separator = false;
486 }
487 }
488 result
489}
490
491#[must_use]
494pub fn to_screaming_snake_case(s: &str) -> String {
495 let mut result = String::with_capacity(s.len() + 4);
496 let mut last_was_separator = false;
497 for (i, c) in s.chars().enumerate() {
498 if c == '-' {
499 if !last_was_separator {
500 result.push('_');
501 }
502 last_was_separator = true;
503 } else if c.is_uppercase() && i > 0 && !last_was_separator {
504 result.push('_');
505 result.push(c.to_ascii_uppercase());
506 last_was_separator = false;
507 } else {
508 result.push(c.to_ascii_uppercase());
509 last_was_separator = false;
510 }
511 }
512 result
513}
514
515#[must_use]
517pub fn to_pascal_case(s: &str) -> String {
518 let mut result = String::with_capacity(s.len());
519 let mut capitalize_next = true;
520
521 for c in s.chars() {
522 if c == '_' || c == '-' {
523 capitalize_next = true;
524 } else if capitalize_next {
525 result.push(c.to_ascii_uppercase());
526 capitalize_next = false;
527 } else {
528 result.push(c);
529 }
530 }
531
532 result
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use crate::parser::parse_schema;
539
540 #[test]
541 fn test_to_snake_case() {
542 assert_eq!(to_snake_case("clOrdId"), "cl_ord_id");
543 assert_eq!(to_snake_case("symbol"), "symbol");
544 assert_eq!(to_snake_case("MDEntryPx"), "m_d_entry_px");
545 assert_eq!(to_snake_case("some-hyphenated"), "some_hyphenated");
546 assert_eq!(to_snake_case("some-Hyphen"), "some_hyphen");
548 }
549
550 #[test]
551 fn test_to_screaming_snake_case() {
552 assert_eq!(to_screaming_snake_case("clOrdId"), "CL_ORD_ID");
553 assert_eq!(to_screaming_snake_case("symbol"), "SYMBOL");
554 assert_eq!(
555 to_screaming_snake_case("some-hyphenated"),
556 "SOME_HYPHENATED"
557 );
558 assert_eq!(to_screaming_snake_case("some-Hyphen"), "SOME_HYPHEN");
560 }
561
562 #[test]
563 fn test_to_pascal_case() {
564 assert_eq!(to_pascal_case("message_header"), "MessageHeader");
565 assert_eq!(to_pascal_case("side"), "Side");
566 assert_eq!(to_pascal_case("order-type"), "OrderType");
567 }
568
569 #[test]
570 fn test_schema_ir_from_schema() {
571 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
572<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
573 package="test" id="1" version="1" byteOrder="littleEndian">
574 <types>
575 <type name="uint64" primitiveType="uint64"/>
576 </types>
577 <sbe:message name="Test" id="1" blockLength="8">
578 <field name="value" id="1" type="uint64" offset="0"/>
579 </sbe:message>
580</sbe:messageSchema>"#;
581
582 let schema = parse_schema(xml).expect("Failed to parse");
583 let ir = SchemaIr::from_schema(&schema);
584
585 assert_eq!(ir.package, "test");
586 assert_eq!(ir.schema_id, 1);
587 assert_eq!(ir.schema_version, 1);
588 assert!(!ir.messages.is_empty());
589 }
590
591 #[test]
592 fn test_resolved_type_from_primitive() {
593 let resolved = ResolvedType::from_primitive(PrimitiveType::Uint64);
594 assert_eq!(resolved.name, "uint64");
595 assert_eq!(resolved.encoded_length, 8);
596 assert_eq!(resolved.rust_type, "u64");
597 assert!(!resolved.is_array);
598 }
599
600 #[test]
601 fn test_type_kind_debug() {
602 let kind = TypeKind::Primitive(PrimitiveType::Int32);
603 let debug_str = format!("{:?}", kind);
604 assert!(debug_str.contains("Primitive"));
605
606 let kind = TypeKind::Composite { fields: vec![] };
607 let debug_str = format!("{:?}", kind);
608 assert!(debug_str.contains("Composite"));
609
610 let kind = TypeKind::Enum {
611 encoding: PrimitiveType::Uint8,
612 variants: vec![],
613 };
614 let debug_str = format!("{:?}", kind);
615 assert!(debug_str.contains("Enum"));
616
617 let kind = TypeKind::Set {
618 encoding: PrimitiveType::Uint16,
619 choices: vec![],
620 };
621 let debug_str = format!("{:?}", kind);
622 assert!(debug_str.contains("Set"));
623 }
624
625 #[test]
626 fn test_resolved_type_clone() {
627 let resolved = ResolvedType::from_primitive(PrimitiveType::Float);
628 let cloned = resolved.clone();
629 assert_eq!(resolved.name, cloned.name);
630 assert_eq!(resolved.encoded_length, cloned.encoded_length);
631 }
632
633 #[test]
634 fn test_schema_ir_with_enum() {
635 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
636<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
637 package="test" id="1" version="1" byteOrder="littleEndian">
638 <types>
639 <enum name="Side" encodingType="uint8">
640 <validValue name="Buy">1</validValue>
641 <validValue name="Sell">2</validValue>
642 </enum>
643 </types>
644 <sbe:message name="Test" id="1" blockLength="1">
645 <field name="side" id="1" type="Side" offset="0"/>
646 </sbe:message>
647</sbe:messageSchema>"#;
648
649 let schema = parse_schema(xml).expect("Failed to parse");
650 let ir = SchemaIr::from_schema(&schema);
651
652 assert!(ir.types.contains_key("Side"));
653 }
654
655 #[test]
656 fn test_schema_ir_with_composite() {
657 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
658<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
659 package="test" id="1" version="1" byteOrder="littleEndian">
660 <types>
661 <composite name="Decimal">
662 <type name="mantissa" primitiveType="int64"/>
663 <type name="exponent" primitiveType="int8"/>
664 </composite>
665 </types>
666 <sbe:message name="Test" id="1" blockLength="9">
667 <field name="price" id="1" type="Decimal" offset="0"/>
668 </sbe:message>
669</sbe:messageSchema>"#;
670
671 let schema = parse_schema(xml).expect("Failed to parse");
672 let ir = SchemaIr::from_schema(&schema);
673
674 assert!(ir.types.contains_key("Decimal"));
675 }
676}