use std::{borrow::Cow, ops::ControlFlow};
use encoding_rs::{Decoder, DecoderResult, Encoding};
use crate::{loader::LoadError, LoadableYamlNode, Yaml};
pub type YAMLDecodingTrapFn = fn(
malformation_length: u8,
bytes_read_after_malformation: u8,
input_at_malformation: &[u8],
output: &mut String,
) -> ControlFlow<Cow<'static, str>>;
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum YAMLDecodingTrap {
Ignore,
Strict,
Replace,
Call(YAMLDecodingTrapFn),
}
pub struct YamlDecoder<T: std::io::Read> {
source: T,
trap: YAMLDecodingTrap,
}
impl<T: std::io::Read> YamlDecoder<T> {
pub fn read(source: T) -> YamlDecoder<T> {
YamlDecoder {
source,
trap: YAMLDecodingTrap::Strict,
}
}
pub fn encoding_trap(&mut self, trap: YAMLDecodingTrap) -> &mut Self {
self.trap = trap;
self
}
pub fn decode(&mut self) -> Result<Vec<Yaml>, LoadError> {
let mut buffer = Vec::new();
self.source.read_to_end(&mut buffer)?;
let (encoding, _) =
Encoding::for_bom(&buffer).unwrap_or_else(|| (detect_utf16_endianness(&buffer), 2));
let mut decoder = encoding.new_decoder();
let mut output = String::new();
decode_loop(&buffer, &mut output, &mut decoder, self.trap)?;
Yaml::load_from_str(&output).map_err(LoadError::Scan)
}
}
fn decode_loop(
input: &[u8],
output: &mut String,
decoder: &mut Decoder,
trap: YAMLDecodingTrap,
) -> Result<(), LoadError> {
use crate::loader::LoadError;
output.reserve(input.len());
let mut total_bytes_read = 0;
loop {
match decoder.decode_to_string_without_replacement(&input[total_bytes_read..], output, true)
{
(DecoderResult::InputEmpty, _) => break Ok(()),
(DecoderResult::OutputFull, bytes_read) => {
total_bytes_read += bytes_read;
output.reserve(input.len() / 10);
}
(DecoderResult::Malformed(malformed_len, bytes_after_malformed), bytes_read) => {
total_bytes_read += bytes_read;
match trap {
YAMLDecodingTrap::Ignore => {}
YAMLDecodingTrap::Replace => {
output.push('\u{FFFD}');
}
YAMLDecodingTrap::Strict => {
let malformed_len = malformed_len as usize;
let bytes_after_malformed = bytes_after_malformed as usize;
let byte_idx = total_bytes_read - (malformed_len + bytes_after_malformed);
let malformed_sequence = &input[byte_idx..byte_idx + malformed_len];
break Err(LoadError::Decode(Cow::Owned(format!(
"Invalid character sequence at {byte_idx}: {malformed_sequence:?}",
))));
}
YAMLDecodingTrap::Call(callback) => {
let byte_idx =
total_bytes_read - ((malformed_len + bytes_after_malformed) as usize);
let malformed_sequence =
&input[byte_idx..byte_idx + malformed_len as usize];
if let ControlFlow::Break(error) = callback(
malformed_len,
bytes_after_malformed,
&input[byte_idx..],
output,
) {
if error.is_empty() {
break Err(LoadError::Decode(Cow::Owned(format!(
"Invalid character sequence at {byte_idx}: {malformed_sequence:?}",
))));
}
break Err(LoadError::Decode(error));
}
}
}
}
}
}
}
fn detect_utf16_endianness(b: &[u8]) -> &'static Encoding {
if b.len() > 1 && (b[0] != b[1]) {
if b[0] == 0 {
return encoding_rs::UTF_16BE;
} else if b[1] == 0 {
return encoding_rs::UTF_16LE;
}
}
encoding_rs::UTF_8
}
#[cfg(test)]
mod test {
use crate::Scalar;
use super::{YAMLDecodingTrap, Yaml, YamlDecoder};
#[test]
fn test_read_bom() {
let s = b"\xef\xbb\xbf---
a: 1
b: 2.2
c: [1, 2]
";
let mut decoder = YamlDecoder::read(s as &[u8]);
let out = decoder.decode().unwrap();
let doc = &out[0];
assert_eq!(doc["a"].as_integer().unwrap(), 1i64);
assert!((doc["b"].as_floating_point().unwrap() - 2.2f64).abs() <= f64::EPSILON);
assert_eq!(doc["c"][1].as_integer().unwrap(), 2i64);
assert!(!doc.contains_mapping_key("d"));
}
#[test]
fn test_read_utf16le() {
let s = b"\xff\xfe-\x00-\x00-\x00
\x00a\x00:\x00 \x001\x00
\x00b\x00:\x00 \x002\x00.\x002\x00
\x00c\x00:\x00 \x00[\x001\x00,\x00 \x002\x00]\x00
\x00";
let mut decoder = YamlDecoder::read(s as &[u8]);
let out = decoder.decode().unwrap();
let doc = &out[0];
println!("GOT: {doc:?}");
assert_eq!(doc["a"].as_integer().unwrap(), 1i64);
assert!((doc["b"].as_floating_point().unwrap() - 2.2f64) <= f64::EPSILON);
assert_eq!(doc["c"][1].as_integer().unwrap(), 2i64);
assert!(!doc.contains_mapping_key("d"));
}
#[test]
fn test_read_utf16be() {
let s = b"\xfe\xff\x00-\x00-\x00-\x00
\x00a\x00:\x00 \x001\x00
\x00b\x00:\x00 \x002\x00.\x002\x00
\x00c\x00:\x00 \x00[\x001\x00,\x00 \x002\x00]\x00
";
let mut decoder = YamlDecoder::read(s as &[u8]);
let out = decoder.decode().unwrap();
let doc = &out[0];
println!("GOT: {doc:?}");
assert_eq!(doc["a"].as_integer().unwrap(), 1i64);
assert!((doc["b"].as_floating_point().unwrap() - 2.2f64).abs() <= f64::EPSILON);
assert_eq!(doc["c"][1].as_integer().unwrap(), 2i64);
assert!(!doc.contains_mapping_key("d"));
}
#[test]
fn test_read_utf16le_nobom() {
let s = b"-\x00-\x00-\x00
\x00a\x00:\x00 \x001\x00
\x00b\x00:\x00 \x002\x00.\x002\x00
\x00c\x00:\x00 \x00[\x001\x00,\x00 \x002\x00]\x00
\x00";
let mut decoder = YamlDecoder::read(s as &[u8]);
let out = decoder.decode().unwrap();
let doc = &out[0];
println!("GOT: {doc:?}");
assert_eq!(doc["a"].as_integer().unwrap(), 1i64);
assert!((doc["b"].as_floating_point().unwrap() - 2.2f64).abs() <= f64::EPSILON);
assert_eq!(doc["c"][1].as_integer().unwrap(), 2i64);
assert!(!doc.contains_mapping_key("d"));
}
#[test]
fn test_read_trap() {
let s = b"---
a\xa9: 1
b: 2.2
c: [1, 2]
";
let mut decoder = YamlDecoder::read(s as &[u8]);
let out = decoder
.encoding_trap(YAMLDecodingTrap::Ignore)
.decode()
.unwrap();
let doc = &out[0];
println!("GOT: {doc:?}");
assert_eq!(doc["a"].as_integer().unwrap(), 1i64);
assert!((doc["b"].as_floating_point().unwrap() - 2.2f64).abs() <= f64::EPSILON);
assert_eq!(doc["c"][1].as_integer().unwrap(), 2i64);
assert!(!doc.contains_mapping_key("d"));
}
#[test]
fn test_or() {
assert_eq!(
Yaml::Value(Scalar::Null).or(Yaml::Value(Scalar::Integer(3))),
Yaml::Value(Scalar::Integer(3))
);
assert_eq!(
Yaml::Value(Scalar::Integer(3)).or(Yaml::Value(Scalar::Integer(7))),
Yaml::Value(Scalar::Integer(3))
);
}
}