use std::collections::{HashMap, HashSet};
use crate::tier3_aead::KEY_LEN;
#[derive(Clone)]
pub struct HmacKey(pub Vec<u8>);
#[derive(Clone)]
pub struct AeadKey(pub [u8; KEY_LEN]);
#[derive(Clone, Default)]
pub struct ChannelScope {
allow_all: bool,
channels: HashSet<String>,
}
impl ChannelScope {
pub fn list<I, S>(channels: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
ChannelScope {
allow_all: false,
channels: channels.into_iter().map(Into::into).collect(),
}
}
pub fn all() -> Self {
ChannelScope {
allow_all: true,
channels: HashSet::new(),
}
}
pub fn permits(&self, channel: &str) -> bool {
self.allow_all || self.channels.contains(channel)
}
}
pub trait KeyProvider: Send + Sync {
fn signing_keys(&self, key_id: &str) -> Vec<HmacKey>;
fn primary_signing_key(&self, key_id: &str) -> Option<HmacKey> {
self.signing_keys(key_id).into_iter().next()
}
fn aead_key(&self, enc_key_id: &str) -> Option<AeadKey>;
fn is_authorized(&self, key_id: &str, channel: &str) -> bool;
fn is_revoked(&self, key_id: &str) -> bool;
}
#[derive(Default)]
struct SigningEntry {
keys: Vec<HmacKey>,
scope: ChannelScope,
revoked: bool,
}
#[derive(Default)]
pub struct StaticKeyProvider {
signing: HashMap<String, SigningEntry>,
aead: HashMap<String, AeadKey>,
}
impl StaticKeyProvider {
pub fn new() -> Self {
Self::default()
}
pub fn with_signing_key(mut self, key_id: &str, key: HmacKey, scope: ChannelScope) -> Self {
self.signing.insert(
key_id.to_string(),
SigningEntry {
keys: vec![key],
scope,
revoked: false,
},
);
self
}
pub fn add_overlap_key(mut self, key_id: &str, key: HmacKey) -> Self {
self.signing
.entry(key_id.to_string())
.or_default()
.keys
.push(key);
self
}
pub fn revoke(mut self, key_id: &str) -> Self {
if let Some(e) = self.signing.get_mut(key_id) {
e.revoked = true;
}
self
}
pub fn with_aead_key(mut self, enc_key_id: &str, key: AeadKey) -> Self {
self.aead.insert(enc_key_id.to_string(), key);
self
}
}
impl KeyProvider for StaticKeyProvider {
fn signing_keys(&self, key_id: &str) -> Vec<HmacKey> {
self.signing
.get(key_id)
.map(|e| e.keys.clone())
.unwrap_or_default()
}
fn aead_key(&self, enc_key_id: &str) -> Option<AeadKey> {
self.aead.get(enc_key_id).cloned()
}
fn is_authorized(&self, key_id: &str, channel: &str) -> bool {
self.signing
.get(key_id)
.map(|e| !e.revoked && e.scope.permits(channel))
.unwrap_or(false)
}
fn is_revoked(&self, key_id: &str) -> bool {
self.signing.get(key_id).map(|e| e.revoked).unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lookup_and_scope() {
let p = StaticKeyProvider::new().with_signing_key(
"sas-east-01",
HmacKey(b"secret".to_vec()),
ChannelScope::list(["SportsFeed-East"]),
);
assert_eq!(p.signing_keys("sas-east-01").len(), 1);
assert!(p.signing_keys("nope").is_empty());
assert!(p.is_authorized("sas-east-01", "SportsFeed-East"));
assert!(!p.is_authorized("sas-east-01", "PremiumFeed"));
assert!(!p.is_revoked("sas-east-01"));
}
#[test]
fn revoked_key_not_authorized() {
let p = StaticKeyProvider::new()
.with_signing_key("k", HmacKey(b"s".to_vec()), ChannelScope::all())
.revoke("k");
assert!(p.is_revoked("k"));
assert!(!p.is_authorized("k", "anything"));
}
#[test]
fn separate_aead_namespace() {
let p = StaticKeyProvider::new().with_aead_key("enc-2026q1", AeadKey([9u8; KEY_LEN]));
assert!(p.aead_key("enc-2026q1").is_some());
assert!(p.aead_key("sas-east-01").is_none());
}
}