use crate::types::{PrimitiveType, Schema, TypeDef};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SchemaIr {
pub package: String,
pub schema_id: u16,
pub schema_version: u16,
pub types: HashMap<String, ResolvedType>,
pub messages: Vec<ResolvedMessage>,
}
impl SchemaIr {
#[must_use]
pub fn from_schema(schema: &Schema) -> Self {
let mut ir = Self {
package: schema.package.clone(),
schema_id: schema.id,
schema_version: schema.version,
types: HashMap::new(),
messages: Vec::new(),
};
for type_def in &schema.types {
let resolved = ResolvedType::from_type_def(type_def);
ir.types.insert(resolved.name.clone(), resolved);
}
for msg in &schema.messages {
ir.messages
.push(ResolvedMessage::from_message_def(msg, &ir.types));
}
ir
}
#[must_use]
pub fn get_type(&self, name: &str) -> Option<&ResolvedType> {
self.types.get(name)
}
}
#[derive(Debug, Clone)]
pub struct ResolvedType {
pub name: String,
pub kind: TypeKind,
pub encoded_length: usize,
pub rust_type: String,
pub is_array: bool,
pub array_length: Option<usize>,
}
impl ResolvedType {
#[must_use]
pub fn from_type_def(type_def: &TypeDef) -> Self {
match type_def {
TypeDef::Primitive(p) => Self {
name: p.name.clone(),
kind: TypeKind::Primitive(p.primitive_type),
encoded_length: p.encoded_length(),
rust_type: if p.is_array() {
format!(
"[{}; {}]",
p.primitive_type.rust_type(),
p.length.unwrap_or(1)
)
} else {
p.primitive_type.rust_type().to_string()
},
is_array: p.is_array(),
array_length: p.length,
},
TypeDef::Composite(c) => {
let mut offset = 0usize;
let fields = c
.fields
.iter()
.filter_map(|f| {
let field_offset = offset;
offset += f.encoded_length;
f.primitive_type.map(|prim| CompositeFieldInfo {
name: f.name.clone(),
primitive_type: prim,
offset: field_offset,
encoded_length: f.encoded_length,
})
})
.collect();
Self {
name: c.name.clone(),
kind: TypeKind::Composite { fields },
encoded_length: c.encoded_length(),
rust_type: to_pascal_case(&c.name),
is_array: false,
array_length: None,
}
}
TypeDef::Enum(e) => {
let variants: Vec<EnumVariant> = e
.valid_values
.iter()
.filter_map(|v| {
let value = if e.encoding_type.is_signed() {
v.as_i64()
} else {
v.as_u64().map(|u| u as i64)
};
value.map(|val| EnumVariant {
name: v.name.clone(),
value: val,
})
})
.collect();
Self {
name: e.name.clone(),
kind: TypeKind::Enum {
encoding: e.encoding_type,
variants,
},
encoded_length: e.encoding_type.size(),
rust_type: to_pascal_case(&e.name),
is_array: false,
array_length: None,
}
}
TypeDef::Set(s) => {
let choices = s
.choices
.iter()
.map(|c| SetVariant {
name: c.name.clone(),
bit_position: c.bit_position,
})
.collect();
Self {
name: s.name.clone(),
kind: TypeKind::Set {
encoding: s.encoding_type,
choices,
},
encoded_length: s.encoding_type.size(),
rust_type: to_pascal_case(&s.name),
is_array: false,
array_length: None,
}
}
}
}
#[must_use]
pub fn from_primitive(prim: PrimitiveType) -> Self {
Self {
name: prim.sbe_name().to_string(),
kind: TypeKind::Primitive(prim),
encoded_length: prim.size(),
rust_type: prim.rust_type().to_string(),
is_array: false,
array_length: None,
}
}
}
#[derive(Debug, Clone)]
pub struct EnumVariant {
pub name: String,
pub value: i64,
}
#[derive(Debug, Clone)]
pub struct SetVariant {
pub name: String,
pub bit_position: u8,
}
#[derive(Debug, Clone)]
pub struct CompositeFieldInfo {
pub name: String,
pub primitive_type: PrimitiveType,
pub offset: usize,
pub encoded_length: usize,
}
#[derive(Debug, Clone)]
pub enum TypeKind {
Primitive(PrimitiveType),
Composite {
fields: Vec<CompositeFieldInfo>,
},
Enum {
encoding: PrimitiveType,
variants: Vec<EnumVariant>,
},
Set {
encoding: PrimitiveType,
choices: Vec<SetVariant>,
},
}
#[derive(Debug, Clone)]
pub struct ResolvedMessage {
pub name: String,
pub template_id: u16,
pub block_length: u16,
pub fields: Vec<ResolvedField>,
pub groups: Vec<ResolvedGroup>,
pub var_data: Vec<ResolvedVarData>,
}
impl ResolvedMessage {
#[must_use]
pub fn from_message_def(
msg: &crate::messages::MessageDef,
types: &HashMap<String, ResolvedType>,
) -> Self {
let fields = msg
.fields
.iter()
.map(|f| ResolvedField::from_field_def(f, types))
.collect();
let groups = msg
.groups
.iter()
.map(|g| ResolvedGroup::from_group_def(g, types))
.collect();
let var_data = msg
.data_fields
.iter()
.map(|d| ResolvedVarData {
name: d.name.clone(),
id: d.id,
type_name: d.type_name.clone(),
})
.collect();
Self {
name: msg.name.clone(),
template_id: msg.id,
block_length: msg.block_length,
fields,
groups,
var_data,
}
}
#[must_use]
pub fn decoder_name(&self) -> String {
format!("{}Decoder", self.name)
}
#[must_use]
pub fn encoder_name(&self) -> String {
format!("{}Encoder", self.name)
}
}
#[derive(Debug, Clone)]
pub struct ResolvedField {
pub name: String,
pub id: u16,
pub type_name: String,
pub offset: usize,
pub encoded_length: usize,
pub rust_type: String,
pub getter_name: String,
pub setter_name: String,
pub is_optional: bool,
pub is_array: bool,
pub array_length: Option<usize>,
pub primitive_type: Option<PrimitiveType>,
}
impl ResolvedField {
#[must_use]
pub fn from_field_def(
field: &crate::messages::FieldDef,
types: &HashMap<String, ResolvedType>,
) -> Self {
let resolved_type = types.get(&field.type_name).cloned().or_else(|| {
PrimitiveType::from_sbe_name(&field.type_name).map(ResolvedType::from_primitive)
});
let (encoded_length, rust_type, is_array, array_length, primitive_type) =
if let Some(rt) = &resolved_type {
(
rt.encoded_length,
rt.rust_type.clone(),
rt.is_array,
rt.array_length,
match &rt.kind {
TypeKind::Primitive(p) => Some(*p),
_ => None,
},
)
} else {
(field.encoded_length, "u64".to_string(), false, None, None)
};
Self {
name: field.name.clone(),
id: field.id,
type_name: field.type_name.clone(),
offset: field.offset,
encoded_length,
rust_type,
getter_name: to_snake_case(&field.name),
setter_name: format!("set_{}", to_snake_case(&field.name)),
is_optional: field.is_optional(),
is_array,
array_length,
primitive_type,
}
}
}
#[derive(Debug, Clone)]
pub struct ResolvedGroup {
pub name: String,
pub id: u16,
pub block_length: u16,
pub fields: Vec<ResolvedField>,
pub nested_groups: Vec<ResolvedGroup>,
pub var_data: Vec<ResolvedVarData>,
}
impl ResolvedGroup {
#[must_use]
pub fn from_group_def(
group: &crate::messages::GroupDef,
types: &HashMap<String, ResolvedType>,
) -> Self {
let fields = group
.fields
.iter()
.map(|f| ResolvedField::from_field_def(f, types))
.collect();
let nested_groups = group
.nested_groups
.iter()
.map(|g| ResolvedGroup::from_group_def(g, types))
.collect();
let var_data = group
.data_fields
.iter()
.map(|d| ResolvedVarData {
name: d.name.clone(),
id: d.id,
type_name: d.type_name.clone(),
})
.collect();
Self {
name: group.name.clone(),
id: group.id,
block_length: group.block_length,
fields,
nested_groups,
var_data,
}
}
#[must_use]
pub fn decoder_name(&self) -> String {
format!("{}GroupDecoder", to_pascal_case(&self.name))
}
#[must_use]
pub fn entry_decoder_name(&self) -> String {
format!("{}EntryDecoder", to_pascal_case(&self.name))
}
#[must_use]
pub fn encoder_name(&self) -> String {
format!("{}GroupEncoder", to_pascal_case(&self.name))
}
#[must_use]
pub fn entry_encoder_name(&self) -> String {
format!("{}EntryEncoder", to_pascal_case(&self.name))
}
}
#[derive(Debug, Clone)]
pub struct ResolvedVarData {
pub name: String,
pub id: u16,
pub type_name: String,
}
#[must_use]
pub fn to_snake_case(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
let mut prev_lower = false;
let mut start = 0;
let mut first = true;
for segment in s.split(|c: char| !c.is_alphanumeric()) {
if segment.is_empty() {
continue;
}
let chars: Vec<char> = segment.chars().collect();
let len = chars.len();
for i in 0..len {
let c = chars[i];
let is_lower = c.is_lowercase();
let is_upper = c.is_uppercase();
if i > 0 {
let prev = chars[i - 1];
let next_is_lower = chars.get(i + 1).is_some_and(|n| n.is_lowercase());
let boundary =
(prev_lower && is_upper) || (prev.is_uppercase() && is_upper && next_is_lower);
if boundary {
if !first {
result.push('_');
}
for wc in &chars[start..i] {
result.push(wc.to_ascii_lowercase());
}
start = i;
first = false;
}
}
prev_lower = is_lower;
}
if !first {
result.push('_');
}
for wc in &chars[start..] {
result.push(wc.to_ascii_lowercase());
}
first = false;
start = 0;
}
result
}
#[must_use]
pub fn to_screaming_snake_case(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
let mut last_was_separator = false;
for (i, c) in s.chars().enumerate() {
if c == '-' {
if !last_was_separator {
result.push('_');
}
last_was_separator = true;
} else if c.is_uppercase() && i > 0 && !last_was_separator {
result.push('_');
result.push(c.to_ascii_uppercase());
last_was_separator = false;
} else {
result.push(c.to_ascii_uppercase());
last_was_separator = false;
}
}
result
}
#[must_use]
pub fn to_pascal_case(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut capitalize_next = true;
for c in s.chars() {
if c == '_' || c == '-' {
capitalize_next = true;
} else if capitalize_next {
result.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(c);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_schema;
#[test]
fn test_to_snake_case() {
assert_eq!(to_snake_case("clOrdId"), "cl_ord_id");
assert_eq!(to_snake_case("symbol"), "symbol");
assert_eq!(to_snake_case("MDEntryPx"), "md_entry_px");
assert_eq!(to_snake_case("some-hyphenated"), "some_hyphenated");
assert_eq!(to_snake_case("some-Hyphen"), "some_hyphen");
assert_eq!(to_snake_case("some_underscore"), "some_underscore");
assert_eq!(
to_snake_case("some-mixed_separator"),
"some_mixed_separator"
);
assert_eq!(
to_snake_case("some__double_underscore"),
"some_double_underscore"
);
assert_eq!(to_snake_case("some--double-hyphen"), "some_double_hyphen");
assert_eq!(to_snake_case("AB_C"), "ab_c");
assert_eq!(to_snake_case("ABC"), "abc");
assert_eq!(to_snake_case("REDUCE_ONLY"), "reduce_only");
assert_eq!(to_snake_case("DISABLE_SELF_TRADE"), "disable_self_trade");
}
#[test]
fn test_to_screaming_snake_case() {
assert_eq!(to_screaming_snake_case("clOrdId"), "CL_ORD_ID");
assert_eq!(to_screaming_snake_case("symbol"), "SYMBOL");
assert_eq!(
to_screaming_snake_case("some-hyphenated"),
"SOME_HYPHENATED"
);
assert_eq!(to_screaming_snake_case("some-Hyphen"), "SOME_HYPHEN");
}
#[test]
fn test_to_pascal_case() {
assert_eq!(to_pascal_case("message_header"), "MessageHeader");
assert_eq!(to_pascal_case("side"), "Side");
assert_eq!(to_pascal_case("order-type"), "OrderType");
}
#[test]
fn test_schema_ir_from_schema() {
let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
package="test" id="1" version="1" byteOrder="littleEndian">
<types>
<type name="uint64" primitiveType="uint64"/>
</types>
<sbe:message name="Test" id="1" blockLength="8">
<field name="value" id="1" type="uint64" offset="0"/>
</sbe:message>
</sbe:messageSchema>"#;
let schema = parse_schema(xml).expect("Failed to parse");
let ir = SchemaIr::from_schema(&schema);
assert_eq!(ir.package, "test");
assert_eq!(ir.schema_id, 1);
assert_eq!(ir.schema_version, 1);
assert!(!ir.messages.is_empty());
}
#[test]
fn test_resolved_type_from_primitive() {
let resolved = ResolvedType::from_primitive(PrimitiveType::Uint64);
assert_eq!(resolved.name, "uint64");
assert_eq!(resolved.encoded_length, 8);
assert_eq!(resolved.rust_type, "u64");
assert!(!resolved.is_array);
}
#[test]
fn test_type_kind_debug() {
let kind = TypeKind::Primitive(PrimitiveType::Int32);
let debug_str = format!("{:?}", kind);
assert!(debug_str.contains("Primitive"));
let kind = TypeKind::Composite { fields: vec![] };
let debug_str = format!("{:?}", kind);
assert!(debug_str.contains("Composite"));
let kind = TypeKind::Enum {
encoding: PrimitiveType::Uint8,
variants: vec![],
};
let debug_str = format!("{:?}", kind);
assert!(debug_str.contains("Enum"));
let kind = TypeKind::Set {
encoding: PrimitiveType::Uint16,
choices: vec![],
};
let debug_str = format!("{:?}", kind);
assert!(debug_str.contains("Set"));
}
#[test]
fn test_resolved_type_clone() {
let resolved = ResolvedType::from_primitive(PrimitiveType::Float);
let cloned = resolved.clone();
assert_eq!(resolved.name, cloned.name);
assert_eq!(resolved.encoded_length, cloned.encoded_length);
}
#[test]
fn test_schema_ir_with_enum() {
let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
package="test" id="1" version="1" byteOrder="littleEndian">
<types>
<enum name="Side" encodingType="uint8">
<validValue name="Buy">1</validValue>
<validValue name="Sell">2</validValue>
</enum>
</types>
<sbe:message name="Test" id="1" blockLength="1">
<field name="side" id="1" type="Side" offset="0"/>
</sbe:message>
</sbe:messageSchema>"#;
let schema = parse_schema(xml).expect("Failed to parse");
let ir = SchemaIr::from_schema(&schema);
assert!(ir.types.contains_key("Side"));
}
#[test]
fn test_schema_ir_with_composite() {
let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
package="test" id="1" version="1" byteOrder="littleEndian">
<types>
<composite name="Decimal">
<type name="mantissa" primitiveType="int64"/>
<type name="exponent" primitiveType="int8"/>
</composite>
</types>
<sbe:message name="Test" id="1" blockLength="9">
<field name="price" id="1" type="Decimal" offset="0"/>
</sbe:message>
</sbe:messageSchema>"#;
let schema = parse_schema(xml).expect("Failed to parse");
let ir = SchemaIr::from_schema(&schema);
assert!(ir.types.contains_key("Decimal"));
}
}