use core::cmp::min;
use core::fmt::Debug;
use core::mem::{replace, size_of};
use core::ptr::read_unaligned;
use core::slice::from_raw_parts;
use crate::{BinProto, BytesLike, Error, Fixed, Result, Sigint, SizeStack, Varint, _size_varint, EGRP};
const MSB: u8 = 0b1000_0000;
const DROP_MSB: u8 = 0b0111_1111;
pub struct InputStream<'buf> {
pub(crate) buf: &'buf [u8],
pub(crate) pos: usize,
pub(crate) limit: usize,
}
impl<'buf> InputStream<'buf> {
#[inline(always)]
pub fn new(buf: &'buf [u8]) -> Self {
Self {
buf,
pos: 0,
limit: buf.len(),
}
}
#[inline(always)]
pub fn len(&self) -> usize {
min(self.buf.len(), self.limit) - self.pos
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn limit(&mut self, limit: usize) -> Result<usize> {
if self.limit < limit {
return Err(Error::InvalidBytesLimit);
}
Ok(replace(&mut self.limit, min(self.pos + limit, self.buf.len())))
}
#[inline(always)]
pub fn unlimit(&mut self, limit: usize) {
debug_assert!(self.limit <= limit);
self.limit = min(self.buf.len(), limit);
}
pub fn skip(&mut self, tag: u32) -> Result<()> {
if tag >> 3 == 0 {
return crate::unknown_tag(0);
}
match (tag & 0b111) as u8 {
crate::VARINT => {
self._varint::<u64>()?;
}
crate::BYTES => {
self._bytes()?;
}
crate::FIX32 => {
self._fixed::<u32>()?;
}
crate::FIX64 => {
self._fixed::<u64>()?;
}
crate::SGRP => {
self.skip_until_tag((tag & !0b111) | EGRP as u32)?;
}
other => return crate::unknown_wire(other),
}
Ok(())
}
pub fn skip_until_tag(&mut self, next_tag: u32) -> Result<()> {
let tag = self._varint()?;
if tag == next_tag {
return Ok(());
}
self.skip(tag)
}
pub fn _varint<T: Varint>(&mut self) -> Result<T> {
let mut result: u64 = 0;
let mut shift = 0;
let mut success = false;
for b in unsafe { self.buf.get_unchecked(self.pos .. self.limit).iter() } {
let msb_dropped = b & DROP_MSB;
result |= (msb_dropped as u64) << shift;
shift += 7;
if b & MSB == 0 || shift > (9 * 7) {
success = b & MSB == 0;
break;
}
}
if success {
self.pos += (shift / 7) as usize;
Ok(T::from_u64(result))
} else {
Err(Error::UnexpectedEOF)
}
}
pub fn _sigint<T: Sigint>(&mut self) -> Result<T> {
Ok(T::decode(self._varint::<T::Varint>()?))
}
pub fn _fixed<T: Fixed>(&mut self) -> Result<T> {
let tlen = size_of::<T>();
if self.len() < tlen {
return Err(Error::UnexpectedEOF);
}
let out = unsafe { read_unaligned(self.buf.as_ptr().add(self.pos) as *const T) };
self.pos += tlen;
Ok(out)
}
pub fn _bytes(&mut self) -> Result<&'buf [u8]> {
let len: usize = self._varint()?;
if self.len() < len {
return Err(Error::UnexpectedEOF);
}
self.pos += len;
Ok(&self.buf[self.pos - len .. self.pos])
}
pub fn _string(&mut self) -> Result<&'buf str> {
Ok(core::str::from_utf8(self._bytes()?)?)
}
pub fn varint<T: Varint>(&mut self, field: &mut T) -> Result<()> {
*field = self._varint()?;
Ok(())
}
#[inline(always)]
pub fn sigint<T: Sigint + Default>(&mut self, field: &mut T) -> Result<()> {
*field = self._sigint()?;
Ok(())
}
#[inline(always)]
pub fn protoenum<T: From<u32>>(&mut self, field: &mut T) -> Result<()> {
*field = self._varint::<u32>()?.into();
Ok(())
}
#[inline(always)]
pub fn bool(&mut self, field: &mut bool) -> Result<()> {
*field = self._varint::<u64>()? > 0;
Ok(())
}
#[inline(always)]
pub fn fixed32<T: Default + Fixed>(&mut self, field: &mut T) -> Result<()> {
debug_assert_eq!(size_of::<T>(), 4);
*field = self._fixed()?;
Ok(())
}
#[inline(always)]
pub fn fixed64<T: Fixed>(&mut self, field: &mut T) -> Result<()> {
debug_assert_eq!(size_of::<T>(), 8);
*field = self._fixed()?;
Ok(())
}
pub fn bytes<'x, T: BytesLike<'buf>>(&mut self, field: &mut T) -> Result<()> {
field.set(self._bytes()?)?;
Ok(())
}
pub fn string<T: BytesLike<'buf>>(&mut self, field: &mut T) -> Result<()> {
field.set(self._bytes()?)?;
Ok(())
}
pub fn nested<P: BinProto<'buf>>(&mut self, p: &mut P) -> Result<()> {
self._field_nested(p)
}
pub fn group<P: BinProto<'buf>>(&mut self, p: &mut P) -> Result<()> {
self._field_group(0, p)
}
pub fn _field_nested(&mut self, proto: &mut dyn BinProto<'buf>) -> Result<()> {
let len = self._varint()?;
if len > self.len() {
return Err(Error::UnexpectedEOF);
}
let start = self.pos;
let olimit = self.limit(len)?;
while !self.is_empty() {
let tag = self._varint()?;
proto.merge_field(tag, self)?;
}
assert_eq!(self.pos, start + len);
self.unlimit(olimit);
Ok(())
}
pub fn _field_group(&mut self, _gtag: u32, proto: &mut dyn BinProto<'buf>) -> Result<()> {
while !self.is_empty() {
let tag = self._varint()?;
if tag & 7 == EGRP as _ {
break;
}
proto.merge_field(tag, self)?
}
Ok(())
}
}
#[derive(Default)]
pub struct OutputStream {
pub(crate) stack: SizeStack,
pub(crate) buf: Vec<u8>,
}
impl OutputStream {
pub fn new(vec: Vec<u8>) -> Self {
Self {
stack: SizeStack::default(),
buf: vec,
}
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn finish(self) -> Vec<u8> {
self.buf
}
pub(crate) fn _tag(&mut self, t: u32) {
self._varint(t)
}
pub(crate) fn _varint<V: Varint + Debug>(&mut self, v: V) {
let mut n = v.into_u64();
let len = _size_varint(n);
self.buf.reserve(len);
for _ in 0 .. len - 1 {
self.buf.push(MSB | (n as u8 & DROP_MSB));
n >>= 7;
}
self.buf.push(n as u8);
}
pub(crate) fn _bytes(&mut self, v: &[u8]) {
self.buf.extend_from_slice(v);
}
pub fn varint<V: Varint + Debug>(&mut self, _: u32, v: &V) {
self._varint(*v)
}
pub fn sigint<V: Sigint + Debug>(&mut self, _: u32, v: &V) {
self._varint(v.encode())
}
pub fn bool(&mut self, _: u32, b: &bool) {
self._varint(if *b { 1 } else { 0 });
}
pub fn protoenum<V: Clone + Copy + Into<u32>>(&mut self, _: u32, v: &V) {
self._varint((*v).into());
}
pub fn fixed<V: Fixed>(&mut self, _: u32, v: &V) {
let wire = v.to_wire();
let slice = unsafe { from_raw_parts(&wire as *const V::Wire as *const u8, size_of::<V::Wire>()) };
self.buf.extend_from_slice(slice);
}
pub fn fixed32<V: Fixed>(&mut self, _: u32, v: &V) {
self.fixed(0, v)
}
pub fn fixed64<V: Fixed>(&mut self, _: u32, v: &V) {
self.fixed(0, v)
}
pub fn bytes<'x, B: BytesLike<'x>>(&mut self, _: u32, b: &B) {
self._varint(b.buf().len());
self._bytes(b.buf());
}
pub fn string<'out, B: BytesLike<'out>>(&mut self, _: u32, b: &B) {
self._varint(b.buf().len());
self._bytes(b.buf());
}
#[inline(never)]
fn _nested(&mut self, len: usize, v: &dyn BinProto<'_>) {
self._varint(len);
v.encode(self)
}
pub fn nested<'buf, P: BinProto<'buf>>(&mut self, _: u32, v: &P) {
assert_eq!(v as *const P as *const u8, self.stack.top().0);
let len = self.stack.pop().1;
self._nested(len, v);
}
pub fn group<'buf, P: BinProto<'buf>>(&mut self, num: u32, v: &P) {
v.encode(self);
self._varint((num & !0b111) | EGRP as u32);
}
}