Skip to main content

nimiq_network_interface/request/
mod.rs

1use std::{fmt, sync::Arc, time::Duration};
2
3use futures::{stream::BoxStream, Future, StreamExt};
4use nimiq_serde::{Deserialize, DeserializeError, Serialize};
5use thiserror::Error;
6
7/// The range to restrict the responses to the requests on the network layer.
8pub const DEFAULT_MAX_REQUEST_RESPONSE_TIME_WINDOW: Duration = Duration::from_secs(10);
9
10use crate::network::Network;
11
12#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
13pub struct RequestType(pub u16);
14
15impl fmt::Display for RequestType {
16    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17        write!(
18            f,
19            "{}{}",
20            self.type_id(),
21            if self.requires_response() { 'r' } else { 'm' },
22        )
23    }
24}
25
26impl RequestType {
27    const fn new(type_id: u16, requires_response: bool) -> Self {
28        Self((type_id << 1) | requires_response as u16)
29    }
30    pub fn from_request<R: RequestCommon>() -> Self {
31        Self::new(R::TYPE_ID, R::Kind::EXPECT_RESPONSE)
32    }
33    pub const fn request(type_id: u16) -> Self {
34        Self::new(type_id, true)
35    }
36    pub const fn message(type_id: u16) -> Self {
37        Self::new(type_id, false)
38    }
39    pub const fn type_id(self) -> u16 {
40        self.0 >> 1
41    }
42    pub const fn requires_response(self) -> bool {
43        self.0 & 1 != 0
44    }
45}
46
47/// Error enumeration for requests
48#[derive(Clone, Debug, Error, Eq, PartialEq)]
49pub enum RequestError {
50    /// Outbound request error
51    #[error("Outbound error: {0}")]
52    OutboundRequest(#[from] OutboundRequestError),
53    /// Inbound request error
54    #[error("Inbound error: {0}")]
55    InboundRequest(#[from] InboundRequestError),
56}
57
58#[derive(Clone, Debug, Error, Eq, PartialEq)]
59pub enum OutboundRequestError {
60    /// The connection closed before a response was received.
61    ///
62    /// It is not known whether the request may have been
63    /// received (and processed) by the remote peer.
64    #[error("Connection to peer is closed")]
65    ConnectionClosed,
66    /// The request could not be sent because a dialing attempt failed.
67    #[error("Dial attempt failed")]
68    DialFailure,
69    /// No receiver was found for this request and no response could be transmitted
70    #[error("No receiver for request")]
71    NoReceiver,
72    /// Error sending this request
73    #[error("Couldn't send request")]
74    SendError,
75    /// Sender future has already been dropped
76    #[error("Sender future is already dropped")]
77    SenderFutureDropped,
78    /// Request failed to be serialized
79    #[error("Failed to serialized request")]
80    SerializationError,
81    /// Timeout waiting for the response of this request.
82    /// In this case a receiver was registered for responding these requests
83    /// but the response never arrived before the timeout was hit.
84    #[error("Request timed out")]
85    Timeout,
86    /// The remote supports none of the requested protocols.
87    #[error("Remote doesn't support requested protocol")]
88    UnsupportedProtocols,
89    /// No response after asking a couple of peers.
90    #[error("No response after asking a couple of peers")]
91    NoResponse,
92    /// Error that doesn't match any of the other error causes
93    #[error("Other: {0}")]
94    Other(String),
95}
96
97#[repr(u8)]
98#[derive(Clone, Copy, Debug, Error, Eq, PartialEq, Serialize, Deserialize)]
99pub enum InboundRequestError {
100    /// Response failed to be deserialized
101    #[error("Response failed to be deserialized")]
102    DeSerializationError = 1,
103    /// No receiver was found for this incoming request
104    #[error("No receiver for request")]
105    NoReceiver = 2,
106    /// Sender future has already been dropped
107    #[error("Sender future is already dropped")]
108    SenderFutureDropped = 3,
109    /// The request timed out before a response could have been sent.
110    #[error("Request timed out")]
111    Timeout = 4,
112    /// The request exceeded the maximum defined rate limit for its request type.
113    #[error("Request exceeds the maximum rate limit")]
114    ExceedsRateLimit = 5,
115}
116
117pub trait RequestKind {
118    const EXPECT_RESPONSE: bool;
119}
120
121pub struct RequestMarker;
122pub struct MessageMarker;
123
124impl RequestKind for RequestMarker {
125    const EXPECT_RESPONSE: bool = true;
126}
127impl RequestKind for MessageMarker {
128    const EXPECT_RESPONSE: bool = false;
129}
130
131pub trait RequestCommon:
132    Serialize + Deserialize + Send + Sync + Unpin + fmt::Debug + 'static
133{
134    type Kind: RequestKind;
135    const TYPE_ID: u16;
136    type Response: Deserialize + Serialize + Send;
137    const MAX_REQUESTS: u32;
138    const TIME_WINDOW: Duration = DEFAULT_MAX_REQUEST_RESPONSE_TIME_WINDOW;
139
140    /// Returns the type name of the given request type `T`.
141    /// This only works for
142    ///   - non-generic types
143    ///   - generic types with a single level of nesting, in which case the name of
144    ///     the type parameter is returned
145    fn type_name<T>() -> &'static str {
146        let name = std::any::type_name::<T>();
147        let name = match name.rfind("::") {
148            Some(index) => &name[index + 2..],
149            None => name,
150        };
151        match name.rfind(">") {
152            Some(index) => &name[..index],
153            None => name,
154        }
155    }
156}
157
158pub trait RequestSerialize: RequestCommon {
159    /// Serializes a request.
160    /// A serialized request is composed of:
161    /// - A variable sized integer for the Type ID of the request
162    /// - Serialized content of the inner type.
163    fn serialize_request(&self) -> Vec<u8> {
164        let mut data = Vec::with_capacity(self.serialized_request_size());
165        RequestType::from_request::<Self>()
166            .serialize_to_writer(&mut data)
167            .unwrap();
168        Serialize::serialize_to_writer(self, &mut data).unwrap();
169        data
170    }
171
172    /// Computes the size in bytes of a serialized request.
173    /// A serialized request is composed of:
174    /// - A 2 bytes (u16) for the Type ID of the request
175    /// - Serialized content of the inner type.
176    fn serialized_request_size(&self) -> usize {
177        let mut size = 0;
178        size += RequestType::from_request::<Self>().0.serialized_size();
179        size += self.serialized_size();
180        size
181    }
182
183    /// Deserializes a request
184    /// A serialized request is composed of:
185    /// - A variable sized integer for the Type ID of the request
186    /// - Serialized content of the inner type.
187    fn deserialize_request(buffer: &[u8]) -> Result<Self, DeserializeError> {
188        // Check for correct type.
189        let (ty, message_buf) = u16::deserialize_take(buffer)?;
190        if ty != RequestType::from_request::<Self>().0 {
191            return Err(DeserializeError::bad_enum());
192        }
193        Self::deserialize_from_vec(message_buf)
194    }
195}
196
197impl<T: RequestCommon> RequestSerialize for T {}
198
199pub trait Request: RequestCommon<Kind = RequestMarker> {}
200pub trait Message: RequestCommon<Kind = MessageMarker, Response = ()> {}
201
202impl<T: RequestCommon<Kind = RequestMarker>> Request for T {}
203impl<T: RequestCommon<Kind = MessageMarker, Response = ()>> Message for T {}
204
205pub fn peek_type(buffer: &[u8]) -> Result<RequestType, DeserializeError> {
206    let ty = u16::deserialize_from_vec(buffer)?;
207    Ok(RequestType(ty))
208}
209
210/// This trait defines the behaviour when receiving a message and how to generate the response.
211pub trait Handle<N: Network, T>: Request {
212    fn handle(&self, peer_id: N::PeerId, context: &T) -> <Self as RequestCommon>::Response;
213}
214
215/// This trait defines the behaviour when receiving a message
216pub trait MessageHandle<N: Network, T> {
217    fn message_handle(&self, peer_id: N::PeerId, context: &T);
218}
219
220const MAX_CONCURRENT_HANDLERS: usize = 64;
221
222pub fn request_handler<T: Send + Sync + Clone + 'static, Req: Handle<N, T>, N: Network>(
223    network: &Arc<N>,
224    stream: BoxStream<'static, (Req, N::RequestId, N::PeerId)>,
225    req_environment: &T,
226) -> impl Future<Output = ()> + use<T, Req, N> {
227    let req_environment = req_environment.clone();
228    let network = Arc::clone(network);
229    async move {
230        stream
231            .for_each_concurrent(MAX_CONCURRENT_HANDLERS, |(msg, request_id, peer_id)| {
232                let network = Arc::clone(&network);
233                let req_environment = req_environment.clone();
234                async move {
235                    // Try to send the response. If it fails, it usually means that the peer has
236                    // disconnected, so we silently ignore any errors here.
237                    network
238                        .respond::<Req>(request_id, msg.handle(peer_id, &req_environment))
239                        .await
240                        .ok();
241                }
242            })
243            .await
244    }
245}
246
247/// Handler that takes care of sending messages to a network, similar to a request except that we don't expect an answer.
248pub fn message_handler<
249    T: Send + Sync + Clone + 'static,
250    Msg: MessageHandle<N, T> + Message,
251    N: Network,
252>(
253    _network: &Arc<N>,
254    stream: BoxStream<'static, (Msg, N::PeerId)>,
255    req_environment: &T,
256) -> impl Future<Output = ()> {
257    let req_environment = req_environment.clone();
258    async move {
259        stream
260            .for_each_concurrent(MAX_CONCURRENT_HANDLERS, |(msg, peer_id)| {
261                let req_environment = req_environment.clone();
262                async move {
263                    // Messages do not have a response (so the response is ignored)
264                    msg.message_handle(peer_id, &req_environment);
265                }
266            })
267            .await
268    }
269}