use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use crate::{
error::{Error, ErrorKind},
frame::{HandshakeDoneFrame, ReceiveFrame, SendFrame},
sid::Role,
};
#[derive(Debug, Default, Clone)]
pub struct ClientHandshake {
has_keys: Arc<AtomicBool>,
done: Arc<AtomicBool>,
}
impl ClientHandshake {
pub fn is_handshake_done(&self) -> bool {
self.done.load(Ordering::Acquire)
}
pub fn has_keys(&self) -> bool {
self.has_keys.load(Ordering::Acquire)
}
pub fn recv_handshake_done_frame(&self, _frame: &HandshakeDoneFrame) {
let _has_done = self.done.swap(true, Ordering::AcqRel);
if !_has_done {
log::trace!("Client handshake is done");
}
}
pub fn on_key_upgrade(&self) {
let has_keys = self.has_keys.swap(true, Ordering::AcqRel);
if !has_keys {
log::trace!("Client is getting handshake keys");
}
}
}
#[derive(Debug, Clone)]
pub struct ServerHandshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
is_done: Arc<AtomicBool>,
has_keys: Arc<AtomicBool>,
output: T,
}
impl<T> ServerHandshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
pub fn new(output: T) -> Self {
ServerHandshake {
is_done: Arc::new(AtomicBool::new(false)),
has_keys: Arc::new(AtomicBool::new(false)),
output,
}
}
pub fn is_handshake_done(&self) -> bool {
self.is_done.load(Ordering::Acquire)
}
pub fn has_keys(&self) -> bool {
self.has_keys.load(Ordering::Acquire)
}
pub fn done(&self) {
if self
.is_done
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
log::trace!("Server handshake is done");
self.output.send_frame([HandshakeDoneFrame]);
}
}
pub fn on_key_upgrade(&self) {
let has_keys = self.has_keys.swap(true, Ordering::AcqRel);
if !has_keys {
log::trace!("Server is getting handshake keys");
}
}
}
#[derive(Debug, Clone)]
pub enum Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
Client(ClientHandshake),
Server(ServerHandshake<T>),
}
impl<T> Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
pub fn new(role: Role, output: T) -> Self {
match role {
Role::Client => Handshake::Client(ClientHandshake::default()),
Role::Server => Handshake::Server(ServerHandshake::new(output)),
}
}
pub fn new_client() -> Self {
Handshake::Client(ClientHandshake::default())
}
pub fn new_server(output: T) -> Self {
Handshake::Server(ServerHandshake::new(output))
}
pub fn is_handshake_done(&self) -> bool {
match self {
Handshake::Client(h) => h.is_handshake_done(),
Handshake::Server(h) => h.is_handshake_done(),
}
}
pub fn is_getting_keys(&self) -> bool {
match self {
Handshake::Client(h) => h.has_keys(),
Handshake::Server(h) => h.has_keys(),
}
}
pub fn on_key_upgrade(&self) {
match self {
Handshake::Client(h) => h.on_key_upgrade(),
Handshake::Server(h) => h.on_key_upgrade(),
}
}
pub fn role(&self) -> Role {
match self {
Handshake::Client(_) => Role::Client,
Handshake::Server(_) => Role::Server,
}
}
}
impl<T> ReceiveFrame<HandshakeDoneFrame> for Handshake<T>
where
T: SendFrame<HandshakeDoneFrame> + Clone,
{
type Output = ();
fn recv_frame(&self, frame: &HandshakeDoneFrame) -> Result<(), Error> {
match self {
Handshake::Client(h) => {
h.recv_handshake_done_frame(frame);
Ok(())
}
_ => Err(Error::with_default_fty(
ErrorKind::ProtocolViolation,
"Server received a HANDSHAKE_DONE frame",
)),
}
}
}
#[cfg(test)]
mod tests {
use deref_derive::Deref;
use super::*;
use crate::{
error::{Error, ErrorKind},
frame::{ReceiveFrame, SendFrame},
util::ArcAsyncDeque,
};
#[derive(Debug, Default, Clone, Deref)]
struct HandshakeDoneFrameTx(ArcAsyncDeque<HandshakeDoneFrame>);
impl SendFrame<HandshakeDoneFrame> for HandshakeDoneFrameTx {
fn send_frame<I: IntoIterator<Item = HandshakeDoneFrame>>(&self, iter: I) {
(&self.0).extend(iter);
}
}
#[test]
fn test_client_handshake() {
let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
assert!(!handshake.is_handshake_done());
let ret = handshake.recv_frame(&HandshakeDoneFrame);
assert!(ret.is_ok());
assert!(handshake.is_handshake_done());
}
#[test]
fn test_client_handshake_done() {
let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
assert!(!handshake.is_handshake_done());
match &handshake {
Handshake::Client(client_handshake) => {
client_handshake.recv_handshake_done_frame(&HandshakeDoneFrame)
}
Handshake::Server(..) => unreachable!(),
}
assert!(handshake.is_handshake_done());
}
#[test]
fn test_server_handshake() {
let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
assert!(!handshake.is_handshake_done());
match &handshake {
Handshake::Client(..) => unreachable!(),
Handshake::Server(server_handshake) => server_handshake.done(),
}
assert!(handshake.is_handshake_done());
}
#[test]
fn test_server_recv_handshake_done_frame() {
let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
assert!(!handshake.is_handshake_done());
let ret = handshake.recv_frame(&HandshakeDoneFrame);
assert_eq!(
ret,
Err(Error::with_default_fty(
ErrorKind::ProtocolViolation,
"Server received a HANDSHAKE_DONE frame",
))
);
}
#[test]
fn test_server_send_handshake_done_frame() {
let handshake = ServerHandshake::new(HandshakeDoneFrameTx::default());
handshake.done();
assert!(handshake.is_handshake_done());
assert_eq!(handshake.output.len(), 1);
}
}