use encoding_rs::{Encoding, GB18030, UTF_8};
use roxmltree::{Document, Node, ParsingOptions};
use std::fs::File;
use std::io::{self, Read, Write};
use std::path::Path;
pub fn generate_errors_wrapper_code<X, P>(
source_file: X,
target_dir: P,
) -> Result<(), Box<dyn std::error::Error>>
where
X: AsRef<Path>,
P: AsRef<Path>,
{
let xml_content = std::fs::read_to_string(source_file)?;
let popts = ParsingOptions {
allow_dtd: true,
nodes_limit: u32::MAX,
};
let doc = Document::parse_with_options(&xml_content, popts)?;
let mut output =
File::create(target_dir.as_ref().join("error.rs")).expect("create `errors.rs` error");
write_header(&mut output)?;
let mut errors = Vec::new();
for node in doc.descendants().filter(|n| n.has_tag_name("error")) {
if let (Some(id), Some(value), Some(prompt)) = (
node.attribute("id"),
node.attribute("value"),
node.attribute("prompt"),
) {
errors.push((id, value.parse::<i32>().unwrap_or(0), prompt));
}
}
errors.sort_by_key(|e| e.1);
write_enum_definition(&mut output, &errors)?;
write_impl(&mut output, &errors)?;
write_traits(&mut output)?;
write_from_impls(&mut output)?;
Ok(())
}
fn add_error_module_to_mod_rs() -> Result<(), Box<dyn std::error::Error>> {
let base_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
#[cfg(feature = "v1alpha1")]
let version = "v1alpha1";
#[cfg(feature = "v1alpha2")]
let version = "v1alpha2";
let mod_rs_path = Path::new(&base_dir)
.join("src")
.join(version)
.join("mod.rs");
let mod_content = std::fs::read_to_string(&mod_rs_path)?;
if mod_content.contains("pub mod error {") {
println!("cargo::warning=Error module already exists in mod.rs");
return Ok(());
}
let error_module = r#"
pub mod error {
include!(concat!(env!("OUT_DIR"), "/error.rs"));
}
pub use error::*;
"#;
let new_content = mod_content + error_module;
std::fs::write(mod_rs_path, new_content)?;
println!("cargo::warning=Added error module to mod.rs");
Ok(())
}
fn read_file_with_encoding_detection(
path: &str,
check_xml_declaration: bool,
) -> Result<(String, String), Box<dyn std::error::Error>> {
let mut file = File::open(path)?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)?;
let (encoding, encoding_name) = if check_xml_declaration {
let declared_encoding = extract_encoding_from_xml_declaration(&bytes);
match declared_encoding {
Some(enc) if enc.eq_ignore_ascii_case("gb2312") => (GB18030, "gb2312".to_string()),
Some(enc) => {
match Encoding::for_label(enc.as_bytes()) {
Some(e) => (e, enc),
None => detect_encoding(&bytes),
}
}
None => detect_encoding(&bytes),
}
} else {
detect_encoding(&bytes)
};
let (cow, _used_encoding, _had_errors) = encoding.decode(&bytes);
Ok((cow.into_owned(), encoding_name))
}
fn extract_encoding_from_xml_declaration(bytes: &[u8]) -> Option<String> {
let (declaration, _, _) = UTF_8.decode(&bytes[..std::cmp::min(100, bytes.len())]);
let declaration = declaration.to_string();
if let Some(xml_decl_end) = declaration.find("?>") {
let xml_decl = &declaration[..xml_decl_end];
if let Some(encoding_start) = xml_decl.find("encoding=") {
let encoding_part = &xml_decl[encoding_start + 9..]; if let Some(quote_char) = encoding_part.chars().next() {
if quote_char == '"' || quote_char == '\'' {
if let Some(end_pos) = encoding_part[1..].find(quote_char) {
return Some(encoding_part[1..=end_pos].to_string());
}
}
}
}
}
None
}
fn detect_encoding(bytes: &[u8]) -> (&'static Encoding, String) {
if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) {
return (UTF_8, "utf-8 with BOM".to_string());
}
let high_byte_count = bytes.iter().filter(|&&b| b >= 0x80).count();
if high_byte_count > 0 && high_byte_count as f32 / bytes.len() as f32 > 0.1 {
let (_, _, had_errors) = GB18030.decode(bytes);
if !had_errors {
return (GB18030, "gb18030/gb2312".to_string());
}
}
(UTF_8, "utf-8".to_string())
}
fn write_header(output: &mut impl Write) -> io::Result<()> {
writeln!(output, "// 自动生成的代码 - 请勿手动修改")?;
writeln!(output, "// 由 gen_error.rs 从 error.xml 生成")?;
writeln!(output, "")?;
writeln!(output, "use std::fmt;")?;
writeln!(output, "use std::error::Error as StdError;")?;
writeln!(output, "")?;
writeln!(output, "/// CTP错误代码和消息,从error.xml转换而来")?;
writeln!(output, "#[derive(Debug, Clone, PartialEq, Eq)]")?;
writeln!(output, "pub enum CtpError {{")?;
Ok(())
}
fn write_enum_definition(output: &mut impl Write, errors: &[(&str, i32, &str)]) -> io::Result<()> {
writeln!(output, " /// {} ({})", errors[0].2, errors[0].1)?;
writeln!(output, " None,")?;
writeln!(output, "")?;
let mut current_range = 0;
for (id, code, prompt) in errors.iter().skip(1) {
let range = match *code {
1..=100 => 1,
101..=999 => 2,
1000..=1999 => 3,
2000..=2999 => 4,
3000..=3999 => 5,
_ => 6,
};
if range != current_range {
current_range = range;
match range {
1 => writeln!(output, " // 一般错误 (1-100)")?,
2 => writeln!(output, " // 灾备系统错误 (101-999)")?,
3 => writeln!(output, " // 转账系统错误 (1000-1999)")?,
4 => writeln!(output, " // 附加转账错误 (2000-2999)")?,
5 => writeln!(output, " // 外汇系统错误 (3000-3999)")?,
_ => writeln!(output, " // 其他错误 ({}+)", code)?,
}
}
writeln!(output, " /// {} ({})", prompt, code)?;
writeln!(output, " {},", to_rust_enum_name(id))?;
}
writeln!(output, "")?;
writeln!(output, " // 未知错误")?;
writeln!(output, " Unknown(i32),")?;
writeln!(output, "}}")?;
writeln!(output, "")?;
Ok(())
}
fn write_impl(output: &mut impl Write, errors: &[(&str, i32, &str)]) -> io::Result<()> {
writeln!(output, "impl CtpError {{")?;
writeln!(output, " /// 从错误码转换为CtpError枚举")?;
writeln!(output, " pub fn from_code(code: i32) -> Self {{")?;
writeln!(output, " match code {{")?;
for (id, code, _) in errors {
writeln!(
output,
" {} => CtpError::{},",
code,
to_rust_enum_name(id)
)?;
}
writeln!(
output,
" unknown_code => CtpError::Unknown(unknown_code),"
)?;
writeln!(output, " }}")?;
writeln!(output, " }}")?;
writeln!(output, "")?;
writeln!(output, " /// 获取错误码")?;
writeln!(output, " pub fn code(&self) -> i32 {{")?;
writeln!(output, " match self {{")?;
for (id, code, _) in errors {
writeln!(
output,
" CtpError::{} => {},",
to_rust_enum_name(id),
code
)?;
}
writeln!(output, " CtpError::Unknown(code) => *code,")?;
writeln!(output, " }}")?;
writeln!(output, " }}")?;
writeln!(output, "")?;
writeln!(output, " /// 获取错误消息")?;
writeln!(output, " pub fn message(&self) -> &'static str {{")?;
writeln!(output, " match self {{")?;
for (id, _, prompt) in errors {
writeln!(
output,
" CtpError::{} => \"{}\",",
to_rust_enum_name(id),
prompt
)?;
}
writeln!(
output,
" CtpError::Unknown(_) => \"CTP:未知错误\","
)?;
writeln!(output, " }}")?;
writeln!(output, " }}")?;
writeln!(output, "}}")?;
writeln!(output, "")?;
Ok(())
}
fn write_traits(output: &mut impl Write) -> io::Result<()> {
writeln!(output, "impl fmt::Display for CtpError {{")?;
writeln!(
output,
" fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {{"
)?;
writeln!(
output,
" write!(f, \"{{}} ({{}})\", self.message(), self.code())"
)?;
writeln!(output, " }}")?;
writeln!(output, "}}")?;
writeln!(output, "")?;
writeln!(output, "impl StdError for CtpError {{}}")?;
writeln!(output, "")?;
Ok(())
}
fn write_from_impls(output: &mut impl Write) -> io::Result<()> {
writeln!(output, "")?;
writeln!(output, "// 实现从i32转换为CtpError")?;
writeln!(output, "impl From<i32> for CtpError {{")?;
writeln!(output, " fn from(code: i32) -> Self {{")?;
writeln!(output, " CtpError::from_code(code)")?;
writeln!(output, " }}")?;
writeln!(output, "}}")?;
writeln!(output, "")?;
Ok(())
}
fn to_rust_enum_name(id: &str) -> String {
if id == "NONE" {
return "None".to_string();
}
let mut result = String::new();
let mut capitalize_next = true;
for c in id.chars() {
if c == '_' {
capitalize_next = true;
} else if capitalize_next {
result.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(c.to_ascii_lowercase());
}
}
result
}