use std::fmt;
#[derive(Debug, Clone)]
pub struct EncodingError {
pub message: String,
}
impl EncodingError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl fmt::Display for EncodingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "encoding error: {}", self.message)
}
}
impl std::error::Error for EncodingError {}
#[must_use]
pub fn detect_encoding(bytes: &[u8]) -> (&'static str, usize) {
if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
("UTF-8", 3)
} else if bytes.len() >= 2 && bytes[0] == 0xFE && bytes[1] == 0xFF {
("UTF-16BE", 2)
} else if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
("UTF-16LE", 2)
} else {
("UTF-8", 0)
}
}
pub fn transcode(bytes: &[u8], encoding_name: &str) -> Result<String, EncodingError> {
let encoding = encoding_rs::Encoding::for_label(encoding_name.as_bytes())
.ok_or_else(|| EncodingError::new(format!("unsupported encoding: {encoding_name}")))?;
let (result, _used_encoding, had_errors) = encoding.decode(bytes);
if had_errors {
return Err(EncodingError::new(format!(
"malformed byte sequence for encoding {encoding_name}"
)));
}
Ok(result.into_owned())
}
fn extract_xml_decl_encoding(text: &str) -> Option<String> {
let decl_end = text.find("?>")?;
let decl = &text[..decl_end];
if !decl.starts_with("<?xml") {
return None;
}
let enc_pos = decl.find("encoding")?;
let after_enc = &decl[enc_pos + "encoding".len()..];
let after_enc = after_enc.trim_start();
let after_enc = after_enc.strip_prefix('=')?;
let after_enc = after_enc.trim_start();
let quote = after_enc.as_bytes().first().copied()?;
if quote != b'"' && quote != b'\'' {
return None;
}
let after_quote = &after_enc[1..];
let end = after_quote.find(quote as char)?;
Some(after_quote[..end].to_string())
}
pub fn decode_to_utf8(bytes: &[u8]) -> Result<String, EncodingError> {
let (bom_encoding, bom_skip) = detect_encoding(bytes);
let content_bytes = &bytes[bom_skip..];
if bom_encoding == "UTF-8" {
if let Ok(s) = std::str::from_utf8(content_bytes) {
if let Some(declared) = extract_xml_decl_encoding(s) {
let declared_upper = declared.to_ascii_uppercase();
if !is_utf8_label(&declared_upper) {
return transcode(content_bytes, &declared);
}
}
return Ok(s.to_string());
}
if let Some(declared) = extract_encoding_from_ascii_bytes(content_bytes) {
return transcode(content_bytes, &declared);
}
return Err(EncodingError::new("input is not valid UTF-8"));
}
let initial_text = transcode(content_bytes, bom_encoding)?;
if let Some(declared_encoding) = extract_xml_decl_encoding(&initial_text) {
let declared_upper = declared_encoding.to_ascii_uppercase();
let bom_upper = bom_encoding.to_ascii_uppercase();
let effectively_same = declared_upper == bom_upper
|| (is_utf8_label(&declared_upper) && is_utf8_label(&bom_upper))
|| (declared_upper == "UTF-16"
&& (bom_upper == "UTF-16BE" || bom_upper == "UTF-16LE"));
if !effectively_same {
return transcode(content_bytes, &declared_encoding);
}
}
Ok(initial_text)
}
fn extract_encoding_from_ascii_bytes(bytes: &[u8]) -> Option<String> {
let limit = bytes.len().min(200);
let scan = &bytes[..limit];
if !scan.starts_with(b"<?xml") {
return None;
}
let decl_end = scan.windows(2).position(|w| w == b"?>")?;
let decl = &scan[..decl_end];
let enc_needle = b"encoding";
let enc_pos = decl
.windows(enc_needle.len())
.position(|w| w == enc_needle)?;
let after_enc = &decl[enc_pos + enc_needle.len()..];
let after_enc = skip_ascii_whitespace(after_enc);
if after_enc.first() != Some(&b'=') {
return None;
}
let after_eq = skip_ascii_whitespace(&after_enc[1..]);
let quote = *after_eq.first()?;
if quote != b'"' && quote != b'\'' {
return None;
}
let after_quote = &after_eq[1..];
let end = after_quote.iter().position(|&b| b == quote)?;
let encoding_bytes = &after_quote[..end];
if encoding_bytes.iter().all(u8::is_ascii) {
Some(String::from_utf8_lossy(encoding_bytes).into_owned())
} else {
None
}
}
fn skip_ascii_whitespace(bytes: &[u8]) -> &[u8] {
let skip = bytes
.iter()
.take_while(|&&b| b == b' ' || b == b'\t' || b == b'\r' || b == b'\n')
.count();
&bytes[skip..]
}
fn is_utf8_label(label: &str) -> bool {
matches!(label, "UTF-8" | "UTF8")
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_detect_utf8_bom() {
let bytes = b"\xEF\xBB\xBF<?xml version=\"1.0\"?><root/>";
let (encoding, skip) = detect_encoding(bytes);
assert_eq!(encoding, "UTF-8");
assert_eq!(skip, 3);
}
#[test]
fn test_detect_utf16le_bom() {
let bytes = b"\xFF\xFE<\x00r\x00o\x00o\x00t\x00";
let (encoding, skip) = detect_encoding(bytes);
assert_eq!(encoding, "UTF-16LE");
assert_eq!(skip, 2);
}
#[test]
fn test_detect_utf16be_bom() {
let bytes = b"\xFE\xFF\x00<\x00r\x00o\x00o\x00t";
let (encoding, skip) = detect_encoding(bytes);
assert_eq!(encoding, "UTF-16BE");
assert_eq!(skip, 2);
}
#[test]
fn test_detect_no_bom() {
let bytes = b"<?xml version=\"1.0\"?><root/>";
let (encoding, skip) = detect_encoding(bytes);
assert_eq!(encoding, "UTF-8");
assert_eq!(skip, 0);
}
#[test]
fn test_detect_empty_input() {
let (encoding, skip) = detect_encoding(b"");
assert_eq!(encoding, "UTF-8");
assert_eq!(skip, 0);
}
#[test]
fn test_detect_single_byte() {
let (encoding, skip) = detect_encoding(b"\xEF");
assert_eq!(encoding, "UTF-8");
assert_eq!(skip, 0);
}
#[test]
fn test_decode_utf8() {
let bytes = b"<?xml version=\"1.0\"?><root>hello</root>";
let result = decode_to_utf8(bytes).unwrap();
assert_eq!(result, "<?xml version=\"1.0\"?><root>hello</root>");
}
#[test]
fn test_decode_utf8_with_bom() {
let bytes = b"\xEF\xBB\xBF<?xml version=\"1.0\"?><root/>";
let result = decode_to_utf8(bytes).unwrap();
assert_eq!(result, "<?xml version=\"1.0\"?><root/>");
}
#[test]
fn test_decode_latin1() {
let mut bytes = Vec::new();
bytes.extend_from_slice(b"<?xml version=\"1.0\" encoding=\"ISO-8859-1\"?>");
bytes.extend_from_slice(b"<root>caf\xE9</root>");
let result = decode_to_utf8(&bytes).unwrap();
assert!(result.contains("caf\u{00E9}"));
assert!(result.contains("<root>"));
}
#[test]
fn test_transcode_utf8() {
let result = transcode(b"hello world", "UTF-8").unwrap();
assert_eq!(result, "hello world");
}
#[test]
fn test_transcode_latin1() {
let result = transcode(b"caf\xE9", "ISO-8859-1").unwrap();
assert_eq!(result, "caf\u{00E9}");
}
#[test]
fn test_transcode_unknown_encoding() {
let result = transcode(b"hello", "UNKNOWN-ENCODING-42");
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("unsupported encoding"));
}
#[test]
fn test_extract_xml_decl_encoding_present() {
let text = "<?xml version=\"1.0\" encoding=\"ISO-8859-1\"?><root/>";
let enc = extract_xml_decl_encoding(text);
assert_eq!(enc, Some("ISO-8859-1".to_string()));
}
#[test]
fn test_extract_xml_decl_encoding_single_quotes() {
let text = "<?xml version='1.0' encoding='UTF-8'?><root/>";
let enc = extract_xml_decl_encoding(text);
assert_eq!(enc, Some("UTF-8".to_string()));
}
#[test]
fn test_extract_xml_decl_encoding_absent() {
let text = "<?xml version=\"1.0\"?><root/>";
let enc = extract_xml_decl_encoding(text);
assert_eq!(enc, None);
}
#[test]
fn test_extract_xml_decl_no_declaration() {
let text = "<root/>";
let enc = extract_xml_decl_encoding(text);
assert_eq!(enc, None);
}
#[test]
fn test_encoding_error_display() {
let err = EncodingError::new("test error");
assert_eq!(err.to_string(), "encoding error: test error");
}
#[test]
fn test_encoding_error_is_error_trait() {
let err = EncodingError::new("test");
let _: &dyn std::error::Error = &err;
}
#[test]
fn test_decode_invalid_utf8() {
let bytes: &[u8] = &[0x80, 0x81, 0x82];
let result = decode_to_utf8(bytes);
assert!(result.is_err());
}
#[test]
fn test_parse_bytes_utf8() {
use crate::tree::Document;
let result = Document::parse_bytes(b"<root/>");
assert!(result.is_ok());
let doc = result.unwrap();
let root = doc.root_element().unwrap();
assert_eq!(doc.node_name(root), Some("root"));
}
#[test]
fn test_parse_bytes_utf8_with_bom() {
use crate::tree::Document;
let mut bytes = vec![0xEF, 0xBB, 0xBF];
bytes.extend_from_slice(b"<root/>");
let doc = Document::parse_bytes(&bytes).unwrap();
let root = doc.root_element().unwrap();
assert_eq!(doc.node_name(root), Some("root"));
}
}