1use std::pin::pin;
2
3use crate::{Error, Frame, Protocol};
4use futures::{Sink, Stream};
5use jetstream_wireformat::WireFormat;
6use tokio_util::{
7 bytes::{self, Buf, BufMut},
8 codec::{Decoder, Encoder},
9};
10
11pub struct ServerCodec<P: Protocol> {
12 _phantom: std::marker::PhantomData<P>,
13}
14
15impl<P: Protocol> ServerCodec<P> {
16 pub fn new() -> Self {
17 Self {
18 _phantom: std::marker::PhantomData,
19 }
20 }
21}
22
23impl<P: Protocol> Default for ServerCodec<P> {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29pub trait ServiceTransport<P: Protocol>:
30 Sink<Frame<P::Response>, Error = P::Error>
31 + Stream<Item = Result<Frame<P::Request>, P::Error>>
32 + Send
33 + Sync
34 + Unpin
35{
36}
37
38impl<P: Protocol, T> ServiceTransport<P> for T where
39 T: Sink<Frame<P::Response>, Error = P::Error>
40 + Stream<Item = Result<Frame<P::Request>, P::Error>>
41 + Send
42 + Sync
43 + Unpin
44{
45}
46
47impl<P> Decoder for ServerCodec<P>
48where
49 P: Protocol,
50{
51 type Error = Error;
52 type Item = Frame<P::Request>;
53
54 fn decode(
55 &mut self,
56 src: &mut bytes::BytesMut,
57 ) -> Result<Option<Self::Item>, Self::Error> {
58 if src.len() < 4 {
60 src.reserve(4);
61 return Ok(None);
62 }
63 let Some(mut bytz) = src.get(..4) else {
64 return Ok(None);
65 };
66
67 let byte_size: u32 = WireFormat::decode(&mut bytz)?;
68 if src.len() < byte_size as usize {
69 src.reserve(byte_size as usize);
70 return Ok(None);
71 }
72
73 Frame::<P::Request>::decode(&mut src.reader())
74 .map(Some)
75 .map_err(|_| Error::Custom("()".to_string()))
76 }
77}
78
79impl<P> Encoder<Frame<P::Response>> for ServerCodec<P>
80where
81 P: Protocol,
82{
83 type Error = Error;
84
85 fn encode(
86 &mut self,
87 item: Frame<P::Response>,
88 dst: &mut bytes::BytesMut,
89 ) -> Result<(), Self::Error> {
90 item.encode(&mut dst.writer())
91 .map_err(|_| Error::Custom("()".to_string()))
92 .map(|_| ())
93 }
94}
95
96pub async fn run<T, P>(p: &mut P, mut stream: T) -> Result<(), P::Error>
97where
98 T: ServiceTransport<P>,
99 P: Protocol,
100{
101 use futures::{SinkExt, StreamExt};
102 let mut a = pin!(p);
103 while let Some(Ok(frame)) = stream.next().await {
104 stream.send(a.rpc(frame).await?).await?
105 }
106 Ok(())
107}