1mod decoder;
5use std::{io, marker::PhantomData, time::Duration};
6
7use async_trait::async_trait;
8use decoder::DagCborDecodingReader;
9use futures::prelude::*;
10use libp2p::request_response::{self, OutboundFailure};
11use serde::{Serialize, de::DeserializeOwned};
12
13#[derive(Clone)]
17pub struct CborRequestResponse<P, RQ, RS> {
18 protocol: PhantomData<P>,
19 request: PhantomData<RQ>,
20 response: PhantomData<RS>,
21}
22
23impl<P, RQ, RS> Default for CborRequestResponse<P, RQ, RS> {
24 fn default() -> Self {
25 Self {
26 protocol: PhantomData::<P>,
27 request: PhantomData::<RQ>,
28 response: PhantomData::<RS>,
29 }
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
40pub enum RequestResponseError {
41 #[error("DialFailure")]
43 DialFailure,
44 #[error("Timeout")]
49 Timeout,
50 #[error("ConnectionClosed")]
55 ConnectionClosed,
56 #[error("UnsupportedProtocols")]
58 UnsupportedProtocols,
59 #[error("{0}")]
61 Io(io::Error),
62}
63
64impl From<OutboundFailure> for RequestResponseError {
65 fn from(err: OutboundFailure) -> Self {
66 match err {
67 OutboundFailure::DialFailure => Self::DialFailure,
68 OutboundFailure::Timeout => Self::Timeout,
69 OutboundFailure::ConnectionClosed => Self::ConnectionClosed,
70 OutboundFailure::UnsupportedProtocols => Self::UnsupportedProtocols,
71 OutboundFailure::Io(e) => Self::Io(e),
72 }
73 }
74}
75
76#[async_trait]
77impl<P, RQ, RS> request_response::Codec for CborRequestResponse<P, RQ, RS>
78where
79 P: AsRef<str> + Send + Clone,
80 RQ: Serialize + DeserializeOwned + Send + Sync,
81 RS: Serialize + DeserializeOwned + Send + Sync,
82{
83 type Protocol = P;
84 type Request = RQ;
85 type Response = RS;
86
87 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
88 where
89 T: AsyncRead + Unpin + Send,
90 {
91 read_request_and_decode(io).await
92 }
93
94 async fn read_response<T>(
95 &mut self,
96 _: &Self::Protocol,
97 io: &mut T,
98 ) -> io::Result<Self::Response>
99 where
100 T: AsyncRead + Unpin + Send,
101 {
102 let mut bytes = vec![];
103 io.read_to_end(&mut bytes).await?;
104 serde_ipld_dagcbor::de::from_reader(bytes.as_slice()).map_err(io::Error::other)
105 }
106
107 async fn write_request<T>(
108 &mut self,
109 _: &Self::Protocol,
110 io: &mut T,
111 req: Self::Request,
112 ) -> io::Result<()>
113 where
114 T: AsyncWrite + Unpin + Send,
115 {
116 encode_and_write(io, req).await
117 }
118
119 async fn write_response<T>(
120 &mut self,
121 _: &Self::Protocol,
122 io: &mut T,
123 res: Self::Response,
124 ) -> io::Result<()>
125 where
126 T: AsyncWrite + Unpin + Send,
127 {
128 encode_and_write(io, res).await
129 }
130}
131
132async fn read_request_and_decode<IO, T>(io: &mut IO) -> io::Result<T>
153where
154 IO: AsyncRead + Unpin,
155 T: serde::de::DeserializeOwned,
156{
157 const MAX_BYTES_ALLOWED: usize = 2 * 1024 * 1024; const TIMEOUT: Duration = Duration::from_secs(30);
159
160 match tokio::time::timeout(TIMEOUT, DagCborDecodingReader::new(io, MAX_BYTES_ALLOWED)).await {
165 Ok(r) => r,
166 Err(_) => {
167 let err = io::Error::other("read_and_decode timeout");
168 tracing::warn!("{err}");
169 Err(err)
170 }
171 }
172}
173
174async fn encode_and_write<IO, T>(io: &mut IO, data: T) -> io::Result<()>
175where
176 IO: AsyncWrite + Unpin,
177 T: serde::Serialize,
178{
179 let bytes = fvm_ipld_encoding::to_vec(&data).map_err(io::Error::other)?;
180 io.write_all(&bytes).await?;
181 io.close().await?;
182 Ok(())
183}