use std::collections::BTreeMap;
use crate::{
DataType, Error, Float, Format, IoResult, Result, SequenceDecoder, SequenceReader, SimpleValue, Value,
codec::{Argument, Head, Major},
io::{HexReader, HexSliceReader, MyReader, SliceReader},
limits,
parse::Parser,
};
#[derive(Debug, Clone)]
pub struct DecodeOptions {
format: Format,
recursion_limit: u16,
length_limit: u64,
oom_mitigation: usize,
}
impl Default for DecodeOptions {
fn default() -> Self {
Self::new()
}
}
impl DecodeOptions {
#[must_use]
pub const fn new() -> Self {
Self {
format: Format::Binary,
recursion_limit: limits::RECURSION_LIMIT,
length_limit: limits::LENGTH_LIMIT,
oom_mitigation: limits::OOM_MITIGATION,
}
}
pub const fn format(mut self, format: Format) -> Self {
self.format = format;
self
}
pub const fn recursion_limit(mut self, limit: u16) -> Self {
self.recursion_limit = limit;
self
}
pub const fn length_limit(mut self, limit: u64) -> Self {
self.length_limit = limit;
self
}
pub const fn oom_mitigation(mut self, bytes: usize) -> Self {
self.oom_mitigation = bytes;
self
}
pub fn decode(&self, bytes: impl AsRef<[u8]>) -> Result<Value> {
let bytes = bytes.as_ref();
match self.format {
Format::Binary => {
let mut reader = SliceReader(bytes);
let value = self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)?;
if !reader.0.is_empty() {
return Err(Error::InvalidFormat);
}
Ok(value)
}
Format::Hex => {
let mut reader = HexSliceReader(bytes);
let value = self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)?;
if !reader.0.is_empty() {
return Err(Error::InvalidFormat);
}
Ok(value)
}
Format::Diagnostic => {
let mut parser = Parser::new(SliceReader(bytes), self.recursion_limit);
parser.parse_complete()
}
}
}
pub fn read_from(&self, reader: impl std::io::Read) -> IoResult<Value> {
match self.format {
Format::Binary => {
let mut reader = reader;
self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)
}
Format::Hex => {
let mut reader = HexReader(reader);
self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)
}
Format::Diagnostic => {
let mut parser = Parser::new(reader, self.recursion_limit);
parser.parse_stream_item()
}
}
}
pub fn sequence_decoder<'a, B: AsRef<[u8]> + ?Sized>(&self, input: &'a B) -> SequenceDecoder<'a> {
SequenceDecoder::with_options(self.clone(), input.as_ref())
}
pub fn sequence_reader<R: std::io::Read>(&self, reader: R) -> SequenceReader<R> {
SequenceReader::with_options(self.clone(), reader)
}
pub(crate) fn decode_one<R>(&self, reader: &mut R) -> std::result::Result<Value, R::Error>
where
R: MyReader,
R::Error: From<Error>,
{
self.do_read(reader, self.recursion_limit, self.oom_mitigation)
}
pub(crate) fn recursion_limit_value(&self) -> u16 {
self.recursion_limit
}
pub(crate) fn format_value(&self) -> Format {
self.format
}
fn do_read<R>(
&self,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value, R::Error>
where
R: MyReader,
R::Error: From<Error>,
{
let head = Head::read_from(reader)?;
let is_float = head.initial_byte.major() == Major::SimpleOrFloat
&& matches!(head.argument, Argument::U16(_) | Argument::U32(_) | Argument::U64(_));
if !is_float && !head.argument.is_deterministic() {
return Err(Error::NonDeterministic.into());
}
let this = match head.initial_byte.major() {
Major::Unsigned => Value::Unsigned(head.value()),
Major::Negative => Value::Negative(head.value()),
Major::ByteString => {
let len = head.value();
if len > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
Value::ByteString(reader.read_vec(len, oom_mitigation)?)
}
Major::TextString => {
let len = head.value();
if len > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let bytes = reader.read_vec(len, oom_mitigation)?;
let string = String::from_utf8(bytes).map_err(Error::from)?;
Value::TextString(string)
}
Major::Array => {
let value = head.value();
if value > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let request: usize = value.try_into().or(Err(Error::LengthTooLarge))?;
let granted = request.min(oom_mitigation / size_of::<Value>());
let oom_mitigation = oom_mitigation - granted * size_of::<Value>();
let mut vec = Vec::with_capacity(granted);
for _ in 0..value {
vec.push(self.do_read(reader, recursion_limit, oom_mitigation)?);
}
Value::Array(vec)
}
Major::Map => {
let value = head.value();
if value > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let mut map = BTreeMap::new();
let mut prev = None;
for _ in 0..value {
let key = self.do_read(reader, recursion_limit, oom_mitigation)?;
let value = self.do_read(reader, recursion_limit, oom_mitigation)?;
if let Some((prev_key, prev_value)) = prev.take() {
if prev_key >= key {
return Err(Error::NonDeterministic.into());
}
map.insert(prev_key, prev_value);
}
prev = Some((key, value));
}
if let Some((key, value)) = prev.take() {
map.insert(key, value);
}
Value::Map(map)
}
Major::Tag => {
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let tag_number = head.value();
let tag_content = Box::new(self.do_read(reader, recursion_limit, oom_mitigation)?);
let this = Value::Tag(tag_number, tag_content);
if this.data_type() == DataType::BigInt {
let bytes = this.as_bytes().unwrap();
let valid = bytes.len() >= 8 && bytes[0] != 0;
if !valid {
return Err(Error::NonDeterministic.into());
}
}
this
}
Major::SimpleOrFloat => match head.argument {
Argument::None => Value::SimpleValue(SimpleValue(head.initial_byte.info())),
Argument::U8(n) if n >= 32 => Value::SimpleValue(SimpleValue(n)),
Argument::U16(bits) => Value::Float(Float::from_bits_u16(bits)),
Argument::U32(bits) => Value::Float(Float::from_bits_u32(bits)?),
Argument::U64(bits) => Value::Float(Float::from_bits_u64(bits)?),
_ => return Err(Error::Malformed.into()),
},
};
Ok(this)
}
}