use anyhow::Result;
use heck::{ToPascalCase, ToSnakeCase};
use roxmltree::Node;
use std::env;
use std::fs;
use std::io::Write;
use std::path::Path;
#[allow(dead_code)]
#[derive(Debug, Default)]
struct MessageSchema {
package: String,
version: u16,
semantic_version: String,
description: String,
types: Vec<SbeType>,
messages: Vec<SbeMessage>,
}
#[allow(dead_code)]
#[derive(Debug)]
enum SbeType {
Composite(SbeComposite),
Enum(SbeEnum),
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeComposite {
name: String,
types: Vec<SbeField>,
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeEnum {
name: String,
encoding_type: String,
valid_values: Vec<SbeValidValue>,
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeField {
name: String,
id: u16,
field_type: String,
description: String,
presence: String,
offset: u16,
length: usize,
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeGroup {
name: String,
id: u16,
dimension_type: String,
fields: Vec<SbeField>,
groups: Vec<SbeGroup>,
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeMessage {
name: String,
id: u16,
description: String,
fields: Vec<SbeField>,
groups: Vec<SbeGroup>,
block_length: u16,
}
#[allow(dead_code)]
#[derive(Debug)]
struct SbeValidValue {
name: String,
value: String,
}
fn get_type_size(sbe_type: &str, schema: &MessageSchema) -> usize {
match sbe_type {
"char" => 1,
"int8" => 1,
"uint8" => 1,
"int16" => 2,
"uint16" => 2,
"int32" => 4,
"uint32" => 4,
"int64" => 8,
"uint64" => 8,
"float" => 4,
"double" => 8,
_ => {
for t in &schema.types {
if let SbeType::Composite(c) = t {
if c.name == sbe_type {
return c
.types
.iter()
.map(|f| get_type_size(&f.field_type, schema) * f.length)
.sum();
}
}
}
0
}
}
}
fn parse_field(node: &Node) -> SbeField {
SbeField {
name: node.attribute("name").unwrap_or("").to_string(),
id: node.attribute("id").unwrap_or("0").parse().unwrap_or(0),
field_type: node.attribute("type").unwrap_or("").to_string(),
description: node.attribute("description").unwrap_or("").to_string(),
presence: node.attribute("presence").unwrap_or("required").to_string(),
offset: node.attribute("offset").unwrap_or("0").parse().unwrap_or(0),
length: node.attribute("length").unwrap_or("1").parse().unwrap_or(1),
}
}
fn parse_group(node: &Node) -> SbeGroup {
let mut fields = Vec::new();
let mut groups = Vec::new();
for child in node.children().filter(Node::is_element) {
match child.tag_name().name() {
"field" => fields.push(parse_field(&child)),
"group" => groups.push(parse_group(&child)),
_ => {}
}
}
SbeGroup {
name: node.attribute("name").unwrap_or("").to_string(),
id: node.attribute("id").unwrap_or("0").parse().unwrap_or(0),
dimension_type: node.attribute("dimensionType").unwrap_or("").to_string(),
fields,
groups,
}
}
fn generate_code(schema: &MessageSchema, dest: &mut fs::File) -> Result<()> {
writeln!(dest, "// Generated by `build.rs`. DO NOT EDIT.")?;
writeln!(dest, "// SBE message types generated from schema")?;
writeln!(dest)?;
writeln!(dest, "#[allow(dead_code)]")?;
writeln!(dest, "#[allow(unused_imports)]")?;
writeln!(dest, "#[allow(non_snake_case)]")?;
writeln!(dest, "#[allow(missing_docs)]")?;
writeln!(dest)?;
writeln!(
dest,
"use crate::{{SbeMessage, SbeDecoder, SbeEncoder, SbeResult}};"
)?;
writeln!(dest, "use zerocopy::{{IntoBytes, FromBytes, Unaligned}};")?;
writeln!(dest)?;
for sbe_type in &schema.types {
match sbe_type {
SbeType::Enum(e) => {
writeln!(dest, "#[derive(Debug, Clone, Copy, PartialEq, Eq)]")?;
writeln!(dest, "#[repr(u8)]")?;
writeln!(dest, "pub enum {} {{", e.name.to_pascal_case())?;
for (index, vv) in e.valid_values.iter().enumerate() {
let value = if let Ok(numeric_value) = vv.value.trim().parse::<u32>() {
numeric_value.to_string()
} else {
index.to_string()
};
writeln!(dest, " {} = {},", vv.name.to_pascal_case(), value)?;
}
writeln!(dest, "}}")?;
writeln!(dest)?;
}
SbeType::Composite(c) => {
writeln!(
dest,
"#[derive(Debug, Clone, Copy, IntoBytes, FromBytes, Unaligned)]"
)?;
writeln!(dest, "#[repr(C, packed)]")?;
writeln!(dest, "pub struct {} {{", c.name.to_pascal_case())?;
for field in &c.types {
writeln!(
dest,
" pub {}: {},",
field.name.to_snake_case(),
map_type(&field.field_type)
)?;
}
writeln!(dest, "}}")?;
writeln!(dest)?;
}
}
}
for msg in &schema.messages {
let name_pascal = msg.name.to_pascal_case();
writeln!(dest, "/// {}", msg.description)?;
writeln!(dest, "#[derive(Debug, Clone, Copy)]")?;
writeln!(dest, "pub struct {name_pascal}<'a> {{")?;
writeln!(dest, " buffer: &'a [u8],")?;
writeln!(dest, " offset: usize,")?;
writeln!(dest, "}}")?;
writeln!(dest)?;
writeln!(dest, "impl<'a> {name_pascal}<'a> {{")?;
writeln!(dest, " pub const TEMPLATE_ID: u16 = {};", msg.id)?;
writeln!(
dest,
" pub const SCHEMA_VERSION: u16 = {};",
schema.version
)?;
writeln!(
dest,
" pub const BLOCK_LENGTH: u16 = {};",
msg.block_length
)?;
writeln!(dest)?;
writeln!(
dest,
" pub fn wrap(buffer: &'a [u8], offset: usize) -> Self {{"
)?;
writeln!(dest, " Self {{ buffer, offset }}")?;
writeln!(dest, " }}")?;
writeln!(dest)?;
for field in &msg.fields {
let _field_name_pascal = field.name.to_pascal_case();
let return_type = map_type(&field.field_type);
if field.presence != "constant" {
writeln!(
dest,
" pub fn {}(&self) -> {} {{",
field.name.to_snake_case(),
return_type
)?;
writeln!(
dest,
" let range = self.offset + {}..;",
field.offset
)?;
writeln!(
dest,
" unsafe {{ *(&self.buffer[range.start] as *const u8 as *const {return_type}) }}"
)?;
writeln!(dest, " }}")?;
writeln!(dest)?;
} else {
writeln!(
dest,
" pub fn {}(&self) -> {} {{",
field.name.to_snake_case(),
return_type
)?;
writeln!(dest, " todo!(\"constant value\")")?;
writeln!(dest, " }}")?;
writeln!(dest)?;
}
}
writeln!(dest, "}}")?;
writeln!(dest)?;
writeln!(dest, "impl SbeMessage for {name_pascal}<'_> {{")?;
writeln!(dest, " const TEMPLATE_ID: u16 = {};", msg.id)?;
writeln!(dest, " const SCHEMA_VERSION: u16 = {};", schema.version)?;
writeln!(dest, " const BLOCK_LENGTH: u16 = {};", msg.block_length)?;
writeln!(
dest,
" const MESSAGE_NAME: &'static str = \"{}\";",
msg.name
)?;
writeln!(dest, "}}")?;
writeln!(dest)?;
}
Ok(())
}
fn map_type(sbe_type: &str) -> String {
match sbe_type {
"" => "u8".to_string(), "char" => "u8".to_string(),
"int8" => "i8".to_string(),
"uint8" => "u8".to_string(),
"int16" => "i16".to_string(),
"uint16" => "u16".to_string(),
"int32" => "i32".to_string(),
"uint32" => "u32".to_string(),
"int64" => "i64".to_string(),
"uint64" => "u64".to_string(),
"float" => "f32".to_string(),
"double" => "f64".to_string(),
_ => sbe_type.to_pascal_case(), }
}
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=resources/sbe.xml");
let xml_string = fs::read_to_string("resources/sbe.xml")?;
let doc = roxmltree::Document::parse(&xml_string)?;
let root = doc.root_element();
let mut schema = MessageSchema {
package: root.attribute("package").unwrap_or("rustysbe").to_string(),
version: root.attribute("version").unwrap_or("0").parse()?,
semantic_version: root.attribute("semanticVersion").unwrap_or("").to_string(),
description: root.attribute("description").unwrap_or("").to_string(),
..Default::default()
};
for node in root.children().filter(Node::is_element) {
match node.tag_name().name() {
"types" => {
for type_node in node.children().filter(Node::is_element) {
match type_node.tag_name().name() {
"composite" => {
let types = type_node
.children()
.filter(Node::is_element)
.map(|n| parse_field(&n))
.collect();
schema.types.push(SbeType::Composite(SbeComposite {
name: type_node.attribute("name").unwrap_or("").to_string(),
types,
}));
}
"enum" => {
let valid_values = type_node
.children()
.filter(Node::is_element)
.map(|n| SbeValidValue {
name: n.attribute("name").unwrap_or("").to_string(),
value: n.text().unwrap_or("").to_string(),
})
.collect();
schema.types.push(SbeType::Enum(SbeEnum {
name: type_node.attribute("name").unwrap_or("").to_string(),
encoding_type: type_node
.attribute("encodingType")
.unwrap_or("")
.to_string(),
valid_values,
}));
}
_ => {}
}
}
}
"message" => {
let mut fields = Vec::new();
let mut groups = Vec::new();
for child in node.children().filter(Node::is_element) {
match child.tag_name().name() {
"field" => fields.push(parse_field(&child)),
"group" => groups.push(parse_group(&child)),
_ => {}
}
}
let block_length = if let Some(bl) = node.attribute("blockLength") {
bl.parse()?
} else {
fields
.iter()
.map(|f| get_type_size(&f.field_type, &schema) * f.length)
.sum::<usize>() as u16
};
schema.messages.push(SbeMessage {
name: node.attribute("name").unwrap_or("").to_string(),
id: node.attribute("id").unwrap_or("0").parse()?,
description: node.attribute("description").unwrap_or("").to_string(),
fields,
groups,
block_length,
});
}
_ => {}
}
}
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("sbe.rs");
let mut file = fs::File::create(&dest_path)?;
generate_code(&schema, &mut file)?;
Ok(())
}