use crate::{
Encoding, Status,
encoding::DEFAULT_MAX_MESSAGE_SIZE,
frame::{
reader::{ReadState, poll_read_message},
writer::encode_payload,
},
};
use bytes::Bytes;
use futures_lite::{AsyncRead, AsyncWriteExt, Stream};
use std::{
future::poll_fn,
pin::Pin,
task::{Context, Poll},
};
use trillium::{Headers, Upgrade};
pub struct RequestStream<'a, T> {
reader: Pin<Box<dyn AsyncRead + Send + 'a>>,
state: ReadState,
decode: fn(&[u8]) -> Result<T, Status>,
encoding: Encoding,
max_message_size: usize,
}
impl<'a, T> RequestStream<'a, T> {
pub(crate) fn new(
reader: Pin<Box<dyn AsyncRead + Send + 'a>>,
decode: fn(&[u8]) -> Result<T, Status>,
encoding: Encoding,
) -> Self {
Self {
reader,
state: ReadState::new(),
decode,
encoding,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, Status>>> {
poll_read_message(
self.reader.as_mut(),
&mut self.state,
cx,
self.decode,
self.encoding,
self.max_message_size,
)
}
pub async fn recv(&mut self) -> Result<Option<T>, Status> {
poll_fn(|cx| match self.poll_message(cx) {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Ok(Some(t))),
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
Poll::Ready(None) => Poll::Ready(Ok(None)),
Poll::Pending => Poll::Pending,
})
.await
}
}
impl<T: 'static> Stream for RequestStream<'_, T> {
type Item = Result<T, Status>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_message(cx)
}
}
pub struct Channel<'a, Req, Resp> {
upgrade: &'a mut Upgrade,
response_trailers: &'a mut Headers,
state: ReadState,
decode: fn(&[u8]) -> Result<Req, Status>,
encode: fn(&Resp) -> Result<Bytes, Status>,
inbound_encoding: Encoding,
outbound_encoding: Encoding,
max_message_size: usize,
}
impl<'a, Req, Resp> Channel<'a, Req, Resp> {
pub(crate) fn new(
upgrade: &'a mut Upgrade,
response_trailers: &'a mut Headers,
decode: fn(&[u8]) -> Result<Req, Status>,
encode: fn(&Resp) -> Result<Bytes, Status>,
inbound_encoding: Encoding,
outbound_encoding: Encoding,
) -> Self {
Self {
upgrade,
response_trailers,
state: ReadState::new(),
decode,
encode,
inbound_encoding,
outbound_encoding,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
pub fn response_trailers_mut(&mut self) -> &mut Headers {
self.response_trailers
}
pub async fn recv(&mut self) -> Option<Result<Req, Status>> {
let upgrade = &mut *self.upgrade;
let state = &mut self.state;
let decode = self.decode;
let encoding = self.inbound_encoding;
let max = self.max_message_size;
poll_fn(|cx| poll_read_message(Pin::new(&mut *upgrade), state, cx, decode, encoding, max))
.await
}
pub async fn send(&mut self, value: Resp) -> Result<(), Status> {
let payload = (self.encode)(&value)?;
let frame = encode_payload(&payload, self.outbound_encoding)?;
self.upgrade
.write_all(&frame)
.await
.map_err(|e| Status::unavailable(format!("write error: {e}")))
}
}