1use std::fmt::{Display, Formatter};
6
7use anyhow::Context;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::base::{Constant, Err, Res, Void};
13
14pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
18
19pub type Signature = [u8; Constant::SIGNATURE_SIZE];
21
22pub type ExchangePublicKey = [u8; Constant::PEER_PUBLIC_KEY_SIZE];
24
25#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
27pub struct ClientPreamble {
28 pub exchange_public_key: ExchangePublicKey,
29 pub remote: String,
30 pub challenge: Challenge,
31 pub should_encrypt: bool,
32 pub is_udp: bool,
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
39pub struct ServerPreamble {
40 pub identity_public_key: String,
42 pub exchange_public_key: ExchangePublicKey,
44 pub signature: SerializeableSignature,
46 pub challenge: Challenge,
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
52pub struct ClientAuthentication {
53 pub identity_public_key: String,
54 pub signature: SerializeableSignature,
55}
56
57pub trait BincodeMessage: Serialize + DeserializeOwned {}
61
62#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
67pub enum ProtocolMessage {
68 ClientPreamble(ClientPreamble),
69 ServerPreamble(ServerPreamble),
70 ClientAuthentication(ClientAuthentication),
71 HandshakeCompletion,
72 Data(Vec<u8>),
73 UdpData(Vec<u8>),
74 Error(ProtocolError),
75}
76
77impl ProtocolMessage {
78 pub fn fail_if_error(self) -> Res<Self> {
82 if let ProtocolMessage::Error(error) = self {
83 return Err(Err::msg(error));
84 }
85
86 Ok(self)
87 }
88}
89
90impl BincodeMessage for ProtocolMessage {}
91
92#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
94pub enum ProtocolMessageWrapper {
95 Plain(ProtocolMessage),
96 Encrypted { nonce: [u8; Constant::SHARED_SECRET_NONCE_SIZE], data: Vec<u8> },
97}
98
99#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
110pub enum ProtocolError {
111 InvalidHost(String),
112 InvalidKey(String),
113 RemoteFailed(String),
114 Unknown(String),
115}
116
117impl Display for ProtocolError {
118 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119 match self {
120 ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
121 ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
122 ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
123 ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
124 }
125 }
126}
127
128impl ProtocolError {
129 pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
134 where
135 T: BincodeSend,
136 {
137 let error_message = self.to_string();
138
139 let _ = stream.push(ProtocolMessage::Error(self)).await;
140 let _ = stream.shutdown().await;
141
142 Err(Err::msg(error_message))
143 }
144}
145
146pub trait BincodeSend: Sink<ProtocolMessage> + AsyncWrite + AsyncWriteExt + Unpin + Sized {
155 fn push(&mut self, message: ProtocolMessage) -> impl Future<Output = Void> {
156 async move { self.send(message).await.map_err(|_| Err::msg("Failed to send message")) }
157 }
158}
159
160pub trait BincodeReceive: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + AsyncReadExt + Unpin + Sized {
166 fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessage>> {
167 async move {
168 let message = self.next().await.context("Failed to read message")?.context("Failed to parse message")?;
169 Ok(message)
170 }
171 }
172}
173
174impl<T> BincodeSend for T where Self: Sink<ProtocolMessage> + AsyncWrite + Unpin + Sized {}
177impl<T> BincodeReceive for T where Self: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + Unpin + Sized {}
178
179#[derive(Clone, Debug, PartialEq, Eq)]
183pub struct SerializeableSignature(pub Signature);
184
185impl From<Signature> for SerializeableSignature {
186 fn from(signature: Signature) -> Self {
187 Self(signature)
188 }
189}
190
191impl From<&Signature> for SerializeableSignature {
192 fn from(signature: &Signature) -> Self {
193 Self(*signature)
194 }
195}
196
197impl From<SerializeableSignature> for Signature {
198 fn from(signature: SerializeableSignature) -> Self {
199 signature.0
200 }
201}
202
203impl Serialize for SerializeableSignature {
204 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205 where
206 S: serde::Serializer,
207 {
208 serializer.serialize_bytes(&self.0)
209 }
210}
211
212impl<'de> Deserialize<'de> for SerializeableSignature {
213 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
214 where
215 D: serde::Deserializer<'de>,
216 {
217 let bytes = <&[u8]>::deserialize(deserializer)?;
218
219 if bytes.len() != Constant::SIGNATURE_SIZE {
220 return Err(serde::de::Error::custom(format!("Invalid signature length: {}", bytes.len())));
221 }
222
223 let mut signature = [0; Constant::SIGNATURE_SIZE];
224 signature.copy_from_slice(bytes);
225
226 Ok(SerializeableSignature(signature))
227 }
228}
229
230#[cfg(test)]
233mod tests {
234 use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
235
236 use super::*;
237 use pretty_assertions::assert_eq;
238
239 #[tokio::test]
240 async fn test_bincode() {
241 let (mut client, mut server) = generate_test_duplex();
242
243 let data = ClientPreamble {
244 exchange_public_key: generate_test_fake_exchange_public_key(),
245 remote: "remote".to_string(),
246 challenge: Challenge::default(),
247 should_encrypt: true,
248 is_udp: false,
249 };
250
251 client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
252
253 let ProtocolMessage::ClientPreamble(message) = server.pull().await.unwrap() else {
254 panic!("Failed to receive message");
255 };
256
257 assert_eq!(data, message);
258 }
259}