Skip to main content

msg_wire/
auth.rs

1use bytes::{Buf, BufMut, Bytes};
2use thiserror::Error;
3use tokio_util::codec::{Decoder, Encoder};
4
5/// The ID of the auth codec on the wire.
6const WIRE_ID: u8 = 0x01;
7
8#[derive(Debug, Error)]
9pub enum Error {
10    #[error("IO error: {0:?}")]
11    Io(#[from] std::io::Error),
12    #[error("Invalid wire ID: {0}")]
13    WireId(u8),
14    #[error("Rejected")]
15    Rejected,
16}
17
18/// Authentication codec.
19pub struct Codec {
20    state: State,
21}
22
23impl Codec {
24    /// Creates a new authentication codec for a client. This will put the
25    /// codec in the `Ack` state since it will be waiting for an ack.
26    pub fn new_client() -> Self {
27        Self { state: State::Ack }
28    }
29
30    /// Creates a new authentication codec for a server. This will put the
31    /// codec in the `AuthReceive` state since it will be waiting for the
32    /// client to send its ID.
33    pub fn new_server() -> Self {
34        Self { state: State::AuthReceive }
35    }
36}
37
38#[derive(Debug, Clone)]
39enum State {
40    /// Waiting for the client to send its ID
41    AuthReceive,
42    /// Waiting for the server to send an ACK
43    Ack,
44}
45
46#[derive(Debug, Clone)]
47pub enum Message {
48    /// The client sends the ID to the server
49    Auth(Bytes),
50    /// The server responds with an ACK
51    Ack,
52    /// We reject the client
53    Reject,
54}
55
56impl Decoder for Codec {
57    type Item = Message;
58    type Error = Error;
59
60    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
61        match self.state {
62            // We are the server, waiting for the client to send its auth message
63            State::AuthReceive => {
64                // We need at least 5 bytes to read the wire ID and the auth ID
65                if src.is_empty() {
66                    return Ok(None);
67                }
68
69                // Wire ID check (without advancing the cursor)
70                let wire_id = u8::from_be_bytes([src[0]]);
71                if wire_id != WIRE_ID {
72                    return Err(Error::WireId(wire_id));
73                }
74
75                if src.len() < 4 {
76                    return Ok(None);
77                }
78
79                let id_size = u32::from_be_bytes([src[1], src[2], src[3], src[4]]);
80                if src.len() < id_size as usize {
81                    return Ok(None);
82                }
83
84                src.advance(1);
85                src.advance(4);
86
87                let id = src.split_to(id_size as usize).freeze();
88                self.state = State::Ack;
89                Ok(Some(Message::Auth(id)))
90            }
91            // We are the client, and we are waiting for the server to send an ACK
92            State::Ack => {
93                if src.len() < 2 {
94                    return Ok(None);
95                }
96
97                // Wire ID check (without advancing the cursor)
98                let wire_id = u8::from_be_bytes([src[0]]);
99                if wire_id != WIRE_ID {
100                    return Err(Error::WireId(wire_id));
101                }
102
103                src.advance(1);
104
105                let ack = src.get_u8();
106
107                if ack == 0 {
108                    return Err(Error::Rejected);
109                }
110
111                Ok(Some(Message::Ack))
112            }
113        }
114    }
115}
116
117impl Encoder<Message> for Codec {
118    type Error = std::io::Error;
119
120    fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
121        match item {
122            // We are the client, and we are sending the ID to the server
123            Message::Auth(id) => {
124                self.state = State::Ack;
125                dst.reserve(1 + 4 + id.len());
126                dst.put_u8(WIRE_ID);
127                dst.put_u32(id.len() as u32);
128                dst.put(id);
129            }
130            // We are the server, and we are sending an ACK to the client
131            Message::Ack => {
132                dst.reserve(1 + 1);
133                dst.put_u8(WIRE_ID);
134                dst.put_u8(1);
135            }
136            Message::Reject => {
137                dst.reserve(1 + 1);
138                dst.put_u8(WIRE_ID);
139                dst.put_u8(0);
140            }
141        }
142
143        Ok(())
144    }
145}