#[cfg(feature = "oauth2")]
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "oauth2")]
use std::time::Duration;
use axum::extract::FromRequestParts;
use axum::response::{IntoResponse, Response};
use http::StatusCode;
use http::request::Parts;
#[cfg(feature = "oauth2")]
use jsonwebtoken::jwk::JwkSet;
#[cfg(feature = "oauth2")]
use serde::Deserialize;
#[cfg(feature = "oauth2")]
use url::Url;
const DEFAULT_BCRYPT_COST: u32 = 12;
pub async fn hash_password(password: &str) -> crate::AutumnResult<String> {
let password = password.to_string();
tokio::task::spawn_blocking(move || {
bcrypt::hash(password, DEFAULT_BCRYPT_COST)
.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
})
.await
.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?
}
pub async fn verify_password(password: &str, hash: &str) -> crate::AutumnResult<bool> {
let password = password.to_string();
let is_valid_format = hash.len() == 60 && hash.starts_with('$');
let hash_to_verify = if is_valid_format {
hash.to_string()
} else {
"$2b$12$KIXe8K4j1sH6/xH.x9d71uJ5Jk8t6O4m6Q110g4H8y1r6J6O6O6O6".to_string()
};
let result = tokio::task::spawn_blocking(move || bcrypt::verify(&password, &hash_to_verify))
.await
.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?;
if !is_valid_format {
return Ok(false);
}
result.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
}
#[doc(hidden)]
pub async fn __check_secured(
session: &crate::session::Session,
roles: &[&str],
) -> crate::AutumnResult<()> {
if session.get("user_id").await.is_none() {
return Err(crate::AutumnError::unauthorized_msg(
"authentication required",
));
}
if !roles.is_empty() {
let user_role = session.get("role").await.unwrap_or_default();
if !roles.iter().any(|&r| r == user_role) {
return Err(crate::AutumnError::forbidden_msg(
"insufficient permissions",
));
}
}
Ok(())
}
pub struct Auth<T>(pub T);
impl<T, S> FromRequestParts<S> for Auth<T>
where
T: Clone + Send + Sync + 'static,
S: Send + Sync,
{
type Rejection = AuthRejection;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
let user = parts.extensions.get::<T>().cloned();
async move { user.map_or_else(|| Err(AuthRejection), |user| Ok(Self(user))) }
}
}
#[derive(Debug)]
pub struct AuthRejection;
impl IntoResponse for AuthRejection {
fn into_response(self) -> Response {
(
StatusCode::UNAUTHORIZED,
axum::Json(serde_json::json!({
"error": {
"status": 401,
"message": "authentication required"
}
})),
)
.into_response()
}
}
impl std::fmt::Display for AuthRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("authentication required")
}
}
#[derive(Clone)]
pub struct RequireAuth {
session_key: Arc<str>,
}
impl RequireAuth {
pub fn new(session_key: impl Into<String>) -> Self {
Self {
session_key: Arc::from(session_key.into()),
}
}
}
impl<S> tower::Layer<S> for RequireAuth {
type Service = RequireAuthService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequireAuthService {
inner,
session_key: Arc::clone(&self.session_key),
}
}
}
#[derive(Clone)]
pub struct RequireAuthService<S> {
inner: S,
session_key: Arc<str>,
}
impl<S, ResBody> tower::Service<axum::extract::Request> for RequireAuthService<S>
where
S: tower::Service<axum::extract::Request, Response = Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ResBody: From<String> + Default + Send + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: axum::extract::Request) -> Self::Future {
let session_key = Arc::clone(&self.session_key);
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let session = req.extensions().get::<crate::session::Session>().cloned();
let is_authenticated = if let Some(ref session) = session {
session.contains_key(&session_key).await
} else {
false
};
if is_authenticated {
inner.call(req).await
} else {
let body = serde_json::json!({
"error": {
"status": 401,
"message": "authentication required"
}
});
let response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(http::header::CONTENT_TYPE, "application/json")
.body(ResBody::from(
serde_json::to_string(&body).unwrap_or_default(),
))
.unwrap_or_default();
Ok(response)
}
})
}
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_bcrypt_cost")]
pub bcrypt_cost: u32,
#[serde(default = "default_session_key")]
pub session_key: String,
#[cfg(feature = "oauth2")]
#[serde(default)]
pub oauth2: OAuth2Config,
}
const fn default_bcrypt_cost() -> u32 {
DEFAULT_BCRYPT_COST
}
fn default_session_key() -> String {
"user_id".to_owned()
}
#[cfg(feature = "oauth2")]
const fn default_provider_scope() -> String {
String::new()
}
#[cfg(feature = "oauth2")]
const OAUTH_HTTP_TIMEOUT_SECS: u64 = 15;
#[cfg(feature = "oauth2")]
#[derive(Debug, Clone, Default, serde::Deserialize)]
pub struct OAuth2Config {
#[serde(flatten)]
pub providers: HashMap<String, OAuth2ProviderConfig>,
}
#[cfg(feature = "oauth2")]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct OAuth2ProviderConfig {
pub client_id: String,
pub client_secret: String,
pub authorize_url: String,
pub token_url: String,
#[serde(default)]
pub userinfo_url: Option<String>,
pub redirect_uri: String,
#[serde(default = "default_provider_scope")]
pub scope: String,
#[serde(default)]
pub issuer: Option<String>,
#[serde(default)]
pub jwks_url: Option<String>,
}
#[cfg(feature = "oauth2")]
#[derive(Debug, Clone, Deserialize)]
pub struct OAuth2Callback {
pub code: String,
pub state: String,
}
#[cfg(feature = "oauth2")]
#[derive(Debug, Clone)]
pub struct OidcIdentity {
pub subject: String,
pub email: Option<String>,
pub name: Option<String>,
pub preferred_username: Option<String>,
pub raw_claims: serde_json::Value,
}
#[cfg(feature = "oauth2")]
#[derive(Debug, Deserialize)]
struct OAuth2TokenResponse {
access_token: String,
#[allow(dead_code)]
token_type: Option<String>,
id_token: Option<String>,
}
#[cfg(feature = "oauth2")]
pub async fn oauth2_authorize_url(
session: &crate::session::Session,
provider_name: &str,
provider: &OAuth2ProviderConfig,
) -> crate::AutumnResult<String> {
let state = uuid::Uuid::new_v4().to_string();
let nonce = uuid::Uuid::new_v4().to_string();
session
.insert(format!("oauth2:{provider_name}:state"), state.clone())
.await;
session
.insert(format!("oauth2:{provider_name}:nonce"), nonce.clone())
.await;
let mut url = Url::parse(&provider.authorize_url)
.map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid authorize_url: {e}")))?;
{
let mut q = url.query_pairs_mut();
q.append_pair("response_type", "code");
q.append_pair("client_id", &provider.client_id);
q.append_pair("redirect_uri", &provider.redirect_uri);
if !provider.scope.trim().is_empty() {
q.append_pair("scope", &provider.scope);
}
q.append_pair("state", &state);
q.append_pair("nonce", &nonce);
}
Ok(url.into())
}
#[cfg(feature = "oauth2")]
pub async fn oauth2_finish_login(
session: &crate::session::Session,
session_key: &str,
provider_name: &str,
provider: &OAuth2ProviderConfig,
callback: &OAuth2Callback,
) -> crate::AutumnResult<OidcIdentity> {
validate_callback_state(session, provider_name, callback).await?;
let token = exchange_oauth2_token(provider, callback).await?;
let (claims, source) = load_identity_claims(provider, &token).await?;
validate_oidc_nonce(session, provider_name, &claims, source).await?;
let subject = extract_subject(&claims, source)?;
finalize_oauth2_session(session, session_key, provider_name, subject, claims).await
}
#[cfg(feature = "oauth2")]
async fn validate_callback_state(
session: &crate::session::Session,
provider_name: &str,
callback: &OAuth2Callback,
) -> crate::AutumnResult<()> {
let state_key = format!("oauth2:{provider_name}:state");
let expected_state = session.get(&state_key).await.ok_or_else(|| {
crate::AutumnError::unauthorized_msg("oauth2 state missing; restart login")
})?;
if subtle::ConstantTimeEq::ct_eq(expected_state.as_bytes(), callback.state.as_bytes())
.unwrap_u8()
!= 1
{
return Err(crate::AutumnError::unauthorized_msg(
"oauth2 state mismatch",
));
}
session.remove(&state_key).await;
Ok(())
}
#[cfg(feature = "oauth2")]
async fn exchange_oauth2_token(
provider: &OAuth2ProviderConfig,
callback: &OAuth2Callback,
) -> crate::AutumnResult<OAuth2TokenResponse> {
let token_response = oauth_http_client()?
.post(&provider.token_url)
.header(reqwest::header::ACCEPT, "application/json")
.form(&[
("grant_type", "authorization_code"),
("code", callback.code.as_str()),
("redirect_uri", provider.redirect_uri.as_str()),
("client_id", provider.client_id.as_str()),
("client_secret", provider.client_secret.as_str()),
])
.send()
.await
.map_err(|e| {
crate::AutumnError::service_unavailable_msg(format!("token request failed: {e}"))
})?
.error_for_status()
.map_err(|e| crate::AutumnError::unauthorized_msg(format!("token exchange failed: {e}")))?;
let token_content_type = token_response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let token_body = token_response.text().await.map_err(|e| {
crate::AutumnError::bad_request_msg(format!("invalid token response body: {e}"))
})?;
parse_oauth2_token_response(token_content_type.as_deref(), &token_body)
}
#[cfg(feature = "oauth2")]
async fn load_identity_claims(
provider: &OAuth2ProviderConfig,
token: &OAuth2TokenResponse,
) -> crate::AutumnResult<(serde_json::Value, IdentitySource)> {
if let Some(id_token) = token.id_token.as_deref() {
return Ok((
validate_and_decode_id_token(id_token, provider).await?,
IdentitySource::IdToken,
));
}
if let Some(userinfo_url) = &provider.userinfo_url {
let claims = oauth_http_client()?
.get(userinfo_url)
.header(
reqwest::header::USER_AGENT,
concat!("autumn-web/", env!("CARGO_PKG_VERSION")),
)
.bearer_auth(&token.access_token)
.send()
.await
.map_err(|e| {
crate::AutumnError::service_unavailable_msg(format!("userinfo request failed: {e}"))
})?
.error_for_status()
.map_err(|e| crate::AutumnError::unauthorized_msg(format!("userinfo failed: {e}")))?
.json()
.await
.map_err(|e| {
crate::AutumnError::bad_request_msg(format!("invalid userinfo payload: {e}"))
})?;
return Ok((claims, IdentitySource::UserInfo));
}
Err(crate::AutumnError::bad_request_msg(
"provider must return id_token or configure userinfo_url",
))
}
#[cfg(feature = "oauth2")]
async fn validate_oidc_nonce(
session: &crate::session::Session,
provider_name: &str,
claims: &serde_json::Value,
source: IdentitySource,
) -> crate::AutumnResult<()> {
let nonce_key = format!("oauth2:{provider_name}:nonce");
let stored_nonce = session.remove(&nonce_key).await;
if source == IdentitySource::IdToken {
let expected_nonce = stored_nonce.ok_or_else(|| {
crate::AutumnError::unauthorized_msg("oauth2 nonce missing from session")
})?;
let actual_nonce = claims
.get("nonce")
.and_then(serde_json::Value::as_str)
.ok_or_else(|| crate::AutumnError::unauthorized_msg("missing oidc nonce claim"))?;
if subtle::ConstantTimeEq::ct_eq(expected_nonce.as_bytes(), actual_nonce.as_bytes())
.unwrap_u8()
!= 1
{
return Err(crate::AutumnError::unauthorized_msg("oidc nonce mismatch"));
}
}
Ok(())
}
#[cfg(feature = "oauth2")]
async fn finalize_oauth2_session(
session: &crate::session::Session,
session_key: &str,
provider_name: &str,
subject: String,
claims: serde_json::Value,
) -> crate::AutumnResult<OidcIdentity> {
session.insert(session_key, subject.clone()).await;
session.insert("auth_provider", provider_name).await;
session.rotate_id().await;
Ok(OidcIdentity {
subject,
email: claims
.get("email")
.and_then(serde_json::Value::as_str)
.map(str::to_owned),
name: claims
.get("name")
.and_then(serde_json::Value::as_str)
.map(str::to_owned),
preferred_username: claims
.get("preferred_username")
.and_then(serde_json::Value::as_str)
.map(str::to_owned),
raw_claims: claims,
})
}
#[cfg(feature = "oauth2")]
fn parse_oauth2_token_response(
content_type: Option<&str>,
body: &str,
) -> crate::AutumnResult<OAuth2TokenResponse> {
let looks_like_json = content_type.is_some_and(|v| v.contains("application/json"))
|| body.trim_start().starts_with('{');
if looks_like_json {
return serde_json::from_str(body).map_err(|e| {
crate::AutumnError::bad_request_msg(format!("invalid json token response: {e}"))
});
}
let form: HashMap<String, String> = url::form_urlencoded::parse(body.as_bytes())
.into_owned()
.collect();
let access_token = form.get("access_token").cloned().ok_or_else(|| {
crate::AutumnError::bad_request_msg("token response missing access_token")
})?;
Ok(OAuth2TokenResponse {
access_token,
token_type: form.get("token_type").cloned(),
id_token: form.get("id_token").cloned(),
})
}
#[cfg(feature = "oauth2")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IdentitySource {
IdToken,
UserInfo,
}
#[cfg(feature = "oauth2")]
fn extract_subject(
claims: &serde_json::Value,
source: IdentitySource,
) -> crate::AutumnResult<String> {
if let Some(sub) = claims.get("sub").and_then(serde_json::Value::as_str) {
return Ok(sub.to_owned());
}
if source == IdentitySource::UserInfo {
if let Some(id) = claims.get("id").and_then(serde_json::Value::as_i64) {
return Ok(id.to_string());
}
if let Some(id) = claims.get("id").and_then(serde_json::Value::as_str) {
return Ok(id.to_owned());
}
return Err(crate::AutumnError::bad_request_msg(
"missing identity claim: expected sub or id from userinfo",
));
}
Err(crate::AutumnError::bad_request_msg("missing sub claim"))
}
#[cfg(feature = "oauth2")]
async fn validate_and_decode_id_token(
token: &str,
provider: &OAuth2ProviderConfig,
) -> crate::AutumnResult<serde_json::Value> {
let issuer = provider
.issuer
.as_deref()
.ok_or_else(|| crate::AutumnError::bad_request_msg("provider.issuer required for oidc"))?;
let jwks_url = provider.jwks_url.as_deref().ok_or_else(|| {
crate::AutumnError::bad_request_msg("provider.jwks_url required for oidc")
})?;
let header = jsonwebtoken::decode_header(token).map_err(|e| {
crate::AutumnError::unauthorized_msg(format!("invalid id_token header: {e}"))
})?;
let kid = header
.kid
.as_deref()
.ok_or_else(|| crate::AutumnError::unauthorized_msg("id_token header missing kid"))?;
let alg = header.alg;
let jwks: JwkSet = oauth_http_client()?
.get(jwks_url)
.send()
.await
.map_err(|e| {
crate::AutumnError::service_unavailable_msg(format!("jwks request failed: {e}"))
})?
.error_for_status()
.map_err(|e| crate::AutumnError::unauthorized_msg(format!("jwks fetch failed: {e}")))?
.json()
.await
.map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid jwks response: {e}")))?;
let jwk = jwks
.keys
.iter()
.find(|k| k.common.key_id.as_deref() == Some(kid))
.ok_or_else(|| crate::AutumnError::unauthorized_msg("no jwk matched id_token kid"))?;
let decoding_key = jsonwebtoken::DecodingKey::from_jwk(jwk)
.map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid jwk key: {e}")))?;
let mut validation = jsonwebtoken::Validation::new(alg);
validation.set_issuer(&[issuer]);
validation.set_audience(std::slice::from_ref(&provider.client_id));
validation.required_spec_claims = ["exp", "iss", "aud", "sub"]
.into_iter()
.map(str::to_owned)
.collect();
validation.validate_exp = true;
validation.validate_nbf = true;
let claims = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
.map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid id_token: {e}")))?;
Ok(claims.claims)
}
#[cfg(feature = "oauth2")]
fn oauth_http_client() -> crate::AutumnResult<reqwest::Client> {
reqwest::Client::builder()
.timeout(Duration::from_secs(OAUTH_HTTP_TIMEOUT_SECS))
.build()
.map_err(|e| {
crate::AutumnError::service_unavailable_msg(format!(
"failed to build oauth http client: {e}"
))
})
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
bcrypt_cost: default_bcrypt_cost(),
session_key: default_session_key(),
#[cfg(feature = "oauth2")]
oauth2: OAuth2Config::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn hash_and_verify_password() {
let hash = hash_password("test_password").await.unwrap();
assert!(hash.starts_with("$2b$"));
assert!(verify_password("test_password", &hash).await.unwrap());
assert!(!verify_password("wrong_password", &hash).await.unwrap());
}
#[tokio::test]
async fn verify_invalid_hash_returns_false() {
let result = verify_password("test", "not-a-valid-hash").await;
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[tokio::test]
async fn verify_password_rejects_invalid_hash_format_safely() {
let result = verify_password("test", "short").await;
assert!(result.is_ok());
assert!(!result.unwrap());
let bad_prefix = "a".repeat(60);
let result = verify_password("test", &bad_prefix).await;
assert!(result.is_ok());
assert!(!result.unwrap());
let bad_length = "$2b$12$short";
let result = verify_password("test", bad_length).await;
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn auth_config_defaults() {
let config = AuthConfig::default();
assert_eq!(config.bcrypt_cost, 12);
assert_eq!(config.session_key, "user_id");
#[cfg(feature = "oauth2")]
assert!(config.oauth2.providers.is_empty());
}
#[cfg(feature = "oauth2")]
#[test]
fn oauth2_config_deserializes_provider_tables() {
let cfg: crate::config::AutumnConfig = toml::from_str(
r#"
[auth.oauth2.github]
client_id = "cid"
client_secret = "secret"
authorize_url = "https://github.com/login/oauth/authorize"
token_url = "https://github.com/login/oauth/access_token"
redirect_uri = "http://localhost:3000/auth/github/callback"
"#,
)
.unwrap();
let provider = cfg.auth.oauth2.providers.get("github").unwrap();
assert_eq!(provider.client_id, "cid");
assert_eq!(provider.scope, "");
assert!(provider.issuer.is_none());
assert!(provider.jwks_url.is_none());
}
#[cfg(feature = "oauth2")]
#[tokio::test]
async fn oauth2_authorize_url_sets_state_and_nonce() {
let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
let provider = OAuth2ProviderConfig {
client_id: "cid".into(),
client_secret: "secret".into(),
authorize_url: "https://idp.example/authorize".into(),
token_url: "https://idp.example/token".into(),
userinfo_url: None,
redirect_uri: "http://localhost:3000/callback".into(),
scope: "openid profile".into(),
issuer: None,
jwks_url: None,
};
let url = oauth2_authorize_url(&session, "github", &provider)
.await
.unwrap();
assert!(url.contains("response_type=code"));
assert!(session.get("oauth2:github:state").await.is_some());
assert!(session.get("oauth2:github:nonce").await.is_some());
}
#[cfg(feature = "oauth2")]
#[tokio::test]
async fn oauth2_authorize_url_omits_scope_when_empty() {
let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
let provider = OAuth2ProviderConfig {
client_id: "cid".into(),
client_secret: "secret".into(),
authorize_url: "https://idp.example/authorize".into(),
token_url: "https://idp.example/token".into(),
userinfo_url: None,
redirect_uri: "http://localhost:3000/callback".into(),
scope: String::new(),
issuer: None,
jwks_url: None,
};
let url = oauth2_authorize_url(&session, "github", &provider)
.await
.unwrap();
assert!(!url.contains("scope="));
}
#[cfg(feature = "oauth2")]
#[tokio::test]
async fn validate_id_token_requires_oidc_metadata() {
let provider = OAuth2ProviderConfig {
client_id: "cid".into(),
client_secret: "secret".into(),
authorize_url: "https://idp.example/authorize".into(),
token_url: "https://idp.example/token".into(),
userinfo_url: None,
redirect_uri: "http://localhost:3000/callback".into(),
scope: "openid profile".into(),
issuer: None,
jwks_url: None,
};
let err = validate_and_decode_id_token("bad.token.value", &provider)
.await
.unwrap_err();
assert_eq!(err.to_string(), "provider.issuer required for oidc");
}
#[cfg(feature = "oauth2")]
#[test]
fn parse_oauth2_token_response_supports_form_encoded_payload() {
let token = parse_oauth2_token_response(
Some("application/x-www-form-urlencoded"),
"access_token=abc123&token_type=bearer",
)
.unwrap();
assert_eq!(token.access_token, "abc123");
assert_eq!(token.token_type.as_deref(), Some("bearer"));
}
#[cfg(feature = "oauth2")]
#[test]
fn extract_subject_allows_userinfo_id_fallback() {
let claims = serde_json::json!({ "id": 42 });
let subject = extract_subject(&claims, IdentitySource::UserInfo).unwrap();
assert_eq!(subject, "42");
}
#[cfg(feature = "oauth2")]
#[tokio::test]
async fn validate_callback_state_preserves_state_on_mismatch() {
let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
session
.insert("oauth2:github:state".to_owned(), "real-state".to_owned())
.await;
let bad_callback = OAuth2Callback {
code: "c".into(),
state: "wrong-state".into(),
};
let err = validate_callback_state(&session, "github", &bad_callback)
.await
.unwrap_err();
assert!(err.to_string().contains("state mismatch"));
assert_eq!(
session.get("oauth2:github:state").await.as_deref(),
Some("real-state")
);
}
#[cfg(feature = "oauth2")]
#[tokio::test]
async fn validate_oidc_nonce_rejects_missing_nonce_for_id_token() {
let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
let claims = serde_json::json!({ "nonce": "any" });
let err = validate_oidc_nonce(&session, "github", &claims, IdentitySource::IdToken)
.await
.unwrap_err();
assert!(err.to_string().contains("nonce missing from session"));
}
#[cfg(feature = "oauth2")]
#[test]
fn extract_subject_requires_sub_for_id_token() {
let claims = serde_json::json!({ "id": "abc" });
let err = extract_subject(&claims, IdentitySource::IdToken).unwrap_err();
assert_eq!(err.to_string(), "missing sub claim");
}
#[test]
fn auth_rejection_is_401() {
let rejection = AuthRejection;
let response = rejection.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn auth_rejection_display() {
assert_eq!(AuthRejection.to_string(), "authentication required");
}
#[tokio::test]
async fn auth_extractor_returns_401_when_no_user() {
use crate::state::AppState;
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use tower::ServiceExt;
#[derive(Clone)]
struct TestUser {
name: String,
}
async fn handler(Auth(user): Auth<TestUser>) -> String {
user.name
}
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new().route("/", get(handler)).with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn auth_extractor_returns_user_when_present() {
use crate::state::AppState;
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use tower::ServiceExt;
#[derive(Clone)]
struct TestUser {
name: String,
}
async fn handler(Auth(user): Auth<TestUser>) -> String {
user.name
}
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/", get(handler))
.layer(axum::middleware::from_fn(
|mut req: axum::extract::Request, next: axum::middleware::Next| async move {
req.extensions_mut().insert(TestUser {
name: "alice".into(),
});
next.run(req).await
},
))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "alice");
}
#[tokio::test]
async fn require_auth_rejects_unauthenticated() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer};
use crate::state::AppState;
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/protected", get(|| async { "secret" }))
.layer(RequireAuth::new("user_id"))
.layer(SessionLayer::new(
MemoryStore::new(),
SessionConfig::default(),
))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn check_secured_rejects_unauthenticated() {
let session =
crate::session::Session::new_for_test(String::new(), std::collections::HashMap::new());
let result = __check_secured(&session, &[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
assert_eq!(err.to_string(), "authentication required");
}
#[tokio::test]
async fn check_secured_allows_authenticated() {
let data = std::collections::HashMap::from([("user_id".into(), "42".into())]);
let session = crate::session::Session::new_for_test("sess".into(), data);
let result = __check_secured(&session, &[]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn check_secured_rejects_wrong_role() {
let data = std::collections::HashMap::from([
("user_id".into(), "42".into()),
("role".into(), "viewer".into()),
]);
let session = crate::session::Session::new_for_test("sess".into(), data);
let result = __check_secured(&session, &["admin"]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
assert_eq!(err.to_string(), "insufficient permissions");
}
#[tokio::test]
async fn check_secured_allows_matching_role() {
let data = std::collections::HashMap::from([
("user_id".into(), "42".into()),
("role".into(), "admin".into()),
]);
let session = crate::session::Session::new_for_test("sess".into(), data);
let result = __check_secured(&session, &["admin"]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn check_secured_allows_any_of_multiple_roles() {
let data = std::collections::HashMap::from([
("user_id".into(), "42".into()),
("role".into(), "editor".into()),
]);
let session = crate::session::Session::new_for_test("sess".into(), data);
let result = __check_secured(&session, &["admin", "editor"]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn secured_macro_rejects_unauthenticated() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer};
use crate::state::AppState;
#[autumn_macros::secured]
async fn protected_handler() -> crate::AutumnResult<&'static str> {
Ok("secret")
}
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/", get(protected_handler))
.layer(SessionLayer::new(
MemoryStore::new(),
SessionConfig::default(),
))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn secured_macro_allows_authenticated() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use http::header::COOKIE;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
use crate::state::AppState;
#[autumn_macros::secured]
async fn protected_handler() -> crate::AutumnResult<&'static str> {
Ok("secret")
}
let store = MemoryStore::new();
store
.save(
"sess1",
std::collections::HashMap::from([("user_id".into(), "42".into())]),
)
.await
.unwrap();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/", get(protected_handler))
.layer(SessionLayer::new(store, SessionConfig::default()))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.header(COOKIE, "autumn.sid=sess1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
}
#[tokio::test]
async fn secured_macro_with_role_rejects_wrong_role() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use http::header::COOKIE;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
use crate::state::AppState;
#[autumn_macros::secured("admin")]
async fn admin_only() -> crate::AutumnResult<&'static str> {
Ok("admin area")
}
let store = MemoryStore::new();
store
.save(
"sess1",
std::collections::HashMap::from([
("user_id".into(), "42".into()),
("role".into(), "viewer".into()),
]),
)
.await
.unwrap();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/", get(admin_only))
.layer(SessionLayer::new(store, SessionConfig::default()))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.header(COOKIE, "autumn.sid=sess1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn secured_macro_with_multiple_roles_allows_match() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use http::header::COOKIE;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
use crate::state::AppState;
#[autumn_macros::secured("admin", "editor")]
async fn content_handler() -> crate::AutumnResult<&'static str> {
Ok("content")
}
let store = MemoryStore::new();
store
.save(
"sess1",
std::collections::HashMap::from([
("user_id".into(), "42".into()),
("role".into(), "editor".into()),
]),
)
.await
.unwrap();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/", get(content_handler))
.layer(SessionLayer::new(store, SessionConfig::default()))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/")
.header(COOKIE, "autumn.sid=sess1")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "content");
}
#[tokio::test]
async fn require_auth_allows_authenticated() {
use axum::Router;
use axum::body::Body;
use axum::routing::get;
use http::header::COOKIE;
use tower::ServiceExt;
use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
use crate::state::AppState;
let store = MemoryStore::new();
let mut session_data = std::collections::HashMap::new();
session_data.insert("user_id".into(), "42".into());
store.save("valid-session", session_data).await.unwrap();
let state = AppState {
extensions: std::sync::Arc::new(std::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "db")]
pool: None,
profile: None,
started_at: std::time::Instant::now(),
health_detailed: false,
probes: crate::probe::ProbeState::ready_for_test(),
metrics: crate::middleware::MetricsCollector::new(),
log_levels: crate::actuator::LogLevels::new("info"),
task_registry: crate::actuator::TaskRegistry::new(),
config_props: crate::actuator::ConfigProperties::default(),
#[cfg(feature = "ws")]
channels: crate::channels::Channels::new(32),
#[cfg(feature = "ws")]
shutdown: tokio_util::sync::CancellationToken::new(),
};
let app = Router::new()
.route("/protected", get(|| async { "secret" }))
.layer(RequireAuth::new("user_id"))
.layer(SessionLayer::new(store, SessionConfig::default()))
.with_state(state);
let response = app
.oneshot(
http::Request::builder()
.uri("/protected")
.header(COOKIE, "autumn.sid=valid-session")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
}
#[tokio::test]
async fn require_auth_poll_ready_propagates() {
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Clone)]
struct MockService {
ready: bool,
poll_count: std::sync::Arc<std::sync::atomic::AtomicUsize>,
}
impl Service<axum::extract::Request> for MockService {
type Response = axum::response::Response;
type Error = std::convert::Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if self.ready {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn call(&mut self, _req: axum::extract::Request) -> Self::Future {
std::future::ready(Ok(axum::response::Response::new(axum::body::Body::empty())))
}
}
let layer = RequireAuth::new("user_id");
let poll_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let mock_service = MockService {
ready: false,
poll_count: poll_count.clone(),
};
let mut service = layer.layer(mock_service);
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let poll = service.poll_ready(&mut cx);
assert!(poll.is_pending());
assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 1);
let mock_service_ready = MockService {
ready: true,
poll_count: poll_count.clone(),
};
let mut service_ready = layer.layer(mock_service_ready);
let poll_ready = service_ready.poll_ready(&mut cx);
assert!(poll_ready.is_ready());
assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[tokio::test]
async fn auth_rejection_into_response() {
let rejection = AuthRejection;
let response = rejection.into_response();
assert_eq!(response.status(), axum::http::StatusCode::UNAUTHORIZED);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["status"], 401);
assert_eq!(json["error"]["message"], "authentication required");
}
#[test]
fn test_auth_config_defaults() {
let config = AuthConfig::default();
assert_eq!(config.bcrypt_cost, DEFAULT_BCRYPT_COST);
assert_eq!(config.session_key, "user_id");
}
#[tokio::test]
async fn test_hash_password() {
let test_input = uuid::Uuid::new_v4().to_string();
let hash = super::hash_password(&test_input)
.await
.expect("Failed to hash password");
assert!(hash.starts_with("$2b$"));
let is_valid = super::verify_password(&test_input, &hash)
.await
.expect("Failed to verify password");
assert!(is_valid, "Password should be verified successfully");
let is_invalid = super::verify_password(&uuid::Uuid::new_v4().to_string(), &hash)
.await
.expect("Failed to verify wrong password");
assert!(!is_invalid, "Wrong password should not be verified");
}
#[tokio::test]
async fn test_hash_password_empty() {
let test_input = String::new();
let hash = super::hash_password(&test_input)
.await
.expect("Failed to hash empty password");
assert!(hash.starts_with("$2b$"));
let is_valid = super::verify_password(&test_input, &hash)
.await
.expect("Failed to verify empty password");
assert!(is_valid, "Empty password should be verified successfully");
}
#[tokio::test]
async fn test_hash_password_long() {
let test_input = "a".repeat(100);
let hash = super::hash_password(&test_input)
.await
.expect("Failed to hash long password");
assert!(hash.starts_with("$2b$"));
let is_valid = super::verify_password(&test_input, &hash)
.await
.expect("Failed to verify long password");
assert!(is_valid, "Long password should be verified successfully");
}
#[tokio::test]
async fn test_hash_password_unicode() {
let test_input = format!("{}🚀my_secrët_passwörd🔑", uuid::Uuid::new_v4());
let hash = super::hash_password(&test_input)
.await
.expect("Failed to hash unicode password");
assert!(hash.starts_with("$2b$"));
let is_valid = super::verify_password(&test_input, &hash)
.await
.expect("Failed to verify unicode password");
assert!(is_valid, "Unicode password should be verified successfully");
}
#[tokio::test]
async fn test_verify_password_invalid_hash() {
let test_input = uuid::Uuid::new_v4().to_string();
let result = super::verify_password(&test_input, "invalid_hash_string").await;
assert!(result.is_err() || !result.unwrap());
let result2 = super::verify_password(&test_input, "$2b$04$").await;
assert!(result2.is_err() || !result2.unwrap());
}
}