use convert_case::Casing;
use glob::glob;
use serde::Deserialize;
use std::{
collections::HashSet,
fs::File,
io::{Read, Write},
path::{Path, PathBuf},
};
#[derive(Debug, Deserialize)]
struct Protocol {
#[serde(rename = "$value", default)]
pub elements: Vec<Element>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Element {
Enum(Enum),
Struct(Struct),
Packet(Packet),
}
#[derive(Debug, Deserialize, Clone)]
struct Enum {
#[serde(rename = "@name")]
pub name: String,
#[serde(rename = "@type")]
pub data_type: String,
#[serde(rename = "$value", default)]
pub elements: Vec<EnumElement>,
}
#[derive(Debug, Deserialize, Clone)]
enum EnumElement {
#[serde(rename = "comment")]
Comment(String),
#[serde(rename = "value")]
Value(EnumValue),
}
#[derive(Debug, Deserialize, Clone)]
struct EnumValue {
#[serde(rename = "@name")]
pub name: String,
#[serde(rename = "comment")]
pub comment: Option<String>,
#[serde(rename = "$text")]
pub value: i32,
}
#[derive(Debug, Deserialize, Clone)]
struct Struct {
#[serde(rename = "@name")]
pub name: String,
#[serde(rename = "$value", default)]
pub elements: Vec<StructElement>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
enum StructElement {
Break,
Chunked(Chunked),
Comment(String),
Dummy(Dummy),
Field(Field),
Array(Array),
Length(Length),
Switch(Switch),
}
#[derive(Debug, Deserialize, Clone)]
struct Chunked {
#[serde(rename = "$value", default)]
pub elements: Vec<StructElement>,
}
#[derive(Debug, Deserialize, Clone)]
struct Field {
#[serde(rename = "@name")]
pub name: Option<String>,
#[serde(rename = "@type")]
pub data_type: String,
#[serde(rename = "$value", default)]
pub value: Option<String>,
pub comment: Option<String>,
#[serde(rename = "@padded")]
pub padded: Option<bool>,
#[serde(rename = "@optional")]
pub optional: Option<bool>,
#[serde(rename = "@length")]
pub length: Option<String>,
}
fn default_as_true() -> bool {
true
}
#[derive(Debug, Deserialize, Clone)]
struct Array {
#[serde(rename = "@name")]
pub name: String,
#[serde(rename = "@type")]
pub data_type: String,
#[serde(rename = "@length")]
pub length: Option<String>,
#[serde(rename = "@optional")]
pub optional: Option<bool>,
#[serde(rename = "@delimited")]
pub delimited: Option<bool>,
#[serde(rename = "@trailing-delimiter")]
#[serde(default = "default_as_true")]
pub trailing_delimiter: bool,
pub comment: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
struct Length {
#[serde(rename = "@name")]
pub name: String,
#[serde(rename = "@type")]
pub data_type: String,
#[serde(rename = "@length")]
pub optional: Option<bool>,
#[serde(rename = "@offset")]
pub offset: Option<i32>,
}
#[derive(Debug, Deserialize, Clone)]
struct Dummy {
#[serde(rename = "@type")]
pub data_type: String,
#[serde(rename = "$value", default)]
pub value: String,
}
#[derive(Debug, Deserialize, Clone)]
struct Switch {
#[serde(rename = "@field")]
pub field: String,
#[serde(rename = "$value", default)]
pub cases: Vec<Case>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename = "case")]
struct Case {
#[serde(rename = "@default")]
pub default: Option<bool>,
#[serde(rename = "@value")]
pub value: Option<String>,
#[serde(rename = "$value", default)]
pub elements: Option<Vec<StructElement>>,
}
#[derive(Debug, Deserialize, Clone)]
struct Packet {
#[serde(rename = "@action")]
pub action: String,
#[serde(rename = "@family")]
pub family: String,
#[serde(rename = "$value", default)]
pub elements: Vec<StructElement>,
}
static CODEGEN_WARNING: &str = r"// This file is automatically generated by build.rs
// Please do not edit it directly
";
static RUST_KEYWORDS: [&str; 52] = [
"abstract", "alignof", "as", "become", "box", "break", "const", "continue", "crate", "do",
"else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in", "let", "loop",
"macro", "match", "mod", "move", "mut", "offsetof", "override", "priv", "proc", "pub", "pure",
"ref", "return", "Self", "self", "sizeof", "static", "struct", "super", "trait", "true",
"type", "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
];
fn main() {
println!("cargo:rerun-if-changed=eo-protocol/xml");
let mut protocols = Vec::new();
for entry in glob("eo-protocol/xml/**/protocol.xml").expect("Failed to read glob pattern") {
match entry {
Ok(path) => match parse_protocol_file(&path) {
Ok(protocol) => protocols.push((protocol, path.to_owned())),
Err(e) => {
panic!(
"Failed to parse protocol file: {}. {}",
e,
path.to_string_lossy()
);
}
},
Err(e) => {
panic!("Failed to read protocol file: {}", e);
}
}
}
let enums: Vec<Enum> = protocols
.iter()
.flat_map(|(protocol, _)| {
protocol.elements.iter().filter_map(|e| match e {
Element::Enum(protocol_enum) => Some(protocol_enum.clone()),
_ => None,
})
})
.collect();
let structs: Vec<Struct> = protocols
.iter()
.flat_map(|(protocol, _)| {
protocol.elements.iter().filter_map(|e| match e {
Element::Struct(protocol_struct) => Some(protocol_struct.clone()),
_ => None,
})
})
.collect();
for (protocol, path) in &protocols {
let output_dir = get_output_directory(path);
std::fs::create_dir_all(&output_dir).unwrap();
let mut mod_code = String::new();
mod_code.push_str(CODEGEN_WARNING);
for element in &protocol.elements {
match element {
Element::Enum(protocol_enum) => {
generate_enum_file(protocol_enum, &output_dir, &mut mod_code).unwrap()
}
Element::Struct(protocol_struct) => {
let imports = get_imports(&protocol_struct.elements, &protocols);
generate_struct_file(
protocol_struct,
imports,
&output_dir,
&mut mod_code,
&enums,
&structs,
)
.unwrap();
}
Element::Packet(packet) => {
let imports = get_imports(&packet.elements, &protocols);
generate_packet_file(
packet,
imports,
&output_dir,
&mut mod_code,
&enums,
&structs,
)
.unwrap();
}
}
}
let is_root = path.parent().unwrap() == Path::new("eo-protocol/xml");
if is_root {
mod_code.push_str("pub mod map;\n");
mod_code.push_str("pub mod net;\n");
mod_code.push_str("pub mod r#pub;\n");
}
let is_net = path.parent().unwrap() == Path::new("eo-protocol/xml/net");
if is_net {
mod_code.push_str("pub mod client;\n");
mod_code.push_str("pub mod server;\n");
}
let mut mod_file = File::create(output_dir.join("mod.rs")).unwrap();
mod_file.write_all(mod_code.as_bytes()).unwrap();
}
}
fn generate_enum_file(
protocol_enum: &Enum,
path: &Path,
mod_code: &mut String,
) -> Result<(), Box<dyn std::error::Error>> {
let mut code = String::new();
code.push_str(CODEGEN_WARNING);
let comments = match protocol_enum
.elements
.iter()
.find(|e| matches!(e, EnumElement::Comment(_)))
{
Some(EnumElement::Comment(comment)) => get_comments(comment),
_ => vec![],
};
append_doc_comments(&mut code, comments);
code.push_str("#[derive(Debug, PartialEq, Eq, Copy, Clone)]\n");
code.push_str(&format!("pub enum {} {{\n", protocol_enum.name));
let variants: Vec<&EnumValue> = protocol_enum
.elements
.iter()
.filter_map(|e| match e {
EnumElement::Value(value) => Some(value),
_ => None,
})
.collect();
for variant in &variants {
let comments = match &variant.comment {
Some(comment) => get_comments(comment),
None => vec![],
};
append_doc_comments(&mut code, comments);
code.push_str(&format!(" {},\n", replace_keyword(&variant.name)));
}
code.push_str(&format!(" Unrecognized({}),\n", get_field_type(&protocol_enum.data_type)));
code.push_str("}\n\n");
code.push_str(&format!(
"impl From<{}> for {} {{\n",
get_field_type(&protocol_enum.data_type),
protocol_enum.name
));
code.push_str(&format!(
" fn from(value: {}) -> Self {{\n",
get_field_type(&protocol_enum.data_type),
));
code.push_str(" match value {\n");
for variant in &variants {
code.push_str(&format!(
" {} => Self::{},\n",
variant.value,
replace_keyword(&variant.name)
));
}
code.push_str(
" _ => Self::Unrecognized(value),\n"
);
code.push_str(" }\n");
code.push_str(" }\n");
code.push_str("}\n\n");
code.push_str(&format!(
"impl From<{}> for {} {{\n",
protocol_enum.name,
get_field_type(&protocol_enum.data_type)
));
code.push_str(&format!(
" fn from(value: {}) -> {} {{\n",
protocol_enum.name,
get_field_type(&protocol_enum.data_type),
));
code.push_str(" match value {\n");
for variant in &variants {
code.push_str(&format!(
" {}::{} => {},\n",
protocol_enum.name,
replace_keyword(&variant.name),
variant.value
));
}
code.push_str(&format!(
" {}::Unrecognized(value) => value,\n",
protocol_enum.name
));
code.push_str(" }\n");
code.push_str(" }\n");
code.push_str("}\n\n");
code.push_str(&format!("impl Default for {} {{\n", protocol_enum.name));
code.push_str(" fn default() -> Self {\n");
code.push_str(&format!(
" Self::{}\n",
replace_keyword(&variants[0].name)
));
code.push_str(" }\n");
code.push_str("}\n");
code.push_str(CODEGEN_WARNING);
let snake_name = protocol_enum.name.to_case(convert_case::Case::Snake);
mod_code.push_str(&format!("mod {};\n", snake_name));
mod_code.push_str(&format!("pub use {}::*;\n", snake_name));
let mut file = File::create(path.join(format!(
"{}.rs",
protocol_enum.name.to_case(convert_case::Case::Snake)
)))?;
file.write_all(code.as_bytes())?;
Ok(())
}
fn generate_struct_file(
protocol_struct: &Struct,
imports: Vec<String>,
path: &Path,
mod_code: &mut String,
enums: &[Enum],
structs: &[Struct],
) -> Result<(), Box<dyn std::error::Error>> {
let mut code = String::new();
code.push_str(CODEGEN_WARNING);
for import in &imports {
code.push_str(&format!("{}\n", import));
}
if !imports.is_empty() {
code.push('\n');
}
write_struct(
&protocol_struct.name,
&protocol_struct.elements,
&mut code,
enums,
structs,
);
for switch in protocol_struct.elements.iter().filter_map(|e| match e {
StructElement::Switch(switch) => Some(switch),
_ => None,
}) {
let name = get_field_type(&format!("{}_{}_data", protocol_struct.name, switch.field));
generate_switch_code(&name, &mut code, switch, enums, structs);
}
code.push_str(CODEGEN_WARNING);
let snake_name = protocol_struct.name.to_case(convert_case::Case::Snake);
let mut file = File::create(path.join(format!("{}.rs", snake_name)))?;
file.write_all(code.as_bytes())?;
mod_code.push_str(&format!("mod {};\n", snake_name));
mod_code.push_str(&format!("pub use {}::*;\n", snake_name));
Ok(())
}
fn generate_packet_file(
packet: &Packet,
imports: Vec<String>,
path: &Path,
mod_code: &mut String,
enums: &[Enum],
structs: &[Struct],
) -> Result<(), Box<dyn std::error::Error>> {
let mut code = String::new();
code.push_str(CODEGEN_WARNING);
for import in &imports {
code.push_str(&format!("{}\n", import));
}
if !imports.is_empty() {
code.push('\n');
}
let path_name = path.to_str().unwrap();
let source = if path_name.ends_with("server") {
"Server"
} else {
"Client"
};
let name = format!("{}{}{}Packet", packet.family, packet.action, source);
write_struct(&name, &packet.elements, &mut code, enums, structs);
for switch in packet.elements.iter().filter_map(|e| match e {
StructElement::Switch(switch) => Some(switch),
_ => None,
}) {
let name = get_field_type(&format!("{}_{}_data", name, switch.field));
generate_switch_code(&name, &mut code, switch, enums, structs);
}
for chunked in packet.elements.iter().filter_map(|e| match e {
StructElement::Chunked(chunked) => Some(chunked),
_ => None,
}) {
for switch in chunked.elements.iter().filter_map(|e| match e {
StructElement::Switch(switch) => Some(switch),
_ => None,
}) {
let name = get_field_type(&format!("{}_{}_data", name, switch.field));
generate_switch_code(&name, &mut code, switch, enums, structs);
}
}
code.push_str(CODEGEN_WARNING);
let snake_name = name.to_case(convert_case::Case::Snake);
let mut file = File::create(path.join(format!("{}.rs", snake_name)))?;
file.write_all(code.as_bytes())?;
mod_code.push_str(&format!("mod {};\n", snake_name));
mod_code.push_str(&format!("pub use {}::*;\n", snake_name));
Ok(())
}
fn generate_switch_code(
name: &str,
code: &mut String,
switch: &Switch,
enums: &[Enum],
structs: &[Struct],
) {
code.push_str("#[derive(Debug, PartialEq, Eq)]\n");
code.push_str(&format!("pub enum {} {{\n", name));
for case in switch.cases.iter().filter(|c| c.elements.is_some()) {
match case.default {
Some(true) => {
code.push_str(&format!(
" Default({}),\n",
get_field_type(&format!("{}_default", name)),
));
}
_ => {
code.push_str(&format!(
" {}({}),\n",
replace_keyword(case.value.as_ref().unwrap()),
get_field_type(&format!("{}_{}", name, case.value.as_ref().unwrap()))
));
}
}
}
code.push_str("}\n\n");
for case in switch.cases.iter().filter(|c| c.elements.is_some()) {
let elements = case.elements.as_ref().unwrap();
let name = match case.default {
Some(true) => get_field_type(&format!("{}_default", name)),
_ => get_field_type(&format!("{}_{}", name, case.value.as_ref().unwrap())),
};
write_struct(&name, elements, code, enums, structs);
for switch in elements.iter().filter_map(|e| match e {
StructElement::Switch(switch) => Some(switch),
_ => None,
}) {
let name = get_field_type(&format!("{}_{}_data", name, switch.field));
generate_switch_code(&name, code, switch, enums, structs);
}
}
}
fn write_struct(
name: &str,
elements: &[StructElement],
code: &mut String,
enums: &[Enum],
structs: &[Struct],
) {
let comments = match elements
.iter()
.find(|e| matches!(e, StructElement::Comment(_)))
{
Some(StructElement::Comment(comment)) => get_comments(comment),
_ => vec![],
};
for comment in &comments {
code.push_str(&format!("/// {}\n", comment));
}
let mut derives = vec!["Debug", "Default", "PartialEq", "Eq"];
if name == "Coords" {
derives.push("Clone");
derives.push("Copy");
}
code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
code.push_str(&format!("pub struct {} {{\n", name));
let field_count = write_struct_fields(code, name, elements, 0);
code.push_str("}\n\n");
code.push_str(&format!("impl {} {{\n", name));
code.push_str(" pub fn new() -> Self {\n");
code.push_str(" Self::default()\n");
code.push_str(" }\n");
code.push_str("}\n\n");
code.push_str(&format!("impl EoSerialize for {} {{\n", name));
code.push_str(&format!(
" /// Serializes a [{}] into the given [EoWriter] instance\n",
name
));
if field_count > 0 {
code.push_str(
" fn serialize(&self, writer: &mut EoWriter) -> Result<(), EoSerializeError> {\n",
);
write_struct_serialize(code, name, elements, enums, structs);
} else {
code.push_str(
" fn serialize(&self, _writer: &mut EoWriter) -> Result<(), EoSerializeError> {\n",
);
code.push_str(" Ok(())\n");
}
code.push_str(" }\n");
code.push_str(&format!(
" /// Deserializes a [{}] from an [EoReader] instance\n",
name
));
if field_count > 0 {
code.push_str(" fn deserialize(reader: &EoReader) -> Result<Self, EoReaderError> {\n");
write_struct_deserialize(code, name, elements, enums, structs);
} else {
code.push_str(" fn deserialize(_reader: &EoReader) -> Result<Self, EoReaderError> {\n");
code.push_str(" Ok(Self::default())\n");
}
code.push_str(" }\n");
code.push_str("}\n\n");
}
fn write_struct_serialize(
code: &mut String,
name: &str,
elements: &[StructElement],
enums: &[Enum],
structs: &[Struct],
) {
for element in elements {
match element {
StructElement::Break => {
generate_serialize_break(code);
}
StructElement::Dummy(dummy) => {
generate_serialize_dummy(code, dummy);
}
StructElement::Field(field) => {
generate_serialize_field(code, field, enums, structs);
}
StructElement::Array(array) => generate_serialize_array(code, array, enums, structs),
StructElement::Length(length) => generate_serialize_length(
code,
get_name_of_field_that_uses_this_length(&length.name, elements),
length,
),
StructElement::Switch(switch) => {
generate_serialize_switch(code, name, switch);
}
StructElement::Chunked(chunked) => {
for element in &chunked.elements {
match element {
StructElement::Break => {
generate_serialize_break(code);
}
StructElement::Dummy(dummy) => {
generate_serialize_dummy(code, dummy);
}
StructElement::Field(field) => {
generate_serialize_field(code, field, enums, structs);
}
StructElement::Array(array) => {
generate_serialize_array(code, array, enums, structs)
}
StructElement::Length(length) => generate_serialize_length(
code,
get_name_of_field_that_uses_this_length(&length.name, elements),
length,
),
StructElement::Switch(switch) => {
generate_serialize_switch(code, name, switch);
}
_ => {}
}
}
}
_ => {}
}
}
code.push_str(" Ok(())\n");
}
fn get_name_of_field_that_uses_this_length(
length_name: &str,
elements: &[StructElement],
) -> String {
let field = elements.iter().find_map(|e| match e {
StructElement::Array(array) => {
if array.length == Some(length_name.to_owned()) {
Some(array.name.clone())
} else {
None
}
}
StructElement::Field(field) => {
if field.length == Some(length_name.to_owned()) {
Some(field.name.as_ref().unwrap().clone())
} else {
None
}
}
StructElement::Chunked(chunked) => chunked.elements.iter().find_map(|e| match e {
StructElement::Array(array) => {
if array.length == Some(length_name.to_owned()) {
Some(array.name.clone())
} else {
None
}
}
StructElement::Field(field) => {
if field.length == Some(length_name.to_owned()) {
Some(field.name.as_ref().unwrap().clone())
} else {
None
}
}
_ => None,
}),
_ => None,
});
match field {
Some(field) => field,
None => panic!("Length field not found! {}", length_name),
}
}
fn write_struct_deserialize(
code: &mut String,
name: &str,
elements: &[StructElement],
enums: &[Enum],
structs: &[Struct],
) {
code.push_str(
" let current_chunked_readming_mode = reader.get_chunked_reading_mode();\n",
);
code.push_str(" let mut data = Self::default();\n");
for element in elements {
match element {
StructElement::Chunked(chunked) => {
code.push_str(" reader.set_chunked_reading_mode(true);\n");
for element in &chunked.elements {
match element {
StructElement::Break => {
code.push_str(" reader.next_chunk()?;\n");
}
StructElement::Chunked(_) => {
panic!("Nested chunked elements are not supported! {}", name);
}
StructElement::Length(length) => {
generate_deserialize_length(code, length);
}
StructElement::Dummy(dummy) => {
code.push_str(&format!(" reader.get_{}()?;\n", dummy.data_type));
}
StructElement::Field(field) => {
generate_deserialize_field(code, field, enums, structs)
}
StructElement::Array(array) => {
generate_deserialize_array(code, array, enums, structs)
}
StructElement::Switch(switch) => {
let field = match elements.iter().find(|e| match e {
StructElement::Field(field) => {
field.name == Some(switch.field.clone())
}
StructElement::Chunked(chunked) => {
chunked.elements.iter().any(|e| match e {
StructElement::Field(field) => {
field.name == Some(switch.field.clone())
}
_ => false,
})
}
_ => false,
}) {
Some(StructElement::Field(field)) => field,
Some(StructElement::Chunked(chunked)) => {
match chunked.elements.iter().find(|e| match e {
StructElement::Field(field) => {
field.name == Some(switch.field.clone())
}
_ => false,
}) {
Some(StructElement::Field(field)) => field,
_ => panic!("Switch field not found! {}", name),
}
}
_ => panic!("Switch field not found! {}", name),
};
let switch_enum = enums
.iter()
.find(|e| e.name == field.data_type)
.expect("Switch enum not found!");
generate_deserialize_switch(code, name, switch, switch_enum);
}
_ => {}
}
}
}
StructElement::Dummy(dummy) => {
code.push_str(&format!(" reader.get_{}()?;\n", dummy.data_type));
}
StructElement::Length(length) => {
generate_deserialize_length(code, length);
}
StructElement::Field(field) => generate_deserialize_field(code, field, enums, structs),
StructElement::Array(array) => generate_deserialize_array(code, array, enums, structs),
StructElement::Switch(switch) => {
let field = match elements.iter().find(|e| match e {
StructElement::Field(field) => field.name == Some(switch.field.clone()),
StructElement::Chunked(chunked) => chunked.elements.iter().any(|e| match e {
StructElement::Field(field) => field.name == Some(switch.field.clone()),
_ => false,
}),
_ => false,
}) {
Some(StructElement::Field(field)) => field,
Some(StructElement::Chunked(chunked)) => {
match chunked.elements.iter().find(|e| match e {
StructElement::Field(field) => field.name == Some(switch.field.clone()),
_ => false,
}) {
Some(StructElement::Field(field)) => field,
_ => panic!("Switch field not found! {}", name),
}
}
_ => panic!("Switch field not found! {}", name),
};
let switch_enum = enums
.iter()
.find(|e| e.name == field.data_type)
.expect("Switch enum not found!");
generate_deserialize_switch(code, name, switch, switch_enum);
}
_ => {}
}
}
code.push_str(" reader.set_chunked_reading_mode(current_chunked_readming_mode);\n");
code.push_str(" Ok(data)\n");
}
fn needs_result(data_type: &str) -> bool {
!matches!(data_type, "byte" | "string" | "encoded_string")
}
fn generate_serialize_break(code: &mut String) {
code.push_str(" writer.add_byte(0xff);\n");
}
fn generate_serialize_dummy(code: &mut String, dummy: &Dummy) {
code.push_str(&format!(
" writer.add_{}({}){};\n",
dummy.data_type,
if dummy.value.chars().all(|c| c.is_numeric()) {
dummy.value.to_owned()
} else {
format!("\"{}\"", dummy.value)
},
if needs_result(&dummy.data_type) {
"?"
} else {
""
}
));
}
fn generate_serialize_field(code: &mut String, field: &Field, enums: &[Enum], structs: &[Struct]) {
let optional = matches!(field.optional, Some(true));
if optional {
let name = match field.name {
Some(ref name) => name,
None => panic!("Field name is required for optional fields!"),
};
code.push_str(&format!(
" if let Some({}) = self.{} {{\n",
replace_keyword(name),
replace_keyword(name)
));
generate_inner_field_serialize(code, field, enums, structs);
code.push_str(" }\n");
} else {
generate_inner_field_serialize(code, field, enums, structs);
}
}
fn generate_inner_field_serialize(
code: &mut String,
field: &Field,
enums: &[Enum],
structs: &[Struct],
) {
let (data_type, enum_data_type) = if field.data_type.contains(':') {
field.data_type.split_once(':').unwrap()
} else {
(field.data_type.as_str(), "")
};
let optional = matches!(field.optional, Some(true));
if let Some(protocol_enum) = enums.iter().find(|e| e.name == data_type) {
let enum_data_type = if enum_data_type.is_empty() {
protocol_enum.data_type.to_string()
} else {
enum_data_type.to_string()
};
let name = if let Some(value) = &field.value {
format!("{}::{}", get_field_type(data_type), value)
} else {
let name = field.name.as_ref().unwrap();
if optional {
replace_keyword(name)
} else {
format!("self.{}", replace_keyword(name))
}
};
code.push_str(&format!(
" writer.add_{}({}.into()){};\n",
enum_data_type,
name,
if needs_result(&enum_data_type) {
"?"
} else {
""
}
));
} else if structs.iter().any(|s| s.name == data_type) {
let name = if let Some(value) = &field.value {
value.to_owned()
} else {
let name = field.name.as_ref().unwrap();
if optional {
replace_keyword(name)
} else {
format!("self.{}", name)
}
};
code.push_str(&format!("{}.serialize(writer)?;\n", name))
} else {
match data_type {
"blob" => code.push_str(&format!(
" writer.add_bytes(&self.{});\n",
field.name.as_ref().unwrap()
)),
"bool" => {
let name = if let Some(value) = &field.value {
value.to_owned()
} else {
let name = field.name.as_ref().unwrap();
if optional {
replace_keyword(name)
} else {
format!("self.{}", name)
}
};
code.push_str(&format!(
" writer.add_{}(if {} {{ 1 }} else {{ 0 }}){};\n",
if enum_data_type.is_empty() {
"char"
} else {
enum_data_type
},
name,
if needs_result(if enum_data_type.is_empty() {
"char"
} else {
enum_data_type
}) {
"?"
} else {
""
}
))
}
_ => {
let name = if let Some(value) = &field.value {
if value.chars().all(|c| c.is_numeric()) {
value.to_owned()
} else {
format!("\"{}\"", value)
}
} else {
let name = field.name.as_ref().unwrap();
if optional {
replace_keyword(name)
} else {
format!("self.{}", name)
}
};
let length = match &field.length {
Some(length) => length.as_str(),
None => "",
};
let padded = match field.padded {
Some(padded) => padded,
_ => false,
};
if padded
&& !length.is_empty()
&& matches!(field.data_type.as_str(), "string" | "encoded_string")
{
code.push_str(&format!(
" let padding_length = {} - {}.len();\n",
length, name
));
code.push_str(" let padding = \"ÿ\".repeat(padding_length);\n");
code.push_str(&format!(
" writer.add_{}(&format!(\"{{}}{{}}\", {}, padding));\n",
replace_keyword(&field.data_type),
name
));
return;
}
code.push_str(&format!(
" writer.add_{}({}{}){};\n",
replace_keyword(&field.data_type),
if name == "array_item"
&& matches!(
field.data_type.as_str(),
"byte" | "char" | "short" | "three" | "int"
)
{
"*"
} else if matches!(field.data_type.as_str(), "string" | "encoded_string")
&& !name.starts_with('"')
&& name != "array_item"
{
"&"
} else {
""
},
name,
if needs_result(&field.data_type) {
"?"
} else {
""
}
));
}
}
}
}
fn generate_serialize_array(code: &mut String, array: &Array, enums: &[Enum], structs: &[Struct]) {
let optional = matches!(array.optional, Some(true));
if optional {
panic!("Optional array not yet supported because I'm lazy");
}
let delimited = matches!(array.delimited, Some(true));
if delimited && !array.trailing_delimiter {
code.push_str(&format!(
" for (i, array_item) in self.{}.iter().enumerate() {{\n",
replace_keyword(&array.name)
));
code.push_str(" if i > 0 {\n");
code.push_str(" writer.add_byte(0xff);\n");
code.push_str(" }\n");
} else {
code.push_str(&format!(
" for array_item in &self.{} {{\n ",
replace_keyword(&array.name)
));
}
generate_inner_field_serialize(
code,
&Field {
name: Some("array_item".to_owned()),
data_type: array.data_type.clone(),
value: None,
comment: None,
padded: None,
optional: Some(true), length: None,
},
enums,
structs,
);
if delimited && array.trailing_delimiter {
code.push_str(" writer.add_byte(0xff);\n");
}
code.push_str(" }\n");
}
fn generate_serialize_length(code: &mut String, field_name: String, length: &Length) {
let optional = matches!(length.optional, Some(true));
let offset = length.offset.unwrap_or(0);
if optional {
code.push_str(&format!(
" if let Some(length) = &self.{} {{\n",
length.name
));
}
let offset_operation = match offset.cmp(&0) {
std::cmp::Ordering::Less => format!(" + {}", offset.abs()),
std::cmp::Ordering::Greater => format!(" - {}", offset.abs()),
std::cmp::Ordering::Equal => "".to_owned(),
};
code.push_str(&format!(
" writer.add_{}((self.{}.len(){}) as i32){};\n",
length.data_type,
field_name,
offset_operation,
if needs_result(&length.data_type) {
"?"
} else {
""
}
));
if optional {
code.push_str(" }\n");
}
}
fn generate_serialize_switch(code: &mut String, struct_name: &str, switch: &Switch) {
code.push_str(&format!(
" match &self.{}_data {{\n",
replace_keyword(&switch.field)
));
for case in switch.cases.iter().filter(|c| c.elements.is_some()) {
match case.value {
Some(ref value) => {
code.push_str(&format!(
" Some({}::{}(data)) => {{\n",
get_field_type(&format!("{}_{}_data", struct_name, switch.field)),
replace_keyword(value)
));
code.push_str(" data.serialize(writer)?;\n");
code.push_str(" }\n");
}
None => match case.default {
Some(true) => {
code.push_str(&format!(
" Some({}::Default(data)) => {{\n",
get_field_type(&format!("{}_{}_data", struct_name, switch.field)),
));
code.push_str(" data.serialize(writer)?;\n");
code.push_str(" }\n");
}
_ => panic!("Unnamed switch case with default=false"),
},
}
}
code.push_str(" _ => (),\n");
code.push_str(" }\n");
}
fn generate_deserialize_length(code: &mut String, length: &Length) {
let optional = matches!(length.optional, Some(true));
let offset = length.offset.unwrap_or(0);
if optional {
code.push_str("if reader.remaining()? > 0 {{\n");
}
let offset_operation = match offset.cmp(&0) {
std::cmp::Ordering::Greater => format!(" + {}", offset.abs()),
std::cmp::Ordering::Less => format!(" - {}", offset.abs()),
_ => "".to_owned(),
};
code.push_str(&format!(
" let {} = (reader.get_{}()?{}) as usize;\n",
replace_keyword(&length.name),
length.data_type,
offset_operation,
));
if optional {
code.push_str("}\n");
}
}
fn generate_deserialize_field(
code: &mut String,
field: &Field,
enums: &[Enum],
structs: &[Struct],
) {
let optional = matches!(field.optional, Some(true));
if optional {
let name = match field.name {
Some(ref name) => name,
None => panic!("Field name is required for optional fields!"),
};
code.push_str(&format!(
" data.{} = if reader.remaining()? > 0 {{\n",
replace_keyword(name)
));
code.push_str(" Some(");
generate_inner_field_deserialize(code, field, enums, structs);
code.push_str(")\n");
code.push_str(" } else {\n");
code.push_str(" None\n");
code.push_str(" };\n");
} else {
if let Some(name) = &field.name {
code.push_str(&format!(" data.{} = ", replace_keyword(name)));
}
generate_inner_field_deserialize(code, field, enums, structs);
code.push_str(";\n");
}
}
fn generate_deserialize_array(
code: &mut String,
array: &Array,
enums: &[Enum],
structs: &[Struct],
) {
let optional = matches!(array.optional, Some(true));
if optional {
code.push_str(" if reader.remaining()? > 0 {{\n");
generate_inner_array_deserialize(code, array, enums, structs);
code.push_str(" }\n");
} else {
generate_inner_array_deserialize(code, array, enums, structs);
}
}
fn generate_deserialize_switch(
code: &mut String,
struct_name: &str,
switch: &Switch,
switch_enum: &Enum,
) {
code.push_str(&format!(
" data.{}_data = match {}::from(data.{}) {{\n",
replace_keyword(&switch.field),
get_field_type(&switch_enum.data_type),
replace_keyword(&switch.field)
));
for case in switch.cases.iter().filter(|c| c.elements.is_some()) {
match case.value {
Some(ref value) => {
if let Some(EnumElement::Value(enum_value)) =
switch_enum.elements.iter().find(|e| match e {
EnumElement::Value(v) => v.name == *value,
_ => false,
})
{
code.push_str(&format!(
" {} => Some({}::{}({}::deserialize(reader)?)),\n",
enum_value.value,
get_field_type(&format!("{}_{}_data", struct_name, switch.field)),
replace_keyword(value),
get_field_type(&format!(
"{}_{}_data_{}",
struct_name, switch.field, &value
))
));
} else {
code.push_str(&format!(
" {} => Some({}::{}({}::deserialize(reader)?)),\n",
value,
get_field_type(&format!("{}_{}_data", struct_name, switch.field)),
replace_keyword(value),
get_field_type(&format!(
"{}_{}_data_{}",
struct_name, switch.field, &value
))
));
}
}
None => match case.default {
Some(true) => {
code.push_str(&format!(
" _ => Some({}::Default({}::deserialize(reader)?)),\n",
get_field_type(&format!("{}_{}_data", struct_name, switch.field)),
get_field_type(&format!("{}_{}_data_default", struct_name, switch.field))
));
}
_ => panic!("Unnamed switch case with default=false"),
},
}
}
if !switch.cases.iter().any(|c| matches!(c.default, Some(true))) {
code.push_str(" _ => None,\n");
}
code.push_str(" };\n");
}
fn generate_inner_field_deserialize(
code: &mut String,
field: &Field,
enums: &[Enum],
structs: &[Struct],
) {
let (data_type, enum_data_type) = if field.data_type.contains(':') {
field.data_type.split_once(':').unwrap()
} else {
(field.data_type.as_str(), "")
};
if let Some(protocol_enum) = enums.iter().find(|e| e.name == data_type) {
let enum_data_type = if enum_data_type.is_empty() {
protocol_enum.data_type.to_string()
} else {
enum_data_type.to_string()
};
code.push_str(&format!(
"{}::from(reader.get_{}()?)",
get_field_type(data_type),
enum_data_type,
));
} else if structs.iter().any(|s| s.name == data_type) {
code.push_str(&format!("{}::deserialize(reader)?", field.data_type));
} else if let Some(length) = &field.length {
match data_type {
"string" => code.push_str(&format!(" reader.get_fixed_string({})?", length)),
"encoded_string" => code.push_str(&format!(
" reader.get_fixed_encoded_string({})?",
length
)),
_ => panic!("Unexpected length for data type: {}", data_type),
}
} else {
match data_type {
"blob" => code.push_str(" reader.get_bytes(reader.remaining()?)?"),
"bool" => code.push_str(&format!(
"reader.get_{}()? == 1",
if enum_data_type.is_empty() {
"char"
} else {
enum_data_type
}
)),
_ => {
code.push_str(&format!(" reader.get_{}()?", data_type));
}
}
}
}
fn generate_inner_array_deserialize(
code: &mut String,
array: &Array,
enums: &[Enum],
structs: &[Struct],
) {
let delimited = matches!(array.delimited, Some(true));
let need_guard = !array.trailing_delimiter && array.length.is_some();
if let Some(length) = &array.length {
code.push_str(&format!(
" for {} in 0..{} {{\n",
if need_guard { "i" } else { "_" },
length
));
} else {
code.push_str(" while reader.remaining()? > 0 {\n");
}
code.push_str(&format!(" data.{}.push(", array.name));
generate_inner_field_deserialize(
code,
&Field {
name: Some(array.name.clone()),
data_type: array.data_type.clone(),
value: None,
comment: None,
padded: None,
optional: array.optional,
length: None,
},
enums,
structs,
);
code.push_str(");\n");
if delimited {
if need_guard {
let length = match &array.length {
Some(length) => length,
None => panic!("Array length is required for non trailing- delimited arrays!"),
};
code.push_str(&format!(" if i + 1 < {} {{\n", length));
code.push_str(" reader.next_chunk()?;\n");
code.push_str(" }\n");
} else {
code.push_str(" reader.next_chunk()?;\n");
}
}
code.push_str(" }\n");
}
fn write_struct_fields(
code: &mut String,
struct_name: &str,
elements: &[StructElement],
field_count: usize,
) -> usize {
let mut field_count = field_count;
for element in elements {
match element {
StructElement::Field(field) => {
if field.name.is_none() {
continue;
}
field_count += 1;
let optional = matches!(field.optional, Some(true));
let comments = match &field.comment {
Some(comment) => get_comments(comment),
None => vec![],
};
for comment in &comments {
code.push_str(&format!(" /// {}\n", comment));
}
if optional {
code.push_str(&format!(
" pub {}: Option<{}>,\n",
replace_keyword(field.name.as_ref().unwrap()),
get_field_type(&field.data_type)
));
} else {
code.push_str(&format!(
" pub {}: {},\n",
replace_keyword(field.name.as_ref().unwrap()),
get_field_type(&field.data_type)
));
}
}
StructElement::Chunked(chunked) => {
field_count +=
write_struct_fields(code, struct_name, &chunked.elements, field_count);
}
StructElement::Array(array) => {
let comments = match &array.comment {
Some(comment) => get_comments(comment),
None => vec![],
};
for comment in &comments {
code.push_str(&format!(" /// {}\n", comment));
}
field_count += 1;
code.push_str(&format!(
" pub {}: Vec<{}>,\n",
replace_keyword(&array.name),
get_field_type(&array.data_type)
));
}
StructElement::Switch(switch) => {
field_count += 1;
code.push_str(&format!(
" pub {}_data: Option<{}>,\n",
replace_keyword(&switch.field),
get_field_type(&format!("{}_{}_data", struct_name, switch.field))
));
}
_ => {}
}
}
field_count
}
fn get_field_type(data_type: &str) -> String {
if data_type.contains(':') {
return get_field_type(data_type.split(':').next().unwrap());
}
match data_type {
"byte" => "u8".to_owned(),
"char" => "i32".to_owned(),
"short" => "i32".to_owned(),
"three" => "i32".to_owned(),
"int" => "i32".to_owned(),
"bool" => "bool".to_owned(),
"string" => "String".to_owned(),
"encoded_string" => "String".to_owned(),
"blob" => "Vec<u8>".to_owned(),
_ => data_type.to_owned().to_case(convert_case::Case::Pascal),
}
}
static PRIMITIVE_TYPES: [&str; 9] = [
"byte",
"char",
"short",
"three",
"int",
"bool",
"string",
"encoded_string",
"blob",
];
fn get_imports(elements: &[StructElement], protocols: &[(Protocol, PathBuf)]) -> Vec<String> {
let mut imports = vec![
"use crate::data::{EoReader, EoReaderError, EoWriter, EoSerialize, EoSerializeError};"
.to_owned(),
];
let mut unique_types = HashSet::new();
find_unique_types(elements, &mut unique_types);
for primitive in &PRIMITIVE_TYPES {
unique_types.remove(*primitive);
}
for unique_type in &unique_types {
if let Some(protocol_path) = find_protocol_for_type(unique_type, protocols) {
let use_path = match protocol_path.to_str().unwrap() {
"eo-protocol/xml/protocol.xml" => "crate::protocol",
"eo-protocol/xml/map/protocol.xml" => "crate::protocol::map",
"eo-protocol/xml/pub/protocol.xml" => "crate::protocol::r#pub",
"eo-protocol/xml/net/protocol.xml" => "crate::protocol::net",
"eo-protocol/xml/net/client/protocol.xml" => "crate::protocol::net::client",
"eo-protocol/xml/net/server/protocol.xml" => "crate::protocol::net::server",
_ => panic!("Unknown protocol path: {}", protocol_path.to_string_lossy()),
};
imports.push(format!("use {}::{};", use_path, unique_type));
}
}
imports
}
fn find_protocol_for_type<'a>(
data_type: &str,
protocols: &'a [(Protocol, PathBuf)],
) -> Option<&'a PathBuf> {
for (protocol, path) in protocols {
for element in &protocol.elements {
match element {
Element::Struct(protocol_struct) => {
if protocol_struct.name == data_type {
return Some(path);
}
}
Element::Enum(protocol_enum) => {
if protocol_enum.name == data_type {
return Some(path);
}
}
_ => {}
}
}
}
None
}
fn find_unique_types(elements: &[StructElement], unique_types: &mut HashSet<String>) {
for element in elements {
match element {
StructElement::Field(field) => {
if field.data_type.contains(':') {
unique_types.insert(field.data_type.split(':').next().unwrap().to_owned());
} else {
unique_types.insert(field.data_type.clone());
}
}
StructElement::Chunked(chunked) => {
find_unique_types(&chunked.elements, unique_types);
}
StructElement::Array(array) => {
unique_types.insert(array.data_type.clone());
}
StructElement::Switch(switch) => {
for case in &switch.cases {
if let Some(elements) = &case.elements {
find_unique_types(elements, unique_types);
}
}
}
_ => {}
}
}
}
fn get_comments(comment: &str) -> Vec<&str> {
comment.split('\n').map(|c| c.trim()).collect::<Vec<&str>>()
}
fn append_doc_comments(code: &mut String, comments: Vec<&str>) {
for comment in &comments {
code.push_str(&format!("/// {}\n", comment));
}
}
fn get_output_directory(base: &Path) -> PathBuf {
let out_dir = std::env::var_os("OUT_DIR").unwrap();
Path::new(&out_dir).join(
base.parent()
.unwrap()
.strip_prefix("eo-protocol/xml")
.unwrap(),
)
}
fn replace_keyword(word: &str) -> String {
if word == "Self" {
return "SELF".to_owned();
}
if word == "Ok" {
return "OK".to_owned();
}
if word == "0" {
return "Zero".to_owned();
}
if RUST_KEYWORDS.contains(&word) {
format!("r#{}", word)
} else {
word.to_owned()
}
}
fn parse_protocol_file(path: &std::path::Path) -> Result<Protocol, Box<dyn std::error::Error>> {
let mut file = File::open(path)?;
let mut xml = String::new();
file.read_to_string(&mut xml)?;
let protocol: Protocol = quick_xml::de::from_str(&xml)?;
Ok(protocol)
}