ratrodlib/
protocol.rs

1//! Protocol message types and serialization.
2//!
3//! This module contains the types and serialization methods for the protocol messages.
4
5use std::fmt::{Display, Formatter};
6
7use anyhow::anyhow;
8use bincode::{BorrowDecode, Decode, Encode};
9use bytes::Bytes;
10use ouroboros::self_referencing;
11use serde::{Deserialize, Serialize};
12
13use crate::base::{Constant, Res, Void};
14
15// Wire types.
16
17/// A helper type for a challenge.
18pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
19
20/// A helper type for a signature.
21pub type Signature = [u8; Constant::SIGNATURE_SIZE];
22
23/// A helper type for an ephemeral public key.
24pub type ExchangePublicKey = [u8; Constant::EXCHANGE_PUBLIC_KEY_SIZE];
25
26/// Serves as the preamble for the connection.
27#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
28pub struct ClientPreamble<'a> {
29    pub exchange_public_key: &'a [u8],
30    pub remote: &'a str,
31    pub challenge: &'a [u8],
32    pub should_encrypt: bool,
33    pub is_udp: bool,
34}
35
36/// Serves as the server's response to the preamble, containing its
37/// public key, its signature of the client's challenge and a challenge.
38/// The server signs the client's challenge to prove its identity.
39#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
40pub struct ServerPreamble<'a> {
41    /// The server's identity public key (base64 encoded Ed25519 key)
42    pub identity_public_key: &'a str,
43    /// The server's ephemeral public key for the key exchange
44    pub exchange_public_key: &'a [u8],
45    /// The server's signature of the client's challenge
46    pub signature: &'a [u8],
47    /// A random challenge for the client to sign
48    pub challenge: &'a [u8],
49}
50
51/// Serves as the client's response to the server's challenge.
52#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
53pub struct ClientAuthentication<'a> {
54    pub identity_public_key: &'a str,
55    pub signature: &'a [u8],
56}
57
58// Message types.
59
60/// A helper type for protocol messages.
61///
62/// This is the main message type for the protocol. It is used to send and receive messages over the network.
63/// It is also used to serialize and deserialize messages.
64#[derive(Debug, PartialEq, Eq, Encode, BorrowDecode)]
65pub enum ProtocolMessage<'a> {
66    ClientPreamble(ClientPreamble<'a>),
67    ServerPreamble(ServerPreamble<'a>),
68    ClientAuthentication(ClientAuthentication<'a>),
69    HandshakeCompletion,
70    Data(&'a [u8]),
71    UdpData(&'a [u8]),
72    Error(ProtocolError),
73    Shutdown,
74}
75
76impl ProtocolMessage<'_> {
77    /// Checks if the message is an error.
78    ///
79    /// If it is, returns the message wrapped in an error.
80    pub fn fail_if_error(&self) -> Res<&Self> {
81        if let ProtocolMessage::Error(error) = self {
82            return Err(anyhow!(error.clone()));
83        }
84
85        Ok(self)
86    }
87}
88
89// Message error types.
90
91/// A helper type for protocol errors.
92///
93/// This is used to send and receive errors over the network.
94/// It is also used to serialize and deserialize errors.
95///
96/// It should not be sent / received over the network, as it
97/// should be sent as a [`ProtocolMessage::Error`] message.
98/// The type system should prevent this from happening.
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Encode, Decode)]
100pub enum ProtocolError {
101    InvalidHost(String),
102    InvalidKey(String),
103    RemoteFailed(String),
104    Unknown(String),
105}
106
107impl Display for ProtocolError {
108    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109        match self {
110            ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
111            ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
112            ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
113            ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
114        }
115    }
116}
117
118impl ProtocolError {
119    /// Sends the error message and shuts down the stream.
120    ///
121    /// The generic parameter R represents the return type expected by the calling function.
122    /// This method always returns an error, but needs to have the expected return type for the context.
123    pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
124    where
125        T: BincodeSend,
126    {
127        let error_message = self.to_string();
128
129        let _ = stream.push(ProtocolMessage::Error(self)).await;
130        let _ = stream.close().await;
131
132        Err(anyhow!(error_message))
133    }
134}
135
136/// A helper type for protocol message guards.
137///
138/// Essentially, this is a wrapper around [`ProtocolMessage`] that allows
139/// for tying, self-referentially, the underlying buffer to the message.
140/// As a result, while the message is essentially "borrowed" from the buffer,
141/// the buffer is "owned" by the guard.
142///
143/// After a guard goes out of scope, the buffer is dropped, and, due to
144/// the way `BytesMut` works, it _may_ (read: "will when able") reclaim
145/// the memory used by this buffer, thereby reducing buffer allocations
146/// and data clones.
147#[self_referencing(pub_extras)]
148pub struct ProtocolMessageGuard {
149    pub buffer: Bytes,
150    #[borrows(buffer)]
151    #[covariant]
152    pub inner: ProtocolMessage<'this>,
153}
154
155impl ProtocolMessageGuard {
156    /// The inner message of this guard.
157    pub fn message(&self) -> &ProtocolMessage<'_> {
158        self.borrow_inner()
159    }
160}
161
162// Bincode stream impls.
163
164/// A trait for sending protocol messages over a stream.
165///
166/// This impl is designed to ensure that the push method can only be used to send
167/// [`ProtocolMessage`] messages. This restriction is important for type safety
168/// and to ensure that all messages sent through the stream follow the protocol
169/// format and are properly encrypted if necessary.
170pub trait BincodeSend: Unpin + Sized {
171    /// Pushes a message to the stream.
172    ///
173    /// Right now, this only requires `T: Encode`, but in the future, it may
174    /// require a concrete type, such as `ProtocolMessage`.
175    fn push<T>(&mut self, message: T) -> impl Future<Output = Void>
176    where
177        T: Encode;
178
179    /// Closes the stream via `shutdown`.
180    fn close(&mut self) -> impl Future<Output = Void>;
181}
182
183/// A trait for receiving protocol messages over a stream.
184///
185/// This impl is designed to ensure that the pull method can only be used to receive
186/// [`ProtocolMessage`] messages. This restriction provides type safety and ensures
187/// proper message decryption and protocol handling for incoming data.
188pub trait BincodeReceive: Unpin + Sized {
189    /// Pulls a message from the stream.
190    ///
191    /// Since are reading here, we just return the concrete type, though
192    /// it stands to reason that we could just constrain this with a Guard
193    /// of a `type Result: BorrowDecode`.
194    fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessageGuard>>;
195}
196
197// Tests.
198
199#[cfg(test)]
200mod tests {
201    use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
202
203    use super::*;
204    use pretty_assertions::assert_eq;
205
206    #[tokio::test]
207    async fn test_bincode() {
208        let (mut client, mut server) = generate_test_duplex();
209
210        let data = ClientPreamble {
211            exchange_public_key: &generate_test_fake_exchange_public_key(),
212            remote: "remote",
213            challenge: &Challenge::default(),
214            should_encrypt: true,
215            is_udp: false,
216        };
217
218        client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
219
220        let guard = server.pull().await.unwrap();
221        let ProtocolMessage::ClientPreamble(message) = guard.message() else {
222            panic!("Failed to receive message");
223        };
224
225        assert_eq!(data, *message);
226    }
227}