distant_net/common/packet/
request.rs

1use std::borrow::Cow;
2use std::{io, str};
3
4use derive_more::{Display, Error};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7
8use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id};
9use crate::common::utils;
10use crate::header;
11
12/// Represents a request to send
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Request<T> {
15    /// Optional header data to include with request
16    #[serde(default, skip_serializing_if = "Header::is_empty")]
17    pub header: Header,
18
19    /// Unique id associated with the request
20    pub id: Id,
21
22    /// Payload associated with the request
23    pub payload: T,
24}
25
26impl<T> Request<T> {
27    /// Creates a new request with a random, unique id and no header data
28    pub fn new(payload: T) -> Self {
29        Self {
30            header: header!(),
31            id: rand::random::<u64>().to_string(),
32            payload,
33        }
34    }
35}
36
37impl<T> Request<T>
38where
39    T: Serialize,
40{
41    /// Serializes the request into bytes
42    pub fn to_vec(&self) -> io::Result<Vec<u8>> {
43        utils::serialize_to_vec(self)
44    }
45
46    /// Serializes the request's payload into bytes
47    pub fn to_payload_vec(&self) -> io::Result<Vec<u8>> {
48        utils::serialize_to_vec(&self.payload)
49    }
50
51    /// Attempts to convert a typed request to an untyped request
52    pub fn to_untyped_request(&self) -> io::Result<UntypedRequest> {
53        Ok(UntypedRequest {
54            header: Cow::Owned(if !self.header.is_empty() {
55                utils::serialize_to_vec(&self.header)?
56            } else {
57                Vec::new()
58            }),
59            id: Cow::Borrowed(&self.id),
60            payload: Cow::Owned(self.to_payload_vec()?),
61        })
62    }
63}
64
65impl<T> Request<T>
66where
67    T: DeserializeOwned,
68{
69    /// Deserializes the request from bytes
70    pub fn from_slice(slice: &[u8]) -> io::Result<Self> {
71        utils::deserialize_from_slice(slice)
72    }
73}
74
75impl<T> From<T> for Request<T> {
76    fn from(payload: T) -> Self {
77        Self::new(payload)
78    }
79}
80
81/// Error encountered when attempting to parse bytes as an untyped request
82#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)]
83pub enum UntypedRequestParseError {
84    /// When the bytes do not represent a request
85    WrongType,
86
87    /// When a header should be present, but the key is wrong
88    InvalidHeaderKey,
89
90    /// When a header should be present, but the header bytes are wrong
91    InvalidHeader,
92
93    /// When the key for the id is wrong
94    InvalidIdKey,
95
96    /// When the id is not a valid UTF-8 string
97    InvalidId,
98
99    /// When the key for the payload is wrong
100    InvalidPayloadKey,
101}
102
103#[inline]
104fn header_is_empty(header: &[u8]) -> bool {
105    header.is_empty()
106}
107
108/// Represents a request to send whose payload is bytes instead of a specific type
109#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
110pub struct UntypedRequest<'a> {
111    /// Header data associated with the request as bytes
112    #[serde(default, skip_serializing_if = "header_is_empty")]
113    pub header: Cow<'a, [u8]>,
114
115    /// Unique id associated with the request
116    pub id: Cow<'a, str>,
117
118    /// Payload associated with the request as bytes
119    pub payload: Cow<'a, [u8]>,
120}
121
122impl<'a> UntypedRequest<'a> {
123    /// Attempts to convert an untyped request to a typed request
124    pub fn to_typed_request<T: DeserializeOwned>(&self) -> io::Result<Request<T>> {
125        Ok(Request {
126            header: if header_is_empty(&self.header) {
127                header!()
128            } else {
129                utils::deserialize_from_slice(&self.header)?
130            },
131            id: self.id.to_string(),
132            payload: utils::deserialize_from_slice(&self.payload)?,
133        })
134    }
135
136    /// Convert into a borrowed version
137    pub fn as_borrowed(&self) -> UntypedRequest<'_> {
138        UntypedRequest {
139            header: match &self.header {
140                Cow::Borrowed(x) => Cow::Borrowed(x),
141                Cow::Owned(x) => Cow::Borrowed(x.as_slice()),
142            },
143            id: match &self.id {
144                Cow::Borrowed(x) => Cow::Borrowed(x),
145                Cow::Owned(x) => Cow::Borrowed(x.as_str()),
146            },
147            payload: match &self.payload {
148                Cow::Borrowed(x) => Cow::Borrowed(x),
149                Cow::Owned(x) => Cow::Borrowed(x.as_slice()),
150            },
151        }
152    }
153
154    /// Convert into an owned version
155    pub fn into_owned(self) -> UntypedRequest<'static> {
156        UntypedRequest {
157            header: match self.header {
158                Cow::Borrowed(x) => Cow::Owned(x.to_vec()),
159                Cow::Owned(x) => Cow::Owned(x),
160            },
161            id: match self.id {
162                Cow::Borrowed(x) => Cow::Owned(x.to_string()),
163                Cow::Owned(x) => Cow::Owned(x),
164            },
165            payload: match self.payload {
166                Cow::Borrowed(x) => Cow::Owned(x.to_vec()),
167                Cow::Owned(x) => Cow::Owned(x),
168            },
169        }
170    }
171
172    /// Updates the header of the request to the given `header`.
173    pub fn set_header(&mut self, header: impl IntoIterator<Item = u8>) {
174        self.header = Cow::Owned(header.into_iter().collect());
175    }
176
177    /// Updates the id of the request to the given `id`.
178    pub fn set_id(&mut self, id: impl Into<String>) {
179        self.id = Cow::Owned(id.into());
180    }
181
182    /// Allocates a new collection of bytes representing the request.
183    pub fn to_bytes(&self) -> Vec<u8> {
184        let mut bytes = vec![];
185
186        let has_header = !header_is_empty(&self.header);
187        if has_header {
188            rmp::encode::write_map_len(&mut bytes, 3).unwrap();
189        } else {
190            rmp::encode::write_map_len(&mut bytes, 2).unwrap();
191        }
192
193        if has_header {
194            rmp::encode::write_str(&mut bytes, "header").unwrap();
195            bytes.extend_from_slice(&self.header);
196        }
197
198        rmp::encode::write_str(&mut bytes, "id").unwrap();
199        rmp::encode::write_str(&mut bytes, &self.id).unwrap();
200
201        rmp::encode::write_str(&mut bytes, "payload").unwrap();
202        bytes.extend_from_slice(&self.payload);
203
204        bytes
205    }
206
207    /// Parses a collection of bytes, returning a partial request if it can be potentially
208    /// represented as a [`Request`] depending on the payload.
209    ///
210    /// NOTE: This supports parsing an invalid request where the payload would not properly
211    /// deserialize, but the bytes themselves represent a complete request of some kind.
212    pub fn from_slice(input: &'a [u8]) -> Result<Self, UntypedRequestParseError> {
213        if input.is_empty() {
214            return Err(UntypedRequestParseError::WrongType);
215        }
216
217        let has_header = match rmp::Marker::from_u8(input[0]) {
218            rmp::Marker::FixMap(2) => false,
219            rmp::Marker::FixMap(3) => true,
220            _ => return Err(UntypedRequestParseError::WrongType),
221        };
222
223        // Advance position by marker
224        let input = &input[1..];
225
226        // Parse the header if we have one
227        let (header, input) = if has_header {
228            let (_, input) = read_key_eq(input, "header")
229                .map_err(|_| UntypedRequestParseError::InvalidHeaderKey)?;
230
231            let (header, input) =
232                read_header_bytes(input).map_err(|_| UntypedRequestParseError::InvalidHeader)?;
233            (header, input)
234        } else {
235            ([0u8; 0].as_slice(), input)
236        };
237
238        // Validate that next field is id
239        let (_, input) =
240            read_key_eq(input, "id").map_err(|_| UntypedRequestParseError::InvalidIdKey)?;
241
242        // Get the id itself
243        let (id, input) = read_str_bytes(input).map_err(|_| UntypedRequestParseError::InvalidId)?;
244
245        // Validate that final field is payload
246        let (_, input) = read_key_eq(input, "payload")
247            .map_err(|_| UntypedRequestParseError::InvalidPayloadKey)?;
248
249        let header = Cow::Borrowed(header);
250        let id = Cow::Borrowed(id);
251        let payload = Cow::Borrowed(input);
252
253        Ok(Self {
254            header,
255            id,
256            payload,
257        })
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use test_log::test;
264
265    use super::*;
266
267    const TRUE_BYTE: u8 = 0xc3;
268    const NEVER_USED_BYTE: u8 = 0xc1;
269
270    // fixstr of 6 bytes with str "header"
271    const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r'];
272
273    // fixmap of 2 objects with
274    // 1. key fixstr "key" and value fixstr "value"
275    // 1. key fixstr "num" and value fixint 123
276    const HEADER_BYTES: &[u8] = &[
277        0x82, // valid map with 2 pair
278        0xa3, b'k', b'e', b'y', // key: "key"
279        0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value"
280        0xa3, b'n', b'u', b'm', // key: "num"
281        0x7b, // value: 123
282    ];
283
284    // fixstr of 2 bytes with str "id"
285    const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd'];
286
287    // fixstr of 7 bytes with str "payload"
288    const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd'];
289
290    // fixstr of 4 bytes with str "test"
291    const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't'];
292
293    #[test]
294    fn untyped_request_should_support_converting_to_bytes() {
295        let bytes = Request {
296            header: header!(),
297            id: "some id".to_string(),
298            payload: true,
299        }
300        .to_vec()
301        .unwrap();
302
303        let untyped_request = UntypedRequest::from_slice(&bytes).unwrap();
304        assert_eq!(untyped_request.to_bytes(), bytes);
305    }
306
307    #[test]
308    fn untyped_request_should_support_converting_to_bytes_with_header() {
309        let bytes = Request {
310            header: header!("key" -> 123),
311            id: "some id".to_string(),
312            payload: true,
313        }
314        .to_vec()
315        .unwrap();
316
317        let untyped_request = UntypedRequest::from_slice(&bytes).unwrap();
318        assert_eq!(untyped_request.to_bytes(), bytes);
319    }
320
321    #[test]
322    fn untyped_request_should_support_parsing_from_request_bytes_with_header() {
323        let bytes = Request {
324            header: header!("key" -> 123),
325            id: "some id".to_string(),
326            payload: true,
327        }
328        .to_vec()
329        .unwrap();
330
331        assert_eq!(
332            UntypedRequest::from_slice(&bytes),
333            Ok(UntypedRequest {
334                header: Cow::Owned(utils::serialize_to_vec(&header!("key" -> 123)).unwrap()),
335                id: Cow::Borrowed("some id"),
336                payload: Cow::Owned(vec![TRUE_BYTE]),
337            })
338        );
339    }
340
341    #[test]
342    fn untyped_request_should_support_parsing_from_request_bytes_with_valid_payload() {
343        let bytes = Request {
344            header: header!(),
345            id: "some id".to_string(),
346            payload: true,
347        }
348        .to_vec()
349        .unwrap();
350
351        assert_eq!(
352            UntypedRequest::from_slice(&bytes),
353            Ok(UntypedRequest {
354                header: Cow::Owned(vec![]),
355                id: Cow::Borrowed("some id"),
356                payload: Cow::Owned(vec![TRUE_BYTE]),
357            })
358        );
359    }
360
361    #[test]
362    fn untyped_request_should_support_parsing_from_request_bytes_with_invalid_payload() {
363        // Request with id < 32 bytes
364        let mut bytes = Request {
365            header: header!(),
366            id: "".to_string(),
367            payload: true,
368        }
369        .to_vec()
370        .unwrap();
371
372        // Push never used byte in msgpack
373        bytes.push(NEVER_USED_BYTE);
374
375        // We don't actually check for a valid payload, so the extra byte shows up
376        assert_eq!(
377            UntypedRequest::from_slice(&bytes),
378            Ok(UntypedRequest {
379                header: Cow::Owned(vec![]),
380                id: Cow::Owned("".to_string()),
381                payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]),
382            })
383        );
384    }
385
386    #[test]
387    fn untyped_request_should_support_parsing_full_request() {
388        let input = [
389            &[0x83],
390            HEADER_FIELD_BYTES,
391            HEADER_BYTES,
392            ID_FIELD_BYTES,
393            TEST_STR_BYTES,
394            PAYLOAD_FIELD_BYTES,
395            &[TRUE_BYTE],
396        ]
397        .concat();
398
399        // Convert into typed so we can test
400        let untyped_request = UntypedRequest::from_slice(&input).unwrap();
401        let request: Request<bool> = untyped_request.to_typed_request().unwrap();
402
403        assert_eq!(request.header, header!("key" -> "value", "num" -> 123));
404        assert_eq!(request.id, "test");
405        assert!(request.payload);
406    }
407
408    #[test]
409    fn untyped_request_should_fail_to_parse_if_given_bytes_not_representing_a_request() {
410        // Empty byte slice
411        assert_eq!(
412            UntypedRequest::from_slice(&[]),
413            Err(UntypedRequestParseError::WrongType)
414        );
415
416        // Wrong starting byte
417        assert_eq!(
418            UntypedRequest::from_slice(&[0x00]),
419            Err(UntypedRequestParseError::WrongType)
420        );
421
422        // Wrong starting byte (fixmap of 0 fields)
423        assert_eq!(
424            UntypedRequest::from_slice(&[0x80]),
425            Err(UntypedRequestParseError::WrongType)
426        );
427
428        // Invalid header key
429        assert_eq!(
430            UntypedRequest::from_slice(
431                [
432                    &[0x83],
433                    &[0xa0], // header key would be defined here, set to empty str
434                    HEADER_BYTES,
435                    ID_FIELD_BYTES,
436                    TEST_STR_BYTES,
437                    PAYLOAD_FIELD_BYTES,
438                    &[TRUE_BYTE],
439                ]
440                .concat()
441                .as_slice()
442            ),
443            Err(UntypedRequestParseError::InvalidHeaderKey)
444        );
445
446        // Invalid header bytes
447        assert_eq!(
448            UntypedRequest::from_slice(
449                [
450                    &[0x83],
451                    HEADER_FIELD_BYTES,
452                    &[0xa0], // header would be defined here, set to empty str
453                    ID_FIELD_BYTES,
454                    TEST_STR_BYTES,
455                    PAYLOAD_FIELD_BYTES,
456                    &[TRUE_BYTE],
457                ]
458                .concat()
459                .as_slice()
460            ),
461            Err(UntypedRequestParseError::InvalidHeader)
462        );
463
464        // Missing fields (corrupt data)
465        assert_eq!(
466            UntypedRequest::from_slice(&[0x82]),
467            Err(UntypedRequestParseError::InvalidIdKey)
468        );
469
470        // Missing id field (has valid data itself)
471        assert_eq!(
472            UntypedRequest::from_slice(
473                [
474                    &[0x82],
475                    &[0xa0], // id would be defined here, set to empty str
476                    TEST_STR_BYTES,
477                    PAYLOAD_FIELD_BYTES,
478                    &[TRUE_BYTE],
479                ]
480                .concat()
481                .as_slice()
482            ),
483            Err(UntypedRequestParseError::InvalidIdKey)
484        );
485
486        // Non-str id field value
487        assert_eq!(
488            UntypedRequest::from_slice(
489                [
490                    &[0x82],
491                    ID_FIELD_BYTES,
492                    &[TRUE_BYTE], // id value set to boolean
493                    PAYLOAD_FIELD_BYTES,
494                    &[TRUE_BYTE],
495                ]
496                .concat()
497                .as_slice()
498            ),
499            Err(UntypedRequestParseError::InvalidId)
500        );
501
502        // Non-utf8 id field value
503        assert_eq!(
504            UntypedRequest::from_slice(
505                [
506                    &[0x82],
507                    ID_FIELD_BYTES,
508                    &[0xa4, 0, 159, 146, 150],
509                    PAYLOAD_FIELD_BYTES,
510                    &[TRUE_BYTE],
511                ]
512                .concat()
513                .as_slice()
514            ),
515            Err(UntypedRequestParseError::InvalidId)
516        );
517
518        // Missing payload field (has valid data itself)
519        assert_eq!(
520            UntypedRequest::from_slice(
521                [
522                    &[0x82],
523                    ID_FIELD_BYTES,
524                    TEST_STR_BYTES,
525                    &[0xa0], // payload would be defined here, set to empty str
526                    &[TRUE_BYTE],
527                ]
528                .concat()
529                .as_slice()
530            ),
531            Err(UntypedRequestParseError::InvalidPayloadKey)
532        );
533    }
534}