ratrodlib/protocol.rs
1//! Protocol message types and serialization.
2//!
3//! This module contains the types and serialization methods for the protocol messages.
4
5use std::fmt::{Display, Formatter};
6
7use anyhow::anyhow;
8use bincode::{BorrowDecode, Decode, Encode};
9use bytes::Bytes;
10use ouroboros::self_referencing;
11use serde::{Deserialize, Serialize};
12
13use crate::base::{Constant, Res, Void};
14
15// Wire types.
16
17/// A helper type for a challenge.
18pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
19
20/// A helper type for a signature.
21pub type Signature = [u8; Constant::SIGNATURE_SIZE];
22
23/// A helper type for an ephemeral public key.
24pub type ExchangePublicKey = [u8; Constant::EXCHANGE_PUBLIC_KEY_SIZE];
25
26/// Serves as the preamble for the connection.
27#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
28pub struct ClientPreamble<'a> {
29 pub exchange_public_key: &'a [u8],
30 pub remote: &'a str,
31 pub challenge: &'a [u8],
32 pub should_encrypt: bool,
33 pub is_udp: bool,
34}
35
36/// Serves as the server's response to the preamble, containing its
37/// public key, its signature of the client's challenge and a challenge.
38/// The server signs the client's challenge to prove its identity.
39#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
40pub struct ServerPreamble<'a> {
41 /// The server's identity public key (base64 encoded Ed25519 key)
42 pub identity_public_key: &'a str,
43 /// The server's ephemeral public key for the key exchange
44 pub exchange_public_key: &'a [u8],
45 /// The server's signature of the client's challenge
46 pub signature: &'a [u8],
47 /// A random challenge for the client to sign
48 pub challenge: &'a [u8],
49}
50
51/// Serves as the client's response to the server's challenge.
52#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
53pub struct ClientAuthentication<'a> {
54 pub identity_public_key: &'a str,
55 pub signature: &'a [u8],
56}
57
58// Message types.
59
60/// A helper type for protocol messages.
61///
62/// This is the main message type for the protocol. It is used to send and receive messages over the network.
63/// It is also used to serialize and deserialize messages.
64#[derive(Debug, PartialEq, Eq, Encode, BorrowDecode)]
65pub enum ProtocolMessage<'a> {
66 ClientPreamble(ClientPreamble<'a>),
67 ServerPreamble(ServerPreamble<'a>),
68 ClientAuthentication(ClientAuthentication<'a>),
69 HandshakeCompletion,
70 Data(&'a [u8]),
71 UdpData(&'a [u8]),
72 Error(ProtocolError),
73 Shutdown,
74}
75
76impl ProtocolMessage<'_> {
77 /// Checks if the message is an error.
78 ///
79 /// If it is, returns the message wrapped in an error.
80 pub fn fail_if_error(&self) -> Res<&Self> {
81 if let ProtocolMessage::Error(error) = self {
82 return Err(anyhow!(error.clone()));
83 }
84
85 Ok(self)
86 }
87}
88
89// Message error types.
90
91/// A helper type for protocol errors.
92///
93/// This is used to send and receive errors over the network.
94/// It is also used to serialize and deserialize errors.
95///
96/// It should not be sent / received over the network, as it
97/// should be sent as a [`ProtocolMessage::Error`] message.
98/// The type system should prevent this from happening.
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Encode, Decode)]
100pub enum ProtocolError {
101 InvalidHost(String),
102 InvalidKey(String),
103 RemoteFailed(String),
104 Unknown(String),
105}
106
107impl Display for ProtocolError {
108 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109 match self {
110 ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
111 ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
112 ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
113 ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
114 }
115 }
116}
117
118impl ProtocolError {
119 /// Sends the error message and shuts down the stream.
120 ///
121 /// The generic parameter R represents the return type expected by the calling function.
122 /// This method always returns an error, but needs to have the expected return type for the context.
123 pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
124 where
125 T: BincodeSend,
126 {
127 let error_message = self.to_string();
128
129 let _ = stream.push(ProtocolMessage::Error(self)).await;
130 let _ = stream.close().await;
131
132 Err(anyhow!(error_message))
133 }
134}
135
136/// A helper type for protocol message guards.
137///
138/// Essentially, this is a wrapper around [`ProtocolMessage`] that allows
139/// for tying, self-referentially, the underlying buffer to the message.
140/// As a result, while the message is essentially "borrowed" from the buffer,
141/// the buffer is "owned" by the guard.
142///
143/// After a guard goes out of scope, the buffer is dropped, and, due to
144/// the way `BytesMut` works, it _may_ (read: "will when able") reclaim
145/// the memory used by this buffer, thereby reducing buffer allocations
146/// and data clones.
147#[self_referencing(pub_extras)]
148pub struct ProtocolMessageGuard {
149 pub buffer: Bytes,
150 #[borrows(buffer)]
151 #[covariant]
152 pub inner: ProtocolMessage<'this>,
153}
154
155impl ProtocolMessageGuard {
156 /// The inner message of this guard.
157 pub fn message(&self) -> &ProtocolMessage<'_> {
158 self.borrow_inner()
159 }
160}
161
162// Bincode stream impls.
163
164/// A trait for sending protocol messages over a stream.
165///
166/// This impl is designed to ensure that the push method can only be used to send
167/// [`ProtocolMessage`] messages. This restriction is important for type safety
168/// and to ensure that all messages sent through the stream follow the protocol
169/// format and are properly encrypted if necessary.
170pub trait BincodeSend: Unpin + Sized {
171 /// Pushes a message to the stream.
172 ///
173 /// Right now, this only requires `T: Encode`, but in the future, it may
174 /// require a concrete type, such as `ProtocolMessage`.
175 fn push<T>(&mut self, message: T) -> impl Future<Output = Void>
176 where
177 T: Encode;
178
179 /// Closes the stream via `shutdown`.
180 fn close(&mut self) -> impl Future<Output = Void>;
181}
182
183/// A trait for receiving protocol messages over a stream.
184///
185/// This impl is designed to ensure that the pull method can only be used to receive
186/// [`ProtocolMessage`] messages. This restriction provides type safety and ensures
187/// proper message decryption and protocol handling for incoming data.
188pub trait BincodeReceive: Unpin + Sized {
189 /// Pulls a message from the stream.
190 ///
191 /// Since are reading here, we just return the concrete type, though
192 /// it stands to reason that we could just constrain this with a Guard
193 /// of a `type Result: BorrowDecode`.
194 fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessageGuard>>;
195}
196
197// Tests.
198
199#[cfg(test)]
200mod tests {
201 use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
202
203 use super::*;
204 use pretty_assertions::assert_eq;
205
206 #[tokio::test]
207 async fn test_bincode() {
208 let (mut client, mut server) = generate_test_duplex();
209
210 let data = ClientPreamble {
211 exchange_public_key: &generate_test_fake_exchange_public_key(),
212 remote: "remote",
213 challenge: &Challenge::default(),
214 should_encrypt: true,
215 is_udp: false,
216 };
217
218 client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
219
220 let guard = server.pull().await.unwrap();
221 let ProtocolMessage::ClientPreamble(message) = guard.message() else {
222 panic!("Failed to receive message");
223 };
224
225 assert_eq!(data, *message);
226 }
227}