use std::fmt::{Display, Formatter};
use anyhow::anyhow;
use bincode::{BorrowDecode, Encode};
use bytes::Bytes;
use ouroboros::self_referencing;
use serde::{Deserialize, Serialize};
use crate::base::{Constant, Res, Void};
pub type Challenge = [u8; Constant::CHALLENGE_SIZE];
pub type Signature = [u8; Constant::SIGNATURE_SIZE];
pub type ExchangePublicKey = [u8; Constant::EXCHANGE_PUBLIC_KEY_SIZE];
#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
pub struct ClientPreamble<'a> {
pub exchange_public_key: &'a [u8],
pub remote: &'a str,
pub challenge: &'a [u8],
pub should_encrypt: bool,
pub is_udp: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
pub struct ServerPreamble<'a> {
pub identity_public_key: &'a str,
pub exchange_public_key: &'a [u8],
pub signature: &'a [u8],
pub challenge: &'a [u8],
}
#[derive(Clone, Debug, PartialEq, Eq, Encode, BorrowDecode)]
pub struct ClientAuthentication<'a> {
pub identity_public_key: &'a str,
pub signature: &'a [u8],
}
#[derive(Debug, PartialEq, Eq, Encode, BorrowDecode)]
pub enum ProtocolMessage<'a> {
ClientPreamble(ClientPreamble<'a>),
ServerPreamble(ServerPreamble<'a>),
ClientAuthentication(ClientAuthentication<'a>),
HandshakeCompletion,
Data(&'a [u8]),
UdpData(&'a [u8]),
Error(ProtocolError<'a>),
Shutdown,
}
impl ProtocolMessage<'_> {
pub fn fail_if_error(&self) -> Res<&Self> {
if let ProtocolMessage::Error(error) = self {
return Err(anyhow!(error.to_string()));
}
Ok(self)
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Encode, BorrowDecode)]
pub enum ProtocolError<'a> {
InvalidHost(&'a str),
InvalidKey(&'a str),
RemoteFailed(&'a str),
Unknown(&'a str),
}
impl Display for ProtocolError<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ProtocolError::InvalidHost(host) => write!(f, "Invalid host: {}", host),
ProtocolError::InvalidKey(key) => write!(f, "Invalid key: {}", key),
ProtocolError::RemoteFailed(message) => write!(f, "Remote failed: {}", message),
ProtocolError::Unknown(message) => write!(f, "Unknown: {}", message),
}
}
}
impl ProtocolError<'_> {
pub async fn send_and_bail<T, R>(self, stream: &mut T) -> Res<R>
where
T: BincodeSend,
{
let error_message = self.to_string();
let _ = stream.push(ProtocolMessage::Error(self)).await;
let _ = stream.close().await;
Err(anyhow!(error_message))
}
}
#[self_referencing(pub_extras)]
pub struct ProtocolMessageGuard {
pub buffer: Bytes,
#[borrows(buffer)]
#[covariant]
pub inner: ProtocolMessage<'this>,
}
impl ProtocolMessageGuard {
pub fn message(&self) -> &ProtocolMessage<'_> {
self.borrow_inner()
}
}
pub trait BincodeSend: Unpin + Sized {
fn push<T>(&mut self, message: T) -> impl Future<Output = Void>
where
T: Encode;
fn close(&mut self) -> impl Future<Output = Void>;
}
pub trait BincodeReceive: Unpin + Sized {
fn pull(&mut self) -> impl Future<Output = Res<ProtocolMessageGuard>>;
}
#[cfg(test)]
mod tests {
use crate::utils::tests::{generate_test_duplex, generate_test_fake_exchange_public_key};
use super::*;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn test_bincode() {
let (mut client, mut server) = generate_test_duplex();
let data = ClientPreamble {
exchange_public_key: &generate_test_fake_exchange_public_key(),
remote: "remote",
challenge: &Challenge::default(),
should_encrypt: true,
is_udp: false,
};
client.push(ProtocolMessage::ClientPreamble(data.clone())).await.unwrap();
let guard = server.pull().await.unwrap();
let ProtocolMessage::ClientPreamble(message) = guard.message() else {
panic!("Failed to receive message");
};
assert_eq!(data, *message);
}
}