use crate::definitions::EnumElement;
use crate::docs::{self, Data};
use convert_case::{Case, Casing};
use endpoint_libs::model::{EnumVariant, Field, Type};
use eyre::bail;
use itertools::Itertools;
use std::collections::{BTreeSet, HashMap};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::process::Command;
pub trait ToRust {
fn to_rust_ref(&self, serde_with: bool) -> String;
fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String;
fn add_derives(&self, input: String) -> String;
}
impl ToRust for Type {
fn to_rust_ref(&self, serde_with: bool) -> String {
match self {
Type::UInt32 => "u32".to_owned(),
Type::Int32 => "i32".to_owned(),
Type::Int64 => "i64".to_owned(),
Type::Float64 => "f64".to_owned(),
Type::TimeStampMs => "i64".to_owned(),
Type::Struct { name, .. } => name.clone(),
Type::StructRef(name) => name.clone(),
Type::Object => "serde_json::Value".to_owned(),
Type::StructTable { struct_ref } => format!("Vec<{struct_ref}>"),
Type::Vec(ele) => {
format!("Vec<{}>", ele.to_rust_ref(serde_with))
}
Type::Unit => "()".to_owned(),
Type::Optional(t) => {
format!("Option<{}>", t.to_rust_ref(serde_with))
}
Type::Boolean => "bool".to_owned(),
Type::String => "String".to_owned(),
Type::Bytea => "Vec<u8>".to_owned(),
Type::UUID => "Uuid".to_owned(),
Type::IpAddr => "IpAddr".to_owned(),
Type::Enum { name, .. } => format!("Enum{}", name.to_case(Case::Pascal),),
Type::EnumRef {
name,
prefixed_name,
} => {
if *prefixed_name {
format!("Enum{}", name.to_case(Case::Pascal),)
} else {
name.to_case(Case::Pascal)
}
}
Type::BlockchainDecimal => "Decimal".to_owned(),
Type::BlockchainAddress if serde_with => "Address".to_owned(),
Type::BlockchainTransactionHash if serde_with => "H256".to_owned(),
Type::BlockchainAddress => "BlockchainAddress".to_owned(),
Type::BlockchainTransactionHash => "BlockchainTransactionHash".to_owned(),
}
}
fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
let code_regex =
regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
match self {
Type::Struct { name, fields } => {
let mut fields = fields.iter().map(|x| {
let opt = matches!(&x.ty, Type::Optional(_));
let serde_with_opt = match &x.ty {
Type::BlockchainDecimal => "rust_decimal::serde::str",
Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
Type::BlockchainTransactionHash if serde_with => {
"WithBlockchainTransactionHash"
}
_ => "",
};
format!(
"{} {} pub {}: {}",
if opt { "#[serde(default)]" } else { "" },
if serde_with_opt.is_empty() {
"".to_string()
} else {
format!("#[serde(with = \"{serde_with_opt}\")]")
},
x.name,
x.ty.to_rust_ref(serde_with)
)
});
let input = format!("pub struct {} {{{}}}", name, fields.join(","));
if add_derives {
self.add_derives(input)
} else {
input
}
}
Type::Enum {
name,
variants: fields,
} => {
let mut fields = fields
.iter()
.map(|x| {
format!(
r#"
/// {}
{} = {}
"#,
x.description,
if x.name.chars().last().unwrap().is_lowercase() {
x.name.to_case(Case::Pascal)
} else {
x.name.clone()
},
x.value
)
})
.sorted_by(|a, b| {
let code_a = {
match code_regex.captures(a) {
Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
eprintln!(
"Sorting error: {err}: Rust output may not be sorted correctly"
);
0
}),
None => {
eprintln!(
"Sorting error: Rust output may not be sorted correctly"
);
0
}
}
};
let code_b = {
match code_regex.captures(b) {
Some(code) => {
code[1].parse::<u64>().unwrap_or_else(|err| {
eprintln!(
"Sorting error: {err}: Rust output may not be sorted correctly"
);
0
})
}
None => {
eprintln!(
"Sorting error: Rust output may not be sorted correctly"
);
0
}
}
};
code_a.cmp(&code_b)
});
let enum_content = format!(
r#"pub enum Enum{} {{{}}}"#,
name.to_case(Case::Pascal),
fields.join(",")
);
if add_derives {
self.add_derives(enum_content)
} else {
enum_content
}
}
x => x.to_rust_ref(serde_with),
}
}
fn add_derives(&self, input: String) -> String {
match self {
Self::Enum { .. } => Self::add_default_enum_derives(input),
Self::Struct { .. } => Self::add_default_struct_derives(input),
_ => input,
}
}
}
pub fn collect_rust_recursive_types(t: Type) -> Vec<Type> {
match t {
Type::Struct { ref fields, .. } => {
let mut v = vec![t.clone()];
for x in fields {
v.extend(collect_rust_recursive_types(x.ty.clone()));
}
v
}
Type::Vec(x) => collect_rust_recursive_types(*x),
Type::Optional(x) => collect_rust_recursive_types(*x),
_ => vec![],
}
}
pub fn gen_model_rs(data: &Data) -> eyre::Result<()> {
let db_filename = data.output_dir.join("model.rs");
if let Some(parent) = db_filename.parent() {
std::fs::create_dir_all(parent)?;
}
let worktable_imports = if data.enums.iter().any(|e| e.config.worktable_support)
|| data.structs.iter().any(|s| s.config.worktable_support)
{
r#"use worktable::prelude::*;
use rkyv::Archive;
"#
} else {
""
};
let mut model_file = File::create(&db_filename)?;
write!(
&mut model_file,
"use endpoint_libs::libs::error_code::ErrorCode;
use endpoint_libs::libs::ws::*;
use endpoint_libs::libs::types::*;
use num_derive::FromPrimitive;
use serde::*;
use strum_macros::{{Display, EnumString}};
use uuid::Uuid;
use std::net::IpAddr;
{worktable_imports}
",
)?;
for e in &data.enums {
writeln!(&mut model_file, "{}", e.to_rust_decl(false, true))?;
}
for s in &data.structs {
writeln!(&mut model_file, "{}", s.to_rust_decl(false, true))?;
}
check_endpoint_codes(data, &mut model_file)?;
dump_endpoint_schema(data, &mut model_file)?;
let errors = docs::get_error_messages(&data.project_root)?;
let rule = regex::Regex::new(r"\{[\w]+}")?;
for e in &errors.codes {
let name = format!("Error{}", e.symbol.to_case(Case::Pascal));
let s = Type::struct_(
name,
rule.find_iter(&e.message)
.map(|m| m.as_str())
.map(|s| s.trim_matches('{').trim_matches('}'))
.map(|s| Field::new(s.to_string(), Type::String))
.collect(),
);
writeln!(&mut model_file, "{}", s.to_rust_decl(true, true))?;
}
let enum_ = Type::enum_(
"ErrorCode",
errors
.codes
.into_iter()
.map(|x| {
EnumVariant::new_with_description(
x.symbol.to_case(Case::Pascal),
format!("{} {}", x.source, x.message),
x.code,
)
})
.collect(),
);
writeln!(&mut model_file, "{}", enum_.to_rust_decl(false, true))?;
writeln!(
&mut model_file,
r#"
impl From<EnumErrorCode> for ErrorCode {{
fn from(e: EnumErrorCode) -> Self {{
ErrorCode::new(e as _)
}}
}}
"#
)?;
let mut endpoint_reqres_types = BTreeSet::new();
for s in &data.services {
for e in &s.endpoints {
let req = Type::struct_(
format!("{}Request", e.schema.name),
e.schema.parameters.clone(),
);
let resp = Type::struct_(
format!("{}Response", e.schema.name),
e.schema.returns.clone(),
);
endpoint_reqres_types.extend(
[
collect_rust_recursive_types(req),
collect_rust_recursive_types(resp),
e.schema
.stream_response
.clone()
.into_iter()
.flat_map(Type::try_unwrap)
.collect::<Vec<_>>(),
]
.concat()
.into_iter(),
);
}
}
for s in endpoint_reqres_types {
write!(&mut model_file, r#"{}"#, s.to_rust_decl(true, true))?;
}
for s in &data.services {
for endpoint in &s.endpoints {
let roles_list = resolve_roles_ids(&endpoint.schema.roles, &data.enums)
.into_iter()
.map(|x| x.to_string())
.join(", ");
write!(
&mut model_file,
"
impl WsRequest for {end_name2}Request {{
type Response = {end_name2}Response;
const METHOD_ID: u32 = {code};
const ROLES: &[u32] = &[{roles_list}];
const SCHEMA: &'static str = r#\"{schema}\"#;
}}
impl WsResponse for {end_name2}Response {{
type Request = {end_name2}Request;
}}
",
end_name2 = endpoint.schema.name.to_case(Case::Pascal),
code = endpoint.schema.code,
schema = serde_json::to_string_pretty(&endpoint.schema).unwrap()
)?;
}
}
model_file.flush()?;
drop(model_file);
rustfmt(&db_filename)?;
Ok(())
}
fn resolve_roles_ids(endpoint_roles: &Vec<String>, all_enums: &Vec<EnumElement>) -> Vec<i64> {
let mut all_enums_typed: HashMap<String, Vec<EnumVariant>> = HashMap::new();
for e in all_enums {
if let Type::Enum { name: _, variants } = &e.inner {
all_enums_typed.insert(e.to_rust_ref(false), variants.clone());
}
}
let mut roles_ids = vec![];
for role in endpoint_roles {
let (role_enum_name, role_variant_name) =
role.split_once("::").unwrap_or(("", role.as_str()));
if let Some(role_enum_variants) = all_enums_typed.get(role_enum_name) {
if let Some(role_variant_in_endpoint) = role_enum_variants
.iter()
.find(|v| v.name == role_variant_name)
{
roles_ids.push(role_variant_in_endpoint.value);
} else {
eprintln!(
"Warning: Role variant '{role_variant_name}' not found in enum '{role_enum_name}'"
);
}
} else {
eprintln!("Warning: Role enum '{role_enum_name}' not found");
}
}
let mut roles_ids_set: BTreeSet<i64> = BTreeSet::new();
for id in &roles_ids {
if !roles_ids_set.insert(*id) {
eprintln!("Warning: Duplicate role ID found: {id}");
}
}
roles_ids_set.into_iter().collect()
}
pub fn rustfmt(f: &Path) -> eyre::Result<()> {
let exit = Command::new("rustfmt")
.arg("--edition")
.arg("2021")
.arg(f)
.spawn()?
.wait()?;
if !exit.success() {
bail!("failed to rustfmt {:?}", exit);
}
Ok(())
}
pub fn check_endpoint_codes(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
let mut variants = vec![];
for s in &data.services {
for e in &s.endpoints {
variants.push(EnumVariant::new(e.schema.name.clone(), e.schema.code as _));
}
}
let enum_ = Type::enum_("Endpoint", variants);
writeln!(writer, "{}", enum_.to_rust_decl(false, true))?;
Ok(())
}
pub fn dump_endpoint_schema(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
let mut cases = vec![];
for s in &data.services {
for e in &s.endpoints {
cases.push(format!(
"Self::{name} => {name}Request::SCHEMA,",
name = e.schema.name.to_case(Case::Pascal),
));
}
}
let code = format!(
r#"
impl EnumEndpoint {{
pub fn schema(&self) -> endpoint_libs::model::EndpointSchema {{
let schema = match self {{
{cases}
}};
serde_json::from_str(schema).unwrap()
}}
}}
"#,
cases = cases.join("\n")
);
writeln!(writer, "{code}")?;
Ok(())
}
#[cfg(test)]
mod tests {
use regex::Regex;
#[test]
fn test_extract_number_from_error_code() {
let re = Regex::new(r"=\s*(\d+)").unwrap();
let text1 = r#" ///
LoginStep2 = 10003
,"#;
let caps1 = re.captures(text1).expect("Should match");
let number1: u64 = caps1[1].parse().expect("Should parse as u64");
assert_eq!(number1, 10003);
let text2 = "Authorize = 10000,";
let caps2 = re.captures(text2).expect("Should match");
let number2: u64 = caps2[1].parse().expect("Should parse as u64");
assert_eq!(number2, 10000);
let text3 = "SomeError=12345,";
let caps3 = re.captures(text3).expect("Should match");
let number3: u64 = caps3[1].parse().expect("Should parse as u64");
assert_eq!(number3, 12345);
let text4 = r#"/// SQL R0019 UnauthorizedMessage
UnauthorizedMessage = 45349677
, "#;
let caps4 = re.captures(text4).expect("Should match");
let number4: u64 = caps4[1].parse().expect("Should parse as u64");
assert_eq!(number4, 45349677);
}
}