fuel_core_p2p/codecs/
request_response.rs1use std::io;
2
3use crate::request_response::{
4 messages::{
5 RequestMessage,
6 V1ResponseMessage,
7 V2ResponseMessage,
8 },
9 protocols::RequestResponseProtocol,
10};
11use async_trait::async_trait;
12use futures::{
13 AsyncRead,
14 AsyncReadExt,
15 AsyncWrite,
16 AsyncWriteExt,
17};
18use libp2p::request_response;
19use strum::IntoEnumIterator as _;
20
21use super::{
22 Decode,
23 Encode,
24 Encoder,
25 RequestResponseProtocols,
26};
27
28#[derive(Debug, Clone)]
29pub struct RequestResponseMessageHandler<Codec> {
30 pub(crate) codec: Codec,
31 pub(crate) max_response_size: std::num::NonZeroU32,
35}
36
37#[async_trait]
45impl<Codec> request_response::Codec for RequestResponseMessageHandler<Codec>
46where
47 Codec: Encode<RequestMessage, Error = io::Error>
48 + Decode<RequestMessage, Error = io::Error>
49 + Encode<V1ResponseMessage, Error = io::Error>
50 + Decode<V1ResponseMessage, Error = io::Error>
51 + Encode<V2ResponseMessage, Error = io::Error>
52 + Decode<V2ResponseMessage, Error = io::Error>
53 + Send,
54{
55 type Protocol = RequestResponseProtocol;
56 type Request = RequestMessage;
57 type Response = V2ResponseMessage;
58
59 async fn read_request<T>(
60 &mut self,
61 _protocol: &Self::Protocol,
62 socket: &mut T,
63 ) -> io::Result<Self::Request>
64 where
65 T: AsyncRead + Unpin + Send,
66 {
67 let mut response = Vec::new();
68 socket
69 .take(self.max_response_size.get() as u64)
70 .read_to_end(&mut response)
71 .await?;
72 self.codec.decode(&response)
73 }
74
75 async fn read_response<T>(
76 &mut self,
77 protocol: &Self::Protocol,
78 socket: &mut T,
79 ) -> io::Result<Self::Response>
80 where
81 T: AsyncRead + Unpin + Send,
82 {
83 let mut response = Vec::new();
84 socket
85 .take(self.max_response_size.get() as u64)
86 .read_to_end(&mut response)
87 .await?;
88
89 match protocol {
90 RequestResponseProtocol::V1 => {
91 let v1_response: V1ResponseMessage = self.codec.decode(&response)?;
92 Ok(v1_response.into())
93 }
94 RequestResponseProtocol::V2 => self.codec.decode(&response),
95 }
96 }
97
98 async fn write_request<T>(
99 &mut self,
100 _protocol: &Self::Protocol,
101 socket: &mut T,
102 req: Self::Request,
103 ) -> io::Result<()>
104 where
105 T: AsyncWrite + Unpin + Send,
106 {
107 let encoded_data = self.codec.encode(&req)?;
108 socket.write_all(&encoded_data.into_bytes()).await?;
109 Ok(())
110 }
111
112 async fn write_response<T>(
113 &mut self,
114 protocol: &Self::Protocol,
115 socket: &mut T,
116 res: Self::Response,
117 ) -> io::Result<()>
118 where
119 T: AsyncWrite + Unpin + Send,
120 {
121 match protocol {
122 RequestResponseProtocol::V1 => {
123 let v1_response: V1ResponseMessage = res.into();
124 let res = self.codec.encode(&v1_response)?;
125 let res = res.into_bytes();
126 socket.write_all(&res).await?;
127 }
128 RequestResponseProtocol::V2 => {
129 let res = self.codec.encode(&res)?;
130 let res = res.into_bytes();
131 socket.write_all(&res).await?;
132 }
133 };
134
135 Ok(())
136 }
137}
138
139impl<Codec> RequestResponseProtocols for Codec
140where
141 Codec: request_response::Codec<Protocol = RequestResponseProtocol>,
142{
143 fn get_req_res_protocols(
144 &self,
145 ) -> impl Iterator<Item = <Self as request_response::Codec>::Protocol> {
146 RequestResponseProtocol::iter().rev()
151 }
152}