use std::io::ErrorKind;
use imap_next::{Interrupt, Io, State};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;
pub struct Stream<S> {
stream: S,
buf: Vec<u8>,
}
impl<S> Stream<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
buf: vec![0; 1024].into(),
}
}
pub fn into_inner(self) -> S {
self.stream
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
pub async fn next<F: State>(&mut self, mut state: F) -> Result<F::Event, Error<F::Error>> {
let event = loop {
let result = state.next();
let interrupt = match result {
Err(interrupt) => interrupt,
Ok(event) => break event,
};
let io = match interrupt {
Interrupt::Io(io) => io,
Interrupt::Error(err) => return Err(Error::State(err)),
};
match io {
Io::Output(ref bytes) => {
match self.stream.write_all(bytes).await {
Ok(()) => trace!("wrote {} bytes", bytes.len()),
Err(e) if e.kind() == ErrorKind::WriteZero => return Err(Error::Closed),
Err(e) => return Err(e.into()),
}
self.stream.flush().await?;
}
Io::NeedMoreInput => {
trace!("more input needed");
}
}
match self.stream.read(&mut self.buf).await? {
0 => return Err(Error::Closed),
n => {
trace!("read {n}/{} bytes", self.buf.len());
state.enqueue_input(&self.buf[..n]);
}
}
};
Ok(event)
}
}
#[derive(Debug, Error)]
pub enum Error<E> {
#[error("Stream was closed")]
Closed,
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
State(E),
}