volans_request/codec/
protobuf.rs

1use std::{io, marker::PhantomData};
2
3use async_trait::async_trait;
4use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5use volans_swarm::StreamProtocol;
6
7use crate::Codec;
8
9#[derive(Debug, Clone)]
10pub struct ProtobufCodec<Req, Resp> {
11    request_size_maximum: u64,
12    response_size_maximum: u64,
13    phantom: PhantomData<(Req, Resp)>,
14}
15
16impl<Req, Resp> Default for ProtobufCodec<Req, Resp> {
17    fn default() -> Self {
18        ProtobufCodec {
19            request_size_maximum: 1024 * 1024,
20            response_size_maximum: 10 * 1024 * 1024,
21            phantom: PhantomData,
22        }
23    }
24}
25
26impl<Req, Resp> ProtobufCodec<Req, Resp> {
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    pub fn request_size_maximum(mut self, size: u64) -> Self {
32        self.request_size_maximum = size;
33        self
34    }
35
36    pub fn response_size_maximum(mut self, size: u64) -> Self {
37        self.response_size_maximum = size;
38        self
39    }
40}
41
42#[async_trait]
43impl<Req, Resp> Codec for ProtobufCodec<Req, Resp>
44where
45    Req: prost::Message + Send + Default,
46    Resp: prost::Message + Send + Default,
47{
48    type Protocol = StreamProtocol;
49    type Request = Req;
50    type Response = Resp;
51
52    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
53    where
54        T: AsyncRead + Unpin + Send,
55    {
56        let mut buffer = Vec::new();
57        io.take(self.request_size_maximum)
58            .read_to_end(&mut buffer)
59            .await?;
60        Ok(prost::Message::decode(buffer.as_slice())?)
61    }
62    async fn read_response<T>(
63        &mut self,
64        _: &Self::Protocol,
65        io: &mut T,
66    ) -> io::Result<Self::Response>
67    where
68        T: AsyncRead + Unpin + Send,
69    {
70        let mut buffer = Vec::new();
71        io.take(self.response_size_maximum)
72            .read_to_end(&mut buffer)
73            .await?;
74        Ok(prost::Message::decode(buffer.as_slice())?)
75    }
76    async fn write_request<T>(
77        &mut self,
78        _: &Self::Protocol,
79        io: &mut T,
80        request: Self::Request,
81    ) -> io::Result<()>
82    where
83        T: AsyncWrite + Unpin + Send,
84    {
85        let data = request.encode_to_vec();
86        io.write_all(&data).await?;
87        Ok(())
88    }
89    async fn write_response<T>(
90        &mut self,
91        _: &Self::Protocol,
92        io: &mut T,
93        response: Self::Response,
94    ) -> io::Result<()>
95    where
96        T: AsyncWrite + Unpin + Send,
97    {
98        let data = response.encode_to_vec();
99        io.write_all(&data).await?;
100        Ok(())
101    }
102}