distant_net/common/transport/framed/codec/
chain.rs

1use std::io;
2
3use super::{Codec, Frame};
4
5/// Represents a codec that chains together other codecs such that encoding will call the encode
6/// methods of the underlying, chained codecs from left-to-right and decoding will call the decode
7/// methods in reverse order
8#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9pub struct ChainCodec<T, U> {
10    left: T,
11    right: U,
12}
13
14impl<T, U> ChainCodec<T, U> {
15    /// Chains two codecs together such that `left` will be invoked first during encoding and
16    /// `right` will be invoked first during decoding
17    pub fn new(left: T, right: U) -> Self {
18        Self { left, right }
19    }
20
21    /// Returns reference to left codec
22    pub fn as_left(&self) -> &T {
23        &self.left
24    }
25
26    /// Consumes the chain and returns the left codec
27    pub fn into_left(self) -> T {
28        self.left
29    }
30
31    /// Returns reference to right codec
32    pub fn as_right(&self) -> &U {
33        &self.right
34    }
35
36    /// Consumes the chain and returns the right codec
37    pub fn into_right(self) -> U {
38        self.right
39    }
40
41    /// Consumes the chain and returns the left and right codecs
42    pub fn into_left_right(self) -> (T, U) {
43        (self.left, self.right)
44    }
45}
46
47impl<T, U> Codec for ChainCodec<T, U>
48where
49    T: Codec + Clone,
50    U: Codec + Clone,
51{
52    fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
53        Codec::encode(&mut self.left, frame).and_then(|frame| Codec::encode(&mut self.right, frame))
54    }
55
56    fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
57        Codec::decode(&mut self.right, frame).and_then(|frame| Codec::decode(&mut self.left, frame))
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use test_log::test;
64
65    use super::*;
66
67    #[derive(Copy, Clone)]
68    struct TestCodec<'a> {
69        msg: &'a str,
70    }
71
72    impl<'a> TestCodec<'a> {
73        pub fn new(msg: &'a str) -> Self {
74            Self { msg }
75        }
76    }
77
78    impl Codec for TestCodec<'_> {
79        fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
80            let mut item = frame.into_item().to_vec();
81            item.extend_from_slice(self.msg.as_bytes());
82            Ok(Frame::from(item))
83        }
84
85        fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
86            let item = frame.into_item().to_vec();
87            let frame = Frame::new(item.strip_suffix(self.msg.as_bytes()).ok_or_else(|| {
88                io::Error::new(
89                    io::ErrorKind::InvalidData,
90                    format!(
91                        "Decode failed because did not end with suffix: {}",
92                        self.msg
93                    ),
94                )
95            })?);
96            Ok(frame.into_owned())
97        }
98    }
99
100    #[derive(Copy, Clone)]
101    struct ErrCodec;
102
103    impl Codec for ErrCodec {
104        fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
105            Err(io::Error::from(io::ErrorKind::InvalidData))
106        }
107
108        fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
109            Err(io::Error::from(io::ErrorKind::InvalidData))
110        }
111    }
112
113    #[test]
114    fn encode_should_invoke_left_codec_followed_by_right_codec() {
115        let mut codec = ChainCodec::new(TestCodec::new("hello"), TestCodec::new("world"));
116        let frame = codec.encode(Frame::new(b"some bytes")).unwrap();
117        assert_eq!(frame, b"some byteshelloworld");
118    }
119
120    #[test]
121    fn encode_should_fail_if_left_codec_fails_to_encode() {
122        let mut codec = ChainCodec::new(ErrCodec, TestCodec::new("world"));
123        assert_eq!(
124            codec.encode(Frame::new(b"some bytes")).unwrap_err().kind(),
125            io::ErrorKind::InvalidData
126        );
127    }
128
129    #[test]
130    fn encode_should_fail_if_right_codec_fails_to_encode() {
131        let mut codec = ChainCodec::new(TestCodec::new("hello"), ErrCodec);
132        assert_eq!(
133            codec.encode(Frame::new(b"some bytes")).unwrap_err().kind(),
134            io::ErrorKind::InvalidData
135        );
136    }
137
138    #[test]
139    fn decode_should_invoke_right_codec_followed_by_left_codec() {
140        let mut codec = ChainCodec::new(TestCodec::new("hello"), TestCodec::new("world"));
141        let frame = codec.decode(Frame::new(b"some byteshelloworld")).unwrap();
142        assert_eq!(frame, b"some bytes");
143    }
144
145    #[test]
146    fn decode_should_fail_if_left_codec_fails_to_decode() {
147        let mut codec = ChainCodec::new(ErrCodec, TestCodec::new("world"));
148        assert_eq!(
149            codec.decode(Frame::new(b"some bytes")).unwrap_err().kind(),
150            io::ErrorKind::InvalidData
151        );
152    }
153
154    #[test]
155    fn decode_should_fail_if_right_codec_fails_to_decode() {
156        let mut codec = ChainCodec::new(TestCodec::new("hello"), ErrCodec);
157        assert_eq!(
158            codec.decode(Frame::new(b"some bytes")).unwrap_err().kind(),
159            io::ErrorKind::InvalidData
160        );
161    }
162}