use std::io::Cursor;
use std::io::{Read, Seek, SeekFrom};
use byteorder::{BigEndian, ReadBytesExt};
use super::tuple::{Decode, RawBytes, ToTupleBuffer};
use crate::Result;
pub fn skip_value(cur: &mut (impl Read + Seek)) -> Result<()> {
use rmp::Marker;
match rmp::decode::read_marker(cur)? {
Marker::FixPos(_) | Marker::FixNeg(_) | Marker::Null | Marker::True | Marker::False => {}
Marker::U8 | Marker::I8 => {
cur.seek(SeekFrom::Current(1))?;
}
Marker::U16 | Marker::I16 => {
cur.seek(SeekFrom::Current(2))?;
}
Marker::U32 | Marker::I32 | Marker::F32 => {
cur.seek(SeekFrom::Current(4))?;
}
Marker::U64 | Marker::I64 | Marker::F64 => {
cur.seek(SeekFrom::Current(8))?;
}
Marker::FixStr(len) => {
cur.seek(SeekFrom::Current(len as i64))?;
}
Marker::Str8 | Marker::Bin8 => {
let len = cur.read_u8()?;
cur.seek(SeekFrom::Current(len as i64))?;
}
Marker::Str16 | Marker::Bin16 => {
let len = cur.read_u16::<BigEndian>()?;
cur.seek(SeekFrom::Current(len as i64))?;
}
Marker::Str32 | Marker::Bin32 => {
let len = cur.read_u32::<BigEndian>()?;
cur.seek(SeekFrom::Current(len as i64))?;
}
Marker::FixArray(len) => {
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::Array16 => {
let len = cur.read_u16::<BigEndian>()?;
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::Array32 => {
let len = cur.read_u32::<BigEndian>()?;
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::FixMap(len) => {
let len = len * 2;
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::Map16 => {
let len = cur.read_u16::<BigEndian>()? * 2;
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::Map32 => {
let len = cur.read_u32::<BigEndian>()? * 2;
for _ in 0..len {
skip_value(cur)?;
}
}
Marker::FixExt1 => {
cur.seek(SeekFrom::Current(2))?;
}
Marker::FixExt2 => {
cur.seek(SeekFrom::Current(3))?;
}
Marker::FixExt4 => {
cur.seek(SeekFrom::Current(5))?;
}
Marker::FixExt8 => {
cur.seek(SeekFrom::Current(9))?;
}
Marker::FixExt16 => {
cur.seek(SeekFrom::Current(17))?;
}
Marker::Ext8 => {
let len = cur.read_u8()?;
cur.seek(SeekFrom::Current(len as i64 + 1))?;
}
Marker::Ext16 => {
let len = cur.read_u16::<BigEndian>()?;
cur.seek(SeekFrom::Current(len as i64 + 1))?;
}
Marker::Ext32 => {
let len = cur.read_u32::<BigEndian>()?;
cur.seek(SeekFrom::Current(len as i64 + 1))?;
}
Marker::Reserved => {
return Err(rmp::decode::ValueReadError::TypeMismatch(Marker::Reserved).into())
}
}
Ok(())
}
pub fn write_array<T>(w: &mut impl std::io::Write, arr: &[T]) -> Result<()>
where
T: ToTupleBuffer,
{
rmp::encode::write_array_len(w, arr.len() as _)?;
for elem in arr {
elem.write_tuple_data(w)?;
}
Ok(())
}
pub fn write_array_len(
w: &mut impl std::io::Write,
len: u32,
) -> std::result::Result<(), rmp::encode::ValueWriteError> {
rmp::encode::write_array_len(w, len)?;
Ok(())
}
#[derive(Debug)]
pub struct ArrayWriter<W> {
writer: W,
start: u64,
len: u32,
}
impl ArrayWriter<Cursor<Vec<u8>>> {
#[track_caller]
#[inline(always)]
pub fn from_vec(buf: Vec<u8>) -> Self {
Self::new(Cursor::new(buf)).expect("allocation error")
}
}
impl<W> ArrayWriter<W>
where
W: std::io::Write + std::io::Seek,
{
const MAX_ARRAY_HEADER_SIZE: i64 = 5;
#[inline(always)]
pub fn new(mut writer: W) -> Result<Self> {
let start = writer.stream_position()?;
writer.seek(SeekFrom::Current(Self::MAX_ARRAY_HEADER_SIZE))?;
Ok(Self {
start,
writer,
len: 0,
})
}
#[inline(always)]
pub fn start(&self) -> u64 {
self.start
}
#[inline(always)]
pub fn len(&self) -> u32 {
self.len
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn into_inner(self) -> W {
self.writer
}
#[inline(always)]
pub fn push<T>(&mut self, v: &T) -> Result<()>
where
T: ::serde::Serialize + ?Sized,
{
rmp_serde::encode::write(&mut self.writer, &v)?;
self.len += 1;
Ok(())
}
#[inline(always)]
pub fn push_tuple<T>(&mut self, v: &T) -> Result<()>
where
T: ToTupleBuffer + ?Sized,
{
v.write_tuple_data(&mut self.writer)?;
self.len += 1;
Ok(())
}
#[inline(always)]
pub fn push_raw(&mut self, v: &[u8]) -> Result<()> {
self.writer.write_all(v)?;
self.len += 1;
Ok(())
}
pub fn finish(mut self) -> Result<W> {
use rmp::encode::RmpWrite;
self.writer.seek(SeekFrom::Start(self.start))?;
self.writer.write_u8(rmp::Marker::Array32.to_u8())?;
self.writer
.write_data_u32(self.len)
.map_err(rmp::encode::ValueWriteError::from)?;
Ok(self.writer)
}
}
#[derive(Debug)]
pub struct ValueIter<'a> {
cursor: Cursor<&'a [u8]>,
}
impl<'a> ValueIter<'a> {
pub fn from_array(array: &'a [u8]) -> std::result::Result<Self, rmp::decode::ValueReadError> {
let mut cursor = Cursor::new(array);
rmp::decode::read_array_len(&mut cursor)?;
Ok(Self { cursor })
}
pub fn new(data: &'a [u8]) -> Self {
Self {
cursor: Cursor::new(data),
}
}
pub fn decode_next<T>(&mut self) -> Option<Result<T>>
where
T: Decode<'a>,
{
if self.cursor.position() as usize >= self.cursor.get_ref().len() {
return None;
}
let start = self.cursor.position() as usize;
if let Err(e) = skip_value(&mut self.cursor) {
return Some(Err(e));
}
let end = self.cursor.position() as usize;
debug_assert_ne!(start, end, "skip_value should've returned Err in this case");
let data = &self.cursor.get_ref()[start..end];
Some(T::decode(data))
}
pub fn into_inner(self) -> Cursor<&'a [u8]> {
self.cursor
}
}
impl<'a> Iterator for ValueIter<'a> {
type Item = &'a [u8];
#[inline(always)]
fn next(&mut self) -> Option<&'a [u8]> {
self.decode_next::<&RawBytes>()?.ok().map(|b| &**b)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn array_writer() {
let mut aw = ArrayWriter::from_vec(Vec::new());
aw.push_tuple(&(420, "foo")).unwrap();
aw.push(&"bar").unwrap();
aw.push_raw(b"\xa3baz").unwrap();
let data = aw.finish().unwrap().into_inner();
eprintln!("{:x?}", &data);
let res: ((u32, String), String, String) = rmp_serde::from_slice(&data).unwrap();
assert_eq!(
res,
((420, "foo".to_owned()), "bar".to_owned(), "baz".to_owned())
);
}
#[test]
fn value_iter() {
let mut iter = ValueIter::new(b"");
assert_eq!(iter.next(), None);
let mut iter = ValueIter::new(b"*");
assert_eq!(iter.next(), Some(&b"*"[..]));
assert_eq!(iter.next(), None);
let err = ValueIter::from_array(b"").unwrap_err();
assert_eq!(err.to_string(), "failed to read MessagePack marker");
let mut iter = ValueIter::from_array(b"\x99").unwrap();
assert_eq!(iter.next(), None);
let mut iter = ValueIter::from_array(b"\x99*").unwrap();
assert_eq!(iter.next(), Some(&b"*"[..]));
assert_eq!(iter.next(), None);
let data = b"\x93*\x93\xc0\xc2\xc3\xa3sup";
let mut iter = ValueIter::from_array(data).unwrap();
let v: u32 = iter.decode_next().unwrap().unwrap();
assert_eq!(v, 42);
let v: Vec<Option<bool>> = iter.decode_next().unwrap().unwrap();
assert_eq!(v, [None, Some(false), Some(true)]);
let v: String = iter.decode_next().unwrap().unwrap();
assert_eq!(v, "sup");
let mut iter = ValueIter::from_array(data).unwrap();
let v = iter.next().unwrap();
assert_eq!(v, b"*");
let v = iter.next().unwrap();
assert_eq!(v, b"\x93\xc0\xc2\xc3");
let v = iter.next().unwrap();
assert_eq!(v, b"\xa3sup");
let mut iter = ValueIter::new(data);
let v: (u32, Vec<Option<bool>>, String) =
rmp_serde::from_slice(iter.next().unwrap()).unwrap();
assert_eq!(v, (42, vec![None, Some(false), Some(true)], "sup".into()));
}
}