use std::error;
use std::fmt::{self, Display};
use std::io::{self, BufRead, BufReader, Read, Write};
use rmp::Marker;
use rmp_serde::decode::Error::{InvalidDataRead, InvalidMarkerRead};
use serde::{Deserialize, de, ser};
use crate::input::{self, Input, Ref};
use crate::transcode;
const DEPTH_LIMIT: usize = 1024;
pub(crate) fn input_matches(mut input: Ref) -> io::Result<bool> {
if !matches!(
input.prefix(1)?.first().copied().map(Marker::from_u8),
Some(
Marker::FixArray(_)
| Marker::Array16
| Marker::Array32
| Marker::FixMap(_)
| Marker::Map16
| Marker::Map32
)
) {
return Ok(false);
}
let result = match &mut input {
Ref::Slice(b) => match_input_buffer(b),
Ref::Reader(r) => match_input_reader(r),
};
match result {
Err(InvalidMarkerRead(err) | InvalidDataRead(err)) => Err(err),
Err(_) => Ok(false),
Ok(()) => Ok(true),
}
}
fn match_input_buffer(input: &[u8]) -> Result<(), rmp_serde::decode::Error> {
let mut de = rmp_serde::Deserializer::from_read_ref(input);
de.set_max_depth(DEPTH_LIMIT);
de::IgnoredAny::deserialize(&mut de).and(Ok(()))
}
fn match_input_reader<R: Read>(input: R) -> Result<(), rmp_serde::decode::Error> {
let mut de = rmp_serde::Deserializer::new(input);
de.set_max_depth(DEPTH_LIMIT);
de::IgnoredAny::deserialize(&mut de).and(Ok(()))
}
pub(crate) fn transcode<O>(input: input::Handle, mut output: O) -> crate::Result<()>
where
O: crate::Output,
{
match input.into() {
Input::Slice(b) => {
let mut rest = &*b;
while !rest.is_empty() {
let next;
(next, rest) = rest.split_at(next_value_size(rest, DEPTH_LIMIT)?);
let mut de = rmp_serde::Deserializer::from_read_ref(next);
de.set_max_depth(DEPTH_LIMIT);
output.transcode_from(&mut de)?;
}
}
Input::Reader(r) => {
let mut r = BufReader::new(r);
while !r.fill_buf()?.is_empty() {
let mut de = rmp_serde::Deserializer::new(&mut r);
de.set_max_depth(DEPTH_LIMIT);
output.transcode_from(&mut de)?;
}
}
}
Ok(())
}
pub(crate) struct Output<W: Write>(W);
impl<W: Write> Output<W> {
pub(crate) fn new(w: W) -> Output<W> {
Output(w)
}
}
impl<W: Write> crate::Output for Output<W> {
fn transcode_from<'de, D, E>(&mut self, de: D) -> crate::Result<()>
where
D: de::Deserializer<'de, Error = E>,
E: de::Error + Send + Sync + 'static,
{
let mut ser = rmp_serde::Serializer::new(&mut self.0);
transcode::transcode(&mut ser, de)?;
Ok(())
}
fn transcode_value<S>(&mut self, value: S) -> crate::Result<()>
where
S: ser::Serialize,
{
let mut ser = rmp_serde::Serializer::new(&mut self.0);
value.serialize(&mut ser)?;
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
fn next_value_size(input: &[u8], depth_limit: usize) -> Result<usize, ReadSizeError> {
if depth_limit == 0 {
return Err(ReadSizeError::DepthLimitExceeded);
}
if input.is_empty() {
return Ok(0);
}
let marker = rmp::Marker::from_u8(input[0]);
let total_size = match marker {
Marker::Reserved => return Err(ReadSizeError::InvalidMarker),
Marker::Null | Marker::True | Marker::False | Marker::FixPos(_) | Marker::FixNeg(_) => 1,
Marker::U8 | Marker::I8 => 2,
Marker::U16 | Marker::I16 => 3,
Marker::U32 | Marker::I32 | Marker::F32 => 5,
Marker::U64 | Marker::I64 | Marker::F64 => 9,
Marker::FixExt1 => 3,
Marker::FixExt2 => 4,
Marker::FixExt4 => 6,
Marker::FixExt8 => 10,
Marker::FixExt16 => 18,
Marker::Ext8 => 3 + try_read_length_8(input)? as usize,
Marker::Ext16 => 4 + try_read_length_16(input)? as usize,
Marker::Ext32 => 6 + try_read_length_32(input)? as usize,
Marker::FixStr(n) => 1 + n as usize,
Marker::Str8 | Marker::Bin8 => 2 + try_read_length_8(input)? as usize,
Marker::Str16 | Marker::Bin16 => 3 + try_read_length_16(input)? as usize,
Marker::Str32 | Marker::Bin32 => 5 + try_read_length_32(input)? as usize,
Marker::FixArray(count) => 1 + total_seq_size(&input[1..], count, depth_limit)?,
Marker::FixMap(pairs) => 1 + total_map_size(&input[1..], pairs, depth_limit)?,
Marker::Array16 => {
let count = try_read_length_16(input)?;
3 + total_seq_size(&input[3..], count, depth_limit)?
}
Marker::Map16 => {
let pairs = try_read_length_16(input)?;
3 + total_map_size(&input[3..], pairs, depth_limit)?
}
Marker::Array32 => {
let count = try_read_length_32(input)?;
5 + total_seq_size(&input[5..], count, depth_limit)?
}
Marker::Map32 => {
let pairs = try_read_length_32(input)?;
5 + total_map_size(&input[5..], pairs, depth_limit)?
}
};
if total_size <= input.len() {
Ok(total_size)
} else {
Err(ReadSizeError::Truncated)
}
}
fn total_seq_size<N>(input: &[u8], count: N, depth_limit: usize) -> Result<usize, ReadSizeError>
where
N: Into<u32>,
{
let count = count.into();
let mut total = 0;
let mut seq = input;
for _ in 0..count {
if seq.is_empty() {
return Err(ReadSizeError::Truncated);
}
let size = next_value_size(seq, depth_limit - 1)?;
total += size;
seq = &seq[size..];
}
Ok(total)
}
fn total_map_size<N>(input: &[u8], pairs: N, depth_limit: usize) -> Result<usize, ReadSizeError>
where
N: Into<u32>,
{
let pairs = pairs.into();
let first = total_seq_size(input, pairs, depth_limit)?;
let second = total_seq_size(&input[first..], pairs, depth_limit)?;
Ok(first + second)
}
fn try_read_length_8(input: &[u8]) -> Result<u8, ReadSizeError> {
try_read_length(input, u8::from_be_bytes)
}
fn try_read_length_16(input: &[u8]) -> Result<u16, ReadSizeError> {
try_read_length(input, u16::from_be_bytes)
}
fn try_read_length_32(input: &[u8]) -> Result<u32, ReadSizeError> {
try_read_length(input, u32::from_be_bytes)
}
fn try_read_length<const N: usize, T, F>(input: &[u8], convert: F) -> Result<T, ReadSizeError>
where
F: FnOnce([u8; N]) -> T,
{
Ok(convert(
input
.get(1..1 + N)
.ok_or(ReadSizeError::Truncated)?
.try_into()
.unwrap(),
))
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum ReadSizeError {
Truncated,
InvalidMarker,
DepthLimitExceeded,
}
impl error::Error for ReadSizeError {}
impl Display for ReadSizeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReadSizeError::Truncated => f.write_str("unexpected end of MessagePack input"),
ReadSizeError::InvalidMarker => f.write_str("invalid MessagePack marker in input"),
ReadSizeError::DepthLimitExceeded => f.write_str("depth limit exceeded"), }
}
}
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::hex;
const VALID_INPUTS: &[&[u8]] = &[
&[],
&hex!("c0"),
&hex!("c2"),
&hex!("c3"),
&hex!("2a"),
&hex!("f4"),
&hex!("cc 09"),
&hex!("cd 09 f9"),
&hex!("ce 09 f9 11 02"),
&hex!("cf 09 f9 11 02 9d 74 e3 5b"),
&hex!("d0 d8"),
&hex!("d1 d8 41"),
&hex!("d2 d8 41 56 c5"),
&hex!("d3 d8 41 56 c5 63 56 88 c0"),
&hex!("ca 64 7a 5a 6e"),
&hex!("cb 54 79 4b 50 45 67 4e 64"),
&hex!("a2 78 74"),
&hex!("d9 02 78 74"),
&hex!("da 00 02 78 74"),
&hex!("db 00 00 00 02 78 74"),
&hex!("a0"),
&hex!("d9 00"),
&hex!("da 00 00"),
&hex!("db 00 00 00 00"),
&hex!("c4 02 78 74"),
&hex!("c5 00 02 78 74"),
&hex!("c6 00 00 00 02 78 74"),
&hex!("c4 00"),
&hex!("c5 00 00"),
&hex!("c6 00 00 00 00"),
&hex!("92 a2 78 74 c3"),
&hex!("dc 00 02 a2 78 74 c3"),
&hex!("dd 00 00 00 02 a2 78 74 c3"),
&hex!("90"),
&hex!("dc 00 00"),
&hex!("dd 00 00 00 00"),
&hex!("82 a2 78 74 c3 a4 67 6f 6f 64 c3"),
&hex!("de 00 02 a2 78 74 c3 a4 67 6f 6f 64 c3"),
&hex!("df 00 00 00 02 a2 78 74 c3 a4 67 6f 6f 64 c3"),
&hex!("80"),
&hex!("de 00 00"),
&hex!("df 00 00 00 00"),
&hex!("d4 01 09"),
&hex!("d5 01 09 f9"),
&hex!("d6 01 09 f9 11 02"),
&hex!("d7 01 09 f9 11 02 9d 74 e3 5b"),
&hex!("d8 01 09 f9 11 02 9d 74 e3 5b d8 41 56 c5 63 56 88 c0"),
&hex!("c7 04 01 09 f9 11 02"),
&hex!("c8 00 04 01 09 f9 11 02"),
&hex!("c9 00 00 00 04 01 09 f9 11 02"),
];
#[test]
fn valid_input_size() {
for input in VALID_INPUTS {
assert_eq!(next_value_size(input, DEPTH_LIMIT), Ok(input.len()));
}
}
#[test]
fn truncated_valid_input_size() {
for input in VALID_INPUTS.iter().filter(|i| i.len() > 1) {
for len in 1..(input.len() - 1) {
assert_eq!(
next_value_size(&input[..len], DEPTH_LIMIT),
Err(ReadSizeError::Truncated)
);
}
}
}
#[test]
fn nonsensically_large_input_size() {
assert_eq!(
next_value_size(&hex!("db ff ff ff ff 78 74"), DEPTH_LIMIT),
Err(ReadSizeError::Truncated)
);
}
#[test]
fn invalid_marker_size() {
assert_eq!(
next_value_size(&hex!("c1"), DEPTH_LIMIT),
Err(ReadSizeError::InvalidMarker)
);
assert_eq!(
next_value_size(&hex!("92 a2 78 74 c1"), DEPTH_LIMIT),
Err(ReadSizeError::InvalidMarker)
);
assert_eq!(
next_value_size(&hex!("82 a2 78 74 c3 a4 67 6f 6f 64 c1"), DEPTH_LIMIT),
Err(ReadSizeError::InvalidMarker)
);
}
#[test]
fn size_skipping_invalid_suffixes() {
assert_eq!(next_value_size(&hex!("c3 c1"), DEPTH_LIMIT), Ok(1));
assert_eq!(next_value_size(&hex!("91 a2 78 74 c1"), DEPTH_LIMIT), Ok(4));
}
#[test]
fn excessively_deep_input_size() {
assert_eq!(next_value_size(&hex!("91 91 c3"), 3), Ok(3));
assert_eq!(
next_value_size(&hex!("91 91 91 c3"), 3),
Err(ReadSizeError::DepthLimitExceeded)
);
}
#[test]
#[cfg_attr(miri, ignore)] fn consistent_depth_limits() {
let mut input = [0x91_u8; DEPTH_LIMIT];
*input.last_mut().unwrap() = 0xc0;
std::thread::Builder::new()
.stack_size(8 * 1024 * 1024)
.spawn(move || {
match_input_buffer(&input[..]).expect("failed to detect buffer");
super::transcode(
input::Handle::from_slice(&input[..]),
super::Output::new(io::sink()),
)
.expect("failed to translate buffer");
match_input_reader(&input[..]).expect("failed to detect reader");
super::transcode(
input::Handle::from_reader(&input[..]),
super::Output::new(io::sink()),
)
.expect("failed to translate reader");
})
.unwrap()
.join()
.unwrap();
}
}