ratrodlib/
protocol.rs

1//! Protocol message types and serialization.
2//!
3//! This module contains the types and serialization methods for the protocol messages.
4
5use std::fmt::{Display, Formatter};
6
7use anyhow::Context;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::base::{Constant, Err, Res, Void};
13
14// Wire types.
15
16/// A helper type for a challenge.
17pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
18
19/// A helper type for a signature.
20pub type Signature = [u8; Constant::SIGNATURE_SIZE];
21
22/// A helper type for an ephemeral public key.
23pub type ExchangePublicKey = [u8; Constant::PEER_PUBLIC_KEY_SIZE];
24
25/// Serves as the preamble for the connection.
26#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
27pub struct ClientPreamble {
28    pub exchange_public_key: ExchangePublicKey,
29    pub remote: String,
30    pub challenge: Challenge,
31    pub should_encrypt: bool,
32    pub is_udp: bool,
33}
34
35/// Serves as the server's response to the preamble, containing its
36/// public key, its signature of the client's challenge and a challenge.
37#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
38pub struct ServerPreamble {
39    pub identity_public_key: String,
40    pub exchange_public_key: ExchangePublicKey,
41    pub signature: SerializeableSignature,
42    pub challenge: Challenge,
43}
44
45/// Serves as the client's response to the server's challenge.
46#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
47pub struct ClientAuthentication {
48    pub identity_public_key: String,
49    pub signature: SerializeableSignature,
50}
51
52// Message types.
53
54/// A helper trait for protocol messages.
55pub trait BincodeMessage: Serialize + DeserializeOwned {}
56
57/// A helper type for protocol messages.
58///
59/// This is the main message type for the protocol. It is used to send and receive messages over the network.
60/// It is also used to serialize and deserialize messages.
61#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
62pub enum ProtocolMessage {
63    ClientPreamble(ClientPreamble),
64    ServerPreamble(ServerPreamble),
65    ClientAuthentication(ClientAuthentication),
66    HandshakeCompletion,
67    Data(Vec<u8>),
68    UdpData(Vec<u8>),
69    Error(ProtocolError),
70}
71
72impl ProtocolMessage {
73    /// Checks if the message is an error.
74    ///
75    /// If it is, returns the message wrapped in an error.
76    pub fn fail_if_error(self) -> Res<Self> {
77        if let ProtocolMessage::Error(error) = self {
78            return Err(Err::msg(error));
79        }
80
81        Ok(self)
82    }
83}
84
85impl BincodeMessage for ProtocolMessage {}
86
87/// A wrapper type for protocol messages.
88#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
89pub enum ProtocolMessageWrapper {
90    Plain(ProtocolMessage),
91    Encrypted { nonce: [u8; Constant::SHARED_SECRET_NONCE_SIZE], data: Vec<u8> },
92}
93
94// Message error types.
95
96/// A helper type for protocol errors.
97///
98/// This is used to send and receive errors over the network.
99/// It is also used to serialize and deserialize errors.
100///
101/// It should not be sent / received over the network, as it
102/// should be sent as a [`ProtocolMessage::Error`] message.
103/// The type system should prevent this from happening.
104#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
105pub enum ProtocolError {
106    InvalidHost(String),
107    InvalidKey(String),
108    RemoteFailed(String),
109    Unknown(String),
110}
111
112impl Display for ProtocolError {
113    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114        match self {
115            ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
116            ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
117            ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
118            ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
119        }
120    }
121}
122
123impl ProtocolError {
124    /// Sends the error message and shuts down the stream.
125    pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
126    where
127        T: BincodeSend,
128    {
129        let error_message = self.to_string();
130
131        let _ = stream.push(ProtocolMessage::Error(self)).await;
132        let _ = stream.shutdown().await;
133
134        Err(Err::msg(error_message))
135    }
136}
137
138// Bincode stream impls.
139
140/// A trait for sending protocol messages over a stream.
141///
142/// This impl is designed to ensure that the push method can only be used to send
143/// [`ProtocolMessage`] messages.
144pub trait BincodeSend: Sink<ProtocolMessage> + AsyncWrite + AsyncWriteExt + Unpin + Sized {
145    fn push(&mut self, message: ProtocolMessage) -> impl Future<Output = Void> {
146        async move { self.send(message).await.map_err(|_| Err::msg("Failed to send message")) }
147    }
148}
149
150/// A trait for receiving protocol messages over a stream.
151///
152/// This impl is designed to ensure that the pull method can only be used to receive
153/// [`ProtocolMessage`] messages.
154pub trait BincodeReceive: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + AsyncReadExt + Unpin + Sized {
155    fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessage>> {
156        async move {
157            let message = self.next().await.context("Failed to read message")?.context("Failed to parse message")?;
158            Ok(message)
159        }
160    }
161}
162
163// Blanket impl for BincodeSend and BincodeReceive where T implements `Sink` and `Stream`.
164
165impl<T> BincodeSend for T where Self: Sink<ProtocolMessage> + AsyncWrite + Unpin + Sized {}
166impl<T> BincodeReceive for T where Self: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + Unpin + Sized {}
167
168// Signature serialization.
169
170/// A helper type for serializing signatures (bincode cannot serialize a `[u8; 64]` our of the box).
171#[derive(Clone, Debug, PartialEq, Eq)]
172pub struct SerializeableSignature(pub Signature);
173
174impl From<Signature> for SerializeableSignature {
175    fn from(signature: Signature) -> Self {
176        Self(signature)
177    }
178}
179
180impl From<&Signature> for SerializeableSignature {
181    fn from(signature: &Signature) -> Self {
182        Self(*signature)
183    }
184}
185
186impl From<SerializeableSignature> for Signature {
187    fn from(signature: SerializeableSignature) -> Self {
188        signature.0
189    }
190}
191
192impl Serialize for SerializeableSignature {
193    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194    where
195        S: serde::Serializer,
196    {
197        serializer.serialize_bytes(&self.0)
198    }
199}
200
201impl<'de> Deserialize<'de> for SerializeableSignature {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        let bytes = <&[u8]>::deserialize(deserializer)?;
207
208        if bytes.len() != Constant::SIGNATURE_SIZE {
209            return Err(serde::de::Error::custom(format!("Invalid signature length: {}", bytes.len())));
210        }
211
212        let mut signature = [0; Constant::SIGNATURE_SIZE];
213        signature.copy_from_slice(bytes);
214
215        Ok(SerializeableSignature(signature))
216    }
217}
218
219// Tests.
220
221#[cfg(test)]
222mod tests {
223    use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
224
225    use super::*;
226    use pretty_assertions::assert_eq;
227
228    #[tokio::test]
229    async fn test_bincode() {
230        let (mut client, mut server) = generate_test_duplex();
231
232        let data = ClientPreamble {
233            exchange_public_key: generate_test_fake_exchange_public_key(),
234            remote: "remote".to_string(),
235            challenge: Challenge::default(),
236            should_encrypt: true,
237            is_udp: false,
238        };
239
240        client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
241
242        let ProtocolMessage::ClientPreamble(message) = server.pull().await.unwrap() else {
243            panic!("Failed to receive message");
244        };
245
246        assert_eq!(data, message);
247    }
248}