1use bytes::{Buf, BufMut, BytesMut};
2use tokio::io;
3use tokio_util::codec::{Decoder, Encoder};
4
5static PREFACE: &[u8] = b"ort.olix0r.net/load\r\n\r\n";
6
7#[derive(Debug)]
8pub struct Codec<C> {
9 preface: &'static [u8],
10 inner: C,
11 state: State,
12}
13
14#[derive(Debug)]
15enum State {
16 Init,
17 Prefaced,
18}
19
20impl<C> From<C> for Codec<C> {
23 fn from(inner: C) -> Self {
24 Self {
25 inner,
26 preface: PREFACE,
27 state: State::Init,
28 }
29 }
30}
31
32impl<C: Default> Default for Codec<C> {
33 fn default() -> Self {
34 Self::from(C::default())
35 }
36}
37
38impl<D: Decoder> Decoder for Codec<D> {
39 type Item = D::Item;
40 type Error = D::Error;
41
42 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<D::Item>, D::Error> {
43 loop {
44 match self.state {
45 State::Prefaced => {
46 return self.inner.decode(src);
47 }
48 State::Init => {
49 if src.len() < self.preface.len() {
50 return Ok(None);
51 }
52 if &src[0..self.preface.len()] != self.preface {
53 return Err(D::Error::from(io::Error::new(
54 io::ErrorKind::InvalidData,
55 "Invalid protocol header",
56 )));
57 }
58 src.advance(self.preface.len());
59 self.state = State::Prefaced;
60 }
61 }
62 }
63 }
64}
65
66impl<T, E: Encoder<T>> Encoder<T> for Codec<E> {
67 type Error = E::Error;
68
69 fn encode(&mut self, value: T, dst: &mut BytesMut) -> Result<(), E::Error> {
70 loop {
71 match self.state {
72 State::Prefaced => {
73 return self.inner.encode(value, dst);
74 }
75 State::Init => {
76 dst.reserve(self.preface.len());
77 dst.put(self.preface);
78 self.state = State::Prefaced;
79 }
80 }
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use bytes::Bytes;
89 use tokio_util::codec::LengthDelimitedCodec;
90
91 #[tokio::test]
92 async fn roundtrip() {
93 let b0 = Bytes::from_static(b"abcde");
94 let b1 = Bytes::from_static(b"fghij");
95
96 let mut buf = BytesMut::with_capacity(100);
97
98 let mut enc = Codec::from(LengthDelimitedCodec::default());
99 enc.encode(b0.clone(), &mut buf).expect("must encode");
100 enc.encode(b1.clone(), &mut buf).expect("must encode");
101
102 let mut dec = Codec::from(LengthDelimitedCodec::default());
103 let d0 = dec
104 .decode(&mut buf)
105 .expect("must decode")
106 .expect("must decode");
107 let d1 = dec
108 .decode(&mut buf)
109 .expect("must decode")
110 .expect("must decode");
111 assert_eq!(d0.freeze(), b0);
112 assert_eq!(d1.freeze(), b1);
113 }
114}