#![allow(dead_code)]
use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use std::sync::Mutex;
use crate::tls::quic_hooks::{Direction, Level, QuicHooks};
#[derive(Default)]
pub(crate) struct QuicHookState {
pub(crate) tx_handshake: [Vec<u8>; 4],
pub(crate) secret_events: Vec<(Level, Direction, Vec<u8>)>,
pub(crate) peer_params: Option<Vec<u8>>,
}
pub(crate) struct QuicTlsHooks {
pub(crate) state: Arc<Mutex<QuicHookState>>,
pub(crate) our_params: Arc<Mutex<Vec<u8>>>,
}
impl QuicHooks for QuicTlsHooks {
fn on_handshake_data(&mut self, level: Level, data: &[u8]) {
let mut g = self.state.lock().expect("hooks mutex poisoned");
g.tx_handshake[level as usize].extend_from_slice(data);
}
fn on_traffic_secret(&mut self, level: Level, dir: Direction, secret: &[u8]) {
let mut g = self.state.lock().expect("hooks mutex poisoned");
g.secret_events.push((level, dir, secret.to_vec()));
}
fn our_transport_params(&self) -> Vec<u8> {
self.our_params
.lock()
.expect("our_params mutex poisoned")
.clone()
}
fn on_peer_transport_params(&mut self, raw: &[u8]) {
let mut g = self.state.lock().expect("hooks mutex poisoned");
g.peer_params = Some(raw.to_vec());
}
}
pub(crate) fn build_hooks(our_params: Vec<u8>) -> (Box<QuicTlsHooks>, HookHandle) {
let state = Arc::new(Mutex::new(QuicHookState::default()));
let our_params = Arc::new(Mutex::new(our_params));
let handle = HookHandle {
state: state.clone(),
our_params: our_params.clone(),
};
let boxed = Box::new(QuicTlsHooks { state, our_params });
(boxed, handle)
}
#[derive(Clone)]
pub(crate) struct HookHandle {
pub(crate) state: Arc<Mutex<QuicHookState>>,
pub(crate) our_params: Arc<Mutex<Vec<u8>>>,
}
impl HookHandle {
pub(crate) fn drain_handshake(&self, level: Level) -> Vec<u8> {
let mut g = self.state.lock().expect("hooks mutex poisoned");
core::mem::take(&mut g.tx_handshake[level as usize])
}
pub(crate) fn drain_secret_events(&self) -> Vec<(Level, Direction, Vec<u8>)> {
let mut g = self.state.lock().expect("hooks mutex poisoned");
core::mem::take(&mut g.secret_events)
}
pub(crate) fn take_peer_params(&self) -> Option<Vec<u8>> {
let mut g = self.state.lock().expect("hooks mutex poisoned");
g.peer_params.take()
}
pub(crate) fn set_our_params(&self, bytes: Vec<u8>) {
let mut g = self.our_params.lock().expect("our_params mutex poisoned");
*g = bytes;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hooks_round_trip_handshake_bytes() {
let (mut boxed, handle) = build_hooks(alloc::vec![1, 2, 3]);
boxed.on_handshake_data(Level::Initial, b"hello-CH");
boxed.on_handshake_data(Level::Handshake, b"finished");
boxed.on_handshake_data(Level::Initial, b"-cont");
let init = handle.drain_handshake(Level::Initial);
assert_eq!(init, b"hello-CH-cont");
let hs = handle.drain_handshake(Level::Handshake);
assert_eq!(hs, b"finished");
assert!(handle.drain_handshake(Level::Initial).is_empty());
}
#[test]
fn hooks_capture_secret_events_in_order() {
let (mut boxed, handle) = build_hooks(alloc::vec![]);
boxed.on_traffic_secret(Level::Handshake, Direction::Tx, b"shts");
boxed.on_traffic_secret(Level::Handshake, Direction::Rx, b"chts");
boxed.on_traffic_secret(Level::OneRtt, Direction::Tx, b"app");
let events = handle.drain_secret_events();
assert_eq!(events.len(), 3);
assert_eq!(events[0].0, Level::Handshake);
assert_eq!(events[0].1, Direction::Tx);
assert_eq!(events[0].2, b"shts");
assert_eq!(events[1].1, Direction::Rx);
assert_eq!(events[2].0, Level::OneRtt);
assert!(handle.drain_secret_events().is_empty());
}
#[test]
fn hooks_capture_peer_params() {
let (mut boxed, handle) = build_hooks(alloc::vec![0xa, 0xb]);
assert!(handle.take_peer_params().is_none());
boxed.on_peer_transport_params(&[0xde, 0xad]);
let got = handle.take_peer_params().expect("set");
assert_eq!(got, &[0xde, 0xad]);
assert!(handle.take_peer_params().is_none());
}
#[test]
fn hooks_return_our_params() {
let (boxed, handle) = build_hooks(alloc::vec![1, 2, 3, 4]);
assert_eq!(boxed.our_transport_params(), alloc::vec![1u8, 2, 3, 4]);
handle.set_our_params(alloc::vec![5, 6]);
assert_eq!(boxed.our_transport_params(), alloc::vec![5u8, 6]);
}
}