1use bytes::{Buf, BufMut, Bytes};
2use thiserror::Error;
3use tokio_util::codec::{Decoder, Encoder};
4
5const WIRE_ID: u8 = 0x01;
7
8#[derive(Debug, Error)]
9pub enum Error {
10 #[error("IO error: {0:?}")]
11 Io(#[from] std::io::Error),
12 #[error("Invalid wire ID: {0}")]
13 WireId(u8),
14 #[error("Rejected")]
15 Rejected,
16}
17
18pub struct Codec {
20 state: State,
21}
22
23impl Codec {
24 pub fn new_client() -> Self {
27 Self { state: State::Ack }
28 }
29
30 pub fn new_server() -> Self {
34 Self { state: State::AuthReceive }
35 }
36}
37
38#[derive(Debug, Clone)]
39enum State {
40 AuthReceive,
42 Ack,
44}
45
46#[derive(Debug, Clone)]
47pub enum Message {
48 Auth(Bytes),
50 Ack,
52 Reject,
54}
55
56impl Decoder for Codec {
57 type Item = Message;
58 type Error = Error;
59
60 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
61 match self.state {
62 State::AuthReceive => {
64 if src.is_empty() {
66 return Ok(None);
67 }
68
69 let wire_id = u8::from_be_bytes([src[0]]);
71 if wire_id != WIRE_ID {
72 return Err(Error::WireId(wire_id));
73 }
74
75 if src.len() < 4 {
76 return Ok(None);
77 }
78
79 let id_size = u32::from_be_bytes([src[1], src[2], src[3], src[4]]);
80 if src.len() < id_size as usize {
81 return Ok(None);
82 }
83
84 src.advance(1);
85 src.advance(4);
86
87 let id = src.split_to(id_size as usize).freeze();
88 self.state = State::Ack;
89 Ok(Some(Message::Auth(id)))
90 }
91 State::Ack => {
93 if src.len() < 2 {
94 return Ok(None);
95 }
96
97 let wire_id = u8::from_be_bytes([src[0]]);
99 if wire_id != WIRE_ID {
100 return Err(Error::WireId(wire_id));
101 }
102
103 src.advance(1);
104
105 let ack = src.get_u8();
106
107 if ack == 0 {
108 return Err(Error::Rejected);
109 }
110
111 Ok(Some(Message::Ack))
112 }
113 }
114 }
115}
116
117impl Encoder<Message> for Codec {
118 type Error = std::io::Error;
119
120 fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
121 match item {
122 Message::Auth(id) => {
124 self.state = State::Ack;
125 dst.reserve(1 + 4 + id.len());
126 dst.put_u8(WIRE_ID);
127 dst.put_u32(id.len() as u32);
128 dst.put(id);
129 }
130 Message::Ack => {
132 dst.reserve(1 + 1);
133 dst.put_u8(WIRE_ID);
134 dst.put_u8(1);
135 }
136 Message::Reject => {
137 dst.reserve(1 + 1);
138 dst.put_u8(WIRE_ID);
139 dst.put_u8(0);
140 }
141 }
142
143 Ok(())
144 }
145}