use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
time::{Duration, Instant},
};
use snap_tun::server::SnapTunAuthorization;
use crate::crpc_api::api_service::model::SnapTunIdentityRegistry;
type Identity = [u8; 32];
#[derive(Default, Clone)]
struct IdentityRegistryState {
pub associations: BTreeMap<Arc<str>, Arc<Identity>>,
pub expiry: BTreeMap<Arc<Identity>, Instant>,
}
impl IdentityRegistryState {
pub(crate) fn is_authorized(&self, now: Instant, ident: &Identity) -> bool {
self.expiry
.get(ident)
.map(|expiry| *expiry > now)
.unwrap_or(false)
}
pub(crate) fn add_identity<S: AsRef<str>>(
&mut self,
key: S,
identity: Identity,
expiry: Instant,
) -> bool {
let key = Arc::<str>::from(key.as_ref().to_string());
let ident = Arc::new(identity);
if let Some(prev_ident) = self.associations.insert(key.clone(), ident.clone())
&& prev_ident != ident
{
self.expiry.remove(&prev_ident);
}
self.expiry.insert(ident, expiry).is_none()
}
pub(crate) fn clean_expired(&mut self, now: Instant) {
let mut removed: HashSet<Arc<Identity>> = Default::default();
self.expiry.retain(|ident, expiry| {
if *expiry <= now {
removed.insert(ident.clone());
return false;
}
true
});
self.associations
.retain(|_, ident| !removed.contains(ident));
}
}
pub struct IdentityRegistry {
state: arc_swap::ArcSwap<IdentityRegistryState>,
}
impl IdentityRegistry {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
state: Default::default(),
}
}
pub fn is_authorized(&self, now: Instant, identity: &Identity) -> bool {
self.state.load().is_authorized(now, identity)
}
pub fn register<S: AsRef<str>>(
&self,
now: Instant,
key: S,
ident: Identity,
lifetime: Duration,
) -> bool {
let mut res = false;
self.update_state(|state| {
res = state.add_identity(key, ident, now + lifetime);
});
res
}
pub fn remove_expired(&self, now: Instant) {
self.update_state(|state| state.clean_expired(now));
}
fn update_state<F>(&self, modifier: F)
where
F: FnOnce(&mut IdentityRegistryState),
{
let mut state: IdentityRegistryState = (**self.state.load()).clone();
(modifier)(&mut state);
self.state.store(Arc::new(state))
}
#[cfg(test)]
pub(crate) fn ident_exist(&self, ident: &Identity) -> bool {
self.state
.load()
.associations
.values()
.any(|v| v.as_ref() == ident)
|| self.state.load().expiry.keys().any(|k| k.as_ref() == ident)
}
}
impl SnapTunIdentityRegistry for IdentityRegistry {
fn register(
&self,
now: Instant,
key: &str,
identity: Identity,
_psk_share: Option<[u8; 32]>,
lifetime: Duration,
) -> bool {
self.register(now, key, identity, lifetime)
}
}
impl SnapTunAuthorization for IdentityRegistry {
fn is_authorized(&self, now: Instant, identity: &Identity) -> bool {
self.is_authorized(now, identity)
}
}
#[cfg(test)]
mod tests {
use x25519_dalek::PublicKey;
use super::*;
fn create_test_identity(seed: u8) -> PublicKey {
let mut bytes = [0u8; 32];
bytes[0] = seed;
PublicKey::from(bytes)
}
#[test]
fn test_identity_not_registered() {
let registry = IdentityRegistry::new();
let now = Instant::now();
let identity = create_test_identity(1);
assert!(!registry.is_authorized(now, identity.as_bytes()));
}
#[test]
fn test_identity_is_authorized_before_expires() {
let registry = IdentityRegistry::new();
let now = Instant::now();
let identity = create_test_identity(1);
registry.register(now, "", *identity.as_bytes(), Duration::from_secs(30));
assert!(registry.is_authorized(now, identity.as_bytes()));
}
#[test]
fn test_reregistering_identity_returns_false_but_succeeds() {
let registry = IdentityRegistry::new();
let now = Instant::now();
let identity = create_test_identity(1);
let delta_t = Duration::from_secs(10);
registry.register(now, "", *identity.as_bytes(), delta_t);
assert!(!registry.is_authorized(now + delta_t, identity.as_bytes()));
assert!(!registry.register(now, "", *identity.as_bytes(), 2 * delta_t));
assert!(registry.is_authorized(now + delta_t, identity.as_bytes()));
}
#[test]
fn test_identity_is_unauthorized_at_expiry() {
let registry = IdentityRegistry::new();
let now = Instant::now();
let identity = create_test_identity(1);
let delta_t = Duration::from_secs(30);
registry.register(now, "", *identity.as_bytes(), delta_t);
assert!(!registry.is_authorized(now + delta_t, identity.as_bytes()));
}
#[test]
fn test_identity_is_removed_after_expiry() {
let registry = IdentityRegistry::new();
let now = Instant::now();
let identity = create_test_identity(1);
let delta_t = Duration::from_secs(30);
registry.register(now, "", *identity.as_bytes(), delta_t);
assert!(registry.ident_exist(identity.as_bytes()));
registry.remove_expired(now + delta_t);
assert!(!registry.ident_exist(identity.as_bytes()));
}
}