use std::sync::{Arc, Mutex};
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode, header};
use tokio::sync::OnceCell;
use tower::ServiceExt;
use umbral_auth::mailer::{AuthMailError, AuthMailer, OutgoingMail};
use umbral_auth::{AuthPlugin, AuthUser};
use umbral_sessions::SessionsPlugin;
#[derive(Default, Clone)]
struct Recorder(Arc<Mutex<Vec<OutgoingMail>>>);
#[async_trait::async_trait]
impl AuthMailer for Recorder {
async fn send(&self, mail: OutgoingMail) -> Result<(), AuthMailError> {
self.0.lock().unwrap().push(mail);
Ok(())
}
}
static BOOT: OnceCell<()> = OnceCell::const_new();
static ROUTER: std::sync::OnceLock<Router> = std::sync::OnceLock::new();
async fn boot() -> &'static Router {
BOOT.get_or_init(|| async {
let settings =
umbral::Settings::from_env().expect("figment defaults always load in a test env");
let tmp = tempfile::tempdir().expect("tempdir");
let db_path = tmp.path().join("umbral_form_surface.sqlite");
std::mem::forget(tmp);
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(
SqliteConnectOptions::new()
.filename(&db_path)
.create_if_missing(true)
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.busy_timeout(std::time::Duration::from_secs(30)),
)
.await
.expect("sqlite tempfile pool");
let rec = Recorder::default();
let app = umbral::App::builder()
.settings(settings)
.database("default", pool)
.plugin(SessionsPlugin::default())
.plugin(
AuthPlugin::<AuthUser>::default()
.with_form_routes()
.disable_throttle()
.mailer(rec),
)
.build()
.expect("App::build should succeed with AuthPlugin + SessionsPlugin + form routes");
let router = app.into_router();
ROUTER.set(router).ok();
let pool = umbral::db::pool();
sqlx::query(
"CREATE TABLE auth_user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
is_active INTEGER NOT NULL,
is_staff INTEGER NOT NULL,
is_superuser INTEGER NOT NULL,
date_joined TEXT NOT NULL,
last_login TEXT,
email_verified_at TEXT
)",
)
.execute(&pool)
.await
.expect("create auth_user table");
sqlx::query(
"CREATE TABLE auth_challenge (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
purpose TEXT NOT NULL,
secret_hash TEXT NOT NULL,
expires_at TEXT NOT NULL,
attempts INTEGER NOT NULL,
used_at TEXT,
created_at TEXT NOT NULL
)",
)
.execute(&pool)
.await
.expect("create auth_challenge table");
sqlx::query(
"CREATE TABLE auth_token (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
key_hash TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
created_at TEXT NOT NULL,
last_used_at TEXT
)",
)
.execute(&pool)
.await
.expect("create auth_token table");
sqlx::query(
"CREATE TABLE session (
id TEXT PRIMARY KEY,
user_id TEXT,
data TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL
)",
)
.execute(&pool)
.await
.expect("create session table");
})
.await;
ROUTER.get().expect("router set during boot")
}
async fn post_form(router: &Router, uri: &str, body: &str) -> axum::http::Response<Body> {
let req = Request::builder()
.method("POST")
.uri(uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(body.to_string()))
.unwrap();
router.clone().oneshot(req).await.unwrap()
}
#[tokio::test]
async fn form_login_bad_creds_redirects_and_sets_session_for_flash() {
let router = boot().await;
let resp = post_form(router, "/auth/login", "username=nobody&password=wrong").await;
assert_eq!(
resp.status(),
StatusCode::SEE_OTHER,
"bad-creds login must redirect"
);
let loc = resp
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert_eq!(loc, "/", "bad-creds redirect must go to '/'");
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
set_cookie.contains("umbral_session"),
"bad-creds login must set a session cookie via session_layer (for flash storage); got: {set_cookie}"
);
let raw_token = set_cookie
.split(';')
.next()
.and_then(|kv| kv.strip_prefix("umbral_session="))
.map(|v| v.trim())
.expect("Set-Cookie must contain umbral_session=<token>");
let stored_id = umbral_sessions::store::hash_token_pub(raw_token);
let pool = umbral::db::pool();
let row: (String,) = sqlx::query_as("SELECT data FROM session WHERE id = ?")
.bind(&stored_id)
.fetch_one(&pool)
.await
.expect("session row must exist after flash write");
let data: serde_json::Value =
serde_json::from_str(&row.0).expect("session.data must be valid JSON");
let messages = data
.get("_umbral_messages")
.expect("session.data must contain _umbral_messages key after msgs.error()")
.as_array()
.expect("_umbral_messages must be a JSON array");
assert!(
!messages.is_empty(),
"flash queue must be non-empty after a bad-creds error"
);
let first = &messages[0];
assert_eq!(
first.get("level").and_then(|v| v.as_str()),
Some("error"),
"flash level must be 'error' for a bad-creds attempt"
);
}
#[tokio::test]
async fn form_login_good_creds_sets_session_cookie() {
let router = boot().await;
umbral_auth::create_user("formuser1", "formuser1@example.com", "G00d$Pass!")
.await
.expect("seed user");
let resp = post_form(
router,
"/auth/login",
"username=formuser1&password=G00d%24Pass%21",
)
.await;
assert_eq!(
resp.status(),
StatusCode::SEE_OTHER,
"good-creds login must redirect"
);
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
set_cookie.contains("umbral_session"),
"good-creds login must set an umbral_session cookie; got: {set_cookie}"
);
}
#[tokio::test]
async fn form_login_safe_redirect_param_honored() {
let router = boot().await;
umbral_auth::create_user("formuser2", "formuser2@example.com", "G00d$Pass!")
.await
.expect("seed user");
let resp = post_form(
router,
"/auth/login?redirect=%2Faccount",
"username=formuser2&password=G00d%24Pass%21",
)
.await;
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let loc = resp
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert_eq!(
loc, "/account",
"safe redirect param must be honored; got: {loc}"
);
}
#[tokio::test]
async fn form_login_open_redirect_rejected() {
let router = boot().await;
umbral_auth::create_user("formuser3", "formuser3@example.com", "G00d$Pass!")
.await
.expect("seed user");
let resp = post_form(
router,
"/auth/login?redirect=%2F%2Fevil.com",
"username=formuser3&password=G00d%24Pass%21",
)
.await;
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let loc = resp
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert_eq!(
loc, "/",
"open-redirect via // must be rejected to '/'; got: {loc}"
);
}
#[tokio::test]
async fn form_signup_creates_user_and_redirects() {
let router = boot().await;
let resp = post_form(
router,
"/auth/signup",
"username=signupuser&email=signup%40example.com&password=G00d%24Pass%21",
)
.await;
assert_eq!(resp.status(), StatusCode::SEE_OTHER, "signup must redirect");
umbral_auth::authenticate::<AuthUser>("signupuser", "G00d$Pass!")
.await
.expect("user created by form signup must be authenticatable");
}
#[tokio::test]
async fn form_logout_redirects_and_clears_cookie() {
let router = boot().await;
let resp = post_form(router, "/auth/logout", "").await;
assert_eq!(resp.status(), StatusCode::SEE_OTHER, "logout must redirect");
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
set_cookie.contains("Max-Age=0")
|| set_cookie.contains("max-age=0")
|| set_cookie.contains("umbral_session"),
"logout must emit a cookie-clearing Set-Cookie; got: {set_cookie:?}"
);
}