jetstream_rpc/
server.rs

1use std::pin::pin;
2
3use crate::{Error, Frame, Protocol};
4use futures::{Sink, Stream};
5use jetstream_wireformat::WireFormat;
6use tokio_util::{
7    bytes::{self, Buf, BufMut},
8    codec::{Decoder, Encoder},
9};
10
11pub struct ServerCodec<P: Protocol> {
12    _phantom: std::marker::PhantomData<P>,
13}
14
15impl<P: Protocol> ServerCodec<P> {
16    pub fn new() -> Self {
17        Self {
18            _phantom: std::marker::PhantomData,
19        }
20    }
21}
22
23impl<P: Protocol> Default for ServerCodec<P> {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29pub trait ServiceTransport<P: Protocol>:
30    Sink<Frame<P::Response>, Error = P::Error>
31    + Stream<Item = Result<Frame<P::Request>, P::Error>>
32    + Send
33    + Sync
34    + Unpin
35{
36}
37
38impl<P: Protocol, T> ServiceTransport<P> for T where
39    T: Sink<Frame<P::Response>, Error = P::Error>
40        + Stream<Item = Result<Frame<P::Request>, P::Error>>
41        + Send
42        + Sync
43        + Unpin
44{
45}
46
47impl<P> Decoder for ServerCodec<P>
48where
49    P: Protocol,
50{
51    type Error = Error;
52    type Item = Frame<P::Request>;
53
54    fn decode(
55        &mut self,
56        src: &mut bytes::BytesMut,
57    ) -> Result<Option<Self::Item>, Self::Error> {
58        // check to see if you have at least 4 bytes to figure out the size
59        if src.len() < 4 {
60            src.reserve(4);
61            return Ok(None);
62        }
63        let Some(mut bytz) = src.get(..4) else {
64            return Ok(None);
65        };
66
67        let byte_size: u32 = WireFormat::decode(&mut bytz)?;
68        if src.len() < byte_size as usize {
69            src.reserve(byte_size as usize);
70            return Ok(None);
71        }
72
73        Frame::<P::Request>::decode(&mut src.reader())
74            .map(Some)
75            .map_err(|_| Error::Custom("()".to_string()))
76    }
77}
78
79impl<P> Encoder<Frame<P::Response>> for ServerCodec<P>
80where
81    P: Protocol,
82{
83    type Error = Error;
84
85    fn encode(
86        &mut self,
87        item: Frame<P::Response>,
88        dst: &mut bytes::BytesMut,
89    ) -> Result<(), Self::Error> {
90        item.encode(&mut dst.writer())
91            .map_err(|_| Error::Custom("()".to_string()))
92            .map(|_| ())
93    }
94}
95
96pub async fn run<T, P>(p: &mut P, mut stream: T) -> Result<(), P::Error>
97where
98    T: ServiceTransport<P>,
99    P: Protocol,
100{
101    use futures::{SinkExt, StreamExt};
102    let mut a = pin!(p);
103    while let Some(Ok(frame)) = stream.next().await {
104        stream.send(a.rpc(frame).await?).await?
105    }
106    Ok(())
107}