use std::{
future::Future,
ops::DerefMut,
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
task::{Context, Poll, Waker},
};
use futures::FutureExt;
use rustls::quic::{
DirectionalKeys as RustlsDirectionalKeys, HeaderProtectionKey, Keys as RustlsKeys, PacketKey,
Secrets,
};
#[derive(Clone)]
pub struct DirectionalKeys {
pub header: Arc<dyn HeaderProtectionKey>,
pub packet: Arc<dyn PacketKey>,
}
impl From<RustlsDirectionalKeys> for DirectionalKeys {
fn from(keys: RustlsDirectionalKeys) -> Self {
Self {
header: keys.header.into(),
packet: keys.packet.into(),
}
}
}
#[derive(Clone)]
pub struct Keys {
pub local: DirectionalKeys,
pub remote: DirectionalKeys,
}
impl From<RustlsKeys> for Keys {
fn from(keys: RustlsKeys) -> Self {
Self {
local: keys.local.into(),
remote: keys.remote.into(),
}
}
}
use super::KeyPhaseBit;
use crate::role::Role;
#[derive(Clone)]
enum KeysState<K> {
Pending(Option<Waker>),
Ready(K),
Invalid,
}
impl<K> KeysState<K> {
fn set(&mut self, keys: K) {
match self {
KeysState::Pending(waker) => {
if let Some(waker) = waker.take() {
waker.wake();
}
*self = KeysState::Ready(keys);
}
KeysState::Ready(_) => unreachable!("KeysState::set called twice"),
KeysState::Invalid => unreachable!("KeysState::set called after invalidation"),
}
}
fn get(&mut self) -> Option<&K> {
match self {
KeysState::Ready(keys) => Some(keys),
KeysState::Pending(..) | KeysState::Invalid => None,
}
}
fn invalid(&mut self) -> Option<K> {
match std::mem::replace(self, KeysState::Invalid) {
KeysState::Pending(waker) => {
if let Some(waker) = waker {
waker.wake();
}
None
}
KeysState::Ready(keys) => Some(keys),
KeysState::Invalid => None,
}
}
}
impl<K: Unpin + Clone> Future for KeysState<K> {
type Output = Option<K>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
KeysState::Pending(waker) => {
if waker
.as_ref()
.is_some_and(|waker| !waker.will_wake(cx.waker()))
{
unreachable!(
"Try to get remote keys from multiple tasks! This is a bug, please report it."
)
}
*waker = Some(cx.waker().clone());
Poll::Pending
}
KeysState::Ready(keys) => Poll::Ready(Some(keys.clone())),
KeysState::Invalid => Poll::Ready(None),
}
}
}
#[derive(Clone)]
pub struct ArcKeys(Arc<Mutex<KeysState<Keys>>>);
impl ArcKeys {
fn lock_guard(&self) -> MutexGuard<'_, KeysState<Keys>> {
self.0.lock().unwrap()
}
pub fn new_pending() -> Self {
Self(Arc::new(KeysState::Pending(None).into()))
}
pub fn with_keys(keys: Keys) -> Self {
Self(Arc::new(KeysState::Ready(keys).into()))
}
pub fn get_remote_keys(&self) -> GetRemoteKeys<'_, Keys> {
GetRemoteKeys(&self.0)
}
pub fn get_local_keys(&self) -> Option<Keys> {
self.lock_guard().get().cloned()
}
pub fn set_keys(&self, keys: Keys) {
self.lock_guard().set(keys);
}
pub fn invalid(&self) -> Option<Keys> {
self.lock_guard().invalid()
}
}
pub struct GetRemoteKeys<'k, K>(&'k Mutex<KeysState<K>>);
impl<K: Unpin + Clone> Future for GetRemoteKeys<'_, K> {
type Output = Option<K>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(self.0.lock().unwrap()).poll_unpin(cx)
}
}
#[derive(Clone)]
pub struct ArcZeroRttKeys {
role: Role,
keys: Arc<Mutex<KeysState<DirectionalKeys>>>,
}
impl ArcZeroRttKeys {
pub fn new_pending(role: Role) -> Self {
Self {
role,
keys: Arc::new(Mutex::new(KeysState::Pending(None))),
}
}
fn lock_guard(&self) -> MutexGuard<'_, KeysState<DirectionalKeys>> {
self.keys.lock().unwrap()
}
pub fn set_keys(&self, keys: DirectionalKeys) {
self.lock_guard().set(keys);
}
pub fn get_encrypt_keys(&self) -> Option<DirectionalKeys> {
match self.role {
Role::Client => self.lock_guard().get().cloned(),
Role::Server => None,
}
}
pub fn get_decrypt_keys(&self) -> Option<GetRemoteKeys<'_, DirectionalKeys>> {
match self.role {
Role::Client => None,
Role::Server => Some(GetRemoteKeys(&self.keys)),
}
}
pub fn invalid(&self) -> Option<DirectionalKeys> {
self.lock_guard().invalid()
}
}
pub struct OneRttPacketKeys {
cur_phase: KeyPhaseBit,
secrets: Secrets,
remote: [Option<Arc<dyn PacketKey>>; 2],
local: Arc<dyn PacketKey>,
}
impl OneRttPacketKeys {
fn new(remote: Box<dyn PacketKey>, local: Box<dyn PacketKey>, secrets: Secrets) -> Self {
Self {
cur_phase: KeyPhaseBit::default(),
secrets,
remote: [Some(Arc::from(remote)), None],
local: Arc::from(local),
}
}
pub fn update(&mut self) {
self.cur_phase.toggle();
let key_set = self.secrets.next_packet_keys();
self.remote[self.cur_phase.as_index()] = Some(Arc::from(key_set.remote));
self.local = Arc::from(key_set.local);
}
pub fn phase_out(&mut self) {
self.remote[(!self.cur_phase).as_index()].take();
}
pub fn get_remote(&mut self, key_phase: KeyPhaseBit, _pn: u64) -> Arc<dyn PacketKey> {
if key_phase != self.cur_phase && self.remote[key_phase.as_index()].is_none() {
self.update();
}
self.remote[key_phase.as_index()].clone().unwrap()
}
pub fn get_local(&self) -> (KeyPhaseBit, Arc<dyn PacketKey>) {
(self.cur_phase, self.local.clone())
}
}
#[derive(Clone)]
pub struct ArcOneRttPacketKeys(Arc<(Mutex<OneRttPacketKeys>, usize)>);
impl ArcOneRttPacketKeys {
pub fn lock_guard(&self) -> MutexGuard<'_, OneRttPacketKeys> {
self.0.0.lock().unwrap()
}
pub fn tag_len(&self) -> usize {
self.0.1
}
}
#[derive(Clone)]
pub struct HeaderProtectionKeys {
pub local: Arc<dyn HeaderProtectionKey>,
pub remote: Arc<dyn HeaderProtectionKey>,
}
enum OneRttKeysState {
Pending(Option<Waker>),
Ready {
hpk: HeaderProtectionKeys,
pk: ArcOneRttPacketKeys,
},
Invalid,
}
#[derive(Clone)]
pub struct ArcOneRttKeys(Arc<Mutex<OneRttKeysState>>);
impl ArcOneRttKeys {
fn lock_guard(&self) -> MutexGuard<'_, OneRttKeysState> {
self.0.lock().unwrap()
}
pub fn new_pending() -> Self {
Self(Arc::new(OneRttKeysState::Pending(None).into()))
}
pub fn set_keys(&self, keys: RustlsKeys, secrets: Secrets) {
let mut state = self.lock_guard();
match &mut *state {
OneRttKeysState::Pending(waker) => {
let hpk = HeaderProtectionKeys {
local: Arc::from(keys.local.header),
remote: Arc::from(keys.remote.header),
};
let tag_len = keys.local.packet.tag_len();
let pk = ArcOneRttPacketKeys(Arc::new((
Mutex::new(OneRttPacketKeys::new(
keys.remote.packet,
keys.local.packet,
secrets,
)),
tag_len,
)));
if let Some(w) = waker.take() {
w.wake();
}
*state = OneRttKeysState::Ready { hpk, pk };
}
OneRttKeysState::Ready { .. } => panic!("set_keys called twice"),
OneRttKeysState::Invalid => panic!("set_keys called after invalidation"),
}
}
pub fn invalid(&self) -> Option<(HeaderProtectionKeys, ArcOneRttPacketKeys)> {
let mut state = self.lock_guard();
match std::mem::replace(state.deref_mut(), OneRttKeysState::Invalid) {
OneRttKeysState::Pending(rx_waker) => {
if let Some(waker) = rx_waker {
waker.wake();
}
None
}
OneRttKeysState::Ready { hpk, pk } => Some((hpk, pk)),
OneRttKeysState::Invalid => unreachable!(),
}
}
pub fn get_local_keys(&self) -> Option<(Arc<dyn HeaderProtectionKey>, ArcOneRttPacketKeys)> {
let mut keys = self.lock_guard();
match &mut *keys {
OneRttKeysState::Ready { hpk, pk, .. } => Some((hpk.local.clone(), pk.clone())),
_ => None,
}
}
pub fn remote_keys(&self) -> Option<(Arc<dyn HeaderProtectionKey>, ArcOneRttPacketKeys)> {
match &mut *self.lock_guard() {
OneRttKeysState::Ready { hpk, pk, .. } => Some((hpk.remote.clone(), pk.clone())),
_ => None,
}
}
pub fn get_remote_keys(&self) -> GetRemoteOneRttKeys<'_> {
GetRemoteOneRttKeys(self)
}
}
pub struct GetRemoteOneRttKeys<'k>(&'k ArcOneRttKeys);
impl Future for GetRemoteOneRttKeys<'_> {
type Output = Option<(Arc<dyn HeaderProtectionKey>, ArcOneRttPacketKeys)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut keys = self.0.lock_guard();
match &mut *keys {
OneRttKeysState::Pending(waker) => {
if waker
.as_ref()
.is_some_and(|waker| !waker.will_wake(cx.waker()))
{
unreachable!(
"Try to get remote keys from multiple tasks! This is a bug, please report it."
)
}
*waker = Some(cx.waker().clone());
Poll::Pending
}
OneRttKeysState::Ready { hpk, pk, .. } => {
Poll::Ready(Some((hpk.remote.clone(), pk.clone())))
}
OneRttKeysState::Invalid => Poll::Ready(None),
}
}
}