use super::MembershipCertificate;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertificateState {
Valid {
expires_in_ms: u64,
},
Warning {
expires_in_ms: u64,
},
GracePeriod {
grace_remaining_ms: u64,
},
Expired,
}
impl CertificateState {
pub fn is_operational(&self) -> bool {
matches!(
self,
CertificateState::Valid { .. }
| CertificateState::Warning { .. }
| CertificateState::GracePeriod { .. }
)
}
pub fn should_reauth(&self) -> bool {
matches!(
self,
CertificateState::Warning { .. } | CertificateState::GracePeriod { .. }
)
}
pub fn is_expired(&self) -> bool {
matches!(self, CertificateState::Expired)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AuthConfig {
pub auth_interval_hours: u16,
pub grace_period_hours: u16,
pub warning_threshold_hours: u16,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
auth_interval_hours: 24,
grace_period_hours: 4,
warning_threshold_hours: 1,
}
}
}
impl AuthConfig {
pub fn new(
auth_interval_hours: u16,
grace_period_hours: u16,
warning_threshold_hours: u16,
) -> Self {
Self {
auth_interval_hours,
grace_period_hours,
warning_threshold_hours,
}
}
pub fn auth_interval_ms(&self) -> u64 {
self.auth_interval_hours as u64 * 3_600_000
}
pub fn grace_period_ms(&self) -> u64 {
self.grace_period_hours as u64 * 3_600_000
}
pub fn warning_threshold_ms(&self) -> u64 {
self.warning_threshold_hours as u64 * 3_600_000
}
}
#[derive(Debug, Clone)]
pub struct AuthStateTracker {
config: AuthConfig,
}
impl Default for AuthStateTracker {
fn default() -> Self {
Self::new(AuthConfig::default())
}
}
impl AuthStateTracker {
pub fn new(config: AuthConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &AuthConfig {
&self.config
}
pub fn check_state(&self, cert: &MembershipCertificate, now_ms: u64) -> CertificateState {
let expires_at = cert.expires_at_ms;
if now_ms < expires_at {
let expires_in_ms = expires_at - now_ms;
if expires_in_ms <= self.config.warning_threshold_ms() {
CertificateState::Warning { expires_in_ms }
} else {
CertificateState::Valid { expires_in_ms }
}
} else {
let expired_for_ms = now_ms - expires_at;
if expired_for_ms < self.config.grace_period_ms() {
let grace_remaining_ms = self.config.grace_period_ms() - expired_for_ms;
CertificateState::GracePeriod { grace_remaining_ms }
} else {
CertificateState::Expired
}
}
}
pub fn needs_reauth(&self, cert: &MembershipCertificate, now_ms: u64) -> bool {
self.check_state(cert, now_ms).should_reauth()
}
pub fn is_operational(&self, cert: &MembershipCertificate, now_ms: u64) -> bool {
self.check_state(cert, now_ms).is_operational()
}
pub fn is_expired(&self, cert: &MembershipCertificate, now_ms: u64) -> bool {
self.check_state(cert, now_ms).is_expired()
}
pub fn reauth_deadline(&self, cert: &MembershipCertificate) -> u64 {
cert.expires_at_ms
.saturating_sub(self.config.warning_threshold_ms())
}
pub fn hard_cutoff(&self, cert: &MembershipCertificate) -> u64 {
cert.expires_at_ms
.saturating_add(self.config.grace_period_ms())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthStateEvent {
EnteringWarning { expires_in_ms: u64 },
EnteringGracePeriod { grace_remaining_ms: u64 },
Expired,
Renewed { new_expires_at_ms: u64 },
}
#[derive(Debug, Clone)]
pub struct AuthStateMonitor {
tracker: AuthStateTracker,
last_state: Option<CertificateState>,
}
impl AuthStateMonitor {
pub fn new(tracker: AuthStateTracker) -> Self {
Self {
tracker,
last_state: None,
}
}
pub fn update(&mut self, cert: &MembershipCertificate, now_ms: u64) -> Option<AuthStateEvent> {
let new_state = self.tracker.check_state(cert, now_ms);
let event = match (&self.last_state, &new_state) {
(Some(CertificateState::Valid { .. }), CertificateState::Warning { expires_in_ms })
| (None, CertificateState::Warning { expires_in_ms }) => {
Some(AuthStateEvent::EnteringWarning {
expires_in_ms: *expires_in_ms,
})
}
(
Some(CertificateState::Warning { .. }),
CertificateState::GracePeriod { grace_remaining_ms },
) => Some(AuthStateEvent::EnteringGracePeriod {
grace_remaining_ms: *grace_remaining_ms,
}),
(Some(state), CertificateState::Expired) if *state != CertificateState::Expired => {
Some(AuthStateEvent::Expired)
}
_ => None,
};
self.last_state = Some(new_state);
event
}
pub fn notify_renewed(&mut self, new_cert: &MembershipCertificate) -> AuthStateEvent {
self.last_state = Some(CertificateState::Valid {
expires_in_ms: new_cert.expires_at_ms,
});
AuthStateEvent::Renewed {
new_expires_at_ms: new_cert.expires_at_ms,
}
}
pub fn current_state(&self) -> Option<&CertificateState> {
self.last_state.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_cert(issued_at_ms: u64, expires_at_ms: u64) -> MembershipCertificate {
MembershipCertificate {
member_public_key: [0u8; 32],
mesh_id: "A1B2C3D4".to_string(),
callsign: "TEST-01".to_string(),
permissions: super::super::MemberPermissions::STANDARD,
issued_at_ms,
expires_at_ms,
issuer_public_key: [0u8; 32],
issuer_signature: [0u8; 64],
}
}
#[test]
fn test_config_defaults() {
let config = AuthConfig::default();
assert_eq!(config.auth_interval_hours, 24);
assert_eq!(config.grace_period_hours, 4);
assert_eq!(config.warning_threshold_hours, 1);
}
#[test]
fn test_config_to_ms() {
let config = AuthConfig::default();
assert_eq!(config.auth_interval_ms(), 24 * 3_600_000);
assert_eq!(config.grace_period_ms(), 4 * 3_600_000);
assert_eq!(config.warning_threshold_ms(), 3_600_000);
}
#[test]
fn test_valid_state() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
let state = tracker.check_state(&cert, 0);
assert!(
matches!(state, CertificateState::Valid { expires_in_ms } if expires_in_ms == 24 * 3_600_000)
);
assert!(state.is_operational());
assert!(!state.should_reauth());
let state = tracker.check_state(&cert, 12 * 3_600_000);
assert!(
matches!(state, CertificateState::Valid { expires_in_ms } if expires_in_ms == 12 * 3_600_000)
);
}
#[test]
fn test_warning_state() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
let state = tracker.check_state(&cert, 23 * 3_600_000);
assert!(
matches!(state, CertificateState::Warning { expires_in_ms } if expires_in_ms == 3_600_000)
);
assert!(state.is_operational());
assert!(state.should_reauth());
let state = tracker.check_state(&cert, 23 * 3_600_000 + 1_800_000);
assert!(
matches!(state, CertificateState::Warning { expires_in_ms } if expires_in_ms == 1_800_000)
);
}
#[test]
fn test_grace_period_state() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
let state = tracker.check_state(&cert, 24 * 3_600_000);
assert!(
matches!(state, CertificateState::GracePeriod { grace_remaining_ms } if grace_remaining_ms == 4 * 3_600_000)
);
assert!(state.is_operational());
assert!(state.should_reauth());
let state = tracker.check_state(&cert, 26 * 3_600_000);
assert!(
matches!(state, CertificateState::GracePeriod { grace_remaining_ms } if grace_remaining_ms == 2 * 3_600_000)
);
}
#[test]
fn test_expired_state() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
let state = tracker.check_state(&cert, 28 * 3_600_000);
assert!(matches!(state, CertificateState::Expired));
assert!(!state.is_operational());
assert!(!state.should_reauth());
let state = tracker.check_state(&cert, 30 * 3_600_000);
assert!(matches!(state, CertificateState::Expired));
}
#[test]
fn test_needs_reauth() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
assert!(!tracker.needs_reauth(&cert, 0));
assert!(!tracker.needs_reauth(&cert, 22 * 3_600_000));
assert!(tracker.needs_reauth(&cert, 23 * 3_600_000));
assert!(tracker.needs_reauth(&cert, 23 * 3_600_000 + 1_800_000));
assert!(tracker.needs_reauth(&cert, 25 * 3_600_000));
assert!(!tracker.needs_reauth(&cert, 29 * 3_600_000));
}
#[test]
fn test_is_operational() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
assert!(tracker.is_operational(&cert, 0));
assert!(tracker.is_operational(&cert, 23 * 3_600_000)); assert!(tracker.is_operational(&cert, 26 * 3_600_000)); assert!(!tracker.is_operational(&cert, 29 * 3_600_000)); }
#[test]
fn test_deadlines() {
let tracker = AuthStateTracker::default();
let cert = test_cert(0, 24 * 3_600_000);
assert_eq!(tracker.reauth_deadline(&cert), 23 * 3_600_000);
assert_eq!(tracker.hard_cutoff(&cert), 28 * 3_600_000);
}
#[test]
fn test_custom_config() {
let config = AuthConfig::new(48, 8, 2); let tracker = AuthStateTracker::new(config);
let cert = test_cert(0, 48 * 3_600_000);
let state = tracker.check_state(&cert, 45 * 3_600_000);
assert!(matches!(state, CertificateState::Valid { .. }));
let state = tracker.check_state(&cert, 46 * 3_600_000 + 1_800_000);
assert!(matches!(state, CertificateState::Warning { .. }));
let state = tracker.check_state(&cert, 52 * 3_600_000);
assert!(
matches!(state, CertificateState::GracePeriod { grace_remaining_ms } if grace_remaining_ms == 4 * 3_600_000)
);
let state = tracker.check_state(&cert, 56 * 3_600_000);
assert!(matches!(state, CertificateState::Expired));
}
#[test]
fn test_monitor_transitions() {
let tracker = AuthStateTracker::default();
let mut monitor = AuthStateMonitor::new(tracker);
let cert = test_cert(0, 24 * 3_600_000);
let event = monitor.update(&cert, 0);
assert!(event.is_none());
let event = monitor.update(&cert, 22 * 3_600_000);
assert!(event.is_none());
let event = monitor.update(&cert, 23 * 3_600_000);
assert!(matches!(
event,
Some(AuthStateEvent::EnteringWarning { .. })
));
let event = monitor.update(&cert, 24 * 3_600_000);
assert!(matches!(
event,
Some(AuthStateEvent::EnteringGracePeriod { .. })
));
let event = monitor.update(&cert, 28 * 3_600_000);
assert!(matches!(event, Some(AuthStateEvent::Expired)));
let event = monitor.update(&cert, 30 * 3_600_000);
assert!(event.is_none());
}
#[test]
fn test_monitor_renewal() {
let tracker = AuthStateTracker::default();
let mut monitor = AuthStateMonitor::new(tracker);
let cert = test_cert(0, 24 * 3_600_000);
monitor.update(&cert, 23 * 3_600_000);
let new_cert = test_cert(23 * 3_600_000, 47 * 3_600_000);
let event = monitor.notify_renewed(&new_cert);
assert!(matches!(
event,
AuthStateEvent::Renewed {
new_expires_at_ms: exp
} if exp == 47 * 3_600_000
));
let state = monitor.current_state();
assert!(matches!(state, Some(CertificateState::Valid { .. })));
}
}