use super::{Decoder, Packet};
use crate::traits::{DecodePartial, Partial, PartialIterator, ResumePartial, SeekSentinel};
use winnow::error::ContextError;
macro_rules! label_to_context_error {
($input:expr, $label:expr) => {{
use winnow::stream::Stream as _;
let cp = $input.checkpoint();
winnow::error::AddContext::add_context(
winnow::error::ContextError::new(),
&$input,
&cp,
winnow::error::StrContext::Label($label),
)
}};
}
#[derive(Debug)]
pub enum DecodeIterError {
Eof,
NeedMore,
Malformed(ContextError),
}
impl std::error::Error for DecodeIterError {}
impl core::fmt::Display for DecodeIterError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DecodeIterError::Eof => write!(f, "EOF"),
DecodeIterError::NeedMore => write!(f, "need more bytes"),
DecodeIterError::Malformed(e) => write!(f, "malformed KLV: {e}"),
}
}
}
pub struct DecoderIter<'a, P, S> {
pub(super) dec: &'a mut Decoder<P, S>,
}
impl<P, T, S> Iterator for DecoderIter<'_, P, S>
where
S: winnow::stream::Stream,
P: Partial<Final = T> + Default,
Decoder<P, S>: PartialIterator<P>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.dec.partial.is_some() {
return self.dec.next_resume().ok();
}
self.dec.next_fresh().ok()
}
}
impl<'d, P, T, S> IntoIterator for &'d mut Decoder<P, S>
where
Decoder<P, S>: PartialIterator<P>,
S: winnow::stream::Stream,
P: Partial<Final = T> + Default,
for<'b> T: DecodePartial<S, Partial = P> + ResumePartial<S> + SeekSentinel<S>,
{
type Item = T;
type IntoIter = DecoderIter<'d, P, S>;
fn into_iter(self) -> Self::IntoIter {
DecoderIter { dec: self }
}
}
impl<'a, P, T> PartialIterator<P> for &mut Decoder<P, &'a [u8]>
where
P: Partial<Final = T> + Default,
Decoder<P, &'a [u8]>: PartialIterator<P>,
{
#[inline(always)]
fn next_resume(&mut self) -> Result<<P as Partial>::Final, DecodeIterError> {
(**self).next_resume()
}
#[inline(always)]
fn next_fresh(&mut self) -> Result<<P as Partial>::Final, DecodeIterError> {
(**self).next_fresh()
}
}
impl<P, T> PartialIterator<P> for Decoder<P, &[u8]>
where
P: Partial<Final = T> + Default,
for<'a> T:
DecodePartial<&'a [u8], Partial = P> + ResumePartial<&'a [u8]> + SeekSentinel<&'a [u8]>,
{
fn next_resume(&mut self) -> Result<T, DecodeIterError> {
if self.buf.is_empty() && self.partial.is_none() {
return Err(DecodeIterError::Eof);
}
let partial = self.partial.take().unwrap_or_default();
let mut cursor: &[u8] = self.buf.as_slice();
let before = cursor.len();
let packet = <T as ResumePartial<&[u8]>>::resume_partial(&mut cursor, partial);
let consumed = before - cursor.len();
let result = packet.map_err(|label| {
let ce = label_to_context_error!(&self.buf[..consumed], label);
DecodeIterError::Malformed(ce)
});
self.buf.drain(..consumed);
match result {
Ok(Packet::Ready(pkt)) => Ok(pkt),
Ok(Packet::NeedMore(p)) => {
self.partial = Some(p);
Err(DecodeIterError::NeedMore)
}
Err(e) => Err(e),
}
}
fn next_fresh(&mut self) -> Result<T, DecodeIterError> {
if self.buf.is_empty() {
return Err(DecodeIterError::Eof);
}
let mut cursor: &[u8] = self.buf.as_slice();
let before = cursor.len();
let body = match T::seek_sentinel(&mut cursor) {
Ok(b) => b,
Err(e) => return Err(DecodeIterError::Malformed(e)),
};
let mut body_cursor: &[u8] = body;
let packet = T::decode_partial(&mut body_cursor);
let consumed = before - cursor.len();
let result = packet.map_err(|label| {
let ce = label_to_context_error!(&self.buf[..consumed], label);
DecodeIterError::Malformed(ce)
});
self.buf.drain(..consumed);
match result {
Ok(Packet::Ready(pkt)) => Ok(pkt),
Ok(Packet::NeedMore(p)) => p.finalize().map_err(|label| {
let ce = label_to_context_error!(&self.buf[..consumed], label);
DecodeIterError::Malformed(ce)
}),
Err(e) => Err(e),
}
}
}