1use std::fmt::{Display, Formatter};
6
7use anyhow::Context;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::base::{Constant, Err, Res, Void};
13
14pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
18
19pub type Signature = [u8; Constant::SIGNATURE_SIZE];
21
22pub type ExchangePublicKey = [u8; Constant::PEER_PUBLIC_KEY_SIZE];
24
25#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
27pub struct ClientPreamble {
28 pub exchange_public_key: ExchangePublicKey,
29 pub remote: String,
30 pub challenge: Challenge,
31 pub should_encrypt: bool,
32 pub is_udp: bool,
33}
34
35#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
38pub struct ServerPreamble {
39 pub identity_public_key: String,
40 pub exchange_public_key: ExchangePublicKey,
41 pub signature: SerializeableSignature,
42 pub challenge: Challenge,
43}
44
45#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
47pub struct ClientAuthentication {
48 pub identity_public_key: String,
49 pub signature: SerializeableSignature,
50}
51
52pub trait BincodeMessage: Serialize + DeserializeOwned {}
56
57#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
62pub enum ProtocolMessage {
63 ClientPreamble(ClientPreamble),
64 ServerPreamble(ServerPreamble),
65 ClientAuthentication(ClientAuthentication),
66 HandshakeCompletion,
67 Data(Vec<u8>),
68 UdpData(Vec<u8>),
69 Error(ProtocolError),
70}
71
72impl ProtocolMessage {
73 pub fn fail_if_error(self) -> Res<Self> {
77 if let ProtocolMessage::Error(error) = self {
78 return Err(Err::msg(error));
79 }
80
81 Ok(self)
82 }
83}
84
85impl BincodeMessage for ProtocolMessage {}
86
87#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
89pub enum ProtocolMessageWrapper {
90 Plain(ProtocolMessage),
91 Encrypted { nonce: [u8; Constant::SHARED_SECRET_NONCE_SIZE], data: Vec<u8> },
92}
93
94#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
105pub enum ProtocolError {
106 InvalidHost(String),
107 InvalidKey(String),
108 RemoteFailed(String),
109 Unknown(String),
110}
111
112impl Display for ProtocolError {
113 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114 match self {
115 ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
116 ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
117 ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
118 ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
119 }
120 }
121}
122
123impl ProtocolError {
124 pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
126 where
127 T: BincodeSend,
128 {
129 let error_message = self.to_string();
130
131 let _ = stream.push(ProtocolMessage::Error(self)).await;
132 let _ = stream.shutdown().await;
133
134 Err(Err::msg(error_message))
135 }
136}
137
138pub trait BincodeSend: Sink<ProtocolMessage> + AsyncWrite + AsyncWriteExt + Unpin + Sized {
145 fn push(&mut self, message: ProtocolMessage) -> impl Future<Output = Void> {
146 async move { self.send(message).await.map_err(|_| Err::msg("Failed to send message")) }
147 }
148}
149
150pub trait BincodeReceive: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + AsyncReadExt + Unpin + Sized {
155 fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessage>> {
156 async move {
157 let message = self.next().await.context("Failed to read message")?.context("Failed to parse message")?;
158 Ok(message)
159 }
160 }
161}
162
163impl<T> BincodeSend for T where Self: Sink<ProtocolMessage> + AsyncWrite + Unpin + Sized {}
166impl<T> BincodeReceive for T where Self: Stream<Item = std::io::Result<ProtocolMessage>> + AsyncRead + Unpin + Sized {}
167
168#[derive(Clone, Debug, PartialEq, Eq)]
172pub struct SerializeableSignature(pub Signature);
173
174impl From<Signature> for SerializeableSignature {
175 fn from(signature: Signature) -> Self {
176 Self(signature)
177 }
178}
179
180impl From<&Signature> for SerializeableSignature {
181 fn from(signature: &Signature) -> Self {
182 Self(*signature)
183 }
184}
185
186impl From<SerializeableSignature> for Signature {
187 fn from(signature: SerializeableSignature) -> Self {
188 signature.0
189 }
190}
191
192impl Serialize for SerializeableSignature {
193 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: serde::Serializer,
196 {
197 serializer.serialize_bytes(&self.0)
198 }
199}
200
201impl<'de> Deserialize<'de> for SerializeableSignature {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: serde::Deserializer<'de>,
205 {
206 let bytes = <&[u8]>::deserialize(deserializer)?;
207
208 if bytes.len() != Constant::SIGNATURE_SIZE {
209 return Err(serde::de::Error::custom(format!("Invalid signature length: {}", bytes.len())));
210 }
211
212 let mut signature = [0; Constant::SIGNATURE_SIZE];
213 signature.copy_from_slice(bytes);
214
215 Ok(SerializeableSignature(signature))
216 }
217}
218
219#[cfg(test)]
222mod tests {
223 use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
224
225 use super::*;
226 use pretty_assertions::assert_eq;
227
228 #[tokio::test]
229 async fn test_bincode() {
230 let (mut client, mut server) = generate_test_duplex();
231
232 let data = ClientPreamble {
233 exchange_public_key: generate_test_fake_exchange_public_key(),
234 remote: "remote".to_string(),
235 challenge: Challenge::default(),
236 should_encrypt: true,
237 is_udp: false,
238 };
239
240 client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
241
242 let ProtocolMessage::ClientPreamble(message) = server.pull().await.unwrap() else {
243 panic!("Failed to receive message");
244 };
245
246 assert_eq!(data, message);
247 }
248}