use std::{borrow::Cow, collections::BTreeMap};
use crate::{
Error, Float, Format, IoResult, Result, SequenceDecoder, SequenceReader, SimpleValue, Strictness, Value,
codec::{Argument, Head, HeadOrStop, Major},
io::{HexReader, HexSliceReader, MyReader, SliceReader},
limits,
parse::Parser,
tag::{NEG_BIG_INT, POS_BIG_INT},
util::{trim_leading_zeros, u64_from_slice},
};
#[derive(Debug, Clone)]
pub struct DecodeOptions {
pub(crate) format: Format,
pub(crate) recursion_limit: u16,
pub(crate) length_limit: u64,
pub(crate) oom_mitigation: usize,
pub(crate) strictness: Strictness,
}
impl Default for DecodeOptions {
fn default() -> Self {
Self::new()
}
}
impl DecodeOptions {
#[must_use]
pub const fn new() -> Self {
Self {
format: Format::Binary,
recursion_limit: limits::RECURSION_LIMIT,
length_limit: limits::LENGTH_LIMIT,
oom_mitigation: limits::OOM_MITIGATION,
strictness: Strictness::STRICT,
}
}
pub const fn format(mut self, format: Format) -> Self {
self.format = format;
self
}
pub const fn recursion_limit(mut self, limit: u16) -> Self {
self.recursion_limit = limit;
self
}
pub const fn length_limit(mut self, limit: u64) -> Self {
self.length_limit = limit;
self
}
pub const fn oom_mitigation(mut self, bytes: usize) -> Self {
self.oom_mitigation = bytes;
self
}
pub const fn strictness(mut self, strictness: Strictness) -> Self {
self.strictness = strictness;
self
}
pub fn decode<'a, T>(&self, bytes: &'a T) -> Result<Value<'a>>
where
T: AsRef<[u8]> + ?Sized,
{
let bytes = bytes.as_ref();
match self.format {
Format::Binary => {
let mut reader = SliceReader(bytes);
let value = self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)?;
if !reader.0.is_empty() {
return Err(Error::InvalidFormat);
}
Ok(value)
}
Format::Hex => {
let mut reader = HexSliceReader(bytes);
let value = self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)?;
if !reader.0.is_empty() {
return Err(Error::InvalidFormat);
}
Ok(value)
}
Format::Diagnostic => {
let mut parser = Parser::new(SliceReader(bytes), self.recursion_limit, self.strictness);
parser.parse_complete()
}
}
}
pub fn decode_owned<'a>(&self, bytes: impl AsRef<[u8]>) -> Result<Value<'a>> {
let mut bytes = bytes.as_ref();
match self.format {
Format::Binary | Format::Hex => {
let value = self.read_from(&mut bytes).map_err(|err| match err {
crate::IoError::Io(_io_error) => unreachable!(),
crate::IoError::Data(error) => error,
})?;
if bytes.is_empty() {
Ok(value)
} else {
Err(Error::InvalidFormat)
}
}
Format::Diagnostic => {
let mut parser = Parser::new(SliceReader(bytes), self.recursion_limit, self.strictness);
parser.parse_complete()
}
}
}
pub fn read_from<'a>(&self, reader: impl std::io::Read) -> IoResult<Value<'a>> {
match self.format {
Format::Binary => {
let mut reader = reader;
self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)
}
Format::Hex => {
let mut reader = HexReader(reader);
self.do_read(&mut reader, self.recursion_limit, self.oom_mitigation)
}
Format::Diagnostic => {
let mut parser = Parser::new(reader, self.recursion_limit, self.strictness);
parser.parse_stream_item()
}
}
}
pub fn sequence_decoder<'a, T>(&self, input: &'a T) -> SequenceDecoder<'a>
where
T: AsRef<[u8]> + ?Sized,
{
SequenceDecoder::with_options(self.clone(), input.as_ref())
}
pub fn sequence_reader<R: std::io::Read>(&self, reader: R) -> SequenceReader<R> {
SequenceReader::with_options(self.clone(), reader)
}
pub(crate) fn decode_one<'a, R>(&self, reader: &mut R) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
self.do_read(reader, self.recursion_limit, self.oom_mitigation)
}
fn do_read<'a, R>(
&self,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
match self.read_value_or_break(reader, recursion_limit, oom_mitigation)? {
Some(value) => Ok(value),
None => Err(Error::Malformed.into()),
}
}
fn read_value_or_break<'a, R>(
&self,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Option<Value<'a>>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
match HeadOrStop::read_from(reader)? {
HeadOrStop::Definite(head) => self
.process_head(head, reader, recursion_limit, oom_mitigation)
.map(Some),
HeadOrStop::Indefinite(major) => {
if self.strictness.allow_indefinite_length {
self.process_indefinite(major, reader, recursion_limit, oom_mitigation)
.map(Some)
} else {
Err(Error::NonDeterministic.into())
}
}
HeadOrStop::Break => Ok(None),
}
}
fn process_head<'a, R>(
&self,
head: Head,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
let is_float = head.initial_byte.major() == Major::SimpleOrFloat
&& matches!(head.argument, Argument::U16(_) | Argument::U32(_) | Argument::U64(_));
if !is_float && !head.argument.is_deterministic() && !self.strictness.allow_non_shortest_integers {
return Err(Error::NonDeterministic.into());
}
let this = match head.initial_byte.major() {
Major::Unsigned => Value::Unsigned(head.value()),
Major::Negative => Value::Negative(head.value()),
Major::ByteString => {
let len = head.value();
if len > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
Value::ByteString(reader.read_cow(len, oom_mitigation)?)
}
Major::TextString => {
let len = head.value();
if len > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let text = match reader.read_cow(len, oom_mitigation)? {
Cow::Borrowed(bytes) => Cow::Borrowed(std::str::from_utf8(bytes).map_err(Error::from)?),
Cow::Owned(bytes) => Cow::Owned(String::from_utf8(bytes).map_err(Error::from)?),
};
Value::TextString(text)
}
Major::Array => {
let value = head.value();
if value > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let request: usize = value.try_into().or(Err(Error::LengthTooLarge))?;
let granted = request.min(oom_mitigation / size_of::<Value>());
let oom_mitigation = oom_mitigation - granted * size_of::<Value>();
let mut vec = Vec::with_capacity(granted);
for _ in 0..value {
vec.push(self.do_read(reader, recursion_limit, oom_mitigation)?);
}
Value::Array(vec)
}
Major::Map => {
let value = head.value();
if value > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let mut map = BTreeMap::new();
for _ in 0..value {
let key = self.do_read(reader, recursion_limit, oom_mitigation)?;
let val = self.do_read(reader, recursion_limit, oom_mitigation)?;
self.map_insert(&mut map, key, val)?;
}
Value::Map(map)
}
Major::Tag => {
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let tag_number = head.value();
let tag_content = self.do_read(reader, recursion_limit, oom_mitigation)?;
match tag_content {
Value::ByteString(bytes) if matches!(tag_number, POS_BIG_INT | NEG_BIG_INT) => {
let canonical = bytes.len() > 8 && bytes[0] != 0;
if canonical {
Value::Tag(tag_number, Box::new(Value::ByteString(bytes)))
} else if self.strictness.allow_oversized_bigints {
normalize_bigint(tag_number, bytes)
} else {
return Err(Error::NonDeterministic.into());
}
}
other => Value::Tag(tag_number, Box::new(other)),
}
}
Major::SimpleOrFloat => match head.argument {
Argument::None => Value::SimpleValue(SimpleValue(head.initial_byte.info())),
Argument::U8(n) if n >= 32 => Value::SimpleValue(SimpleValue(n)),
Argument::U16(bits) => Value::Float(Float::from_bits_u16(bits)),
Argument::U32(bits) => self.checked_float(Float::from_bits_u32(bits))?,
Argument::U64(bits) => self.checked_float(Float::from_bits_u64(bits))?,
_ => return Err(Error::Malformed.into()),
},
};
Ok(this)
}
fn checked_float<'a>(&self, float: Float) -> Result<Value<'a>> {
if float.is_deterministic() {
Ok(Value::Float(float))
} else if self.strictness.allow_non_shortest_floats {
Ok(Value::Float(float.shortest()))
} else {
Err(Error::NonDeterministic)
}
}
fn map_insert<'a>(&self, map: &mut BTreeMap<Value<'a>, Value<'a>>, key: Value<'a>, val: Value<'a>) -> Result<()> {
if !self.strictness.allow_unsorted_map_keys
&& let Some(last) = map.last_entry()
&& *last.key() >= key
{
Err(Error::NonDeterministic)
} else if map.insert(key, val).is_some() && !self.strictness.allow_duplicate_map_keys {
Err(Error::NonDeterministic)
} else {
Ok(())
}
}
fn process_indefinite<'a, R>(
&self,
major: Major,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
match major {
Major::ByteString => self.read_indefinite_bytes(reader, oom_mitigation),
Major::TextString => self.read_indefinite_text(reader, oom_mitigation),
Major::Array => self.read_indefinite_array(reader, recursion_limit, oom_mitigation),
Major::Map => self.read_indefinite_map(reader, recursion_limit, oom_mitigation),
_ => unreachable!("process_indefinite: invalid major"),
}
}
fn read_indefinite_bytes<'a, R>(
&self,
reader: &mut R,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
let mut buf = Vec::new();
let mut total: u64 = 0;
loop {
match HeadOrStop::read_from(reader)? {
HeadOrStop::Break => break,
HeadOrStop::Definite(head) if head.initial_byte.major() == Major::ByteString => {
if !head.argument.is_deterministic() && !self.strictness.allow_non_shortest_integers {
return Err(Error::NonDeterministic.into());
}
let chunk_len = head.value();
total = total.checked_add(chunk_len).ok_or(Error::LengthTooLarge)?;
if total > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let chunk = reader.read_cow(chunk_len, oom_mitigation)?;
buf.extend_from_slice(&chunk);
}
_ => return Err(Error::Malformed.into()),
}
}
Ok(Value::ByteString(Cow::Owned(buf)))
}
fn read_indefinite_text<'a, R>(
&self,
reader: &mut R,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
let mut buf = String::new();
let mut total: u64 = 0;
loop {
match HeadOrStop::read_from(reader)? {
HeadOrStop::Break => break,
HeadOrStop::Definite(head) if head.initial_byte.major() == Major::TextString => {
if !head.argument.is_deterministic() && !self.strictness.allow_non_shortest_integers {
return Err(Error::NonDeterministic.into());
}
let chunk_len = head.value();
total = total.checked_add(chunk_len).ok_or(Error::LengthTooLarge)?;
if total > self.length_limit {
return Err(Error::LengthTooLarge.into());
}
let chunk = reader.read_cow(chunk_len, oom_mitigation)?;
buf.push_str(std::str::from_utf8(&chunk).map_err(Error::from)?);
}
_ => return Err(Error::Malformed.into()),
}
}
Ok(Value::TextString(Cow::Owned(buf)))
}
fn read_indefinite_array<'a, R>(
&self,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let mut vec = Vec::new();
for _ in 0..self.length_limit {
match self.read_value_or_break(reader, recursion_limit, oom_mitigation)? {
Some(item) => vec.push(item),
None => return Ok(Value::Array(vec)),
};
}
match HeadOrStop::read_from(reader)? {
HeadOrStop::Definite(_) => Err(Error::LengthTooLarge.into()),
HeadOrStop::Indefinite(_) => Err(Error::Malformed.into()),
HeadOrStop::Break => Ok(Value::Array(vec)),
}
}
fn read_indefinite_map<'a, R>(
&self,
reader: &mut R,
recursion_limit: u16,
oom_mitigation: usize,
) -> std::result::Result<Value<'a>, R::Error>
where
R: MyReader<'a>,
R::Error: From<Error>,
{
let Some(recursion_limit) = recursion_limit.checked_sub(1) else {
return Err(Error::NestingTooDeep.into());
};
let mut map = BTreeMap::new();
for _ in 0..self.length_limit {
match self.read_value_or_break(reader, recursion_limit, oom_mitigation)? {
Some(key) => {
let value = self.do_read(reader, recursion_limit, oom_mitigation)?;
self.map_insert(&mut map, key, value)?;
}
None => return Ok(Value::Map(map)),
};
}
match HeadOrStop::read_from(reader)? {
HeadOrStop::Definite(_) => Err(Error::LengthTooLarge.into()),
HeadOrStop::Indefinite(_) => Err(Error::Malformed.into()),
HeadOrStop::Break => Ok(Value::Map(map)),
}
}
}
fn normalize_bigint<'a>(tag_number: u64, bytes: Cow<'a, [u8]>) -> Value<'a> {
fn integer<'b>(tag_number: u64, n: u64) -> Value<'b> {
match tag_number {
POS_BIG_INT => Value::Unsigned(n),
NEG_BIG_INT => Value::Negative(n),
_other => unreachable!("normalize_bigint: invalid tag"),
}
}
match bytes {
Cow::Borrowed(bytes) => {
let trimmed = trim_leading_zeros(bytes);
if let Ok(n) = u64_from_slice(trimmed) {
integer(tag_number, n)
} else {
let bytes = trimmed.into();
Value::Tag(tag_number, Box::new(Value::ByteString(bytes)))
}
}
Cow::Owned(bytes) => {
let trimmed = trim_leading_zeros(&bytes);
if let Ok(n) = u64_from_slice(trimmed) {
integer(tag_number, n)
} else {
let bytes = if trimmed.len() == bytes.len() {
bytes.into()
} else {
trimmed.to_vec().into()
};
Value::Tag(tag_number, Box::new(Value::ByteString(bytes)))
}
}
}
}