qbase 0.0.1

Core structure of the QUIC protocol, a part of gm-quic
Documentation
use std::sync::{
    atomic::{AtomicBool, Ordering},
    Arc,
};

use crate::{
    error::{Error, ErrorKind},
    frame::{HandshakeDoneFrame, ReceiveFrame, SendFrame},
    sid::Role,
};

/// The completion flag for the client handshake.
///
/// The client considers the handshake complete only after
/// receiving the [`HandshakeDoneFrame`] from the server.
/// In the QUIC protocol, there are no tasks that specifically
/// require waiting for the client handshake to complete.
/// Instead, it simply queries the handshake status.
#[derive(Debug, Default, Clone)]
pub struct ClientHandshake {
    has_keys: Arc<AtomicBool>,
    done: Arc<AtomicBool>,
}

impl ClientHandshake {
    /// Check if the client handshake is complete.
    pub fn is_handshake_done(&self) -> bool {
        self.done.load(Ordering::Acquire)
    }

    pub fn has_keys(&self) -> bool {
        self.has_keys.load(Ordering::Acquire)
    }
    /// Receive the HANDSHAKE_DONE frame.
    ///
    /// Once the client receives the HANDSHAKE_DONE frame,
    /// it marks the completion of the client handshake.
    pub fn recv_handshake_done_frame(&self, _frame: &HandshakeDoneFrame) {
        let _has_done = self.done.swap(true, Ordering::AcqRel);
        if !_has_done {
            log::trace!("Client handshake is done");
        }
    }

    /// TLS upgrade the handshake keys.
    pub fn on_key_upgrade(&self) {
        let has_keys = self.has_keys.swap(true, Ordering::AcqRel);
        if !has_keys {
            log::trace!("Client is getting handshake keys");
        }
    }
}

/// Server's handshake status.
///
/// - `T` is responsible for reliably sending [`HandshakeDoneFrame`] to the client.
///    It can be a channel, a queue, or a buffer. Whatever, it must be able to send the
///    [`HandshakeDoneFrame`] to the client.
///
/// The server considers the handshake complete only after receiving
/// the [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4)
/// from the client during the TLS handshake process.
/// If the [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4)
/// from the TLS handshake is not received,
/// the server can also consider the handshake complete upon receiving and
/// successfully decrypting the client's 1-RTT packet.
/// Once the server's handshake is complete, the server will send a [`HandshakeDoneFrame`] immediately.
#[derive(Debug, Clone)]
pub struct ServerHandshake<T>
where
    T: SendFrame<HandshakeDoneFrame> + Clone,
{
    is_done: Arc<AtomicBool>,
    has_keys: Arc<AtomicBool>,
    output: T,
}

impl<T> ServerHandshake<T>
where
    T: SendFrame<HandshakeDoneFrame> + Clone,
{
    /// Create a new server handshake signal.
    ///
    /// The `output` is responsible for sending the [`HandshakeDoneFrame`] to the client,
    /// see [`ServerHandshake`].
    pub fn new(output: T) -> Self {
        ServerHandshake {
            is_done: Arc::new(AtomicBool::new(false)),
            has_keys: Arc::new(AtomicBool::new(false)),
            output,
        }
    }

    /// Check if the server handshake is complete.
    pub fn is_handshake_done(&self) -> bool {
        self.is_done.load(Ordering::Acquire)
    }

    /// Check if the server is getting handshake keys.
    pub fn has_keys(&self) -> bool {
        self.has_keys.load(Ordering::Acquire)
    }

    /// Actively set the server's handshake status to complete.
    ///
    /// Call this method when the TLS handshake
    /// [finished message](https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.4) is received.
    /// If the TLS handshake completion message is not received,
    /// receiving and successfully decrypting the client's 1-RTT packet
    /// is also considered handshake completion.
    /// Servers MUST NOT send a [`HandshakeDoneFrame`] before completing the handshake.
    /// and once the server handshake is complete,
    /// servers should send the [`HandshakeDoneFrame`] immediately.
    /// See [`ServerHandshake`].
    pub fn done(&self) {
        if self
            .is_done
            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
            .is_ok()
        {
            log::trace!("Server handshake is done");
            self.output.send_frame([HandshakeDoneFrame]);
        }
    }

    /// TLS upgrade the handshake keys.
    pub fn on_key_upgrade(&self) {
        let has_keys = self.has_keys.swap(true, Ordering::AcqRel);
        if !has_keys {
            log::trace!("Server is getting handshake keys");
        }
    }
}

/// A merged handshake state that can be used by both the client and the server.
///
/// For convenience, a unified [`Handshake`]` should be used,
/// which will internally choose the corresponding behavior based on the role.
#[derive(Debug, Clone)]
pub enum Handshake<T>
where
    T: SendFrame<HandshakeDoneFrame> + Clone,
{
    /// The client's handshake state if the endpoint is a client.
    Client(ClientHandshake),
    /// The server's handshake state if the endpoint is a server.
    Server(ServerHandshake<T>),
}

impl<T> Handshake<T>
where
    T: SendFrame<HandshakeDoneFrame> + Clone,
{
    /// Create a new handshake state, based on the role.
    pub fn new(role: Role, output: T) -> Self {
        match role {
            Role::Client => Handshake::Client(ClientHandshake::default()),
            Role::Server => Handshake::Server(ServerHandshake::new(output)),
        }
    }

    /// Create a new client handshake state.
    pub fn new_client() -> Self {
        Handshake::Client(ClientHandshake::default())
    }

    /// Create a new server handshake state.
    /// The `output` is responsible for sending the [`HandshakeDoneFrame`] to the client,
    /// see [`ServerHandshake::new`].
    pub fn new_server(output: T) -> Self {
        Handshake::Server(ServerHandshake::new(output))
    }

    /// Check if the handshake is complete.
    pub fn is_handshake_done(&self) -> bool {
        match self {
            Handshake::Client(h) => h.is_handshake_done(),
            Handshake::Server(h) => h.is_handshake_done(),
        }
    }

    /// Check if the getting handshake keys.
    pub fn is_getting_keys(&self) -> bool {
        match self {
            Handshake::Client(h) => h.has_keys(),
            Handshake::Server(h) => h.has_keys(),
        }
    }

    /// TLS upgrade the handshake keys.
    pub fn on_key_upgrade(&self) {
        match self {
            Handshake::Client(h) => h.on_key_upgrade(),
            Handshake::Server(h) => h.on_key_upgrade(),
        }
    }

    /// Return the role of this handshake signal.
    pub fn role(&self) -> Role {
        match self {
            Handshake::Client(_) => Role::Client,
            Handshake::Server(_) => Role::Server,
        }
    }
}

impl<T> ReceiveFrame<HandshakeDoneFrame> for Handshake<T>
where
    T: SendFrame<HandshakeDoneFrame> + Clone,
{
    type Output = ();

    /// Receive the [`HandshakeDoneFrame`].
    ///
    /// A [`HandshakeDoneFrame`] can only be received by the client.
    /// A server MUST treat receipt of a [`HandshakeDoneFrame`]
    /// as a connection error of type PROTOCOL_VIOLATION.
    /// See [section 19.20](https://www.rfc-editor.org/rfc/rfc9000.html#section-19.20)
    /// of [QUIC](https://www.rfc-editor.org/rfc/rfc9000.html).
    fn recv_frame(&self, frame: &HandshakeDoneFrame) -> Result<(), Error> {
        match self {
            Handshake::Client(h) => {
                h.recv_handshake_done_frame(frame);
                Ok(())
            }
            _ => Err(Error::with_default_fty(
                ErrorKind::ProtocolViolation,
                "Server received a HANDSHAKE_DONE frame",
            )),
        }
    }
}

#[cfg(test)]
mod tests {
    use deref_derive::Deref;

    use super::*;
    use crate::{
        error::{Error, ErrorKind},
        frame::{ReceiveFrame, SendFrame},
        util::ArcAsyncDeque,
    };

    #[derive(Debug, Default, Clone, Deref)]
    struct HandshakeDoneFrameTx(ArcAsyncDeque<HandshakeDoneFrame>);

    impl SendFrame<HandshakeDoneFrame> for HandshakeDoneFrameTx {
        fn send_frame<I: IntoIterator<Item = HandshakeDoneFrame>>(&self, iter: I) {
            (&self.0).extend(iter);
        }
    }

    #[test]
    fn test_client_handshake() {
        let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
        assert!(!handshake.is_handshake_done());

        let ret = handshake.recv_frame(&HandshakeDoneFrame);
        assert!(ret.is_ok());
        assert!(handshake.is_handshake_done());
    }

    #[test]
    fn test_client_handshake_done() {
        let handshake = Handshake::<HandshakeDoneFrameTx>::new_client();
        assert!(!handshake.is_handshake_done());

        match &handshake {
            Handshake::Client(client_handshake) => {
                client_handshake.recv_handshake_done_frame(&HandshakeDoneFrame)
            }
            Handshake::Server(..) => unreachable!(),
        }
        assert!(handshake.is_handshake_done());
    }

    #[test]
    fn test_server_handshake() {
        let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
        assert!(!handshake.is_handshake_done());

        match &handshake {
            Handshake::Client(..) => unreachable!(),
            Handshake::Server(server_handshake) => server_handshake.done(),
        }
        assert!(handshake.is_handshake_done());
    }

    #[test]
    fn test_server_recv_handshake_done_frame() {
        let handshake = Handshake::new_server(HandshakeDoneFrameTx::default());
        assert!(!handshake.is_handshake_done());

        let ret = handshake.recv_frame(&HandshakeDoneFrame);
        assert_eq!(
            ret,
            Err(Error::with_default_fty(
                ErrorKind::ProtocolViolation,
                "Server received a HANDSHAKE_DONE frame",
            ))
        );
    }

    #[test]
    fn test_server_send_handshake_done_frame() {
        let handshake = ServerHandshake::new(HandshakeDoneFrameTx::default());
        handshake.done();
        assert!(handshake.is_handshake_done());
        assert_eq!(handshake.output.len(), 1);
    }
}