1use crate::types::{PrimitiveType, Schema, TypeDef};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct SchemaIr {
12 pub package: String,
14 pub schema_id: u16,
16 pub schema_version: u16,
18 pub types: HashMap<String, ResolvedType>,
20 pub messages: Vec<ResolvedMessage>,
22}
23
24impl SchemaIr {
25 #[must_use]
27 pub fn from_schema(schema: &Schema) -> Self {
28 let mut ir = Self {
29 package: schema.package.clone(),
30 schema_id: schema.id,
31 schema_version: schema.version,
32 types: HashMap::new(),
33 messages: Vec::new(),
34 };
35
36 for type_def in &schema.types {
38 let resolved = ResolvedType::from_type_def(type_def);
39 ir.types.insert(resolved.name.clone(), resolved);
40 }
41
42 for msg in &schema.messages {
44 ir.messages
45 .push(ResolvedMessage::from_message_def(msg, &ir.types));
46 }
47
48 ir
49 }
50
51 #[must_use]
53 pub fn get_type(&self, name: &str) -> Option<&ResolvedType> {
54 self.types.get(name)
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct ResolvedType {
61 pub name: String,
63 pub kind: TypeKind,
65 pub encoded_length: usize,
67 pub rust_type: String,
69 pub is_array: bool,
71 pub array_length: Option<usize>,
73}
74
75impl ResolvedType {
76 #[must_use]
78 pub fn from_type_def(type_def: &TypeDef) -> Self {
79 match type_def {
80 TypeDef::Primitive(p) => Self {
81 name: p.name.clone(),
82 kind: TypeKind::Primitive(p.primitive_type),
83 encoded_length: p.encoded_length(),
84 rust_type: if p.is_array() {
85 format!(
86 "[{}; {}]",
87 p.primitive_type.rust_type(),
88 p.length.unwrap_or(1)
89 )
90 } else {
91 p.primitive_type.rust_type().to_string()
92 },
93 is_array: p.is_array(),
94 array_length: p.length,
95 },
96 TypeDef::Composite(c) => Self {
97 name: c.name.clone(),
98 kind: TypeKind::Composite,
99 encoded_length: c.encoded_length(),
100 rust_type: to_pascal_case(&c.name),
101 is_array: false,
102 array_length: None,
103 },
104 TypeDef::Enum(e) => Self {
105 name: e.name.clone(),
106 kind: TypeKind::Enum(e.encoding_type),
107 encoded_length: e.encoding_type.size(),
108 rust_type: to_pascal_case(&e.name),
109 is_array: false,
110 array_length: None,
111 },
112 TypeDef::Set(s) => Self {
113 name: s.name.clone(),
114 kind: TypeKind::Set(s.encoding_type),
115 encoded_length: s.encoding_type.size(),
116 rust_type: to_pascal_case(&s.name),
117 is_array: false,
118 array_length: None,
119 },
120 }
121 }
122
123 #[must_use]
125 pub fn from_primitive(prim: PrimitiveType) -> Self {
126 Self {
127 name: prim.sbe_name().to_string(),
128 kind: TypeKind::Primitive(prim),
129 encoded_length: prim.size(),
130 rust_type: prim.rust_type().to_string(),
131 is_array: false,
132 array_length: None,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Copy)]
139pub enum TypeKind {
140 Primitive(PrimitiveType),
142 Composite,
144 Enum(PrimitiveType),
146 Set(PrimitiveType),
148}
149
150#[derive(Debug, Clone)]
152pub struct ResolvedMessage {
153 pub name: String,
155 pub template_id: u16,
157 pub block_length: u16,
159 pub fields: Vec<ResolvedField>,
161 pub groups: Vec<ResolvedGroup>,
163 pub var_data: Vec<ResolvedVarData>,
165}
166
167impl ResolvedMessage {
168 #[must_use]
170 pub fn from_message_def(
171 msg: &crate::messages::MessageDef,
172 types: &HashMap<String, ResolvedType>,
173 ) -> Self {
174 let fields = msg
175 .fields
176 .iter()
177 .map(|f| ResolvedField::from_field_def(f, types))
178 .collect();
179
180 let groups = msg
181 .groups
182 .iter()
183 .map(|g| ResolvedGroup::from_group_def(g, types))
184 .collect();
185
186 let var_data = msg
187 .data_fields
188 .iter()
189 .map(|d| ResolvedVarData {
190 name: d.name.clone(),
191 id: d.id,
192 type_name: d.type_name.clone(),
193 })
194 .collect();
195
196 Self {
197 name: msg.name.clone(),
198 template_id: msg.id,
199 block_length: msg.block_length,
200 fields,
201 groups,
202 var_data,
203 }
204 }
205
206 #[must_use]
208 pub fn decoder_name(&self) -> String {
209 format!("{}Decoder", self.name)
210 }
211
212 #[must_use]
214 pub fn encoder_name(&self) -> String {
215 format!("{}Encoder", self.name)
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct ResolvedField {
222 pub name: String,
224 pub id: u16,
226 pub type_name: String,
228 pub offset: usize,
230 pub encoded_length: usize,
232 pub rust_type: String,
234 pub getter_name: String,
236 pub setter_name: String,
238 pub is_optional: bool,
240 pub is_array: bool,
242 pub array_length: Option<usize>,
244 pub primitive_type: Option<PrimitiveType>,
246}
247
248impl ResolvedField {
249 #[must_use]
251 pub fn from_field_def(
252 field: &crate::messages::FieldDef,
253 types: &HashMap<String, ResolvedType>,
254 ) -> Self {
255 let resolved_type = types.get(&field.type_name).cloned().or_else(|| {
256 PrimitiveType::from_sbe_name(&field.type_name).map(ResolvedType::from_primitive)
257 });
258
259 let (encoded_length, rust_type, is_array, array_length, primitive_type) =
260 if let Some(rt) = &resolved_type {
261 (
262 rt.encoded_length,
263 rt.rust_type.clone(),
264 rt.is_array,
265 rt.array_length,
266 match rt.kind {
267 TypeKind::Primitive(p) => Some(p),
268 _ => None,
269 },
270 )
271 } else {
272 (field.encoded_length, "u64".to_string(), false, None, None)
273 };
274
275 Self {
276 name: field.name.clone(),
277 id: field.id,
278 type_name: field.type_name.clone(),
279 offset: field.offset,
280 encoded_length,
281 rust_type,
282 getter_name: to_snake_case(&field.name),
283 setter_name: format!("set_{}", to_snake_case(&field.name)),
284 is_optional: field.is_optional(),
285 is_array,
286 array_length,
287 primitive_type,
288 }
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct ResolvedGroup {
295 pub name: String,
297 pub id: u16,
299 pub block_length: u16,
301 pub fields: Vec<ResolvedField>,
303 pub nested_groups: Vec<ResolvedGroup>,
305 pub var_data: Vec<ResolvedVarData>,
307}
308
309impl ResolvedGroup {
310 #[must_use]
312 pub fn from_group_def(
313 group: &crate::messages::GroupDef,
314 types: &HashMap<String, ResolvedType>,
315 ) -> Self {
316 let fields = group
317 .fields
318 .iter()
319 .map(|f| ResolvedField::from_field_def(f, types))
320 .collect();
321
322 let nested_groups = group
323 .nested_groups
324 .iter()
325 .map(|g| ResolvedGroup::from_group_def(g, types))
326 .collect();
327
328 let var_data = group
329 .data_fields
330 .iter()
331 .map(|d| ResolvedVarData {
332 name: d.name.clone(),
333 id: d.id,
334 type_name: d.type_name.clone(),
335 })
336 .collect();
337
338 Self {
339 name: group.name.clone(),
340 id: group.id,
341 block_length: group.block_length,
342 fields,
343 nested_groups,
344 var_data,
345 }
346 }
347
348 #[must_use]
350 pub fn decoder_name(&self) -> String {
351 format!("{}GroupDecoder", to_pascal_case(&self.name))
352 }
353
354 #[must_use]
356 pub fn entry_decoder_name(&self) -> String {
357 format!("{}EntryDecoder", to_pascal_case(&self.name))
358 }
359}
360
361#[derive(Debug, Clone)]
363pub struct ResolvedVarData {
364 pub name: String,
366 pub id: u16,
368 pub type_name: String,
370}
371
372#[must_use]
374pub fn to_snake_case(s: &str) -> String {
375 let mut result = String::with_capacity(s.len() + 4);
376 for (i, c) in s.chars().enumerate() {
377 if c.is_uppercase() && i > 0 {
378 result.push('_');
379 }
380 result.push(c.to_ascii_lowercase());
381 }
382 result
383}
384
385#[must_use]
387pub fn to_pascal_case(s: &str) -> String {
388 let mut result = String::with_capacity(s.len());
389 let mut capitalize_next = true;
390
391 for c in s.chars() {
392 if c == '_' || c == '-' {
393 capitalize_next = true;
394 } else if capitalize_next {
395 result.push(c.to_ascii_uppercase());
396 capitalize_next = false;
397 } else {
398 result.push(c);
399 }
400 }
401
402 result
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::parser::parse_schema;
409
410 #[test]
411 fn test_to_snake_case() {
412 assert_eq!(to_snake_case("clOrdId"), "cl_ord_id");
413 assert_eq!(to_snake_case("symbol"), "symbol");
414 assert_eq!(to_snake_case("MDEntryPx"), "m_d_entry_px");
415 }
416
417 #[test]
418 fn test_to_pascal_case() {
419 assert_eq!(to_pascal_case("message_header"), "MessageHeader");
420 assert_eq!(to_pascal_case("side"), "Side");
421 assert_eq!(to_pascal_case("order-type"), "OrderType");
422 }
423
424 #[test]
425 fn test_schema_ir_from_schema() {
426 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
427<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
428 package="test" id="1" version="1" byteOrder="littleEndian">
429 <types>
430 <type name="uint64" primitiveType="uint64"/>
431 </types>
432 <sbe:message name="Test" id="1" blockLength="8">
433 <field name="value" id="1" type="uint64" offset="0"/>
434 </sbe:message>
435</sbe:messageSchema>"#;
436
437 let schema = parse_schema(xml).expect("Failed to parse");
438 let ir = SchemaIr::from_schema(&schema);
439
440 assert_eq!(ir.package, "test");
441 assert_eq!(ir.schema_id, 1);
442 assert_eq!(ir.schema_version, 1);
443 assert!(!ir.messages.is_empty());
444 }
445
446 #[test]
447 fn test_resolved_type_from_primitive() {
448 let resolved = ResolvedType::from_primitive(PrimitiveType::Uint64);
449 assert_eq!(resolved.name, "uint64");
450 assert_eq!(resolved.encoded_length, 8);
451 assert_eq!(resolved.rust_type, "u64");
452 assert!(!resolved.is_array);
453 }
454
455 #[test]
456 fn test_type_kind_debug() {
457 let kind = TypeKind::Primitive(PrimitiveType::Int32);
458 let debug_str = format!("{:?}", kind);
459 assert!(debug_str.contains("Primitive"));
460
461 let kind = TypeKind::Composite;
462 let debug_str = format!("{:?}", kind);
463 assert!(debug_str.contains("Composite"));
464
465 let kind = TypeKind::Enum(PrimitiveType::Uint8);
466 let debug_str = format!("{:?}", kind);
467 assert!(debug_str.contains("Enum"));
468
469 let kind = TypeKind::Set(PrimitiveType::Uint16);
470 let debug_str = format!("{:?}", kind);
471 assert!(debug_str.contains("Set"));
472 }
473
474 #[test]
475 fn test_resolved_type_clone() {
476 let resolved = ResolvedType::from_primitive(PrimitiveType::Float);
477 let cloned = resolved.clone();
478 assert_eq!(resolved.name, cloned.name);
479 assert_eq!(resolved.encoded_length, cloned.encoded_length);
480 }
481
482 #[test]
483 fn test_schema_ir_with_enum() {
484 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
485<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
486 package="test" id="1" version="1" byteOrder="littleEndian">
487 <types>
488 <enum name="Side" encodingType="uint8">
489 <validValue name="Buy">1</validValue>
490 <validValue name="Sell">2</validValue>
491 </enum>
492 </types>
493 <sbe:message name="Test" id="1" blockLength="1">
494 <field name="side" id="1" type="Side" offset="0"/>
495 </sbe:message>
496</sbe:messageSchema>"#;
497
498 let schema = parse_schema(xml).expect("Failed to parse");
499 let ir = SchemaIr::from_schema(&schema);
500
501 assert!(ir.types.contains_key("Side"));
502 }
503
504 #[test]
505 fn test_schema_ir_with_composite() {
506 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
507<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
508 package="test" id="1" version="1" byteOrder="littleEndian">
509 <types>
510 <composite name="Decimal">
511 <type name="mantissa" primitiveType="int64"/>
512 <type name="exponent" primitiveType="int8"/>
513 </composite>
514 </types>
515 <sbe:message name="Test" id="1" blockLength="9">
516 <field name="price" id="1" type="Decimal" offset="0"/>
517 </sbe:message>
518</sbe:messageSchema>"#;
519
520 let schema = parse_schema(xml).expect("Failed to parse");
521 let ir = SchemaIr::from_schema(&schema);
522
523 assert!(ir.types.contains_key("Decimal"));
524 }
525}