ratrodlib/
protocol.rs

1//! Protocol message types and serialization.
2//!
3//! This module contains the types and serialization methods for the protocol messages.
4
5use std::fmt::{Display, Formatter};
6
7use anyhow::Context;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::base::{Constant, Err, Res, Void};
13
14// Wire types.
15
16/// A helper type for a challenge.
17pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
18
19/// A helper type for a signature.
20pub type Signature = [u8; Constant::SIGNATURE_SIZE];
21
22/// A helper type for an ephemeral public key.
23pub type ExchangePublicKey = [u8; Constant::PEER_PUBLIC_KEY_SIZE];
24
25/// Serves as the preamble for the connection.
26#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
27pub struct ClientPreamble {
28    pub exchange_public_key: ExchangePublicKey,
29    pub remote: String,
30    pub challenge: Challenge,
31    pub should_encrypt: bool,
32    pub is_udp: bool,
33}
34
35/// Serves as the server's response to the preamble, containing its
36/// public key, its signature of the client's challenge and a challenge.
37/// The server signs the client's challenge to prove its identity.
38#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
39pub struct ServerPreamble {
40    /// The server's identity public key (base64 encoded Ed25519 key)
41    pub identity_public_key: String,
42    /// The server's ephemeral public key for the key exchange
43    pub exchange_public_key: ExchangePublicKey,
44    /// The server's signature of the client's challenge
45    pub signature: SerializeableSignature,
46    /// A random challenge for the client to sign
47    pub challenge: Challenge,
48}
49
50/// Serves as the client's response to the server's challenge.
51#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
52pub struct ClientAuthentication {
53    pub identity_public_key: String,
54    pub signature: SerializeableSignature,
55}
56
57// Message types.
58
59/// A helper trait for protocol messages.
60pub trait BincodeMessage: Serialize + DeserializeOwned {}
61
62/// A helper type for protocol messages.
63///
64/// This is the main message type for the protocol. It is used to send and receive messages over the network.
65/// It is also used to serialize and deserialize messages.
66#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
67pub enum ProtocolMessage {
68    ClientPreamble(ClientPreamble),
69    ServerPreamble(ServerPreamble),
70    ClientAuthentication(ClientAuthentication),
71    HandshakeCompletion,
72    Data(Vec<u8>),
73    UdpData(Vec<u8>),
74    Error(ProtocolError),
75}
76
77impl ProtocolMessage {
78    /// Checks if the message is an error.
79    ///
80    /// If it is, returns the message wrapped in an error.
81    pub fn fail_if_error(self) -> Res<Self> {
82        if let ProtocolMessage::Error(error) = self {
83            return Err(Err::msg(error));
84        }
85
86        Ok(self)
87    }
88}
89
90impl BincodeMessage for ProtocolMessage {}
91
92/// A wrapper type for protocol messages.
93#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
94pub enum ProtocolMessageWrapper {
95    Plain(ProtocolMessage),
96    Encrypted { nonce: [u8; Constant::SHARED_SECRET_NONCE_SIZE], data: Vec<u8> },
97}
98
99// Message error types.
100
101/// A helper type for protocol errors.
102///
103/// This is used to send and receive errors over the network.
104/// It is also used to serialize and deserialize errors.
105///
106/// It should not be sent / received over the network, as it
107/// should be sent as a [`ProtocolMessage::Error`] message.
108/// The type system should prevent this from happening.
109#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
110pub enum ProtocolError {
111    InvalidHost(String),
112    InvalidKey(String),
113    RemoteFailed(String),
114    Unknown(String),
115}
116
117impl Display for ProtocolError {
118    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119        match self {
120            ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
121            ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
122            ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
123            ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
124        }
125    }
126}
127
128impl ProtocolError {
129    /// Sends the error message and shuts down the stream.
130    ///
131    /// The generic parameter R represents the return type expected by the calling function.
132    /// This method always returns an error, but needs to have the expected return type for the context.
133    pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
134    where
135        T: BincodeSend,
136    {
137        let error_message = self.to_string();
138
139        let _ = stream.push(ProtocolMessage::Error(self)).await;
140        let _ = stream.shutdown().await;
141
142        Err(Err::msg(error_message))
143    }
144}
145
146// Bincode stream impls.
147
148/// A trait for sending protocol messages over a stream.
149///
150/// This impl is designed to ensure that the push method can only be used to send
151/// [`ProtocolMessage`] messages. This restriction is important for type safety
152/// and to ensure that all messages sent through the stream follow the protocol
153/// format and are properly encrypted if necessary.
154pub trait BincodeSend: Sink<ProtocolMessage> + AsyncWrite + AsyncWriteExt + Unpin + Sized {
155    fn push(&mut self, message: ProtocolMessage) -> impl Future<Output = Void> {
156        async move { self.send(message).await.map_err(|_| Err::msg("Failed to send message")) }
157    }
158}
159
160/// A trait for receiving protocol messages over a stream.
161///
162/// This impl is designed to ensure that the pull method can only be used to receive
163/// [`ProtocolMessage`] messages. This restriction provides type safety and ensures
164/// proper message decryption and protocol handling for incoming data.
165pub trait BincodeReceive: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + AsyncReadExt + Unpin + Sized {
166    fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessage>> {
167        async move {
168            let message = self.next().await.context("Failed to read message")?.context("Failed to parse message")?;
169            Ok(message)
170        }
171    }
172}
173
174// Blanket impl for BincodeSend and BincodeReceive where T implements `Sink` and `Stream`.
175
176impl<T> BincodeSend for T where Self: Sink<ProtocolMessage> + AsyncWrite + Unpin + Sized {}
177impl<T> BincodeReceive for T where Self: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + Unpin + Sized {}
178
179// Signature serialization.
180
181/// A helper type for serializing signatures (bincode cannot serialize a `[u8; 64]` out of the box).
182#[derive(Clone, Debug, PartialEq, Eq)]
183pub struct SerializeableSignature(pub Signature);
184
185impl From<Signature> for SerializeableSignature {
186    fn from(signature: Signature) -> Self {
187        Self(signature)
188    }
189}
190
191impl From<&Signature> for SerializeableSignature {
192    fn from(signature: &Signature) -> Self {
193        Self(*signature)
194    }
195}
196
197impl From<SerializeableSignature> for Signature {
198    fn from(signature: SerializeableSignature) -> Self {
199        signature.0
200    }
201}
202
203impl Serialize for SerializeableSignature {
204    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205    where
206        S: serde::Serializer,
207    {
208        serializer.serialize_bytes(&self.0)
209    }
210}
211
212impl<'de> Deserialize<'de> for SerializeableSignature {
213    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
214    where
215        D: serde::Deserializer<'de>,
216    {
217        let bytes = <&[u8]>::deserialize(deserializer)?;
218
219        if bytes.len() != Constant::SIGNATURE_SIZE {
220            return Err(serde::de::Error::custom(format!("Invalid signature length: {}", bytes.len())));
221        }
222
223        let mut signature = [0; Constant::SIGNATURE_SIZE];
224        signature.copy_from_slice(bytes);
225
226        Ok(SerializeableSignature(signature))
227    }
228}
229
230// Tests.
231
232#[cfg(test)]
233mod tests {
234    use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
235
236    use super::*;
237    use pretty_assertions::assert_eq;
238
239    #[tokio::test]
240    async fn test_bincode() {
241        let (mut client, mut server) = generate_test_duplex();
242
243        let data = ClientPreamble {
244            exchange_public_key: generate_test_fake_exchange_public_key(),
245            remote: "remote".to_string(),
246            challenge: Challenge::default(),
247            should_encrypt: true,
248            is_udp: false,
249        };
250
251        client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
252
253        let ProtocolMessage::ClientPreamble(message) = server.pull().await.unwrap() else {
254            panic!("Failed to receive message");
255        };
256
257        assert_eq!(data, message);
258    }
259}