use std::{
future::Future,
ops::DerefMut,
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
task::{Context, Poll, Waker},
};
use rustls::quic::{HeaderProtectionKey, Keys, PacketKey, Secrets};
use super::KeyPhaseBit;
#[derive(Clone)]
enum KeysState {
Pending(Option<Waker>),
Ready(Arc<Keys>),
Invalid,
}
#[derive(Clone)]
pub struct ArcKeys(Arc<Mutex<KeysState>>);
impl ArcKeys {
#[inline]
fn lock_guard(&self) -> MutexGuard<KeysState> {
self.0.lock().unwrap()
}
pub fn new_pending() -> Self {
Self(Arc::new(KeysState::Pending(None).into()))
}
pub fn with_keys(keys: Keys) -> Self {
Self(Arc::new(KeysState::Ready(Arc::new(keys)).into()))
}
pub fn get_remote_keys(&self) -> GetRemoteKeys {
GetRemoteKeys(self)
}
pub fn get_local_keys(&self) -> Option<Arc<Keys>> {
let state = self.lock_guard();
match &*state {
KeysState::Ready(keys) => Some(keys.clone()),
_ => None,
}
}
pub fn set_keys(&self, keys: Keys) {
let mut state = self.lock_guard();
match &mut *state {
KeysState::Pending(waker) => {
if let Some(waker) = waker.take() {
waker.wake();
}
*state = KeysState::Ready(Arc::new(keys));
}
KeysState::Ready(_) => panic!("set_keys called twice"),
KeysState::Invalid => panic!("set_keys called after invalidation"),
}
}
pub fn invalid(&self) -> Option<Arc<Keys>> {
let mut state = self.lock_guard();
match std::mem::replace(state.deref_mut(), KeysState::Invalid) {
KeysState::Pending(waker) => {
if let Some(waker) = waker {
waker.wake();
}
None
}
KeysState::Ready(keys) => Some(keys),
KeysState::Invalid => unreachable!(),
}
}
}
pub struct GetRemoteKeys<'k>(&'k ArcKeys);
impl Future for GetRemoteKeys<'_> {
type Output = Option<Arc<Keys>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut keys = self.0.lock_guard();
match &mut *keys {
KeysState::Pending(waker) => {
if waker
.as_ref()
.is_some_and(|waker| waker.will_wake(cx.waker()))
{
unreachable!("Try to get remote keys from multiple tasks! This is a bug, please report it.")
}
*waker = Some(cx.waker().clone());
Poll::Pending
}
KeysState::Ready(keys) => Poll::Ready(Some(keys.clone())),
KeysState::Invalid => Poll::Ready(None),
}
}
}
pub struct OneRttPacketKeys {
cur_phase: KeyPhaseBit,
secrets: Secrets,
remote: [Option<Arc<dyn PacketKey>>; 2],
local: Arc<dyn PacketKey>,
}
impl OneRttPacketKeys {
fn new(remote: Box<dyn PacketKey>, local: Box<dyn PacketKey>, secrets: Secrets) -> Self {
Self {
cur_phase: KeyPhaseBit::default(),
secrets,
remote: [Some(Arc::from(remote)), None],
local: Arc::from(local),
}
}
pub fn update(&mut self) {
self.cur_phase.toggle();
let key_set = self.secrets.next_packet_keys();
self.remote[self.cur_phase.as_index()] = Some(Arc::from(key_set.remote));
self.local = Arc::from(key_set.local);
}
pub fn phase_out(&mut self) {
self.remote[(!self.cur_phase).as_index()].take();
}
pub fn get_remote(&mut self, key_phase: KeyPhaseBit, _pn: u64) -> Arc<dyn PacketKey> {
if key_phase != self.cur_phase && self.remote[key_phase.as_index()].is_none() {
self.update();
}
self.remote[key_phase.as_index()].clone().unwrap()
}
pub fn get_local(&self) -> (KeyPhaseBit, Arc<dyn PacketKey>) {
(self.cur_phase, self.local.clone())
}
}
#[derive(Clone)]
pub struct ArcOneRttPacketKeys(Arc<(Mutex<OneRttPacketKeys>, usize)>);
impl ArcOneRttPacketKeys {
pub fn lock_guard(&self) -> MutexGuard<OneRttPacketKeys> {
self.0 .0.lock().unwrap()
}
pub fn tag_len(&self) -> usize {
self.0 .1
}
}
#[derive(Clone)]
pub struct HeaderProtectionKeys {
pub local: Arc<dyn HeaderProtectionKey>,
pub remote: Arc<dyn HeaderProtectionKey>,
}
enum OneRttKeysState {
Pending(Option<Waker>),
Ready {
hpk: HeaderProtectionKeys,
pk: ArcOneRttPacketKeys,
},
Invalid,
}
#[derive(Clone)]
pub struct ArcOneRttKeys(Arc<Mutex<OneRttKeysState>>);
impl ArcOneRttKeys {
fn lock_guard(&self) -> MutexGuard<OneRttKeysState> {
self.0.lock().unwrap()
}
pub fn new_pending() -> Self {
Self(Arc::new(OneRttKeysState::Pending(None).into()))
}
pub fn set_keys(&self, keys: Keys, secrets: Secrets) {
let mut state = self.lock_guard();
match &mut *state {
OneRttKeysState::Pending(waker) => {
let hpk = HeaderProtectionKeys {
local: Arc::from(keys.local.header),
remote: Arc::from(keys.remote.header),
};
let tag_len = keys.local.packet.tag_len();
let pk = ArcOneRttPacketKeys(Arc::new((
Mutex::new(OneRttPacketKeys::new(
keys.remote.packet,
keys.local.packet,
secrets,
)),
tag_len,
)));
if let Some(w) = waker.take() {
w.wake();
}
*state = OneRttKeysState::Ready { hpk, pk };
}
OneRttKeysState::Ready { .. } => panic!("set_keys called twice"),
OneRttKeysState::Invalid => panic!("set_keys called after invalidation"),
}
}
pub fn invalid(&self) -> Option<(HeaderProtectionKeys, ArcOneRttPacketKeys)> {
let mut state = self.lock_guard();
match std::mem::replace(state.deref_mut(), OneRttKeysState::Invalid) {
OneRttKeysState::Pending(rx_waker) => {
if let Some(waker) = rx_waker {
waker.wake();
}
None
}
OneRttKeysState::Ready { hpk, pk } => Some((hpk, pk)),
OneRttKeysState::Invalid => unreachable!(),
}
}
pub fn get_local_keys(&self) -> Option<(Arc<dyn HeaderProtectionKey>, ArcOneRttPacketKeys)> {
let mut keys = self.lock_guard();
match &mut *keys {
OneRttKeysState::Ready { hpk, pk, .. } => Some((hpk.local.clone(), pk.clone())),
_ => None,
}
}
pub fn get_remote_keys(&self) -> GetRemoteOneRttKeys {
GetRemoteOneRttKeys(self)
}
}
pub struct GetRemoteOneRttKeys<'k>(&'k ArcOneRttKeys);
impl Future for GetRemoteOneRttKeys<'_> {
type Output = Option<(Arc<dyn HeaderProtectionKey>, ArcOneRttPacketKeys)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut keys = self.0.lock_guard();
match &mut *keys {
OneRttKeysState::Pending(waker) => {
if waker
.as_ref()
.is_some_and(|waker| waker.will_wake(cx.waker()))
{
unreachable!("Try to get remote keys from multiple tasks! This is a bug, please report it.")
}
*waker = Some(cx.waker().clone());
Poll::Pending
}
OneRttKeysState::Ready { hpk, pk, .. } => {
Poll::Ready(Some((hpk.remote.clone(), pk.clone())))
}
OneRttKeysState::Invalid => Poll::Ready(None),
}
}
}