use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use crate::{
error::{Error, ErrorKind, QuicError},
frame::{HandshakeDoneFrame, ReceiveFrame, SendFrame},
role::Role,
};
#[derive(Debug, Default, Clone)]
pub struct ClientHandshake {
done: Arc<AtomicBool>,
}
impl ClientHandshake {
pub fn is_handshake_done(&self) -> bool {
self.done.load(Ordering::Acquire)
}
pub fn recv_handshake_done_frame(&self, _frame: &HandshakeDoneFrame) -> bool {
!self.done.swap(true, Ordering::AcqRel)
}
}
#[derive(Debug, Clone)]
pub struct ServerHandshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
is_done: Arc<AtomicBool>,
output: T,
}
impl<T> ServerHandshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
pub fn new(output: T) -> Self {
ServerHandshake {
is_done: Arc::new(AtomicBool::new(false)),
output,
}
}
pub fn is_handshake_done(&self) -> bool {
self.is_done.load(Ordering::Acquire)
}
pub fn done(&self) -> bool {
if self
.is_done
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
self.output.send_frame([HandshakeDoneFrame]);
true
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub enum Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
Client(ClientHandshake),
Server(ServerHandshake<T>),
}
impl<T> Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
pub fn new(role: Role, output: T) -> Self {
match role {
Role::Client => Handshake::Client(ClientHandshake::default()),
Role::Server => Handshake::Server(ServerHandshake::new(output)),
}
}
pub fn new_client() -> Self {
Handshake::Client(ClientHandshake::default())
}
pub fn new_server(output: T) -> Self {
Handshake::Server(ServerHandshake::new(output))
}
pub fn is_handshake_done(&self) -> bool {
match self {
Handshake::Client(h) => h.is_handshake_done(),
Handshake::Server(h) => h.is_handshake_done(),
}
}
pub fn done(&self) -> bool {
match self {
Handshake::Client(..) => false,
Handshake::Server(h) => h.done(),
}
}
pub fn role(&self) -> Role {
match self {
Handshake::Client(_) => Role::Client,
Handshake::Server(_) => Role::Server,
}
}
}
impl<T> ReceiveFrame<HandshakeDoneFrame> for Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
type Output = bool;
fn recv_frame(&self, frame: &HandshakeDoneFrame) -> Result<bool, Error> {
match self {
Handshake::Client(h) => Ok(h.recv_handshake_done_frame(frame)),
_ => Err(QuicError::with_default_fty(
ErrorKind::ProtocolViolation,
"Server received a HANDSHAKE_DONE frame",
)
.into()),
}
}
}
#[cfg(test)]
mod tests {
use derive_more::Deref;
use super::*;
use crate::{
error::ErrorKind,
frame::{ReceiveFrame, SendFrame},
util::ArcAsyncDeque,
};
#[derive(Debug, Default, Clone, Deref)]
struct HandshakeDoneFrameTx(ArcAsyncDeque<HandshakeDoneFrame>);
impl SendFrame<HandshakeDoneFrame> for HandshakeDoneFrameTx {
fn send_frame<I: IntoIterator<Item = HandshakeDoneFrame>>(&self, iter: I) {
(&self.0).extend(iter);
}
}
#[test]
fn test_client_handshake() {
let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
assert!(!handshake.is_handshake_done());
let ret = handshake.recv_frame(&HandshakeDoneFrame);
assert!(ret.is_ok());
assert!(handshake.is_handshake_done());
}
#[test]
fn test_client_handshake_done() {
let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
assert!(!handshake.is_handshake_done());
assert!(handshake.recv_frame(&HandshakeDoneFrame).unwrap());
assert!(handshake.is_handshake_done());
assert!(!handshake.recv_frame(&HandshakeDoneFrame).unwrap());
assert!(handshake.is_handshake_done());
}
#[test]
fn test_server_handshake() {
let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
assert!(!handshake.is_handshake_done());
assert!(handshake.done());
assert!(handshake.is_handshake_done());
assert!(!handshake.done());
assert!(handshake.is_handshake_done());
}
#[test]
fn test_server_recv_handshake_done_frame() {
let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
assert!(!handshake.is_handshake_done());
let ret = handshake.recv_frame(&HandshakeDoneFrame);
assert_eq!(
ret,
Err(QuicError::with_default_fty(
ErrorKind::ProtocolViolation,
"Server received a HANDSHAKE_DONE frame",
)
.into())
);
}
#[test]
fn test_server_send_handshake_done_frame() {
let handshake = ServerHandshake::new(HandshakeDoneFrameTx::default());
handshake.done();
assert!(handshake.is_handshake_done());
assert_eq!(handshake.output.len(), 1);
}
}