use super::authorization::Role;
use super::device_id::DeviceId;
use super::error::SecurityError;
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2,
};
use hmac::{Hmac, Mac};
use rand_core::RngCore;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use uuid::Uuid;
pub const DEFAULT_SESSION_EXPIRY_HOURS: u64 = 8;
pub const TOTP_TIME_STEP_SECS: u64 = 30;
pub const TOTP_CODE_LENGTH: usize = 6;
pub const TOTP_CLOCK_DRIFT_STEPS: i64 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MilitaryRank {
Private,
Specialist,
Corporal,
Sergeant,
StaffSergeant,
SergeantFirstClass,
MasterSergeant,
FirstSergeant,
SergeantMajor,
WarrantOfficer1,
ChiefWarrantOfficer2,
ChiefWarrantOfficer3,
ChiefWarrantOfficer4,
ChiefWarrantOfficer5,
SecondLieutenant,
FirstLieutenant,
Captain,
Major,
LieutenantColonel,
Colonel,
BrigadierGeneral,
MajorGeneral,
LieutenantGeneral,
General,
Civilian,
}
impl std::fmt::Display for MilitaryRank {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MilitaryRank::Private => write!(f, "PVT"),
MilitaryRank::Specialist => write!(f, "SPC"),
MilitaryRank::Corporal => write!(f, "CPL"),
MilitaryRank::Sergeant => write!(f, "SGT"),
MilitaryRank::StaffSergeant => write!(f, "SSG"),
MilitaryRank::SergeantFirstClass => write!(f, "SFC"),
MilitaryRank::MasterSergeant => write!(f, "MSG"),
MilitaryRank::FirstSergeant => write!(f, "1SG"),
MilitaryRank::SergeantMajor => write!(f, "SGM"),
MilitaryRank::WarrantOfficer1 => write!(f, "WO1"),
MilitaryRank::ChiefWarrantOfficer2 => write!(f, "CW2"),
MilitaryRank::ChiefWarrantOfficer3 => write!(f, "CW3"),
MilitaryRank::ChiefWarrantOfficer4 => write!(f, "CW4"),
MilitaryRank::ChiefWarrantOfficer5 => write!(f, "CW5"),
MilitaryRank::SecondLieutenant => write!(f, "2LT"),
MilitaryRank::FirstLieutenant => write!(f, "1LT"),
MilitaryRank::Captain => write!(f, "CPT"),
MilitaryRank::Major => write!(f, "MAJ"),
MilitaryRank::LieutenantColonel => write!(f, "LTC"),
MilitaryRank::Colonel => write!(f, "COL"),
MilitaryRank::BrigadierGeneral => write!(f, "BG"),
MilitaryRank::MajorGeneral => write!(f, "MG"),
MilitaryRank::LieutenantGeneral => write!(f, "LTG"),
MilitaryRank::General => write!(f, "GEN"),
MilitaryRank::Civilian => write!(f, "CIV"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum SecurityClearance {
Unclassified,
Cui,
Confidential,
Secret,
TopSecret,
TopSecretSci,
}
impl std::fmt::Display for SecurityClearance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SecurityClearance::Unclassified => write!(f, "UNCLASSIFIED"),
SecurityClearance::Cui => write!(f, "CUI"),
SecurityClearance::Confidential => write!(f, "CONFIDENTIAL"),
SecurityClearance::Secret => write!(f, "SECRET"),
SecurityClearance::TopSecret => write!(f, "TOP SECRET"),
SecurityClearance::TopSecretSci => write!(f, "TS/SCI"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OrganizationUnit {
pub name: String,
pub parent: Option<String>,
pub uic: Option<String>,
}
impl OrganizationUnit {
pub fn new(name: impl Into<String>, parent: impl Into<String>) -> Self {
Self {
name: name.into(),
parent: Some(parent.into()),
uic: None,
}
}
pub fn top_level(name: impl Into<String>) -> Self {
Self {
name: name.into(),
parent: None,
uic: None,
}
}
pub fn with_uic(mut self, uic: impl Into<String>) -> Self {
self.uic = Some(uic.into());
self
}
}
impl std::fmt::Display for OrganizationUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(parent) = &self.parent {
write!(f, "{}, {}", self.name, parent)
} else {
write!(f, "{}", self.name)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserIdentity {
pub username: String,
pub display_name: String,
pub rank: MilitaryRank,
pub clearance: SecurityClearance,
pub unit: OrganizationUnit,
pub roles: HashSet<Role>,
}
impl UserIdentity {
pub fn builder(username: impl Into<String>) -> UserIdentityBuilder {
UserIdentityBuilder::new(username)
}
}
pub struct UserIdentityBuilder {
username: String,
display_name: Option<String>,
rank: MilitaryRank,
clearance: SecurityClearance,
unit: Option<OrganizationUnit>,
roles: HashSet<Role>,
}
impl UserIdentityBuilder {
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
display_name: None,
rank: MilitaryRank::Civilian,
clearance: SecurityClearance::Unclassified,
unit: None,
roles: HashSet::new(),
}
}
pub fn display_name(mut self, name: impl Into<String>) -> Self {
self.display_name = Some(name.into());
self
}
pub fn rank(mut self, rank: MilitaryRank) -> Self {
self.rank = rank;
self
}
pub fn clearance(mut self, clearance: SecurityClearance) -> Self {
self.clearance = clearance;
self
}
pub fn unit(mut self, unit: OrganizationUnit) -> Self {
self.unit = Some(unit);
self
}
pub fn role(mut self, role: Role) -> Self {
self.roles.insert(role);
self
}
pub fn roles(mut self, roles: impl IntoIterator<Item = Role>) -> Self {
self.roles.extend(roles);
self
}
pub fn build(self) -> UserIdentity {
UserIdentity {
username: self.username.clone(),
display_name: self.display_name.unwrap_or_else(|| self.username.clone()),
rank: self.rank,
clearance: self.clearance,
unit: self
.unit
.unwrap_or_else(|| OrganizationUnit::top_level("Unknown")),
roles: self.roles,
}
}
}
#[derive(Debug, Clone)]
pub struct UserRecord {
pub identity: UserIdentity,
pub auth_method: AuthMethod,
pub status: AccountStatus,
pub created_at: SystemTime,
pub last_login: Option<SystemTime>,
pub failed_attempts: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AccountStatus {
Active,
Locked,
Disabled,
Pending,
}
#[derive(Debug, Clone)]
pub enum AuthMethod {
PasswordMfa {
password_hash: String,
totp_secret: Vec<u8>,
},
SmartCard {
card_id: String,
pin_hash: String,
},
Certificate {
certificate_fingerprint: String,
},
}
#[derive(Debug, Clone)]
pub enum Credential {
PasswordMfa {
password: String,
totp_code: String,
},
SmartCard {
card_id: String,
pin: String,
},
Certificate {
fingerprint: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserSession {
pub session_id: SessionId,
pub identity: UserIdentity,
pub device_id: Option<DeviceId>,
pub created_at: SystemTime,
pub expires_at: SystemTime,
}
impl UserSession {
pub fn is_expired(&self) -> bool {
SystemTime::now() > self.expires_at
}
pub fn remaining_time(&self) -> Option<Duration> {
self.expires_at.duration_since(SystemTime::now()).ok()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionId(Uuid);
impl SessionId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub trait UserStore: Send + Sync {
fn get_user(&self, username: &str) -> Option<UserRecord>;
fn store_user(&self, record: UserRecord) -> Result<(), SecurityError>;
fn update_user(&self, record: UserRecord) -> Result<(), SecurityError>;
fn delete_user(&self, username: &str) -> Result<(), SecurityError>;
fn list_users(&self) -> Vec<String>;
}
#[derive(Debug, Default)]
pub struct LocalUserStore {
users: RwLock<HashMap<String, UserRecord>>,
}
impl LocalUserStore {
pub fn new() -> Self {
Self {
users: RwLock::new(HashMap::new()),
}
}
pub fn with_users(users: Vec<UserRecord>) -> Self {
let store = Self::new();
{
let mut map = store.users.write().expect("users lock poisoned");
for user in users {
map.insert(user.identity.username.clone(), user);
}
}
store
}
}
impl UserStore for LocalUserStore {
fn get_user(&self, username: &str) -> Option<UserRecord> {
self.users.read().ok()?.get(username).cloned()
}
fn store_user(&self, record: UserRecord) -> Result<(), SecurityError> {
let mut users = self
.users
.write()
.map_err(|e| SecurityError::Internal(format!("users lock poisoned: {e}")))?;
if users.contains_key(&record.identity.username) {
return Err(SecurityError::UserAlreadyExists {
username: record.identity.username,
});
}
users.insert(record.identity.username.clone(), record);
Ok(())
}
fn update_user(&self, record: UserRecord) -> Result<(), SecurityError> {
let mut users = self
.users
.write()
.map_err(|e| SecurityError::Internal(format!("users lock poisoned: {e}")))?;
if !users.contains_key(&record.identity.username) {
return Err(SecurityError::UserNotFound {
username: record.identity.username,
});
}
users.insert(record.identity.username.clone(), record);
Ok(())
}
fn delete_user(&self, username: &str) -> Result<(), SecurityError> {
let mut users = self
.users
.write()
.map_err(|e| SecurityError::Internal(format!("users lock poisoned: {e}")))?;
if users.remove(username).is_none() {
return Err(SecurityError::UserNotFound {
username: username.to_string(),
});
}
Ok(())
}
fn list_users(&self) -> Vec<String> {
self.users
.read()
.map(|u| u.keys().cloned().collect())
.unwrap_or_default()
}
}
pub struct UserAuthenticator {
user_store: Box<dyn UserStore>,
sessions: Arc<RwLock<HashMap<SessionId, UserSession>>>,
session_expiry: Duration,
max_failed_attempts: u32,
}
impl UserAuthenticator {
pub fn new(user_store: Box<dyn UserStore>) -> Self {
Self {
user_store,
sessions: Arc::new(RwLock::new(HashMap::new())),
session_expiry: Duration::from_secs(DEFAULT_SESSION_EXPIRY_HOURS * 3600),
max_failed_attempts: 5,
}
}
pub fn with_session_expiry(mut self, expiry: Duration) -> Self {
self.session_expiry = expiry;
self
}
pub fn with_max_failed_attempts(mut self, max: u32) -> Self {
self.max_failed_attempts = max;
self
}
pub fn register_user(
&self,
_username: &str,
password: &str,
identity: UserIdentity,
) -> Result<String, SecurityError> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| SecurityError::PasswordHashError {
message: e.to_string(),
})?
.to_string();
let totp_secret: Vec<u8> = (0..20).map(|_| rand_core::OsRng.next_u64() as u8).collect();
let totp_secret_b32 = base32_encode(&totp_secret);
let record = UserRecord {
identity,
auth_method: AuthMethod::PasswordMfa {
password_hash,
totp_secret,
},
status: AccountStatus::Active,
created_at: SystemTime::now(),
last_login: None,
failed_attempts: 0,
};
self.user_store.store_user(record)?;
Ok(totp_secret_b32)
}
pub fn authenticate(
&self,
username: &str,
credential: &Credential,
) -> Result<UserSession, SecurityError> {
let mut user =
self.user_store
.get_user(username)
.ok_or_else(|| SecurityError::UserNotFound {
username: username.to_string(),
})?;
match user.status {
AccountStatus::Locked => {
return Err(SecurityError::AccountLocked {
username: username.to_string(),
})
}
AccountStatus::Disabled => {
return Err(SecurityError::AccountDisabled {
username: username.to_string(),
})
}
AccountStatus::Pending => {
return Err(SecurityError::AccountPending {
username: username.to_string(),
})
}
AccountStatus::Active => {}
}
let verified = self.verify_credentials(&user.auth_method, credential)?;
if !verified {
user.failed_attempts += 1;
if user.failed_attempts >= self.max_failed_attempts {
user.status = AccountStatus::Locked;
}
let _ = self.user_store.update_user(user);
return Err(SecurityError::InvalidCredential {
username: username.to_string(),
});
}
user.failed_attempts = 0;
user.last_login = Some(SystemTime::now());
let _ = self.user_store.update_user(user.clone());
let now = SystemTime::now();
let session = UserSession {
session_id: SessionId::new(),
identity: user.identity,
device_id: None,
created_at: now,
expires_at: now + self.session_expiry,
};
self.sessions
.write()
.map_err(|e| SecurityError::Internal(format!("sessions lock poisoned: {e}")))?
.insert(session.session_id, session.clone());
Ok(session)
}
pub fn authenticate_password_only(
&self,
username: &str,
password: &str,
) -> Result<UserSession, SecurityError> {
let mut user =
self.user_store
.get_user(username)
.ok_or_else(|| SecurityError::UserNotFound {
username: username.to_string(),
})?;
match user.status {
AccountStatus::Locked => {
return Err(SecurityError::AccountLocked {
username: username.to_string(),
})
}
AccountStatus::Disabled => {
return Err(SecurityError::AccountDisabled {
username: username.to_string(),
})
}
AccountStatus::Pending => {
return Err(SecurityError::AccountPending {
username: username.to_string(),
})
}
AccountStatus::Active => {}
}
match &user.auth_method {
AuthMethod::PasswordMfa { password_hash, .. } => {
let parsed_hash = PasswordHash::new(password_hash).map_err(|e| {
SecurityError::PasswordHashError {
message: e.to_string(),
}
})?;
if Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_err()
{
user.failed_attempts += 1;
if user.failed_attempts >= self.max_failed_attempts {
user.status = AccountStatus::Locked;
}
let _ = self.user_store.update_user(user);
return Err(SecurityError::InvalidCredential {
username: username.to_string(),
});
}
}
_ => {
return Err(SecurityError::UnsupportedAuthMethod {
method: "non-password".to_string(),
})
}
}
user.failed_attempts = 0;
user.last_login = Some(SystemTime::now());
let _ = self.user_store.update_user(user.clone());
let now = SystemTime::now();
let session = UserSession {
session_id: SessionId::new(),
identity: user.identity,
device_id: None,
created_at: now,
expires_at: now + self.session_expiry,
};
self.sessions
.write()
.map_err(|e| SecurityError::Internal(format!("sessions lock poisoned: {e}")))?
.insert(session.session_id, session.clone());
Ok(session)
}
pub fn validate_session(&self, session_id: &SessionId) -> Result<UserIdentity, SecurityError> {
let sessions = self
.sessions
.read()
.map_err(|e| SecurityError::Internal(format!("sessions lock poisoned: {e}")))?;
let session = sessions
.get(session_id)
.ok_or(SecurityError::SessionNotFound)?;
if session.is_expired() {
drop(sessions);
self.invalidate_session(session_id);
return Err(SecurityError::SessionExpired);
}
Ok(session.identity.clone())
}
pub fn get_session(&self, session_id: &SessionId) -> Option<UserSession> {
self.sessions.read().ok()?.get(session_id).cloned()
}
pub fn invalidate_session(&self, session_id: &SessionId) {
if let Ok(mut sessions) = self.sessions.write() {
sessions.remove(session_id);
}
}
pub fn invalidate_user_sessions(&self, username: &str) {
if let Ok(mut sessions) = self.sessions.write() {
sessions.retain(|_, session| session.identity.username != username);
}
}
pub fn cleanup_expired_sessions(&self) {
let now = SystemTime::now();
if let Ok(mut sessions) = self.sessions.write() {
sessions.retain(|_, session| session.expires_at > now);
}
}
pub fn active_session_count(&self) -> usize {
self.sessions.read().map(|s| s.len()).unwrap_or(0)
}
pub fn bind_session_to_device(
&self,
session_id: &SessionId,
device_id: DeviceId,
) -> Result<(), SecurityError> {
let mut sessions = self
.sessions
.write()
.map_err(|e| SecurityError::Internal(format!("sessions lock poisoned: {e}")))?;
let session = sessions
.get_mut(session_id)
.ok_or(SecurityError::SessionNotFound)?;
session.device_id = Some(device_id);
Ok(())
}
fn verify_credentials(
&self,
auth_method: &AuthMethod,
credential: &Credential,
) -> Result<bool, SecurityError> {
match (auth_method, credential) {
(
AuthMethod::PasswordMfa {
password_hash,
totp_secret,
},
Credential::PasswordMfa {
password,
totp_code,
},
) => {
let parsed_hash = PasswordHash::new(password_hash).map_err(|e| {
SecurityError::PasswordHashError {
message: e.to_string(),
}
})?;
if Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_err()
{
return Ok(false);
}
if !verify_totp(totp_secret, totp_code)? {
return Err(SecurityError::InvalidMfaCode);
}
Ok(true)
}
(
AuthMethod::SmartCard {
card_id: stored_id,
pin_hash,
},
Credential::SmartCard { card_id, pin },
) => {
if stored_id != card_id {
return Ok(false);
}
let parsed_hash =
PasswordHash::new(pin_hash).map_err(|e| SecurityError::PasswordHashError {
message: e.to_string(),
})?;
Ok(Argon2::default()
.verify_password(pin.as_bytes(), &parsed_hash)
.is_ok())
}
(
AuthMethod::Certificate {
certificate_fingerprint: stored_fp,
},
Credential::Certificate { fingerprint },
) => {
Ok(stored_fp == fingerprint)
}
_ => Err(SecurityError::UnsupportedAuthMethod {
method: "mismatched auth method and credential".to_string(),
}),
}
}
pub fn unlock_account(&self, username: &str) -> Result<(), SecurityError> {
let mut user =
self.user_store
.get_user(username)
.ok_or_else(|| SecurityError::UserNotFound {
username: username.to_string(),
})?;
user.status = AccountStatus::Active;
user.failed_attempts = 0;
self.user_store.update_user(user)
}
pub fn disable_account(&self, username: &str) -> Result<(), SecurityError> {
let mut user =
self.user_store
.get_user(username)
.ok_or_else(|| SecurityError::UserNotFound {
username: username.to_string(),
})?;
user.status = AccountStatus::Disabled;
self.user_store.update_user(user)?;
self.invalidate_user_sessions(username);
Ok(())
}
pub fn change_password(
&self,
user_name: &str,
old_password: &str,
new_password: &str,
) -> Result<(), SecurityError> {
let mut user =
self.user_store
.get_user(user_name)
.ok_or_else(|| SecurityError::UserNotFound {
username: user_name.to_string(),
})?;
match &user.auth_method {
AuthMethod::PasswordMfa {
password_hash,
totp_secret,
} => {
let parsed_hash = PasswordHash::new(password_hash).map_err(|e| {
SecurityError::PasswordHashError {
message: e.to_string(),
}
})?;
if Argon2::default()
.verify_password(old_password.as_bytes(), &parsed_hash)
.is_err()
{
return Err(SecurityError::InvalidCredential {
username: user_name.to_string(),
});
}
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let new_hash = argon2
.hash_password(new_password.as_bytes(), &salt)
.map_err(|e| SecurityError::PasswordHashError {
message: e.to_string(),
})?
.to_string();
user.auth_method = AuthMethod::PasswordMfa {
password_hash: new_hash,
totp_secret: totp_secret.clone(),
};
}
_ => {
return Err(SecurityError::UnsupportedAuthMethod {
method: "non-password".to_string(),
})
}
}
self.user_store.update_user(user)?;
self.invalidate_user_sessions(user_name);
Ok(())
}
}
#[allow(dead_code)]
pub fn generate_totp(secret: &[u8]) -> Result<String, SecurityError> {
let time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| SecurityError::TotpError {
message: "System time before UNIX epoch".to_string(),
})?
.as_secs();
let counter = time / TOTP_TIME_STEP_SECS;
generate_hotp(secret, counter)
}
fn generate_hotp(secret: &[u8], counter: u64) -> Result<String, SecurityError> {
let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|e| SecurityError::TotpError {
message: format!("HMAC error: {}", e),
})?;
mac.update(&counter.to_be_bytes());
let result = mac.finalize().into_bytes();
let offset = (result[result.len() - 1] & 0x0f) as usize;
let binary = ((result[offset] & 0x7f) as u32) << 24
| (result[offset + 1] as u32) << 16
| (result[offset + 2] as u32) << 8
| (result[offset + 3] as u32);
let otp = binary % 10u32.pow(TOTP_CODE_LENGTH as u32);
Ok(format!("{:0width$}", otp, width = TOTP_CODE_LENGTH))
}
pub fn verify_totp(secret: &[u8], code: &str) -> Result<bool, SecurityError> {
let time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| SecurityError::TotpError {
message: "System time before UNIX epoch".to_string(),
})?
.as_secs();
let counter = (time / TOTP_TIME_STEP_SECS) as i64;
for offset in -TOTP_CLOCK_DRIFT_STEPS..=TOTP_CLOCK_DRIFT_STEPS {
let check_counter = (counter + offset) as u64;
let expected = generate_hotp(secret, check_counter)?;
if constant_time_compare(code.as_bytes(), expected.as_bytes()) {
return Ok(true);
}
}
Ok(false)
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
fn base32_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
let mut result = String::new();
let mut buffer: u64 = 0;
let mut bits_left = 0;
for &byte in data {
buffer = (buffer << 8) | (byte as u64);
bits_left += 8;
while bits_left >= 5 {
bits_left -= 5;
let index = ((buffer >> bits_left) & 0x1f) as usize;
result.push(ALPHABET[index] as char);
}
}
if bits_left > 0 {
let index = ((buffer << (5 - bits_left)) & 0x1f) as usize;
result.push(ALPHABET[index] as char);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_military_rank_display() {
assert_eq!(MilitaryRank::Captain.to_string(), "CPT");
assert_eq!(MilitaryRank::Sergeant.to_string(), "SGT");
assert_eq!(MilitaryRank::Colonel.to_string(), "COL");
}
#[test]
fn test_security_clearance_ordering() {
assert!(SecurityClearance::Secret > SecurityClearance::Confidential);
assert!(SecurityClearance::TopSecret > SecurityClearance::Secret);
assert!(SecurityClearance::TopSecretSci > SecurityClearance::TopSecret);
}
#[test]
fn test_organization_unit() {
let unit = OrganizationUnit::new("1st Platoon", "Alpha Company");
assert_eq!(unit.to_string(), "1st Platoon, Alpha Company");
let top = OrganizationUnit::top_level("Battalion HQ");
assert_eq!(top.to_string(), "Battalion HQ");
}
#[test]
fn test_user_identity_builder() {
let identity = UserIdentity::builder("alpha_6")
.display_name("CPT John Smith")
.rank(MilitaryRank::Captain)
.clearance(SecurityClearance::Secret)
.unit(OrganizationUnit::new("1st Plt", "A Co"))
.role(Role::Commander)
.build();
assert_eq!(identity.username, "alpha_6");
assert_eq!(identity.rank, MilitaryRank::Captain);
assert!(identity.roles.contains(&Role::Commander));
}
#[test]
fn test_session_id_generation() {
let id1 = SessionId::new();
let id2 = SessionId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_local_user_store() {
let store = LocalUserStore::new();
let identity = UserIdentity::builder("test_user")
.rank(MilitaryRank::Sergeant)
.build();
let record = UserRecord {
identity,
auth_method: AuthMethod::PasswordMfa {
password_hash: "hash".to_string(),
totp_secret: vec![1, 2, 3],
},
status: AccountStatus::Active,
created_at: SystemTime::now(),
last_login: None,
failed_attempts: 0,
};
store.store_user(record.clone()).unwrap();
let retrieved = store.get_user("test_user").unwrap();
assert_eq!(retrieved.identity.username, "test_user");
let users = store.list_users();
assert_eq!(users.len(), 1);
assert!(users.contains(&"test_user".to_string()));
store.delete_user("test_user").unwrap();
assert!(store.get_user("test_user").is_none());
}
#[test]
fn test_user_store_duplicate_prevention() {
let store = LocalUserStore::new();
let identity = UserIdentity::builder("test_user").build();
let record = UserRecord {
identity: identity.clone(),
auth_method: AuthMethod::PasswordMfa {
password_hash: "hash".to_string(),
totp_secret: vec![1, 2, 3],
},
status: AccountStatus::Active,
created_at: SystemTime::now(),
last_login: None,
failed_attempts: 0,
};
store.store_user(record.clone()).unwrap();
let result = store.store_user(record);
assert!(matches!(
result,
Err(SecurityError::UserAlreadyExists { .. })
));
}
#[test]
fn test_hotp_generation() {
let secret = b"12345678901234567890";
let code = generate_hotp(secret, 0).unwrap();
assert_eq!(code.len(), 6);
}
#[test]
fn test_totp_generation_and_verification() {
let secret = b"test_secret_key_1234";
let code = generate_totp(secret).unwrap();
assert_eq!(code.len(), 6);
assert!(verify_totp(secret, &code).unwrap());
assert!(!verify_totp(secret, "000000").unwrap());
}
#[test]
fn test_base32_encode() {
assert_eq!(base32_encode(b""), "");
assert_eq!(base32_encode(b"f"), "MY");
assert_eq!(base32_encode(b"fo"), "MZXQ");
assert_eq!(base32_encode(b"foo"), "MZXW6");
assert_eq!(base32_encode(b"foob"), "MZXW6YQ");
assert_eq!(base32_encode(b"fooba"), "MZXW6YTB");
assert_eq!(base32_encode(b"foobar"), "MZXW6YTBOI");
}
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare(b"abc", b"abc"));
assert!(!constant_time_compare(b"abc", b"abd"));
assert!(!constant_time_compare(b"abc", b"ab"));
}
#[test]
fn test_user_registration_and_password_auth() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store));
let identity = UserIdentity::builder("test_commander")
.display_name("MAJ Test")
.rank(MilitaryRank::Major)
.clearance(SecurityClearance::Secret)
.role(Role::Commander)
.build();
let totp_secret = authenticator
.register_user("test_commander", "secure_password_123", identity)
.unwrap();
assert!(!totp_secret.is_empty());
let session = authenticator
.authenticate_password_only("test_commander", "secure_password_123")
.unwrap();
assert_eq!(session.identity.username, "test_commander");
assert!(!session.is_expired());
}
#[test]
fn test_session_validation() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store));
let identity = UserIdentity::builder("session_test")
.role(Role::Observer)
.build();
authenticator
.register_user("session_test", "password123", identity)
.unwrap();
let session = authenticator
.authenticate_password_only("session_test", "password123")
.unwrap();
let validated = authenticator.validate_session(&session.session_id).unwrap();
assert_eq!(validated.username, "session_test");
authenticator.invalidate_session(&session.session_id);
assert!(matches!(
authenticator.validate_session(&session.session_id),
Err(SecurityError::SessionNotFound)
));
}
#[test]
fn test_account_lockout() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store)).with_max_failed_attempts(3);
let identity = UserIdentity::builder("lockout_test").build();
authenticator
.register_user("lockout_test", "correct_password", identity)
.unwrap();
for _ in 0..3 {
let _ = authenticator.authenticate_password_only("lockout_test", "wrong_password");
}
let result = authenticator.authenticate_password_only("lockout_test", "correct_password");
assert!(matches!(result, Err(SecurityError::AccountLocked { .. })));
authenticator.unlock_account("lockout_test").unwrap();
let session = authenticator
.authenticate_password_only("lockout_test", "correct_password")
.unwrap();
assert_eq!(session.identity.username, "lockout_test");
}
#[test]
fn test_password_change() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store));
let identity = UserIdentity::builder("pwd_change_test").build();
authenticator
.register_user("pwd_change_test", "old_password", identity)
.unwrap();
authenticator
.change_password("pwd_change_test", "old_password", "new_password")
.unwrap();
let result = authenticator.authenticate_password_only("pwd_change_test", "old_password");
assert!(result.is_err());
let session = authenticator
.authenticate_password_only("pwd_change_test", "new_password")
.unwrap();
assert_eq!(session.identity.username, "pwd_change_test");
}
#[test]
fn test_session_count_and_cleanup() {
let store = LocalUserStore::new();
let authenticator =
UserAuthenticator::new(Box::new(store)).with_session_expiry(Duration::from_millis(10));
let identity = UserIdentity::builder("cleanup_test").build();
authenticator
.register_user("cleanup_test", "password", identity)
.unwrap();
let _session = authenticator
.authenticate_password_only("cleanup_test", "password")
.unwrap();
assert_eq!(authenticator.active_session_count(), 1);
std::thread::sleep(Duration::from_millis(20));
authenticator.cleanup_expired_sessions();
assert_eq!(authenticator.active_session_count(), 0);
}
#[test]
fn test_disable_account() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store));
let identity = UserIdentity::builder("disable_test").build();
authenticator
.register_user("disable_test", "password", identity)
.unwrap();
let session = authenticator
.authenticate_password_only("disable_test", "password")
.unwrap();
authenticator.disable_account("disable_test").unwrap();
assert!(authenticator.validate_session(&session.session_id).is_err());
let result = authenticator.authenticate_password_only("disable_test", "password");
assert!(matches!(result, Err(SecurityError::AccountDisabled { .. })));
}
#[test]
fn test_bind_session_to_device() {
let store = LocalUserStore::new();
let authenticator = UserAuthenticator::new(Box::new(store));
let identity = UserIdentity::builder("device_bind_test").build();
authenticator
.register_user("device_bind_test", "password", identity)
.unwrap();
let session = authenticator
.authenticate_password_only("device_bind_test", "password")
.unwrap();
assert!(session.device_id.is_none());
let keypair = crate::security::DeviceKeypair::generate();
let device_id = keypair.device_id();
authenticator
.bind_session_to_device(&session.session_id, device_id)
.unwrap();
let updated_session = authenticator.get_session(&session.session_id).unwrap();
assert!(updated_session.device_id.is_some());
}
}