1use bytes::{Buf, BufMut, Bytes};
2use thiserror::Error;
3use tokio_util::codec::{Decoder, Encoder};
4
5const WIRE_ID: u8 = 0x02;
7
8#[derive(Debug, Error)]
9pub enum Error {
10 #[error("IO error: {0:?}")]
11 Io(#[from] std::io::Error),
12 #[error("Invalid wire ID: {0}")]
13 WireId(u8),
14 #[error("Failed to decompress message")]
15 Decompression,
16}
17
18#[derive(Debug, Clone)]
19pub struct Message {
20 header: Header,
22 payload: Bytes,
24}
25
26impl Message {
27 #[inline]
28 pub fn new(id: u32, compression_type: u8, payload: Bytes) -> Self {
29 Self { header: Header { id, compression_type, size: payload.len() as u32 }, payload }
30 }
31
32 #[inline]
33 pub fn id(&self) -> u32 {
34 self.header.id
35 }
36
37 #[inline]
38 pub fn payload_size(&self) -> u32 {
39 self.header.size
40 }
41
42 #[inline]
43 pub fn size(&self) -> usize {
44 self.header.len() + self.payload_size() as usize
45 }
46
47 #[inline]
48 pub fn header(&self) -> &Header {
49 &self.header
50 }
51
52 #[inline]
53 pub fn payload(&self) -> &Bytes {
54 &self.payload
55 }
56
57 #[inline]
58 pub fn into_payload(self) -> Bytes {
59 self.payload
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
64pub struct Header {
65 pub(crate) compression_type: u8,
67 pub(crate) id: u32,
69 pub(crate) size: u32,
71}
72
73impl Header {
74 #[inline]
76 pub fn len(&self) -> usize {
77 4 + 4 + 1 }
81
82 #[inline]
83 pub fn is_empty(&self) -> bool {
84 self.len() == 0
85 }
86
87 #[inline]
88 pub fn compression_type(&self) -> u8 {
89 self.compression_type
90 }
91}
92
93#[derive(Default)]
94enum State {
95 #[default]
96 Header,
97 Payload(Header),
98}
99
100#[derive(Default)]
101pub struct Codec {
102 state: State,
104}
105
106impl Codec {
107 pub fn new() -> Self {
108 Self::default()
109 }
110}
111
112impl Decoder for Codec {
113 type Item = Message;
114 type Error = Error;
115
116 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
117 loop {
118 match self.state {
119 State::Header => {
120 let mut cursor = 0;
121
122 if src.is_empty() {
123 return Ok(None);
124 }
125
126 let wire_id = u8::from_be_bytes([src[cursor]]);
128 cursor += 1;
129 if wire_id != WIRE_ID {
130 return Err(Error::WireId(wire_id));
131 }
132
133 if src.len() < cursor + 1 {
135 return Ok(None);
136 }
137
138 let compression_type = u8::from_be_bytes([src[cursor]]);
139
140 cursor += 1;
141
142 if src.len() < cursor + 8 {
143 return Ok(None);
144 }
145
146 src.advance(cursor);
148
149 let header =
151 Header { compression_type, id: src.get_u32(), size: src.get_u32() };
152
153 self.state = State::Payload(header);
154 }
155 State::Payload(header) => {
156 if src.len() < header.size as usize {
157 return Ok(None);
158 }
159
160 let payload = src.split_to(header.size as usize);
161 let message = Message { header, payload: payload.freeze() };
162
163 self.state = State::Header;
164 return Ok(Some(message));
165 }
166 }
167 }
168 }
169}
170
171impl Encoder<Message> for Codec {
172 type Error = Error;
173
174 fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
175 dst.reserve(1 + item.header.len() + item.payload_size() as usize);
176
177 dst.put_u8(WIRE_ID);
178 dst.put_u8(item.header.compression_type);
179 dst.put_u32(item.header.id);
180 dst.put_u32(item.header.size);
181 dst.put(item.payload);
182
183 Ok(())
184 }
185}