use alloc::vec::Vec;
use core::ops::{Deref, DerefMut};
use async_tungstenite::WebSocketStream;
use future_form::{FutureForm, Local, Sendable, future_form};
use futures_util::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
use subduction_core::connection::handshake::Handshake;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum WebSocketHandshakeError {
#[error("WebSocket error: {0}")]
WebSocket(#[from] tungstenite::Error),
#[error("expected binary message, got: {0}")]
UnexpectedMessageType(&'static str),
#[error("connection closed during handshake")]
ConnectionClosed,
}
#[derive(Debug)]
pub struct WebSocketHandshake<T>(pub WebSocketStream<T>);
impl<T> WebSocketHandshake<T> {
pub const fn new(stream: WebSocketStream<T>) -> Self {
Self(stream)
}
pub fn into_inner(self) -> WebSocketStream<T> {
self.0
}
}
impl<T> Deref for WebSocketHandshake<T> {
type Target = WebSocketStream<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for WebSocketHandshake<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[future_form(
Sendable where T: AsyncRead + AsyncWrite + Unpin + Send,
Local where T: AsyncRead + AsyncWrite + Unpin
)]
impl<K: FutureForm, T> Handshake<K> for WebSocketHandshake<T> {
type Error = WebSocketHandshakeError;
fn send(&mut self, bytes: Vec<u8>) -> K::Future<'_, Result<(), Self::Error>> {
K::from_future(async move {
SinkExt::send(&mut self.0, tungstenite::Message::Binary(bytes.into())).await?;
Ok(())
})
}
fn recv(&mut self) -> K::Future<'_, Result<Vec<u8>, Self::Error>> {
K::from_future(async move {
loop {
let msg = self
.0
.next()
.await
.ok_or(WebSocketHandshakeError::ConnectionClosed)??;
match msg {
tungstenite::Message::Binary(bytes) => return Ok(bytes.to_vec()),
tungstenite::Message::Text(_) => {
return Err(WebSocketHandshakeError::UnexpectedMessageType("text"));
}
tungstenite::Message::Ping(_) | tungstenite::Message::Pong(_) => {}
tungstenite::Message::Close(_) => {
return Err(WebSocketHandshakeError::ConnectionClosed);
}
tungstenite::Message::Frame(_) => {
return Err(WebSocketHandshakeError::UnexpectedMessageType("frame"));
}
}
}
})
}
}
#[cfg(test)]
mod tests {
use future_form::Sendable;
use subduction_core::{
connection::handshake::{
Audience, Challenge, HandshakeMessage, Rejection, RejectionReason,
},
timestamp::TimestampSeconds,
};
use subduction_crypto::{nonce::Nonce, signed::Signed, signer::memory::MemorySigner};
fn test_signer(seed: u8) -> MemorySigner {
MemorySigner::from_bytes(&[seed; 32])
}
mod handshake_message {
use super::*;
#[tokio::test]
async fn signed_challenge_roundtrips() {
let test_signer = test_signer(1);
let challenge = Challenge::new(
Audience::discover(b"test"),
TimestampSeconds::new(1000),
Nonce::from_u128(42),
);
let signed_challenge = Signed::seal::<Sendable, _>(&test_signer, challenge)
.await
.into_signed();
let msg = HandshakeMessage::SignedChallenge(signed_challenge.clone());
let bytes = msg.encode();
let decoded = HandshakeMessage::try_decode(&bytes)
.unwrap_or_else(|e| unreachable!("decoding should succeed: {e}"));
let HandshakeMessage::SignedChallenge(decoded_signed) = decoded else {
unreachable!(
"expected SignedChallenge, got {:?}",
core::mem::discriminant(&decoded)
);
};
assert_eq!(decoded_signed.issuer(), signed_challenge.issuer());
}
#[test]
fn rejection_roundtrips() {
let rejection =
Rejection::new(RejectionReason::ClockDrift, TimestampSeconds::new(1000));
let msg = HandshakeMessage::Rejection(rejection);
let bytes = msg.encode();
let decoded = HandshakeMessage::try_decode(&bytes)
.unwrap_or_else(|e| unreachable!("decoding should succeed: {e}"));
let HandshakeMessage::Rejection(decoded_rejection) = decoded else {
unreachable!(
"expected Rejection, got {:?}",
core::mem::discriminant(&decoded)
);
};
assert_eq!(decoded_rejection.reason, rejection.reason);
assert_eq!(
decoded_rejection.server_timestamp,
rejection.server_timestamp
);
}
}
}