mod decoder;
use std::{io, marker::PhantomData, time::Duration};
use async_trait::async_trait;
use decoder::DagCborDecodingReader;
use futures::prelude::*;
use libp2p::request_response::{self, OutboundFailure};
use serde::{Serialize, de::DeserializeOwned};
pub trait CodecConfig {
const MAX_REQUEST_BYTES: usize;
const MAX_RESPONSE_BYTES: usize;
const DECODE_TIMEOUT: Duration;
}
pub struct CborRequestResponse<P, RQ, RS, C> {
protocol: PhantomData<P>,
request: PhantomData<RQ>,
response: PhantomData<RS>,
config: PhantomData<C>,
}
impl<P, RQ, RS, C> Copy for CborRequestResponse<P, RQ, RS, C> {}
impl<P, RQ, RS, C> Clone for CborRequestResponse<P, RQ, RS, C> {
fn clone(&self) -> Self {
*self
}
}
impl<P, RQ, RS, C> Default for CborRequestResponse<P, RQ, RS, C> {
fn default() -> Self {
Self {
protocol: PhantomData,
request: PhantomData,
response: PhantomData,
config: PhantomData,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RequestResponseError {
#[error("DialFailure")]
DialFailure,
#[error("Timeout")]
Timeout,
#[error("ConnectionClosed")]
ConnectionClosed,
#[error("UnsupportedProtocols")]
UnsupportedProtocols,
#[error("{0}")]
Io(io::Error),
}
impl From<OutboundFailure> for RequestResponseError {
fn from(err: OutboundFailure) -> Self {
match err {
OutboundFailure::DialFailure => Self::DialFailure,
OutboundFailure::Timeout => Self::Timeout,
OutboundFailure::ConnectionClosed => Self::ConnectionClosed,
OutboundFailure::UnsupportedProtocols => Self::UnsupportedProtocols,
OutboundFailure::Io(e) => Self::Io(e),
}
}
}
#[async_trait]
impl<P, RQ, RS, C> request_response::Codec for CborRequestResponse<P, RQ, RS, C>
where
P: AsRef<str> + Send + Clone,
RQ: Serialize + DeserializeOwned + Send + Sync,
RS: Serialize + DeserializeOwned + Send + Sync,
C: CodecConfig + Send + Sync,
{
type Protocol = P;
type Request = RQ;
type Response = RS;
async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
timed_decode(io, C::MAX_REQUEST_BYTES, C::DECODE_TIMEOUT).await
}
async fn read_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send,
{
timed_decode(io, C::MAX_RESPONSE_BYTES, C::DECODE_TIMEOUT).await
}
async fn write_request<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
req: Self::Request,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
encode_and_write(io, req).await
}
async fn write_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
res: Self::Response,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
encode_and_write(io, res).await
}
}
async fn timed_decode<IO, T>(io: &mut IO, max_bytes: usize, timeout: Duration) -> io::Result<T>
where
IO: AsyncRead + Unpin,
T: serde::de::DeserializeOwned,
{
match tokio::time::timeout(timeout, DagCborDecodingReader::new(io, max_bytes)).await {
Ok(r) => r,
Err(_) => {
let err = io::Error::from(io::ErrorKind::TimedOut);
tracing::debug!("{err}");
Err(err)
}
}
}
async fn encode_and_write<IO, T>(io: &mut IO, data: T) -> io::Result<()>
where
IO: AsyncWrite + Unpin,
T: serde::Serialize,
{
let bytes = fvm_ipld_encoding::to_vec(&data).map_err(io::Error::other)?;
io.write_all(&bytes).await?;
io.close().await?;
Ok(())
}