ort_tcp/
preface.rs

1use bytes::{Buf, BufMut, BytesMut};
2use tokio::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5static PREFACE: &[u8] = b"ort.olix0r.net/load\r\n\r\n";
6
7#[derive(Debug)]
8pub struct Codec<C> {
9    preface: &'static [u8],
10    inner: C,
11    state: State,
12}
13
14#[derive(Debug)]
15enum State {
16    Init,
17    Prefaced,
18}
19
20// === impl Codec ===
21
22impl<C> From<C> for Codec<C> {
23    fn from(inner: C) -> Self {
24        Self {
25            inner,
26            preface: PREFACE,
27            state: State::Init,
28        }
29    }
30}
31
32impl<C: Default> Default for Codec<C> {
33    fn default() -> Self {
34        Self::from(C::default())
35    }
36}
37
38impl<D: Decoder> Decoder for Codec<D> {
39    type Item = D::Item;
40    type Error = D::Error;
41
42    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<D::Item>, D::Error> {
43        loop {
44            match self.state {
45                State::Prefaced => {
46                    return self.inner.decode(src);
47                }
48                State::Init => {
49                    if src.len() < self.preface.len() {
50                        return Ok(None);
51                    }
52                    if &src[0..self.preface.len()] != self.preface {
53                        return Err(D::Error::from(io::Error::new(
54                            io::ErrorKind::InvalidData,
55                            "Invalid protocol header",
56                        )));
57                    }
58                    src.advance(self.preface.len());
59                    self.state = State::Prefaced;
60                }
61            }
62        }
63    }
64}
65
66impl<T, E: Encoder<T>> Encoder<T> for Codec<E> {
67    type Error = E::Error;
68
69    fn encode(&mut self, value: T, dst: &mut BytesMut) -> Result<(), E::Error> {
70        loop {
71            match self.state {
72                State::Prefaced => {
73                    return self.inner.encode(value, dst);
74                }
75                State::Init => {
76                    dst.reserve(self.preface.len());
77                    dst.put(self.preface);
78                    self.state = State::Prefaced;
79                }
80            }
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use bytes::Bytes;
89    use tokio_util::codec::LengthDelimitedCodec;
90
91    #[tokio::test]
92    async fn roundtrip() {
93        let b0 = Bytes::from_static(b"abcde");
94        let b1 = Bytes::from_static(b"fghij");
95
96        let mut buf = BytesMut::with_capacity(100);
97
98        let mut enc = Codec::from(LengthDelimitedCodec::default());
99        enc.encode(b0.clone(), &mut buf).expect("must encode");
100        enc.encode(b1.clone(), &mut buf).expect("must encode");
101
102        let mut dec = Codec::from(LengthDelimitedCodec::default());
103        let d0 = dec
104            .decode(&mut buf)
105            .expect("must decode")
106            .expect("must decode");
107        let d1 = dec
108            .decode(&mut buf)
109            .expect("must decode")
110            .expect("must decode");
111        assert_eq!(d0.freeze(), b0);
112        assert_eq!(d1.freeze(), b1);
113    }
114}