distant_net/common/packet/
response.rs

1use std::borrow::Cow;
2use std::io;
3
4use derive_more::{Display, Error};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7
8use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id};
9use crate::common::utils;
10use crate::header;
11
12/// Represents a response received related to some response
13#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Response<T> {
15    /// Optional header data to include with response
16    #[serde(default, skip_serializing_if = "Header::is_empty")]
17    pub header: Header,
18
19    /// Unique id associated with the response
20    pub id: Id,
21
22    /// Unique id associated with the response that triggered the response
23    pub origin_id: Id,
24
25    /// Payload associated with the response
26    pub payload: T,
27}
28
29impl<T> Response<T> {
30    /// Creates a new response with a random, unique id and no header data
31    pub fn new(origin_id: Id, payload: T) -> Self {
32        Self {
33            header: header!(),
34            id: rand::random::<u64>().to_string(),
35            origin_id,
36            payload,
37        }
38    }
39}
40
41impl<T> Response<T>
42where
43    T: Serialize,
44{
45    /// Serializes the response into bytes
46    pub fn to_vec(&self) -> std::io::Result<Vec<u8>> {
47        utils::serialize_to_vec(self)
48    }
49
50    /// Serializes the response's payload into bytes
51    pub fn to_payload_vec(&self) -> io::Result<Vec<u8>> {
52        utils::serialize_to_vec(&self.payload)
53    }
54
55    /// Attempts to convert a typed response to an untyped response
56    pub fn to_untyped_response(&self) -> io::Result<UntypedResponse> {
57        Ok(UntypedResponse {
58            header: Cow::Owned(if !self.header.is_empty() {
59                utils::serialize_to_vec(&self.header)?
60            } else {
61                Vec::new()
62            }),
63            id: Cow::Borrowed(&self.id),
64            origin_id: Cow::Borrowed(&self.origin_id),
65            payload: Cow::Owned(self.to_payload_vec()?),
66        })
67    }
68}
69
70impl<T> Response<T>
71where
72    T: DeserializeOwned,
73{
74    /// Deserializes the response from bytes
75    pub fn from_slice(slice: &[u8]) -> std::io::Result<Self> {
76        utils::deserialize_from_slice(slice)
77    }
78}
79
80/// Error encountered when attempting to parse bytes as an untyped response
81#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)]
82pub enum UntypedResponseParseError {
83    /// When the bytes do not represent a response
84    WrongType,
85
86    /// When a header should be present, but the key is wrong
87    InvalidHeaderKey,
88
89    /// When a header should be present, but the header bytes are wrong
90    InvalidHeader,
91
92    /// When the key for the id is wrong
93    InvalidIdKey,
94
95    /// When the id is not a valid UTF-8 string
96    InvalidId,
97
98    /// When the key for the origin id is wrong
99    InvalidOriginIdKey,
100
101    /// When the origin id is not a valid UTF-8 string
102    InvalidOriginId,
103
104    /// When the key for the payload is wrong
105    InvalidPayloadKey,
106}
107
108#[inline]
109fn header_is_empty(header: &[u8]) -> bool {
110    header.is_empty()
111}
112
113/// Represents a response to send whose payload is bytes instead of a specific type
114#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
115pub struct UntypedResponse<'a> {
116    /// Header data associated with the response as bytes
117    #[serde(default, skip_serializing_if = "header_is_empty")]
118    pub header: Cow<'a, [u8]>,
119
120    /// Unique id associated with the response
121    pub id: Cow<'a, str>,
122
123    /// Unique id associated with the response that triggered the response
124    pub origin_id: Cow<'a, str>,
125
126    /// Payload associated with the response as bytes
127    pub payload: Cow<'a, [u8]>,
128}
129
130impl<'a> UntypedResponse<'a> {
131    /// Attempts to convert an untyped response to a typed response
132    pub fn to_typed_response<T: DeserializeOwned>(&self) -> io::Result<Response<T>> {
133        Ok(Response {
134            header: if header_is_empty(&self.header) {
135                header!()
136            } else {
137                utils::deserialize_from_slice(&self.header)?
138            },
139            id: self.id.to_string(),
140            origin_id: self.origin_id.to_string(),
141            payload: utils::deserialize_from_slice(&self.payload)?,
142        })
143    }
144
145    /// Convert into a borrowed version
146    pub fn as_borrowed(&self) -> UntypedResponse<'_> {
147        UntypedResponse {
148            header: match &self.header {
149                Cow::Borrowed(x) => Cow::Borrowed(x),
150                Cow::Owned(x) => Cow::Borrowed(x.as_slice()),
151            },
152            id: match &self.id {
153                Cow::Borrowed(x) => Cow::Borrowed(x),
154                Cow::Owned(x) => Cow::Borrowed(x.as_str()),
155            },
156            origin_id: match &self.origin_id {
157                Cow::Borrowed(x) => Cow::Borrowed(x),
158                Cow::Owned(x) => Cow::Borrowed(x.as_str()),
159            },
160            payload: match &self.payload {
161                Cow::Borrowed(x) => Cow::Borrowed(x),
162                Cow::Owned(x) => Cow::Borrowed(x.as_slice()),
163            },
164        }
165    }
166
167    /// Convert into an owned version
168    pub fn into_owned(self) -> UntypedResponse<'static> {
169        UntypedResponse {
170            header: match self.header {
171                Cow::Borrowed(x) => Cow::Owned(x.to_vec()),
172                Cow::Owned(x) => Cow::Owned(x),
173            },
174            id: match self.id {
175                Cow::Borrowed(x) => Cow::Owned(x.to_string()),
176                Cow::Owned(x) => Cow::Owned(x),
177            },
178            origin_id: match self.origin_id {
179                Cow::Borrowed(x) => Cow::Owned(x.to_string()),
180                Cow::Owned(x) => Cow::Owned(x),
181            },
182            payload: match self.payload {
183                Cow::Borrowed(x) => Cow::Owned(x.to_vec()),
184                Cow::Owned(x) => Cow::Owned(x),
185            },
186        }
187    }
188
189    /// Updates the header of the response to the given `header`.
190    pub fn set_header(&mut self, header: impl IntoIterator<Item = u8>) {
191        self.header = Cow::Owned(header.into_iter().collect());
192    }
193
194    /// Updates the id of the response to the given `id`.
195    pub fn set_id(&mut self, id: impl Into<String>) {
196        self.id = Cow::Owned(id.into());
197    }
198
199    /// Updates the origin id of the response to the given `origin_id`.
200    pub fn set_origin_id(&mut self, origin_id: impl Into<String>) {
201        self.origin_id = Cow::Owned(origin_id.into());
202    }
203
204    /// Allocates a new collection of bytes representing the response.
205    pub fn to_bytes(&self) -> Vec<u8> {
206        let mut bytes = vec![];
207
208        let has_header = !header_is_empty(&self.header);
209        if has_header {
210            rmp::encode::write_map_len(&mut bytes, 4).unwrap();
211        } else {
212            rmp::encode::write_map_len(&mut bytes, 3).unwrap();
213        }
214
215        if has_header {
216            rmp::encode::write_str(&mut bytes, "header").unwrap();
217            bytes.extend_from_slice(&self.header);
218        }
219
220        rmp::encode::write_str(&mut bytes, "id").unwrap();
221        rmp::encode::write_str(&mut bytes, &self.id).unwrap();
222
223        rmp::encode::write_str(&mut bytes, "origin_id").unwrap();
224        rmp::encode::write_str(&mut bytes, &self.origin_id).unwrap();
225
226        rmp::encode::write_str(&mut bytes, "payload").unwrap();
227        bytes.extend_from_slice(&self.payload);
228
229        bytes
230    }
231
232    /// Parses a collection of bytes, returning an untyped response if it can be potentially
233    /// represented as a [`Response`] depending on the payload.
234    ///
235    /// NOTE: This supports parsing an invalid response where the payload would not properly
236    /// deserialize, but the bytes themselves represent a complete response of some kind.
237    pub fn from_slice(input: &'a [u8]) -> Result<Self, UntypedResponseParseError> {
238        if input.is_empty() {
239            return Err(UntypedResponseParseError::WrongType);
240        }
241
242        let has_header = match rmp::Marker::from_u8(input[0]) {
243            rmp::Marker::FixMap(3) => false,
244            rmp::Marker::FixMap(4) => true,
245            _ => return Err(UntypedResponseParseError::WrongType),
246        };
247
248        // Advance position by marker
249        let input = &input[1..];
250
251        // Parse the header if we have one
252        let (header, input) = if has_header {
253            let (_, input) = read_key_eq(input, "header")
254                .map_err(|_| UntypedResponseParseError::InvalidHeaderKey)?;
255
256            let (header, input) =
257                read_header_bytes(input).map_err(|_| UntypedResponseParseError::InvalidHeader)?;
258            (header, input)
259        } else {
260            ([0u8; 0].as_slice(), input)
261        };
262
263        // Validate that next field is id
264        let (_, input) =
265            read_key_eq(input, "id").map_err(|_| UntypedResponseParseError::InvalidIdKey)?;
266
267        // Get the id itself
268        let (id, input) =
269            read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidId)?;
270
271        // Validate that next field is origin_id
272        let (_, input) = read_key_eq(input, "origin_id")
273            .map_err(|_| UntypedResponseParseError::InvalidOriginIdKey)?;
274
275        // Get the origin_id itself
276        let (origin_id, input) =
277            read_str_bytes(input).map_err(|_| UntypedResponseParseError::InvalidOriginId)?;
278
279        // Validate that final field is payload
280        let (_, input) = read_key_eq(input, "payload")
281            .map_err(|_| UntypedResponseParseError::InvalidPayloadKey)?;
282
283        let header = Cow::Borrowed(header);
284        let id = Cow::Borrowed(id);
285        let origin_id = Cow::Borrowed(origin_id);
286        let payload = Cow::Borrowed(input);
287
288        Ok(Self {
289            header,
290            id,
291            origin_id,
292            payload,
293        })
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use test_log::test;
300
301    use super::*;
302
303    const TRUE_BYTE: u8 = 0xc3;
304    const NEVER_USED_BYTE: u8 = 0xc1;
305
306    // fixstr of 6 bytes with str "header"
307    const HEADER_FIELD_BYTES: &[u8] = &[0xa6, b'h', b'e', b'a', b'd', b'e', b'r'];
308
309    // fixmap of 2 objects with
310    // 1. key fixstr "key" and value fixstr "value"
311    // 1. key fixstr "num" and value fixint 123
312    const HEADER_BYTES: &[u8] = &[
313        0x82, // valid map with 2 pair
314        0xa3, b'k', b'e', b'y', // key: "key"
315        0xa5, b'v', b'a', b'l', b'u', b'e', // value: "value"
316        0xa3, b'n', b'u', b'm', // key: "num"
317        0x7b, // value: 123
318    ];
319
320    // fixstr of 2 bytes with str "id"
321    const ID_FIELD_BYTES: &[u8] = &[0xa2, b'i', b'd'];
322
323    // fixstr of 9 bytes with str "origin_id"
324    const ORIGIN_ID_FIELD_BYTES: &[u8] =
325        &[0xa9, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x5f, 0x69, 0x64];
326
327    // fixstr of 7 bytes with str "payload"
328    const PAYLOAD_FIELD_BYTES: &[u8] = &[0xa7, b'p', b'a', b'y', b'l', b'o', b'a', b'd'];
329
330    /// fixstr of 4 bytes with str "test"
331    const TEST_STR_BYTES: &[u8] = &[0xa4, b't', b'e', b's', b't'];
332
333    #[test]
334    fn untyped_response_should_support_converting_to_bytes() {
335        let bytes = Response {
336            header: header!(),
337            id: "some id".to_string(),
338            origin_id: "some origin id".to_string(),
339            payload: true,
340        }
341        .to_vec()
342        .unwrap();
343
344        let untyped_response = UntypedResponse::from_slice(&bytes).unwrap();
345        assert_eq!(untyped_response.to_bytes(), bytes);
346    }
347
348    #[test]
349    fn untyped_response_should_support_converting_to_bytes_with_header() {
350        let bytes = Response {
351            header: header!("key" -> 123),
352            id: "some id".to_string(),
353            origin_id: "some origin id".to_string(),
354            payload: true,
355        }
356        .to_vec()
357        .unwrap();
358
359        let untyped_response = UntypedResponse::from_slice(&bytes).unwrap();
360        assert_eq!(untyped_response.to_bytes(), bytes);
361    }
362
363    #[test]
364    fn untyped_response_should_support_parsing_from_response_bytes_with_header() {
365        let bytes = Response {
366            header: header!("key" -> 123),
367            id: "some id".to_string(),
368            origin_id: "some origin id".to_string(),
369            payload: true,
370        }
371        .to_vec()
372        .unwrap();
373
374        assert_eq!(
375            UntypedResponse::from_slice(&bytes),
376            Ok(UntypedResponse {
377                header: Cow::Owned(utils::serialize_to_vec(&header!("key" -> 123)).unwrap()),
378                id: Cow::Borrowed("some id"),
379                origin_id: Cow::Borrowed("some origin id"),
380                payload: Cow::Owned(vec![TRUE_BYTE]),
381            })
382        );
383    }
384
385    #[test]
386    fn untyped_response_should_support_parsing_from_response_bytes_with_valid_payload() {
387        let bytes = Response {
388            header: header!(),
389            id: "some id".to_string(),
390            origin_id: "some origin id".to_string(),
391            payload: true,
392        }
393        .to_vec()
394        .unwrap();
395
396        assert_eq!(
397            UntypedResponse::from_slice(&bytes),
398            Ok(UntypedResponse {
399                header: Cow::Owned(vec![]),
400                id: Cow::Borrowed("some id"),
401                origin_id: Cow::Borrowed("some origin id"),
402                payload: Cow::Owned(vec![TRUE_BYTE]),
403            })
404        );
405    }
406
407    #[test]
408    fn untyped_response_should_support_parsing_from_response_bytes_with_invalid_payload() {
409        // Response with id < 32 bytes
410        let mut bytes = Response {
411            header: header!(),
412            id: "".to_string(),
413            origin_id: "".to_string(),
414            payload: true,
415        }
416        .to_vec()
417        .unwrap();
418
419        // Push never used byte in msgpack
420        bytes.push(NEVER_USED_BYTE);
421
422        // We don't actually check for a valid payload, so the extra byte shows up
423        assert_eq!(
424            UntypedResponse::from_slice(&bytes),
425            Ok(UntypedResponse {
426                header: Cow::Owned(vec![]),
427                id: Cow::Owned("".to_string()),
428                origin_id: Cow::Owned("".to_string()),
429                payload: Cow::Owned(vec![TRUE_BYTE, NEVER_USED_BYTE]),
430            })
431        );
432    }
433
434    #[test]
435    fn untyped_response_should_support_parsing_full_request() {
436        let input = [
437            &[0x84],
438            HEADER_FIELD_BYTES,
439            HEADER_BYTES,
440            ID_FIELD_BYTES,
441            TEST_STR_BYTES,
442            ORIGIN_ID_FIELD_BYTES,
443            &[0xa2, b'o', b'g'],
444            PAYLOAD_FIELD_BYTES,
445            &[TRUE_BYTE],
446        ]
447        .concat();
448
449        // Convert into typed so we can test
450        let untyped_response = UntypedResponse::from_slice(&input).unwrap();
451        let response: Response<bool> = untyped_response.to_typed_response().unwrap();
452
453        assert_eq!(response.header, header!("key" -> "value", "num" -> 123));
454        assert_eq!(response.id, "test");
455        assert_eq!(response.origin_id, "og");
456        assert!(response.payload);
457    }
458
459    #[test]
460    fn untyped_response_should_fail_to_parse_if_given_bytes_not_representing_a_response() {
461        // Empty byte slice
462        assert_eq!(
463            UntypedResponse::from_slice(&[]),
464            Err(UntypedResponseParseError::WrongType)
465        );
466
467        // Wrong starting byte
468        assert_eq!(
469            UntypedResponse::from_slice(&[0x00]),
470            Err(UntypedResponseParseError::WrongType)
471        );
472
473        // Wrong starting byte (fixmap of 0 fields)
474        assert_eq!(
475            UntypedResponse::from_slice(&[0x80]),
476            Err(UntypedResponseParseError::WrongType)
477        );
478
479        // Invalid header key
480        assert_eq!(
481            UntypedResponse::from_slice(
482                [
483                    &[0x84],
484                    &[0xa0], // header key would be defined here, set to empty str
485                    HEADER_BYTES,
486                    ID_FIELD_BYTES,
487                    TEST_STR_BYTES,
488                    ORIGIN_ID_FIELD_BYTES,
489                    TEST_STR_BYTES,
490                    PAYLOAD_FIELD_BYTES,
491                    &[TRUE_BYTE],
492                ]
493                .concat()
494                .as_slice()
495            ),
496            Err(UntypedResponseParseError::InvalidHeaderKey)
497        );
498
499        // Invalid header bytes
500        assert_eq!(
501            UntypedResponse::from_slice(
502                [
503                    &[0x84],
504                    HEADER_FIELD_BYTES,
505                    &[0xa0], // header would be defined here, set to empty str
506                    ID_FIELD_BYTES,
507                    TEST_STR_BYTES,
508                    ORIGIN_ID_FIELD_BYTES,
509                    TEST_STR_BYTES,
510                    PAYLOAD_FIELD_BYTES,
511                    &[TRUE_BYTE],
512                ]
513                .concat()
514                .as_slice()
515            ),
516            Err(UntypedResponseParseError::InvalidHeader)
517        );
518
519        // Missing fields (corrupt data)
520        assert_eq!(
521            UntypedResponse::from_slice(&[0x83]),
522            Err(UntypedResponseParseError::InvalidIdKey)
523        );
524
525        // Missing id field (has valid data itself)
526        assert_eq!(
527            UntypedResponse::from_slice(
528                [
529                    &[0x83],
530                    &[0xa0], // id would be defined here, set to empty str
531                    TEST_STR_BYTES,
532                    ORIGIN_ID_FIELD_BYTES,
533                    TEST_STR_BYTES,
534                    PAYLOAD_FIELD_BYTES,
535                    &[TRUE_BYTE],
536                ]
537                .concat()
538                .as_slice()
539            ),
540            Err(UntypedResponseParseError::InvalidIdKey)
541        );
542
543        // Non-str id field value
544        assert_eq!(
545            UntypedResponse::from_slice(
546                [
547                    &[0x83],
548                    ID_FIELD_BYTES,
549                    &[TRUE_BYTE], // id value set to boolean
550                    ORIGIN_ID_FIELD_BYTES,
551                    TEST_STR_BYTES,
552                    PAYLOAD_FIELD_BYTES,
553                    &[TRUE_BYTE],
554                ]
555                .concat()
556                .as_slice()
557            ),
558            Err(UntypedResponseParseError::InvalidId)
559        );
560
561        // Non-utf8 id field value
562        assert_eq!(
563            UntypedResponse::from_slice(
564                [
565                    &[0x83],
566                    ID_FIELD_BYTES,
567                    &[0xa4, 0, 159, 146, 150],
568                    ORIGIN_ID_FIELD_BYTES,
569                    TEST_STR_BYTES,
570                    PAYLOAD_FIELD_BYTES,
571                    &[TRUE_BYTE],
572                ]
573                .concat()
574                .as_slice()
575            ),
576            Err(UntypedResponseParseError::InvalidId)
577        );
578
579        // Missing origin_id field (has valid data itself)
580        assert_eq!(
581            UntypedResponse::from_slice(
582                [
583                    &[0x83],
584                    ID_FIELD_BYTES,
585                    TEST_STR_BYTES,
586                    &[0xa0], // id would be defined here, set to empty str
587                    TEST_STR_BYTES,
588                    PAYLOAD_FIELD_BYTES,
589                    &[TRUE_BYTE],
590                ]
591                .concat()
592                .as_slice()
593            ),
594            Err(UntypedResponseParseError::InvalidOriginIdKey)
595        );
596
597        // Non-str origin_id field value
598        assert_eq!(
599            UntypedResponse::from_slice(
600                [
601                    &[0x83],
602                    ID_FIELD_BYTES,
603                    TEST_STR_BYTES,
604                    ORIGIN_ID_FIELD_BYTES,
605                    &[TRUE_BYTE], // id value set to boolean
606                    PAYLOAD_FIELD_BYTES,
607                    &[TRUE_BYTE],
608                ]
609                .concat()
610                .as_slice()
611            ),
612            Err(UntypedResponseParseError::InvalidOriginId)
613        );
614
615        // Non-utf8 origin_id field value
616        assert_eq!(
617            UntypedResponse::from_slice(
618                [
619                    &[0x83],
620                    ID_FIELD_BYTES,
621                    TEST_STR_BYTES,
622                    ORIGIN_ID_FIELD_BYTES,
623                    &[0xa4, 0, 159, 146, 150],
624                    PAYLOAD_FIELD_BYTES,
625                    &[TRUE_BYTE],
626                ]
627                .concat()
628                .as_slice()
629            ),
630            Err(UntypedResponseParseError::InvalidOriginId)
631        );
632
633        // Missing payload field (has valid data itself)
634        assert_eq!(
635            UntypedResponse::from_slice(
636                [
637                    &[0x83],
638                    ID_FIELD_BYTES,
639                    TEST_STR_BYTES,
640                    ORIGIN_ID_FIELD_BYTES,
641                    TEST_STR_BYTES,
642                    &[0xa0], // payload would be defined here, set to empty str
643                    &[TRUE_BYTE],
644                ]
645                .concat()
646                .as_slice()
647            ),
648            Err(UntypedResponseParseError::InvalidPayloadKey)
649        );
650    }
651}