distant_net/common/transport/framed/codec/
predicate.rs

1use std::io;
2use std::sync::Arc;
3
4use super::{Codec, Frame};
5
6/// Represents a codec that invokes one of two codecs based on the given predicate
7#[derive(Debug, Default, PartialEq, Eq)]
8pub struct PredicateCodec<T, U, P> {
9    left: T,
10    right: U,
11    predicate: Arc<P>,
12}
13
14impl<T, U, P> PredicateCodec<T, U, P> {
15    /// Creates a new predicate codec where the left codec is invoked if the predicate returns true
16    /// and the right codec is invoked if the predicate returns false
17    pub fn new(left: T, right: U, predicate: P) -> Self {
18        Self {
19            left,
20            right,
21            predicate: Arc::new(predicate),
22        }
23    }
24
25    /// Returns reference to left codec
26    pub fn as_left(&self) -> &T {
27        &self.left
28    }
29
30    /// Consumes the chain and returns the left codec
31    pub fn into_left(self) -> T {
32        self.left
33    }
34
35    /// Returns reference to right codec
36    pub fn as_right(&self) -> &U {
37        &self.right
38    }
39
40    /// Consumes the chain and returns the right codec
41    pub fn into_right(self) -> U {
42        self.right
43    }
44
45    /// Consumes the chain and returns the left and right codecs
46    pub fn into_left_right(self) -> (T, U) {
47        (self.left, self.right)
48    }
49}
50
51impl<T, U, P> Clone for PredicateCodec<T, U, P>
52where
53    T: Clone,
54    U: Clone,
55{
56    fn clone(&self) -> Self {
57        Self {
58            left: self.left.clone(),
59            right: self.right.clone(),
60            predicate: Arc::clone(&self.predicate),
61        }
62    }
63}
64
65impl<T, U, P> Codec for PredicateCodec<T, U, P>
66where
67    T: Codec + Clone,
68    U: Codec + Clone,
69    P: Fn(&Frame) -> bool,
70{
71    fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
72        if (self.predicate)(&frame) {
73            Codec::encode(&mut self.left, frame)
74        } else {
75            Codec::encode(&mut self.right, frame)
76        }
77    }
78
79    fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
80        if (self.predicate)(&frame) {
81            Codec::decode(&mut self.left, frame)
82        } else {
83            Codec::decode(&mut self.right, frame)
84        }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use test_log::test;
91
92    use super::*;
93
94    #[derive(Copy, Clone)]
95    struct TestCodec<'a> {
96        msg: &'a str,
97    }
98
99    impl<'a> TestCodec<'a> {
100        pub fn new(msg: &'a str) -> Self {
101            Self { msg }
102        }
103    }
104
105    impl Codec for TestCodec<'_> {
106        fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
107            let mut item = frame.into_item().to_vec();
108            item.extend_from_slice(self.msg.as_bytes());
109            Ok(Frame::from(item))
110        }
111
112        fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
113            let item = frame.into_item().to_vec();
114            let frame = Frame::new(item.strip_suffix(self.msg.as_bytes()).ok_or_else(|| {
115                io::Error::new(
116                    io::ErrorKind::InvalidData,
117                    format!(
118                        "Decode failed because did not end with suffix: {}",
119                        self.msg
120                    ),
121                )
122            })?);
123            Ok(frame.into_owned())
124        }
125    }
126
127    #[derive(Copy, Clone)]
128    struct ErrCodec;
129
130    impl Codec for ErrCodec {
131        fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
132            Err(io::Error::from(io::ErrorKind::InvalidData))
133        }
134
135        fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
136            Err(io::Error::from(io::ErrorKind::InvalidData))
137        }
138    }
139
140    #[test]
141    fn encode_should_invoke_left_codec_if_predicate_returns_true() {
142        let mut codec = PredicateCodec::new(
143            TestCodec::new("hello"),
144            TestCodec::new("world"),
145            |_: &Frame| true,
146        );
147        let frame = codec.encode(Frame::new(b"some bytes")).unwrap();
148        assert_eq!(frame, b"some byteshello");
149    }
150
151    #[test]
152    fn encode_should_invoke_right_codec_if_predicate_returns_false() {
153        let mut codec = PredicateCodec::new(
154            TestCodec::new("hello"),
155            TestCodec::new("world"),
156            |_: &Frame| false,
157        );
158        let frame = codec.encode(Frame::new(b"some bytes")).unwrap();
159        assert_eq!(frame, b"some bytesworld");
160    }
161
162    #[test]
163    fn decode_should_invoke_left_codec_if_predicate_returns_true() {
164        let mut codec = PredicateCodec::new(
165            TestCodec::new("hello"),
166            TestCodec::new("world"),
167            |_: &Frame| true,
168        );
169        let frame = codec.decode(Frame::new(b"some byteshello")).unwrap();
170        assert_eq!(frame, b"some bytes");
171    }
172
173    #[test]
174    fn decode_should_invoke_right_codec_if_predicate_returns_false() {
175        let mut codec = PredicateCodec::new(
176            TestCodec::new("hello"),
177            TestCodec::new("world"),
178            |_: &Frame| false,
179        );
180        let frame = codec.decode(Frame::new(b"some bytesworld")).unwrap();
181        assert_eq!(frame, b"some bytes");
182    }
183}