use crate::auth::sessions::{
invalidate_sessions, InvalidationOutcome, SessionInvalidationReason, SessionTarget,
};
use crate::auth::users::hash_password;
use crate::auth::Role;
use crate::error::{Error, Result};
use crate::orm::Db;
#[derive(Debug, Clone)]
pub enum ResetOutcome {
Ok { revoked_session_count: usize },
UnknownTarget,
}
#[derive(Debug, Clone)]
pub enum UnlockOutcome {
Ok { previously_locked: bool },
UnknownTarget,
}
#[derive(Debug, Clone)]
pub enum DisableMfaOutcome {
Ok {
was_enabled: bool,
deleted_backup_codes: usize,
revoked_session_count: usize,
},
UnknownTarget,
}
#[derive(Debug, Clone)]
pub enum PromoteOutcome {
Ok {
previous_role: Role,
new_role: Role,
revoked_session_count: usize,
},
NoChange {
current_role: Role,
},
UnknownTarget,
SoleAdministratorDemoteRefused,
}
#[derive(Debug, Clone)]
pub enum EmergencyAccessOutcome {
Ok {
token_id: i64,
url_path: String,
expires_at: chrono::DateTime<chrono::Utc>,
},
UnknownTarget,
InactiveTarget,
}
async fn target_exists(db: &Db, user_id: i64) -> Result<Option<bool>> {
let row: Option<(bool,)> = sqlx::query_as("SELECT is_active FROM rustio_users WHERE id = $1")
.bind(user_id)
.fetch_optional(db.pool())
.await
.map_err(Error::from)?;
Ok(row.map(|(active,)| active))
}
pub async fn reset_password(
db: &Db,
target_user_id: i64,
new_password: &str,
) -> Result<ResetOutcome> {
if target_exists(db, target_user_id).await?.is_none() {
return Ok(ResetOutcome::UnknownTarget);
}
let hash = hash_password(new_password)?;
let mut tx = db.pool().begin().await.map_err(Error::from)?;
sqlx::query(
"UPDATE rustio_users \
SET password_hash = $1, \
password_changed_at = NOW(), \
must_change_password = TRUE \
WHERE id = $2",
)
.bind(&hash)
.bind(target_user_id)
.execute(&mut *tx)
.await
.map_err(Error::from)?;
tx.commit().await.map_err(Error::from)?;
let outcome = invalidate_sessions(
db,
SessionTarget::User {
user_id: target_user_id,
},
SessionInvalidationReason::PasswordResetByOther,
)
.await?;
Ok(ResetOutcome::Ok {
revoked_session_count: outcome.revoked_session_ids.len(),
})
}
pub async fn unlock(db: &Db, target_user_id: i64) -> Result<UnlockOutcome> {
if target_exists(db, target_user_id).await?.is_none() {
return Ok(UnlockOutcome::UnknownTarget);
}
let was_locked: bool = sqlx::query_scalar(
"SELECT (locked_until IS NOT NULL AND locked_until > NOW()) \
OR failed_login_count > 0 \
FROM rustio_users WHERE id = $1",
)
.bind(target_user_id)
.fetch_one(db.pool())
.await
.map_err(Error::from)?;
sqlx::query(
"UPDATE rustio_users \
SET locked_until = NULL, failed_login_count = 0 \
WHERE id = $1",
)
.bind(target_user_id)
.execute(db.pool())
.await
.map_err(Error::from)?;
Ok(UnlockOutcome::Ok {
previously_locked: was_locked,
})
}
pub async fn disable_mfa(db: &Db, target_user_id: i64) -> Result<DisableMfaOutcome> {
if target_exists(db, target_user_id).await?.is_none() {
return Ok(DisableMfaOutcome::UnknownTarget);
}
let was_enabled: bool =
sqlx::query_scalar("SELECT COALESCE(mfa_enabled, FALSE) FROM rustio_users WHERE id = $1")
.bind(target_user_id)
.fetch_one(db.pool())
.await
.map_err(Error::from)?;
let mut tx = db.pool().begin().await.map_err(Error::from)?;
sqlx::query(
"UPDATE rustio_users SET \
mfa_enabled = FALSE, \
mfa_secret_ciphertext = NULL, \
mfa_secret_key_id = NULL, \
mfa_last_used_step = NULL \
WHERE id = $1",
)
.bind(target_user_id)
.execute(&mut *tx)
.await
.map_err(Error::from)?;
let deleted: u64 = sqlx::query("DELETE FROM rustio_mfa_backup_codes WHERE user_id = $1")
.bind(target_user_id)
.execute(&mut *tx)
.await
.map_err(Error::from)?
.rows_affected();
tx.commit().await.map_err(Error::from)?;
let outcome = invalidate_sessions(
db,
SessionTarget::User {
user_id: target_user_id,
},
SessionInvalidationReason::MfaDisabledByOther,
)
.await?;
Ok(DisableMfaOutcome::Ok {
was_enabled,
deleted_backup_codes: deleted as usize,
revoked_session_count: outcome.revoked_session_ids.len(),
})
}
pub async fn promote(db: &Db, target_user_id: i64, new_role: Role) -> Result<PromoteOutcome> {
let row: Option<(String, bool)> =
sqlx::query_as("SELECT role, is_active FROM rustio_users WHERE id = $1")
.bind(target_user_id)
.fetch_optional(db.pool())
.await
.map_err(Error::from)?;
let (current_role_str, _is_active) = match row {
Some(r) => r,
None => return Ok(PromoteOutcome::UnknownTarget),
};
let current_role = parse_role(¤t_role_str).unwrap_or(Role::User);
if current_role == new_role {
return Ok(PromoteOutcome::NoChange { current_role });
}
if current_role == Role::Administrator && new_role != Role::Administrator {
let other_admins: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM rustio_users \
WHERE role = 'administrator' AND is_active = TRUE AND id <> $1",
)
.bind(target_user_id)
.fetch_one(db.pool())
.await
.map_err(Error::from)?;
if other_admins == 0 {
return Ok(PromoteOutcome::SoleAdministratorDemoteRefused);
}
}
sqlx::query("UPDATE rustio_users SET role = $1 WHERE id = $2")
.bind(role_as_str(new_role))
.bind(target_user_id)
.execute(db.pool())
.await
.map_err(Error::from)?;
let outcome = invalidate_sessions(
db,
SessionTarget::User {
user_id: target_user_id,
},
SessionInvalidationReason::RoleChangedByOther,
)
.await?;
Ok(PromoteOutcome::Ok {
previous_role: current_role,
new_role,
revoked_session_count: outcome.revoked_session_ids.len(),
})
}
pub async fn emergency_access(
db: &Db,
target_user_id: i64,
ttl_minutes: i64,
) -> Result<EmergencyAccessOutcome> {
let is_active = match target_exists(db, target_user_id).await? {
Some(active) => active,
None => return Ok(EmergencyAccessOutcome::UnknownTarget),
};
if !is_active {
return Ok(EmergencyAccessOutcome::InactiveTarget);
}
let ttl = ttl_minutes.clamp(1, 60);
let expires_at = chrono::Utc::now() + chrono::Duration::minutes(ttl);
let token = crate::auth::sessions::random_token();
let token_hash = crate::auth::sessions::hash_token_for_storage(&token);
let token_id: i64 = sqlx::query_scalar(
"INSERT INTO rustio_password_reset_tokens \
(user_id, token_hash, expires_at) \
VALUES ($1, $2, $3) \
RETURNING id",
)
.bind(target_user_id)
.bind(&token_hash)
.bind(expires_at)
.fetch_one(db.pool())
.await
.map_err(Error::from)?;
Ok(EmergencyAccessOutcome::Ok {
token_id,
url_path: format!("/admin/reset-password/{token}"),
expires_at,
})
}
fn parse_role(s: &str) -> Option<Role> {
match s {
"user" => Some(Role::User),
"staff" => Some(Role::Staff),
"supervisor" => Some(Role::Supervisor),
"administrator" => Some(Role::Administrator),
"developer" => Some(Role::Developer),
_ => None,
}
}
fn role_as_str(role: Role) -> &'static str {
match role {
Role::User => "user",
Role::Staff => "staff",
Role::Supervisor => "supervisor",
Role::Administrator => "administrator",
Role::Developer => "developer",
}
}
#[allow(unused_imports)]
use InvalidationOutcome as _;
pub fn generate_temp_password(len: usize) -> String {
use rand::Rng;
const ALPHABET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz23456789";
let mut rng = rand::thread_rng();
(0..len)
.map(|_| ALPHABET[rng.gen_range(0..ALPHABET.len())] as char)
.collect()
}
pub fn fresh_correlation_id() -> String {
uuid::Uuid::now_v7().hyphenated().to_string()
}