pub mod auth_routes;
pub mod bearer_auth;
pub mod challenge;
pub mod extractors;
pub mod form_routes;
pub mod login_required;
pub mod mailer;
pub mod password_validation;
pub mod session_user;
pub mod throttle;
pub mod token;
pub use mailer::{AuthMailError, AuthMailer, ConsoleMailer, MailKind, OutgoingMail};
pub use password_validation::{
CommonPasswordValidator, MinLengthValidator, NumericPasswordValidator, PasswordContext,
PasswordPolicy, PasswordValidator, UserAttributeSimilarityValidator, validate_password,
};
pub use bearer_auth::{BearerAuthentication, parse_bearer_header};
pub use challenge::{
AuthChallenge, reset_password, start_email_verification, start_password_reset, verify_email,
};
pub use extractors::{CurrentIdentity, OptionalIdentity, resolve_identity};
pub use login_required::{
LoggedIn, LoginRequired, LoginRequiredLayer, current_session_user_id, current_session_user_pk,
login_required, login_required_html, resolve_user as current_user_as,
};
pub use session_user::{
OptionalUser, SessionAuthentication, User, current_user, login, login_with_request,
user_context_layer,
};
pub use throttle::{
Throttle, ThrottleConfig, email_action_throttle_check, login_throttle_check,
login_throttle_clear, register_throttle_check,
};
pub use token::{AuthToken, PlaintextToken, TOKEN_PREFIX, digest_token};
#[doc(hidden)]
pub fn auth_routes_openapi_for_test(prefix: &str) -> Vec<(String, serde_json::Value)> {
auth_routes::openapi_paths(prefix)
}
use std::marker::PhantomData;
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::{Algorithm, Argon2, Params, Version, password_hash::rand_core::OsRng};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use umbral::prelude::*;
pub trait UserModel: Model + Send + Sync + 'static {
fn id(&self) -> <Self as Model>::PrimaryKey;
fn id_string(&self) -> String {
self.id().to_string()
}
fn username(&self) -> &str;
fn password_hash(&self) -> &str;
fn set_password_hash(&mut self, hash: String);
fn is_active(&self) -> bool {
true
}
fn is_staff(&self) -> bool {
false
}
fn is_superuser(&self) -> bool {
false
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
pub struct AuthUser {
pub id: i64,
#[umbral(unique)]
pub username: String,
#[umbral(noedit, unique)]
pub email: String,
#[umbral(noform)]
pub password_hash: String,
pub is_active: bool,
pub is_staff: bool,
pub is_superuser: bool,
pub date_joined: DateTime<Utc>,
pub last_login: Option<DateTime<Utc>>,
pub email_verified_at: Option<DateTime<Utc>>,
}
impl UserModel for AuthUser {
fn id(&self) -> <Self as umbral::orm::Model>::PrimaryKey {
self.id
}
fn username(&self) -> &str {
&self.username
}
fn password_hash(&self) -> &str {
&self.password_hash
}
fn set_password_hash(&mut self, hash: String) {
self.password_hash = hash;
}
fn is_active(&self) -> bool {
self.is_active
}
fn is_staff(&self) -> bool {
self.is_staff
}
fn is_superuser(&self) -> bool {
self.is_superuser
}
}
struct MailerSlot(std::sync::Mutex<Option<std::sync::Arc<dyn mailer::AuthMailer>>>);
impl std::fmt::Debug for MailerSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("MailerSlot(..)")
}
}
#[derive(Debug)]
pub struct AuthPlugin<U: UserModel = AuthUser> {
pub user_model_name: Option<String>,
pub default_routes_prefix: Option<String>,
pub form_routes_prefix: Option<String>,
pub user_in_templates: bool,
password_policy: std::sync::Mutex<Option<PasswordPolicy>>,
throttle_config: throttle::ThrottleConfig,
mailer: MailerSlot,
require_verified: bool,
_u: PhantomData<U>,
}
impl<U: UserModel> Default for AuthPlugin<U> {
fn default() -> Self {
Self {
user_model_name: None,
default_routes_prefix: None,
form_routes_prefix: None,
user_in_templates: false,
password_policy: std::sync::Mutex::new(None),
throttle_config: throttle::ThrottleConfig::default(),
mailer: MailerSlot(std::sync::Mutex::new(None)),
require_verified: false,
_u: PhantomData,
}
}
}
impl<U: UserModel> AuthPlugin<U> {
pub fn user_model_name(mut self, name: impl Into<String>) -> Self {
self.user_model_name = Some(name.into());
self
}
pub fn with_user_in_templates(mut self) -> Self {
self.user_in_templates = true;
self
}
pub fn password_validators(mut self, policy: PasswordPolicy) -> Self {
self.password_policy = std::sync::Mutex::new(Some(policy));
self
}
pub fn min_password_length(self, n: usize) -> Self {
self.password_validators(PasswordPolicy::new(vec![
Box::new(MinLengthValidator(n)),
Box::new(CommonPasswordValidator),
Box::new(NumericPasswordValidator),
Box::new(UserAttributeSimilarityValidator::default()),
]))
}
pub fn disable_password_validation(mut self) -> Self {
self.password_policy = std::sync::Mutex::new(Some(PasswordPolicy::empty()));
self
}
pub fn login_throttle(mut self, max: usize, window: std::time::Duration) -> Self {
self.throttle_config.login_max = max;
self.throttle_config.login_window = window;
self
}
pub fn register_throttle(mut self, max: usize, window: std::time::Duration) -> Self {
self.throttle_config.register_max = max;
self.throttle_config.register_window = window;
self
}
pub fn email_action_throttle(mut self, max: usize, window: std::time::Duration) -> Self {
self.throttle_config.email_action_max = max;
self.throttle_config.email_action_window = window;
self
}
pub fn disable_throttle(mut self) -> Self {
self.throttle_config.enabled = false;
self
}
pub fn mailer(self, m: impl mailer::AuthMailer + 'static) -> Self {
*self.mailer.0.lock().expect("mailer slot poisoned") = Some(std::sync::Arc::new(m));
self
}
fn json_prefix(&self) -> Option<String> {
self.default_routes_prefix.as_ref().map(|p| {
if p == JSON_PREFIX_SENTINEL {
format!("{}/auth", umbral::web::api_base())
} else {
p.clone()
}
})
}
}
static REQUIRE_VERIFIED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
pub(crate) fn verified_email_required() -> bool {
*REQUIRE_VERIFIED.get().unwrap_or(&false)
}
const JSON_PREFIX_SENTINEL: &str = "\0auto-api-base\0";
impl AuthPlugin<AuthUser> {
pub fn with_default_routes(mut self) -> Self {
self.default_routes_prefix = Some(JSON_PREFIX_SENTINEL.to_string());
self
}
pub fn with_default_routes_at(mut self, prefix: impl Into<String>) -> Self {
self.default_routes_prefix = Some(prefix.into());
self
}
pub fn require_verified_email(mut self) -> Self {
self.require_verified = true;
self
}
pub fn with_form_routes(mut self) -> Self {
self.form_routes_prefix = Some("/auth".into());
self
}
pub fn with_form_routes_at(mut self, prefix: impl Into<String>) -> Self {
self.form_routes_prefix = Some(prefix.into());
self
}
}
impl<U: UserModel> Plugin for AuthPlugin<U> {
fn name(&self) -> &'static str {
"auth"
}
fn models(&self) -> Vec<umbral::migrate::ModelMeta> {
let mut models = vec![umbral::migrate::ModelMeta::for_::<U>()];
if std::any::TypeId::of::<U>() == std::any::TypeId::of::<AuthUser>() {
models.push(umbral::migrate::ModelMeta::for_::<AuthToken>());
models.push(umbral::migrate::ModelMeta::for_::<AuthChallenge>());
}
models
}
fn templates_dirs(&self) -> Vec<std::path::PathBuf> {
vec![std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("templates")]
}
fn commands(&self) -> Vec<Box<dyn umbral::cli::PluginCommand>> {
vec![Box::new(CreateSuperuserCommand)]
}
fn routes(&self) -> umbral::web::Router {
let mut r = match self.json_prefix() {
Some(prefix) => auth_routes::build_router(&prefix),
None => umbral::web::Router::new(),
};
if let Some(p) = &self.form_routes_prefix {
r = r.merge(form_routes::build_router(p));
}
r
}
fn route_paths(&self) -> Vec<umbral::routes::RouteSpec> {
let mut paths = match self.json_prefix() {
Some(prefix) => auth_routes::declared_routes(&prefix),
None => Vec::new(),
};
if let Some(p) = &self.form_routes_prefix {
paths.extend(form_routes::declared_routes(p));
}
paths
}
fn openapi_paths(&self) -> Vec<(String, serde_json::Value)> {
match self.json_prefix() {
Some(prefix) => auth_routes::openapi_paths(&prefix),
None => Vec::new(),
}
}
fn wrap_router(&self, router: umbral::web::Router) -> umbral::web::Router {
if self.user_in_templates {
router.layer(axum::middleware::from_fn(user_context_layer))
} else {
router
}
}
fn on_ready(
&self,
_ctx: &umbral::plugin::AppContext,
) -> Result<(), umbral::plugin::PluginError> {
let policy = self
.password_policy
.lock()
.ok()
.and_then(|mut guard| guard.take())
.unwrap_or_default();
password_validation::install_policy(policy);
throttle::install(throttle::AuthThrottle::from_config(self.throttle_config));
if let Ok(mut guard) = self.mailer.0.lock() {
if let Some(m) = guard.take() {
crate::mailer::install_mailer(m);
}
}
let _ = REQUIRE_VERIFIED.set(self.require_verified);
Ok(())
}
}
#[derive(Debug)]
pub enum AuthError {
PasswordHash(argon2::password_hash::Error),
Sqlx(sqlx::Error),
Write(umbral::orm::write::WriteError),
InvalidCredentials,
WeakPassword(Vec<String>),
Runtime(String),
Session(String),
Template(String),
Mail(String),
InvalidChallenge,
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthError::PasswordHash(e) => write!(f, "umbral-auth: password hash: {e}"),
AuthError::Sqlx(e) => write!(f, "umbral-auth: sqlx: {e}"),
AuthError::Write(e) => write!(f, "umbral-auth: write: {e:?}"),
AuthError::InvalidCredentials => write!(f, "umbral-auth: invalid credentials"),
AuthError::WeakPassword(reasons) => {
write!(f, "umbral-auth: password rejected: {}", reasons.join(" "))
}
AuthError::Runtime(msg) => write!(f, "umbral-auth: blocking task failed: {msg}"),
AuthError::Session(msg) => write!(f, "umbral-auth: session: {msg}"),
AuthError::Template(msg) => write!(f, "umbral-auth: template: {msg}"),
AuthError::Mail(msg) => write!(f, "umbral-auth: mail: {msg}"),
AuthError::InvalidChallenge => write!(f, "umbral-auth: invalid or expired challenge"),
}
}
}
impl std::error::Error for AuthError {}
impl From<argon2::password_hash::Error> for AuthError {
fn from(e: argon2::password_hash::Error) -> Self {
Self::PasswordHash(e)
}
}
impl From<sqlx::Error> for AuthError {
fn from(e: sqlx::Error) -> Self {
Self::Sqlx(e)
}
}
impl From<umbral::orm::write::WriteError> for AuthError {
fn from(e: umbral::orm::write::WriteError) -> Self {
Self::Write(e)
}
}
pub async fn logout(
req: &umbral::web::HeaderMap,
resp: &mut umbral::web::HeaderMap,
) -> Result<(), AuthError> {
umbral_sessions::logout(req, resp)
.await
.map_err(|e| AuthError::Session(e.to_string()))
}
pub fn hash_password(plaintext: &str) -> Result<String, AuthError> {
let salt = SaltString::generate(&mut OsRng);
let hash = password_hasher()
.hash_password(plaintext.as_bytes(), &salt)?
.to_string();
Ok(hash)
}
pub fn verify_password(plaintext: &str, hash: &str) -> Result<bool, AuthError> {
let parsed = PasswordHash::new(hash)?;
match password_hasher().verify_password(plaintext.as_bytes(), &parsed) {
Ok(()) => Ok(true),
Err(argon2::password_hash::Error::Password) => Ok(false),
Err(e) => Err(AuthError::PasswordHash(e)),
}
}
pub async fn hash_password_async(plaintext: &str) -> Result<String, AuthError> {
let p = plaintext.to_owned();
tokio::task::spawn_blocking(move || hash_password(&p))
.await
.map_err(|e| AuthError::Runtime(e.to_string()))?
}
pub async fn verify_password_async(plaintext: &str, hash: &str) -> Result<bool, AuthError> {
let p = plaintext.to_owned();
let h = hash.to_owned();
tokio::task::spawn_blocking(move || verify_password(&p, &h))
.await
.map_err(|e| AuthError::Runtime(e.to_string()))?
}
fn password_hasher() -> Argon2<'static> {
Argon2::new(
Algorithm::Argon2id,
Version::V0x13,
Params::new(19_456, 2, 1, None).expect("hard-coded argon2 params are valid"),
)
}
pub async fn create_user(
username: &str,
email: &str,
plaintext: &str,
) -> Result<AuthUser, AuthError> {
create_user_with_flags(username, email, plaintext, false, false).await
}
pub async fn create_superuser(
username: &str,
email: &str,
plaintext: &str,
) -> Result<AuthUser, AuthError> {
insert_user(username, email, plaintext, true, true).await
}
pub async fn create_user_with_flags(
username: &str,
email: &str,
plaintext: &str,
is_staff: bool,
is_superuser: bool,
) -> Result<AuthUser, AuthError> {
insert_user(username, email, plaintext, is_staff, is_superuser).await
}
async fn insert_user(
username: &str,
email: &str,
plaintext: &str,
is_staff: bool,
is_superuser: bool,
) -> Result<AuthUser, AuthError> {
let now = chrono::Utc::now();
let hash = hash_password_async(plaintext).await?;
let row = AuthUser::objects()
.create(AuthUser {
id: 0,
username: username.to_string(),
email: email.to_string(),
password_hash: hash,
is_active: true,
is_staff,
is_superuser,
date_joined: now,
last_login: None,
email_verified_at: None,
})
.await?;
Ok(row)
}
pub async fn authenticate<U>(username: &str, plaintext: &str) -> Result<U, AuthError>
where
U: UserModel
+ for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
+ for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
+ umbral::orm::HydrateRelated
+ Unpin,
{
let user: Option<U> = umbral::orm::Manager::<U>::default()
.filter(
umbral::orm::Predicate::<U>::col_eq("username", username)
& umbral::orm::Predicate::<U>::col_eq("is_active", true),
)
.first()
.await?;
let Some(user) = user else {
return Err(AuthError::InvalidCredentials);
};
if !user.is_active() {
return Err(AuthError::InvalidCredentials);
}
if verify_password_async(plaintext, user.password_hash()).await? {
Ok(user)
} else {
Err(AuthError::InvalidCredentials)
}
}
pub async fn set_password<U>(user: &mut U, plaintext: &str) -> Result<(), AuthError>
where
U: UserModel,
{
let hash = hash_password_async(plaintext).await?;
let mut patch = serde_json::Map::new();
patch.insert(
"password_hash".to_string(),
serde_json::Value::String(hash.clone()),
);
umbral::orm::Manager::<U>::default()
.filter(umbral::orm::Predicate::<U>::col_eq("id", user.id()))
.update_values(patch)
.await?;
user.set_password_hash(hash);
Ok(())
}
#[derive(Debug, Default)]
pub struct CreateSuperuserCommand;
#[async_trait::async_trait]
impl umbral::cli::PluginCommand for CreateSuperuserCommand {
fn command(&self) -> clap::Command {
clap::Command::new("createsuperuser")
.about("Create a superuser account (is_staff = is_superuser = true)")
.arg(
clap::Arg::new("username")
.long("username")
.help("Skip the interactive username prompt")
.value_name("NAME"),
)
.arg(
clap::Arg::new("email")
.long("email")
.help("Skip the interactive email prompt")
.value_name("ADDR"),
)
.arg(
clap::Arg::new("noinput")
.long("noinput")
.help(
"Fail rather than prompt for any missing value. \
Reads password from UMBRAL_SUPERUSER_PASSWORD env var.",
)
.action(clap::ArgAction::SetTrue),
)
}
async fn run(&self, matches: &clap::ArgMatches) -> Result<(), umbral::cli::CliError> {
let noinput = matches.get_flag("noinput");
let username = resolve_or_prompt(
matches.get_one::<String>("username").cloned(),
"Username",
noinput,
None,
)?;
let email = resolve_or_prompt(
matches.get_one::<String>("email").cloned(),
"Email",
noinput,
None,
)?;
let password = resolve_password(noinput)?;
let user = create_superuser(&username, &email, &password)
.await
.map_err(|e| -> umbral::cli::CliError { Box::new(e) })?;
println!(
"Created superuser `{}` (id = {}) - is_staff = true, is_superuser = true",
user.username, user.id,
);
Ok(())
}
}
fn resolve_or_prompt(
cli_value: Option<String>,
label: &str,
noinput: bool,
env_var: Option<&str>,
) -> Result<String, umbral::cli::CliError> {
if let Some(v) = cli_value
&& !v.is_empty()
{
return Ok(v);
}
if let Some(key) = env_var
&& let Ok(v) = std::env::var(key)
&& !v.is_empty()
{
return Ok(v);
}
if noinput {
return Err(
format!("umbral createsuperuser: {label} not provided and --noinput is set").into(),
);
}
print!("{label}: ");
use std::io::Write;
std::io::stdout().flush().ok();
let mut s = String::new();
std::io::stdin().read_line(&mut s)?;
let v = s.trim().to_string();
if v.is_empty() {
return Err(format!("umbral createsuperuser: {label} cannot be empty").into());
}
Ok(v)
}
fn resolve_password(noinput: bool) -> Result<String, umbral::cli::CliError> {
if let Ok(v) = std::env::var("UMBRAL_SUPERUSER_PASSWORD")
&& !v.is_empty()
{
return Ok(v);
}
if noinput {
return Err(
"umbral createsuperuser: password not provided (set UMBRAL_SUPERUSER_PASSWORD) \
and --noinput is set"
.into(),
);
}
let first = rpassword::prompt_password("Password: ")?;
if first.is_empty() {
return Err("umbral createsuperuser: password cannot be empty".into());
}
let second = rpassword::prompt_password("Password (again): ")?;
if first != second {
return Err("umbral createsuperuser: passwords do not match".into());
}
Ok(first)
}