use crate::types::{is_xml_whitespace, DeclaredEncoding, Encoding};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProbeResult {
pub encoding: Encoding,
pub bom_length: usize,
}
pub fn probe_encoding(data: &[u8]) -> ProbeResult {
if data.len() < 2 {
return ProbeResult {
encoding: Encoding::Unknown,
bom_length: 0,
};
}
if data.len() >= 4 {
if data[0] == 0xFF && data[1] == 0xFE && data[2] == 0x00 && data[3] == 0x00 {
return ProbeResult {
encoding: Encoding::Utf32Le,
bom_length: 4,
};
}
if data[0] == 0x00 && data[1] == 0x00 && data[2] == 0xFE && data[3] == 0xFF {
return ProbeResult {
encoding: Encoding::Utf32Be,
bom_length: 4,
};
}
}
if data.len() >= 3 {
if data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
return ProbeResult {
encoding: Encoding::Utf8,
bom_length: 3,
};
}
}
if data[0] == 0xFF && data[1] == 0xFE {
return ProbeResult {
encoding: Encoding::Utf16Le,
bom_length: 2,
};
}
if data[0] == 0xFE && data[1] == 0xFF {
return ProbeResult {
encoding: Encoding::Utf16Be,
bom_length: 2,
};
}
if data.len() >= 4 {
if data[0] == 0x00 && data[1] == 0x3C && data[2] == 0x00 && data[3] == 0x3F {
return ProbeResult {
encoding: Encoding::Utf16Be,
bom_length: 0,
};
}
if data[0] == 0x3C && data[1] == 0x00 && data[2] == 0x3F && data[3] == 0x00 {
return ProbeResult {
encoding: Encoding::Utf16Le,
bom_length: 0,
};
}
if data[0] == 0x00 && data[1] == 0x00 && data[2] == 0x00 && data[3] == 0x3C {
return ProbeResult {
encoding: Encoding::Utf32Be,
bom_length: 0,
};
}
if data[0] == 0x3C && data[1] == 0x00 && data[2] == 0x00 && data[3] == 0x00 {
return ProbeResult {
encoding: Encoding::Utf32Le,
bom_length: 0,
};
}
}
if let Some(enc) = extract_encoding_from_decl(data) {
return ProbeResult {
encoding: Encoding::Declared(enc),
bom_length: 0,
};
}
if data[0] == b'<' || data[0].is_ascii() {
return ProbeResult {
encoding: Encoding::Utf8,
bom_length: 0,
};
}
ProbeResult {
encoding: Encoding::Unknown,
bom_length: 0,
}
}
fn extract_encoding_from_decl(data: &[u8]) -> Option<DeclaredEncoding> {
if data.len() < 22 {
return None;
}
if !data.starts_with(b"<?xml") {
return None;
}
if data.len() <= 5 || !is_xml_whitespace(data[5]) {
return None;
}
let limit = data.len().min(256);
let search = &data[6..limit];
let enc_pos = find_subsequence(search, b"encoding")?;
let after_enc = enc_pos + 8;
if after_enc >= search.len() {
return None;
}
let mut pos = after_enc;
while pos < search.len() && is_xml_whitespace(search[pos]) {
pos += 1;
}
if pos >= search.len() || search[pos] != b'=' {
return None;
}
pos += 1; while pos < search.len() && is_xml_whitespace(search[pos]) {
pos += 1;
}
if pos >= search.len() {
return None;
}
let quote = search[pos];
if quote != b'"' && quote != b'\'' {
return None;
}
pos += 1;
let value_start = pos;
while pos < search.len() && search[pos] != quote {
pos += 1;
}
if pos >= search.len() {
return None;
}
let value = &search[value_start..pos];
DeclaredEncoding::new(value)
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|w| w == needle)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn utf8_bom() {
let data = b"\xEF\xBB\xBF<?xml version=\"1.0\"?>";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf8);
assert_eq!(result.bom_length, 3);
}
#[test]
fn utf16_le_bom() {
let data = b"\xFF\xFE<\x00?\x00x\x00m\x00l\x00";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf16Le);
assert_eq!(result.bom_length, 2);
}
#[test]
fn utf16_be_bom() {
let data = b"\xFE\xFF\x00<\x00?\x00x\x00m\x00l";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf16Be);
assert_eq!(result.bom_length, 2);
}
#[test]
fn utf32_le_bom() {
let data = b"\xFF\xFE\x00\x00<\x00\x00\x00";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf32Le);
assert_eq!(result.bom_length, 4);
}
#[test]
fn utf32_be_bom() {
let data = b"\x00\x00\xFE\xFF\x00\x00\x00<";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf32Be);
assert_eq!(result.bom_length, 4);
}
#[test]
fn utf16_be_no_bom() {
let data = b"\x00<\x00?\x00x\x00m\x00l";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf16Be);
assert_eq!(result.bom_length, 0);
}
#[test]
fn utf16_le_no_bom() {
let data = b"<\x00?\x00x\x00m\x00l\x00";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf16Le);
assert_eq!(result.bom_length, 0);
}
#[test]
fn encoding_declaration() {
let data = b"<?xml version=\"1.0\" encoding=\"ISO-8859-1\"?>";
let result = probe_encoding(data);
assert_eq!(result.bom_length, 0);
match result.encoding {
Encoding::Declared(enc) => {
assert_eq!(enc.as_str(), Some("ISO-8859-1"));
}
other => panic!("expected Declared, got {other:?}"),
}
}
#[test]
fn encoding_declaration_single_quotes() {
let data = b"<?xml version='1.0' encoding='Shift_JIS'?>";
let result = probe_encoding(data);
match result.encoding {
Encoding::Declared(enc) => {
assert_eq!(enc.as_str(), Some("Shift_JIS"));
}
other => panic!("expected Declared, got {other:?}"),
}
}
#[test]
fn no_encoding_declaration() {
let data = b"<?xml version=\"1.0\"?><root/>";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf8);
assert_eq!(result.bom_length, 0);
}
#[test]
fn plain_utf8_document() {
let data = b"<root>hello</root>";
let result = probe_encoding(data);
assert_eq!(result.encoding, Encoding::Utf8);
assert_eq!(result.bom_length, 0);
}
#[test]
fn empty_input() {
let result = probe_encoding(b"");
assert_eq!(result.encoding, Encoding::Unknown);
}
#[test]
fn single_byte() {
let result = probe_encoding(b"<");
assert_eq!(result.encoding, Encoding::Unknown);
}
#[test]
fn encoding_with_spaces_around_eq() {
let data = b"<?xml version = \"1.0\" encoding = \"windows-1252\" ?>";
let result = probe_encoding(data);
match result.encoding {
Encoding::Declared(enc) => {
assert_eq!(enc.as_str(), Some("windows-1252"));
}
other => panic!("expected Declared, got {other:?}"),
}
}
}