use std::error::Error as StdError;
use std::io;
use dangerous::{BytesReader, Error, Expected, Input, Invalid, ToRetryRequirement};
const VALID_MESSAGE: &[u8] = &[
0x01, 0x05, b'h', b'e', b'l', b'l', b'o', ];
const INVALID_MESSAGE: &[u8] = &[
0x01, 0x05, b'h', b'e', 0xff, b'l', b'o', ];
#[derive(Debug)]
struct Message<'a> {
body: &'a str,
}
fn main() {
let mut decoder = Decoder::new();
let message = decoder
.read_and_decode_message(&mut Stream::new(VALID_MESSAGE))
.unwrap();
println!("{}", message.body);
let err = decoder
.read_and_decode_message(&mut Stream::new(INVALID_MESSAGE))
.unwrap_err();
eprintln!("error reading message: {}", err);
}
pub struct Decoder {
buf: [u8; 256],
}
impl Decoder {
fn new() -> Self {
Self { buf: [0u8; 256] }
}
fn read_and_decode_message<'i, R>(
&'i mut self,
mut read: R,
) -> Result<Message<'i>, Box<dyn StdError + 'i>>
where
R: io::Read,
{
let mut written_cur = 0;
let mut expects_cur = 0;
loop {
written_cur += read.read(&mut self.buf[written_cur..])?;
if expects_cur > written_cur {
println!(
"not enough to decode, waiting for {} bytes",
expects_cur - written_cur
);
continue;
}
let input = dangerous::input(&self.buf[..written_cur]);
match input.read_all(decode_message) {
Err(err) => match Invalid::to_retry_requirement(&err) {
Some(req) => {
expects_cur += req.continue_after();
continue;
}
None => break,
},
Ok(_) => break,
}
}
dangerous::input(&self.buf[..written_cur])
.read_all(decode_message)
.map_err(Box::<Expected>::into)
}
}
fn decode_message<'i, E>(r: &mut BytesReader<'i, E>) -> Result<Message<'i>, E>
where
E: Error<'i>,
{
r.context("message", |r| {
r.context("version", |r| r.consume(0x01))?;
let body_len = r.context("body len", BytesReader::read_u8)?;
let body = r.context("body", |r| {
let body_input = r.take(body_len as usize)?;
body_input.to_dangerous_str()
})?;
Ok(Message { body })
})
}
struct Stream {
cur: usize,
src: &'static [u8],
}
impl Stream {
fn new(src: &'static [u8]) -> Self {
Self { src, cur: 0 }
}
}
impl io::Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
if self.cur == self.src.len() {
return Err(io::Error::from(io::ErrorKind::NotConnected));
}
buf[0] = self.src[self.cur];
self.cur += 1;
Ok(1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_size() {
assert!(core::mem::size_of::<Message<'_>>() < 128);
}
}