use axum::body::Body;
use axum::http::Request;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use tokio::sync::OnceCell;
use tower::ServiceExt;
use umbral::prelude::Plugin;
use umbral_auth::{
AuthPlugin, AuthUser, CommonPasswordValidator, MinLengthValidator, NumericPasswordValidator,
PasswordContext, PasswordPolicy, PasswordValidator, UserAttributeSimilarityValidator,
create_user, validate_password,
};
#[test]
fn min_length_validator_rejects_short_password() {
let v = MinLengthValidator::default();
assert!(
v.validate_via("abc").is_err(),
"a 3-char password must be rejected by the min-length validator"
);
assert!(
v.validate_via("abcdefgh").is_ok(),
"an 8-char password must pass the default min-length validator"
);
}
#[test]
fn common_password_validator_rejects_password() {
let v = CommonPasswordValidator;
assert!(
v.validate_via("password").is_err(),
"`password` must be in the common-password denylist"
);
assert!(
v.validate_via("PASSWORD").is_err(),
"the common-password match must be case-insensitive"
);
}
#[test]
fn numeric_password_validator_rejects_all_digits() {
let v = NumericPasswordValidator;
assert!(
v.validate_via("12345678").is_err(),
"an all-numeric password must be rejected"
);
assert!(
v.validate_via("abc12345").is_ok(),
"a mixed alphanumeric password must pass the numeric validator"
);
}
#[test]
fn similarity_validator_rejects_password_like_username() {
let v = UserAttributeSimilarityValidator::default();
let ctx = PasswordContext::for_username("alice");
assert!(
v.validate("alice123", &ctx).is_err(),
"`alice123` must be flagged as too similar to username `alice`"
);
assert!(
v.validate("Tr0ub4dour&3xpl", &ctx).is_ok(),
"an unrelated strong password must not be flagged"
);
}
#[test]
fn validate_password_aggregates_multiple_failures() {
let reasons = validate_password("12345678", &PasswordContext::empty())
.expect_err("a doubly-weak password must fail");
assert!(
reasons.len() >= 2,
"validate_password must collect every failure; got {reasons:?}"
);
}
#[test]
fn strong_password_passes_all_validators() {
let ctx = PasswordContext::new(Some("alice"), Some("alice@example.com"));
assert!(
validate_password("Tr0ub4dour&3xpl", &ctx).is_ok(),
"a strong password must pass the full default policy"
);
}
#[test]
fn disable_password_validation_installs_empty_policy() {
let policy = PasswordPolicy::empty();
assert!(
policy.check("a", &PasswordContext::empty()).is_ok(),
"an empty policy must accept any password"
);
assert!(
!PasswordPolicy::default().is_empty(),
"the default policy must enforce the secure validator set"
);
}
trait ValidateVia {
fn validate_via(&self, password: &str) -> Result<(), String>;
}
impl<T: umbral_auth::PasswordValidator> ValidateVia for T {
fn validate_via(&self, password: &str) -> Result<(), String> {
self.validate(password, &PasswordContext::empty())
}
}
const PREFIX: &str = "/api/auth";
static BOOT: OnceCell<()> = OnceCell::const_new();
async fn boot() {
BOOT.get_or_init(|| async {
let settings =
umbral::Settings::from_env().expect("figment defaults always load in a test env");
let tmp = tempfile::tempdir().expect("create tempdir for the test DB");
let db_path = tmp.path().join("umbral_auth_pwvalidation.sqlite");
std::mem::forget(tmp);
let options = SqliteConnectOptions::new()
.filename(&db_path)
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await
.expect("sqlite should connect against the tempfile");
umbral::App::builder()
.settings(settings)
.database("default", pool)
.plugin(
AuthPlugin::<AuthUser>::default()
.with_default_routes()
.disable_throttle(),
)
.build()
.expect("App::build should succeed with AuthPlugin");
let pool = umbral::db::pool();
sqlx::query(
"CREATE TABLE auth_user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL,
password_hash TEXT NOT NULL,
is_active INTEGER NOT NULL,
is_staff INTEGER NOT NULL,
is_superuser INTEGER NOT NULL,
date_joined TEXT NOT NULL,
last_login TEXT
)",
)
.execute(&pool)
.await
.expect("create auth_user table");
})
.await;
}
fn auth_router() -> axum::Router {
AuthPlugin::<AuthUser>::default()
.with_default_routes()
.routes()
}
async fn post_register(json: &str) -> (http::StatusCode, Vec<u8>) {
let resp = auth_router()
.oneshot(
Request::builder()
.method("POST")
.uri(format!("{PREFIX}/register"))
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::from(json.to_string()))
.unwrap(),
)
.await
.expect("register request must not panic");
let status = resp.status();
let body = http_body_util::BodyExt::collect(resp.into_body())
.await
.expect("collect body")
.to_bytes()
.to_vec();
(status, body)
}
#[tokio::test]
async fn register_route_rejects_weak_password() {
boot().await;
let (status, body) = post_register(
r#"{"username":"weakling","email":"weak@example.com","password":"a"}"#,
)
.await;
assert_eq!(
status,
http::StatusCode::BAD_REQUEST,
"register with password `a` must be 400; body={}",
String::from_utf8_lossy(&body),
);
let parsed: serde_json::Value = serde_json::from_slice(&body).expect("error body is JSON");
assert_eq!(parsed["error"], "weak_password");
assert!(
parsed["detail"].as_str().is_some_and(|d| !d.is_empty()),
"the 400 must carry at least one human-readable reason; body={parsed}"
);
let count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM auth_user WHERE username = 'weakling'")
.fetch_one(&umbral::db::pool())
.await
.expect("count query");
assert_eq!(count, 0, "a rejected register must not write a row");
}
#[tokio::test]
async fn register_route_accepts_strong_password() {
boot().await;
let (status, body) = post_register(
r#"{"username":"stronguser","email":"strong@example.com","password":"Tr0ub4dour&3xpl"}"#,
)
.await;
assert_eq!(
status,
http::StatusCode::CREATED,
"register with a strong password must be 201; body={}",
String::from_utf8_lossy(&body),
);
let row: (String, String) =
sqlx::query_as("SELECT username, password_hash FROM auth_user WHERE username = ?")
.bind("stronguser")
.fetch_one(&umbral::db::pool())
.await
.expect("the stronguser row should exist after a successful register");
assert_eq!(row.0, "stronguser");
assert_ne!(
row.1, "Tr0ub4dour&3xpl",
"the stored value must be the hash, not the plaintext"
);
}
#[tokio::test]
async fn create_user_helper_does_not_validate() {
boot().await;
let user = create_user("lowlevel", "lowlevel@example.com", "a")
.await
.expect("create_user is low-level and must NOT validate the password");
assert_eq!(user.username, "lowlevel");
assert_ne!(
user.password_hash, "a",
"create_user must still hash, just not validate"
);
let similar = create_user("bobby", "bobby@example.com", "bobby1234")
.await
.expect("create_user must not run the similarity validator");
assert_eq!(similar.username, "bobby");
}