mod handshake;
mod keychain;
mod rotation;
mod window;
use crate::{crypto::Tag, Result};
use bramble_common::transport::{Id, Latency};
use bramble_crypto::{Role, SymmetricKey};
use futures::{AsyncRead, AsyncWrite};
use handshake::HandshakeKeys;
use keychain::Keychain;
use rotation::RotationKeys;
use window::ReorderingWindows;
pub use window::StreamId;
#[allow(clippy::large_enum_variant)]
enum KeyRotator<T>
where
T: Id + Latency + AsyncRead + AsyncWrite,
{
Handshake(HandshakeKeys<T>),
Rotation(RotationKeys<T>),
}
pub trait RotateKeys {
fn outgoing_keys_for(&mut self, period: u64) -> Result<Keychain>;
fn incoming_keys_for(&mut self, period: u64) -> Result<Keychain>;
}
impl<T> RotateKeys for KeyRotator<T>
where
T: Id + Latency + AsyncRead + AsyncWrite,
{
fn outgoing_keys_for(&mut self, period: u64) -> Result<Keychain> {
match self {
KeyRotator::Handshake(h) => h.outgoing_keys_for(period),
KeyRotator::Rotation(r) => r.outgoing_keys_for(period),
}
}
fn incoming_keys_for(&mut self, period: u64) -> Result<Keychain> {
match self {
KeyRotator::Handshake(h) => h.incoming_keys_for(period),
KeyRotator::Rotation(r) => r.incoming_keys_for(period),
}
}
}
pub struct KeyManager<T>
where
T: Id + Latency + AsyncRead + AsyncWrite,
{
rotator: KeyRotator<T>,
windows: ReorderingWindows,
}
impl<T> KeyManager<T>
where
T: Id + Latency + AsyncRead + AsyncWrite,
{
fn new(mut rotator: KeyRotator<T>, now: u64) -> Result<Self> {
let windows = ReorderingWindows::new(now, &mut rotator)?;
Ok(Self { windows, rotator })
}
pub fn handshake(root_key: SymmetricKey, now: u64, role: Role) -> Result<Self> {
let keys = HandshakeKeys::new(root_key, role);
Self::new(KeyRotator::Handshake(keys), now)
}
pub fn rotation(root_key: SymmetricKey, now: u64, role: Role) -> Result<Self> {
let keys = RotationKeys::new(root_key, now, role);
Self::new(KeyRotator::Rotation(keys), now)
}
pub fn outgoing_keys_for(&mut self, now: u64) -> Result<Keychain> {
RotateKeys::outgoing_keys_for(&mut self.rotator, now)
}
pub fn see_tag(&mut self, new_tag: Tag) -> Option<(StreamId, Keychain)> {
let stream_id = self.windows.see(new_tag, &mut self.rotator)?;
let keys = self.rotator.incoming_keys_for(stream_id.period).ok()?;
Some((stream_id, keys))
}
pub fn advance_to(&mut self, now: u64) -> Result<()> {
self.windows.advance_to(now, &mut self.rotator)
}
}