use std::pin::Pin;
use std::task::{Context, Poll};
use anyhow::Result;
use bytes::Bytes;
use futures_core::Stream;
use serde::Serialize;
use tokio::sync::mpsc;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameType {
Request = 0,
Response = 1,
StreamItem = 2,
StreamEnd = 3,
StreamError = 4,
}
impl FrameType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Request),
1 => Some(Self::Response),
2 => Some(Self::StreamItem),
3 => Some(Self::StreamEnd),
4 => Some(Self::StreamError),
_ => None,
}
}
}
pub struct RpcStream<T> {
rx: mpsc::Receiver<Result<T, String>>,
tx: Option<StreamSender<T>>,
_marker: std::marker::PhantomData<T>,
}
#[derive(Clone)]
pub struct StreamSender<T> {
inner: mpsc::Sender<Bytes>,
_marker: std::marker::PhantomData<T>,
}
impl<T: Serialize> StreamSender<T> {
pub async fn send(&self, item: T) -> Result<()> {
let bytes = crate::postcard::to_allocvec(&item)?;
self.inner
.send(Bytes::from(bytes))
.await
.map_err(|_| anyhow::anyhow!("Stream closed"))
}
pub async fn send_error(&self, error: String) -> Result<()> {
let bytes = crate::postcard::to_allocvec(&Err::<(), _>(error))?;
self.inner
.send(Bytes::from(bytes))
.await
.map_err(|_| anyhow::anyhow!("Stream closed"))
}
}
impl<T> RpcStream<T> {
pub fn new(rx: mpsc::Receiver<Result<T, String>>) -> Self {
Self {
rx,
tx: None,
_marker: std::marker::PhantomData,
}
}
pub fn bidirectional(rx: mpsc::Receiver<Result<T, String>>, tx: mpsc::Sender<Bytes>) -> Self {
Self {
rx,
tx: Some(StreamSender {
inner: tx,
_marker: std::marker::PhantomData,
}),
_marker: std::marker::PhantomData,
}
}
pub fn sender(&self) -> Option<&StreamSender<T>> {
self.tx.as_ref()
}
pub async fn next(&mut self) -> Option<Result<T, String>> {
self.rx.recv().await
}
}
impl<T: Unpin> Stream for RpcStream<T> {
type Item = Result<T, String>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Pin::new(&mut this.rx).poll_recv(cx)
}
}
pub struct StreamBuilder<T> {
tx: mpsc::Sender<Result<T, String>>,
}
impl<T> StreamBuilder<T> {
pub fn new(buffer_size: usize) -> (Self, RpcStream<T>) {
let (tx, rx) = mpsc::channel(buffer_size);
(Self { tx }, RpcStream::new(rx))
}
pub async fn send(&self, item: T) -> Result<()> {
self.tx
.send(Ok(item))
.await
.map_err(|_| anyhow::anyhow!("Stream receiver dropped"))
}
pub async fn error(&self, err: impl std::fmt::Display) -> Result<()> {
self.tx
.send(Err(err.to_string()))
.await
.map_err(|_| anyhow::anyhow!("Stream receiver dropped"))
}
pub fn sender(&self) -> mpsc::Sender<Result<T, String>> {
self.tx.clone()
}
}
impl<T> Clone for StreamBuilder<T> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
}
}
}
pub const STREAM_HEADER_SIZE: usize = 15;
pub fn encode_stream_header(
frame_type: FrameType,
method_id: u16,
request_id: u64,
payload_len: u32,
) -> [u8; STREAM_HEADER_SIZE] {
let mut header = [0u8; STREAM_HEADER_SIZE];
header[0] = frame_type as u8;
header[1..3].copy_from_slice(&method_id.to_le_bytes());
header[3..11].copy_from_slice(&request_id.to_le_bytes());
header[11..15].copy_from_slice(&payload_len.to_le_bytes());
header
}
pub fn decode_stream_header(
header: &[u8; STREAM_HEADER_SIZE],
) -> Option<(FrameType, u16, u64, u32)> {
let frame_type = FrameType::from_u8(header[0])?;
let method_id = u16::from_le_bytes([header[1], header[2]]);
let request_id = u64::from_le_bytes([
header[3], header[4], header[5], header[6], header[7], header[8], header[9], header[10],
]);
let payload_len = u32::from_le_bytes([header[11], header[12], header[13], header[14]]);
Some((frame_type, method_id, request_id, payload_len))
}