use std::collections::HashMap;
use std::path::Path;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, RwLock};
use tracing::info;
use crate::types::TenantId;
use super::super::super::catalog::SystemCatalog;
use super::super::super::identity::Role;
use super::super::super::time::now_secs;
use crate::config::auth::Argon2Config;
use super::super::hash::{
compute_scram_salted_password, generate_scram_salt, hash_password_argon2,
};
use super::super::lockout::LoginAttemptTracker;
use super::super::record::UserRecord;
pub struct CredentialStore {
pub(in crate::control::security::credential) users: RwLock<HashMap<String, UserRecord>>,
pub(in crate::control::security::credential) next_user_id: RwLock<u64>,
pub(in crate::control::security::credential) catalog: Option<SystemCatalog>,
pub(in crate::control::security::credential) login_attempts:
RwLock<HashMap<String, LoginAttemptTracker>>,
pub(in crate::control::security::credential) max_failed_logins: u32,
pub(in crate::control::security::credential) lockout_duration: std::time::Duration,
pub(in crate::control::security::credential) password_expiry_secs: u64,
pub(in crate::control::security::credential) password_expiry_grace_days: u32,
pub(in crate::control::security::credential) argon2_config: Argon2Config,
pub(in crate::control::security::credential) versions: RwLock<HashMap<u64, Arc<AtomicU64>>>,
pub(in crate::control::security::credential) si_bus:
std::sync::OnceLock<Arc<crate::control::security::buses::SessionInvalidationBus>>,
pub(in crate::control::security::credential) uc_bus:
std::sync::OnceLock<Arc<crate::control::security::buses::UserChangeBus>>,
}
impl Default for CredentialStore {
fn default() -> Self {
Self::new()
}
}
pub(in crate::control::security::credential) fn read_lock<T>(
lock: &RwLock<T>,
) -> crate::Result<std::sync::RwLockReadGuard<'_, T>> {
lock.read().map_err(|e| {
tracing::error!("credential store read lock poisoned: {e}");
crate::Error::Internal {
detail: "credential store lock poisoned".into(),
}
})
}
pub(in crate::control::security::credential) fn write_lock<T>(
lock: &RwLock<T>,
) -> crate::Result<std::sync::RwLockWriteGuard<'_, T>> {
lock.write().map_err(|e| {
tracing::error!("credential store write lock poisoned: {e}");
crate::Error::Internal {
detail: "credential store lock poisoned".into(),
}
})
}
impl CredentialStore {
pub fn new() -> Self {
Self {
users: RwLock::new(HashMap::new()),
next_user_id: RwLock::new(1),
catalog: None,
login_attempts: RwLock::new(HashMap::new()),
max_failed_logins: 0,
lockout_duration: std::time::Duration::from_secs(300),
password_expiry_secs: 0,
password_expiry_grace_days: 0,
argon2_config: Argon2Config::default(),
versions: RwLock::new(HashMap::new()),
si_bus: std::sync::OnceLock::new(),
uc_bus: std::sync::OnceLock::new(),
}
}
pub fn open(path: &Path) -> crate::Result<Self> {
let catalog = SystemCatalog::open(path)?;
let stored_users = catalog.load_all_users()?;
let next_id = catalog.load_next_user_id()?;
let mut users = HashMap::with_capacity(stored_users.len());
for stored in stored_users {
let record = UserRecord::from_stored(stored);
users.insert(record.username.clone(), record);
}
let count = users.len();
if count > 0 {
info!(count, "loaded users from system catalog");
}
Ok(Self {
users: RwLock::new(users),
next_user_id: RwLock::new(next_id),
catalog: Some(catalog),
login_attempts: RwLock::new(HashMap::new()),
max_failed_logins: 0,
lockout_duration: std::time::Duration::from_secs(300),
password_expiry_secs: 0,
password_expiry_grace_days: 0,
argon2_config: Argon2Config::default(),
versions: RwLock::new(HashMap::new()),
si_bus: std::sync::OnceLock::new(),
uc_bus: std::sync::OnceLock::new(),
})
}
pub(in crate::control::security::credential) fn persist_user(
&self,
record: &mut UserRecord,
) -> crate::Result<()> {
record.updated_at = now_secs();
if let Some(ref catalog) = self.catalog {
catalog.put_user(&record.to_stored())?;
}
Ok(())
}
pub(in crate::control::security::credential) fn persist_next_id(
&self,
id: u64,
) -> crate::Result<()> {
if let Some(ref catalog) = self.catalog {
catalog.save_next_user_id(id)?;
}
Ok(())
}
pub(in crate::control::security::credential) fn compute_expiry(&self) -> u64 {
if self.password_expiry_secs > 0 {
now_secs() + self.password_expiry_secs
} else {
0
}
}
pub(in crate::control::security::credential) fn alloc_user_id(&self) -> crate::Result<u64> {
let mut next = write_lock(&self.next_user_id)?;
let id = *next;
*next += 1;
self.persist_next_id(*next)?;
Ok(id)
}
pub fn set_buses(
&self,
si_bus: Arc<crate::control::security::buses::SessionInvalidationBus>,
uc_bus: Arc<crate::control::security::buses::UserChangeBus>,
) {
let _ = self.si_bus.set(si_bus);
let _ = self.uc_bus.set(uc_bus);
}
pub fn subscribe_user_changes(
&self,
) -> tokio::sync::broadcast::Receiver<crate::control::security::buses::UserChanged> {
match self.uc_bus.get() {
Some(bus) => bus.subscribe(),
None => {
tokio::sync::broadcast::channel(1).1
}
}
}
pub fn subscribe_session_invalidation(
&self,
) -> tokio::sync::broadcast::Receiver<crate::control::security::buses::SessionInvalidated> {
match self.si_bus.get() {
Some(bus) => bus.subscribe(),
None => tokio::sync::broadcast::channel(1).1,
}
}
pub(in crate::control::security::credential) fn bump_version(
&self,
user_id: u64,
) -> crate::Result<u64> {
{
let map = read_lock(&self.versions)?;
if let Some(ctr) = map.get(&user_id) {
return Ok(ctr.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1);
}
}
let mut map = write_lock(&self.versions)?;
let ctr = map
.entry(user_id)
.or_insert_with(|| Arc::new(AtomicU64::new(0)));
Ok(ctr.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1)
}
pub fn current_version(&self, user_id: u64) -> u64 {
let map = self.versions.read().unwrap_or_else(|p| p.into_inner());
match map.get(&user_id) {
Some(ctr) => ctr.load(std::sync::atomic::Ordering::Relaxed),
None => 0,
}
}
pub(in crate::control::security::credential) fn commit_user_mutation(
&self,
record: &mut UserRecord,
invalidation: Option<crate::control::security::buses::SessionInvalidationReason>,
) -> crate::Result<()> {
let user_id = record.user_id;
self.persist_user(record)?;
self.bump_version(user_id)?;
if let Some(bus) = self.uc_bus.get() {
bus.publish(crate::control::security::buses::UserChanged { user_id });
}
if let Some(reason) = invalidation
&& let Some(bus) = self.si_bus.get()
{
bus.publish(crate::control::security::buses::SessionInvalidated { user_id, reason });
}
Ok(())
}
pub fn bootstrap_superuser(&self, username: &str, password: &str) -> crate::Result<()> {
let salt = generate_scram_salt();
let scram_salted_password = compute_scram_salted_password(password, &salt);
let password_hash = hash_password_argon2(password, &self.argon2_config)?;
let mut users = write_lock(&self.users)?;
if let Some(existing) = users.get_mut(username) {
existing.password_hash = password_hash;
existing.scram_salt = salt;
existing.scram_salted_password = scram_salted_password;
existing.is_superuser = true;
existing.is_active = true;
existing.must_change_password = false;
existing.password_changed_at = now_secs();
if !existing.roles.contains(&Role::Superuser) {
existing.roles.push(Role::Superuser);
}
self.persist_user(existing)?;
} else {
let user_id = self.alloc_user_id()?;
let now = now_secs();
let mut record = UserRecord {
user_id,
username: username.to_string(),
tenant_id: TenantId::new(0),
password_hash,
scram_salt: salt,
scram_salted_password,
roles: vec![Role::Superuser],
is_superuser: true,
is_active: true,
is_service_account: false,
created_at: now,
updated_at: now,
password_expires_at: self.compute_expiry(),
must_change_password: false,
password_changed_at: now,
default_database_id: 0,
accessible_databases: vec![],
};
self.persist_user(&mut record)?;
users.insert(username.to_string(), record);
}
Ok(())
}
}