1mod decoder;
5use std::{io, marker::PhantomData, time::Duration};
6
7use async_trait::async_trait;
8use decoder::DagCborDecodingReader;
9use futures::prelude::*;
10use libp2p::request_response::{self, OutboundFailure};
11use serde::{Serialize, de::DeserializeOwned};
12
13pub trait CodecConfig {
16 const MAX_REQUEST_BYTES: usize;
17 const MAX_RESPONSE_BYTES: usize;
18 const DECODE_TIMEOUT: Duration;
20}
21
22pub struct CborRequestResponse<P, RQ, RS, C> {
26 protocol: PhantomData<P>,
27 request: PhantomData<RQ>,
28 response: PhantomData<RS>,
29 config: PhantomData<C>,
30}
31
32impl<P, RQ, RS, C> Copy for CborRequestResponse<P, RQ, RS, C> {}
35impl<P, RQ, RS, C> Clone for CborRequestResponse<P, RQ, RS, C> {
36 fn clone(&self) -> Self {
37 *self
38 }
39}
40
41impl<P, RQ, RS, C> Default for CborRequestResponse<P, RQ, RS, C> {
42 fn default() -> Self {
43 Self {
44 protocol: PhantomData,
45 request: PhantomData,
46 response: PhantomData,
47 config: PhantomData,
48 }
49 }
50}
51
52#[derive(Debug, thiserror::Error)]
59pub enum RequestResponseError {
60 #[error("DialFailure")]
62 DialFailure,
63 #[error("Timeout")]
68 Timeout,
69 #[error("ConnectionClosed")]
74 ConnectionClosed,
75 #[error("UnsupportedProtocols")]
77 UnsupportedProtocols,
78 #[error("{0}")]
80 Io(io::Error),
81}
82
83impl From<OutboundFailure> for RequestResponseError {
84 fn from(err: OutboundFailure) -> Self {
85 match err {
86 OutboundFailure::DialFailure => Self::DialFailure,
87 OutboundFailure::Timeout => Self::Timeout,
88 OutboundFailure::ConnectionClosed => Self::ConnectionClosed,
89 OutboundFailure::UnsupportedProtocols => Self::UnsupportedProtocols,
90 OutboundFailure::Io(e) => Self::Io(e),
91 }
92 }
93}
94
95#[async_trait]
96impl<P, RQ, RS, C> request_response::Codec for CborRequestResponse<P, RQ, RS, C>
97where
98 P: AsRef<str> + Send + Clone,
99 RQ: Serialize + DeserializeOwned + Send + Sync,
100 RS: Serialize + DeserializeOwned + Send + Sync,
101 C: CodecConfig + Send + Sync,
102{
103 type Protocol = P;
104 type Request = RQ;
105 type Response = RS;
106
107 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
108 where
109 T: AsyncRead + Unpin + Send,
110 {
111 timed_decode(io, C::MAX_REQUEST_BYTES, C::DECODE_TIMEOUT).await
112 }
113
114 async fn read_response<T>(
115 &mut self,
116 _: &Self::Protocol,
117 io: &mut T,
118 ) -> io::Result<Self::Response>
119 where
120 T: AsyncRead + Unpin + Send,
121 {
122 timed_decode(io, C::MAX_RESPONSE_BYTES, C::DECODE_TIMEOUT).await
123 }
124
125 async fn write_request<T>(
126 &mut self,
127 _: &Self::Protocol,
128 io: &mut T,
129 req: Self::Request,
130 ) -> io::Result<()>
131 where
132 T: AsyncWrite + Unpin + Send,
133 {
134 encode_and_write(io, req).await
135 }
136
137 async fn write_response<T>(
138 &mut self,
139 _: &Self::Protocol,
140 io: &mut T,
141 res: Self::Response,
142 ) -> io::Result<()>
143 where
144 T: AsyncWrite + Unpin + Send,
145 {
146 encode_and_write(io, res).await
147 }
148}
149
150async fn timed_decode<IO, T>(io: &mut IO, max_bytes: usize, timeout: Duration) -> io::Result<T>
174where
175 IO: AsyncRead + Unpin,
176 T: serde::de::DeserializeOwned,
177{
178 match tokio::time::timeout(timeout, DagCborDecodingReader::new(io, max_bytes)).await {
179 Ok(r) => r,
180 Err(_) => {
181 let err = io::Error::from(io::ErrorKind::TimedOut);
182 tracing::debug!("{err}");
183 Err(err)
184 }
185 }
186}
187
188async fn encode_and_write<IO, T>(io: &mut IO, data: T) -> io::Result<()>
189where
190 IO: AsyncWrite + Unpin,
191 T: serde::Serialize,
192{
193 let bytes = fvm_ipld_encoding::to_vec(&data).map_err(io::Error::other)?;
194 io.write_all(&bytes).await?;
195 io.close().await?;
196 Ok(())
197}