Skip to main content

autumn_web/
auth.rs

1//! Authentication utilities for Autumn applications.
2//!
3//! Provides password hashing, an [`Auth<T>`] extractor for retrieving the
4//! authenticated user, and a [`RequireAuth`] middleware layer for protecting
5//! routes.
6//!
7//! ## Quick start
8//!
9//! ```rust,no_run
10//! use autumn_web::prelude::*;
11//! use autumn_web::auth::{Auth, hash_password, verify_password};
12//! use autumn_web::session::Session;
13//!
14//! #[derive(Clone)]
15//! struct User { id: i64, name: String }
16//!
17//! #[post("/register")]
18//! async fn register() -> AutumnResult<&'static str> {
19//!     let hashed = hash_password("secret123").await?;
20//!     // Save hashed password to database...
21//!     Ok("registered")
22//! }
23//!
24//! #[post("/login")]
25//! async fn login(session: Session) -> AutumnResult<&'static str> {
26//!     // Verify credentials...
27//!     let stored_hash = "$2b$12$..."; // from database
28//!     if verify_password("secret123", stored_hash).await? {
29//!         session.insert("user_id", "42").await;
30//!         Ok("logged in")
31//!     } else {
32//!         Err(AutumnError::bad_request_msg("invalid credentials"))
33//!     }
34//! }
35//! ```
36//!
37//! ## Password hashing
38//!
39//! Uses bcrypt with a default cost of 12. The [`hash_password`] and
40//! [`verify_password`] functions are simple wrappers that return
41//! [`AutumnResult`](crate::AutumnResult).
42//!
43//! ## The `Auth<T>` extractor
44//!
45//! [`Auth<T>`] extracts the authenticated user from request extensions.
46//! It is typically populated by a custom middleware which might call
47//! `request.extensions_mut().insert(user)` in a handler. Returns `401 Unauthorized` if no
48//! user is present.
49//!
50//! ## Route protection with `RequireAuth`
51//!
52//! The [`RequireAuth`] layer rejects unauthenticated requests with
53//! `401 Unauthorized` before they reach the handler. It checks for the
54//! presence of a session key (default: `"user_id"`).
55
56#[cfg(feature = "oauth2")]
57use std::collections::HashMap;
58use std::future::Future;
59use std::pin::Pin;
60use std::sync::Arc;
61use std::task::{Context, Poll};
62#[cfg(feature = "oauth2")]
63use std::time::Duration;
64
65use axum::extract::FromRequestParts;
66use axum::response::{IntoResponse, Response};
67use http::StatusCode;
68use http::request::Parts;
69#[cfg(feature = "oauth2")]
70use jsonwebtoken::jwk::JwkSet;
71#[cfg(feature = "oauth2")]
72use serde::Deserialize;
73#[cfg(feature = "oauth2")]
74use url::Url;
75
76// ── Password hashing ────────────────────────────────────────────
77
78/// Default bcrypt cost factor.
79const DEFAULT_BCRYPT_COST: u32 = 12;
80
81/// Hash a plaintext password using bcrypt.
82///
83/// Returns the hashed password string suitable for database storage.
84///
85/// # Errors
86///
87/// Returns an error if bcrypt hashing fails (extremely unlikely).
88///
89/// # Examples
90///
91/// ```rust
92/// use autumn_web::auth::hash_password;
93///
94/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
95/// let hashed = hash_password("my_secret").await.unwrap();
96/// assert!(hashed.starts_with("$2b$"));
97/// # });
98/// ```
99pub async fn hash_password(password: &str) -> crate::AutumnResult<String> {
100    let password = password.to_string();
101    tokio::task::spawn_blocking(move || {
102        bcrypt::hash(password, DEFAULT_BCRYPT_COST)
103            .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
104    })
105    .await
106    .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?
107}
108
109/// Verify a plaintext password against a bcrypt hash.
110///
111/// Returns `true` if the password matches the hash.
112///
113/// # Errors
114///
115/// Returns an error if bcrypt verification fails (e.g., invalid hash format).
116///
117/// # Examples
118///
119/// ```rust
120/// use autumn_web::auth::{hash_password, verify_password};
121///
122/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
123/// let hashed = hash_password("my_secret").await.unwrap();
124/// assert!(verify_password("my_secret", &hashed).await.unwrap());
125/// assert!(!verify_password("wrong_password", &hashed).await.unwrap());
126/// # });
127/// ```
128pub async fn verify_password(password: &str, hash: &str) -> crate::AutumnResult<bool> {
129    let password = password.to_string();
130
131    // Parse the hash format outside the blocking task.
132    // A valid bcrypt hash is typically 60 characters and starts with "$".
133    let is_valid_format = hash.len() == 60 && hash.starts_with('$');
134
135    let hash_to_verify = if is_valid_format {
136        hash.to_string()
137    } else {
138        // To prevent timing attacks, perform a dummy verification against a known hash.
139        "$2b$12$KIXe8K4j1sH6/xH.x9d71uJ5Jk8t6O4m6Q110g4H8y1r6J6O6O6O6".to_string()
140    };
141
142    let result = tokio::task::spawn_blocking(move || bcrypt::verify(&password, &hash_to_verify))
143        .await
144        .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?;
145
146    if !is_valid_format {
147        return Ok(false);
148    }
149
150    result.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
151}
152
153// ── Runtime check for #[secured] macro ──────────────────────────
154
155/// Runtime authentication and authorization check used by the
156/// `#[secured]` proc macro. **Not intended for direct use** -- use
157/// `#[secured]` instead.
158///
159/// Checks the session for the configured auth key (default: `"user_id"`).
160/// If `roles` is non-empty, also checks that the session's `"role"` value
161/// matches at least one of the given roles.
162///
163/// Returns `401 Unauthorized` if not authenticated, or `403 Forbidden`
164/// if the user lacks the required role.
165#[doc(hidden)]
166pub async fn __check_secured(
167    session: &crate::session::Session,
168    roles: &[&str],
169) -> crate::AutumnResult<()> {
170    __check_secured_with_key(session, "user_id", roles).await
171}
172
173/// Runtime check used by `#[secured]` when `AppState` is available.
174///
175/// Accepts the configured auth session key so generated login/signup/reset
176/// handlers and `#[secured]` resolve authentication through the same session
177/// entry.
178#[doc(hidden)]
179pub async fn __check_secured_with_key(
180    session: &crate::session::Session,
181    auth_session_key: &str,
182    roles: &[&str],
183) -> crate::AutumnResult<()> {
184    // Check authentication: session must contain the auth key
185    if session.get(auth_session_key).await.is_none() {
186        return Err(crate::AutumnError::unauthorized_msg(
187            "authentication required",
188        ));
189    }
190
191    // Check authorization: if roles are specified, the session's "role"
192    // must match at least one of them
193    if !roles.is_empty() {
194        let user_role = session.get("role").await.unwrap_or_default();
195        if !roles.iter().any(|&r| r == user_role) {
196            return Err(crate::AutumnError::forbidden_msg(
197                "insufficient permissions",
198            ));
199        }
200    }
201
202    Ok(())
203}
204
205// ── Auth<T> extractor ───────────────────────────────────────────
206
207/// Extractor that retrieves the authenticated user from request extensions.
208///
209/// Handlers can declare `Auth<MyUser>` as a parameter to access the
210/// current user. If no user is present in the request extensions,
211/// a `401 Unauthorized` response is returned automatically.
212///
213/// ## Populating the user
214///
215/// The user is typically inserted into request extensions by middleware.
216/// For example, a custom middleware can load the user from the session
217/// and call `request.extensions_mut().insert(user)`.
218///
219/// ## Examples
220///
221/// ```rust,no_run
222/// use autumn_web::prelude::*;
223/// use autumn_web::auth::Auth;
224///
225/// #[derive(Clone)]
226/// struct CurrentUser { id: i64, name: String }
227///
228/// #[get("/profile")]
229/// async fn profile(Auth(user): Auth<CurrentUser>) -> String {
230///     format!("Hello, {}!", user.name)
231/// }
232/// ```
233pub struct Auth<T>(pub T);
234
235impl<T, S> FromRequestParts<S> for Auth<T>
236where
237    T: Clone + Send + Sync + 'static,
238    S: Send + Sync,
239{
240    type Rejection = AuthRejection;
241
242    fn from_request_parts(
243        parts: &mut Parts,
244        _state: &S,
245    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
246        let user = parts.extensions.get::<T>().cloned();
247        async move { user.map_or_else(|| Err(AuthRejection), |user| Ok(Self(user))) }
248    }
249}
250
251/// Rejection type for [`Auth<T>`] when no authenticated user is present.
252#[derive(Debug)]
253pub struct AuthRejection;
254
255impl IntoResponse for AuthRejection {
256    fn into_response(self) -> Response {
257        crate::AutumnError::unauthorized_msg("authentication required").into_response()
258    }
259}
260
261impl std::fmt::Display for AuthRejection {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        f.write_str("authentication required")
264    }
265}
266
267// ── RequireAuth middleware ───────────────────────────────────────
268
269/// Tower [`tower::Layer`] that rejects unauthenticated requests with `401`.
270///
271/// Checks for a specific key in the session to determine if the request
272/// is authenticated. If the key is missing, the request is rejected before
273/// reaching the handler.
274///
275/// # Examples
276///
277/// ```rust,no_run
278/// use autumn_web::auth::RequireAuth;
279/// use autumn_web::reexports::axum::{Router, routing::get};
280/// use autumn_web::AppState;
281///
282/// // Protect all routes under /admin
283/// let admin_routes = Router::<AppState>::new()
284///     .route("/dashboard", get(|| async { "admin" }))
285///     .layer(RequireAuth::new("user_id"));
286/// ```
287#[derive(Clone)]
288pub struct RequireAuth {
289    session_key: Arc<str>,
290}
291
292impl RequireAuth {
293    /// Create a new `RequireAuth` layer that checks for the given session key.
294    pub fn new(session_key: impl Into<String>) -> Self {
295        Self {
296            session_key: Arc::from(session_key.into()),
297        }
298    }
299}
300
301impl<S> tower::Layer<S> for RequireAuth {
302    type Service = RequireAuthService<S>;
303
304    fn layer(&self, inner: S) -> Self::Service {
305        RequireAuthService {
306            inner,
307            session_key: Arc::clone(&self.session_key),
308        }
309    }
310}
311
312/// Tower [`tower::Service`] produced by [`RequireAuth`].
313#[derive(Clone)]
314pub struct RequireAuthService<S> {
315    inner: S,
316    session_key: Arc<str>,
317}
318
319impl<S, ResBody> tower::Service<axum::extract::Request> for RequireAuthService<S>
320where
321    S: tower::Service<axum::extract::Request, Response = Response<ResBody>>
322        + Clone
323        + Send
324        + 'static,
325    S::Future: Send + 'static,
326    S::Error: Send + 'static,
327    ResBody: From<String> + Default + Send + 'static,
328{
329    type Response = Response<ResBody>;
330    type Error = S::Error;
331    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
332
333    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334        self.inner.poll_ready(cx)
335    }
336
337    fn call(&mut self, req: axum::extract::Request) -> Self::Future {
338        let session_key = Arc::clone(&self.session_key);
339        let mut inner = self.inner.clone();
340        std::mem::swap(&mut self.inner, &mut inner);
341
342        Box::pin(async move {
343            // Check if session has the required key
344            let session = req.extensions().get::<crate::session::Session>().cloned();
345
346            let is_authenticated = if let Some(ref session) = session {
347                session.contains_key(&session_key).await
348            } else {
349                false
350            };
351
352            if is_authenticated {
353                inner.call(req).await
354            } else {
355                let body = crate::error::problem_details_json_string(
356                    StatusCode::UNAUTHORIZED,
357                    "authentication required",
358                    None,
359                    None,
360                    req.extensions()
361                        .get::<crate::middleware::RequestId>()
362                        .map(std::string::ToString::to_string),
363                    Some(req.uri().path().to_owned()),
364                    true,
365                );
366                let response = Response::builder()
367                    .status(StatusCode::UNAUTHORIZED)
368                    .header(http::header::CONTENT_TYPE, "application/problem+json")
369                    .body(ResBody::from(body))
370                    .unwrap_or_default();
371                Ok(response)
372            }
373        })
374    }
375}
376
377// ── Auth configuration ──────────────────────────────────────────
378
379/// Configuration for authentication.
380///
381/// # Defaults
382///
383/// | Field | Default |
384/// |-------|---------|
385/// | `bcrypt_cost` | `12` |
386/// | `session_key` | `"user_id"` |
387#[derive(Debug, Clone, serde::Deserialize)]
388pub struct AuthConfig {
389    /// Bcrypt cost factor for password hashing.
390    #[serde(default = "default_bcrypt_cost")]
391    pub bcrypt_cost: u32,
392
393    /// Session key used to identify authenticated users.
394    #[serde(default = "default_session_key")]
395    pub session_key: String,
396
397    /// OAuth2/OIDC provider configuration by provider key
398    /// (for example: `github`, `google`, `okta`).
399    #[cfg(feature = "oauth2")]
400    #[serde(default)]
401    pub oauth2: OAuth2Config,
402}
403
404const fn default_bcrypt_cost() -> u32 {
405    DEFAULT_BCRYPT_COST
406}
407
408fn default_session_key() -> String {
409    "user_id".to_owned()
410}
411
412#[cfg(feature = "oauth2")]
413const fn default_provider_scope() -> String {
414    String::new()
415}
416
417#[cfg(feature = "oauth2")]
418const OAUTH_HTTP_TIMEOUT_SECS: u64 = 15;
419
420#[cfg(feature = "oauth2")]
421/// `OAuth2` provider map loaded from `autumn.toml`.
422///
423/// Example:
424///
425/// ```toml
426/// [auth.oauth2.github]
427/// client_id = "..."
428/// client_secret = "..."
429/// authorize_url = "https://github.com/login/oauth/authorize"
430/// token_url = "https://github.com/login/oauth/access_token"
431/// userinfo_url = "https://api.github.com/user"
432/// redirect_uri = "http://localhost:3000/auth/github/callback"
433/// scope = "read:user user:email"
434/// ```
435#[derive(Debug, Clone, Default, serde::Deserialize)]
436pub struct OAuth2Config {
437    /// Dynamic provider table keyed by provider name.
438    #[serde(flatten)]
439    pub providers: HashMap<String, OAuth2ProviderConfig>,
440}
441
442#[cfg(feature = "oauth2")]
443/// A single OAuth2/OIDC provider configuration entry.
444#[derive(Debug, Clone, serde::Deserialize)]
445pub struct OAuth2ProviderConfig {
446    /// The client ID provided by the `OAuth2` identity provider.
447    pub client_id: String,
448    /// The client secret provided by the `OAuth2` identity provider.
449    pub client_secret: String,
450    /// The authorization endpoint URL where users are redirected to authenticate.
451    pub authorize_url: String,
452    /// The token endpoint URL used to exchange an authorization code for tokens.
453    pub token_url: String,
454    /// The optional userinfo endpoint URL used to fetch profile details.
455    #[serde(default)]
456    pub userinfo_url: Option<String>,
457    /// The local redirect URI registered with the identity provider (e.g., `http://localhost/auth/callback`).
458    pub redirect_uri: String,
459    /// The requested scope string (e.g., `openid profile email`).
460    #[serde(default = "default_provider_scope")]
461    pub scope: String,
462    /// Expected OIDC issuer (`iss`) used for ID token validation.
463    #[serde(default)]
464    pub issuer: Option<String>,
465    /// JWKS endpoint URL used to verify ID token signatures.
466    #[serde(default)]
467    pub jwks_url: Option<String>,
468}
469
470#[cfg(feature = "oauth2")]
471/// Query extractor payload for `OAuth2` callback handlers.
472#[derive(Debug, Clone, Deserialize)]
473pub struct OAuth2Callback {
474    /// The authorization code returned by the provider.
475    pub code: String,
476    /// The anti-CSRF state token passed during the authorization request.
477    pub state: String,
478}
479
480#[cfg(feature = "oauth2")]
481/// Identity information extracted from an OIDC ID token or userinfo endpoint.
482#[derive(Debug, Clone)]
483pub struct OidcIdentity {
484    /// The primary subject identifier (`sub` claim) representing the user.
485    pub subject: String,
486    /// The user's email address, if available in the claims.
487    pub email: Option<String>,
488    /// The user's full name, if available in the claims.
489    pub name: Option<String>,
490    /// The user's preferred username or nickname, if available in the claims.
491    pub preferred_username: Option<String>,
492    /// The raw JSON claims extracted from the token or userinfo response.
493    pub raw_claims: serde_json::Value,
494}
495
496#[cfg(feature = "oauth2")]
497#[derive(Debug, Deserialize)]
498struct OAuth2TokenResponse {
499    access_token: String,
500    #[allow(dead_code)]
501    token_type: Option<String>,
502    id_token: Option<String>,
503}
504
505#[cfg(feature = "oauth2")]
506/// Build an `OAuth2` authorization URL and persist anti-CSRF state + nonce in session.
507///
508/// # Errors
509///
510/// Returns an error if `authorize_url` is not a valid URL.
511pub async fn oauth2_authorize_url(
512    session: &crate::session::Session,
513    provider_name: &str,
514    provider: &OAuth2ProviderConfig,
515) -> crate::AutumnResult<String> {
516    let state = uuid::Uuid::new_v4().to_string();
517    let nonce = uuid::Uuid::new_v4().to_string();
518    session
519        .insert(format!("oauth2:{provider_name}:state"), state.clone())
520        .await;
521    session
522        .insert(format!("oauth2:{provider_name}:nonce"), nonce.clone())
523        .await;
524
525    let mut url = Url::parse(&provider.authorize_url)
526        .map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid authorize_url: {e}")))?;
527    {
528        let mut q = url.query_pairs_mut();
529        q.append_pair("response_type", "code");
530        q.append_pair("client_id", &provider.client_id);
531        q.append_pair("redirect_uri", &provider.redirect_uri);
532        if !provider.scope.trim().is_empty() {
533            q.append_pair("scope", &provider.scope);
534        }
535        q.append_pair("state", &state);
536        q.append_pair("nonce", &nonce);
537    }
538    Ok(url.into())
539}
540
541#[cfg(feature = "oauth2")]
542/// Exchange callback code for tokens, validate state/nonce, and return OIDC identity.
543///
544/// On success this method rotates the session ID and writes:
545/// - `session_key` (OIDC `sub`)
546/// - `auth_provider` (provider key, like `github`)
547///
548/// # Errors
549///
550/// Returns an error when callback state/nonce validation fails, token exchange
551/// fails, ID token/userinfo payloads are invalid, or identity extraction fails.
552pub async fn oauth2_finish_login(
553    session: &crate::session::Session,
554    session_key: &str,
555    provider_name: &str,
556    provider: &OAuth2ProviderConfig,
557    callback: &OAuth2Callback,
558) -> crate::AutumnResult<OidcIdentity> {
559    validate_callback_state(session, provider_name, callback).await?;
560    let token = exchange_oauth2_token(provider, callback).await?;
561    let (claims, source) = load_identity_claims(provider, &token).await?;
562    validate_oidc_nonce(session, provider_name, &claims, source).await?;
563    let subject = extract_subject(&claims, source)?;
564    finalize_oauth2_session(session, session_key, provider_name, subject, claims).await
565}
566
567#[cfg(feature = "oauth2")]
568async fn validate_callback_state(
569    session: &crate::session::Session,
570    provider_name: &str,
571    callback: &OAuth2Callback,
572) -> crate::AutumnResult<()> {
573    let state_key = format!("oauth2:{provider_name}:state");
574    // Read without removing so a stray/attacker-controlled callback with a
575    // wrong state value cannot consume the real state and break the pending
576    // legitimate redirect.
577    let expected_state = session.get(&state_key).await.ok_or_else(|| {
578        crate::AutumnError::unauthorized_msg("oauth2 state missing; restart login")
579    })?;
580    if subtle::ConstantTimeEq::ct_eq(expected_state.as_bytes(), callback.state.as_bytes())
581        .unwrap_u8()
582        != 1
583    {
584        return Err(crate::AutumnError::unauthorized_msg(
585            "oauth2 state mismatch",
586        ));
587    }
588    // Remove the state only after a successful constant-time match.
589    session.remove(&state_key).await;
590    Ok(())
591}
592
593#[cfg(feature = "oauth2")]
594async fn exchange_oauth2_token(
595    provider: &OAuth2ProviderConfig,
596    callback: &OAuth2Callback,
597) -> crate::AutumnResult<OAuth2TokenResponse> {
598    let token_response = oauth_http_client()?
599        .post(&provider.token_url)
600        .header(reqwest::header::ACCEPT, "application/json")
601        .form(&[
602            ("grant_type", "authorization_code"),
603            ("code", callback.code.as_str()),
604            ("redirect_uri", provider.redirect_uri.as_str()),
605            ("client_id", provider.client_id.as_str()),
606            ("client_secret", provider.client_secret.as_str()),
607        ])
608        .send()
609        .await
610        .map_err(|e| {
611            crate::AutumnError::service_unavailable_msg(format!("token request failed: {e}"))
612        })?
613        .error_for_status()
614        .map_err(|e| crate::AutumnError::unauthorized_msg(format!("token exchange failed: {e}")))?;
615
616    let token_content_type = token_response
617        .headers()
618        .get(reqwest::header::CONTENT_TYPE)
619        .and_then(|v| v.to_str().ok())
620        .map(str::to_owned);
621    let token_body = token_response.text().await.map_err(|e| {
622        crate::AutumnError::bad_request_msg(format!("invalid token response body: {e}"))
623    })?;
624    parse_oauth2_token_response(token_content_type.as_deref(), &token_body)
625}
626
627#[cfg(feature = "oauth2")]
628async fn load_identity_claims(
629    provider: &OAuth2ProviderConfig,
630    token: &OAuth2TokenResponse,
631) -> crate::AutumnResult<(serde_json::Value, IdentitySource)> {
632    if let Some(id_token) = token.id_token.as_deref() {
633        return Ok((
634            validate_and_decode_id_token(id_token, provider).await?,
635            IdentitySource::IdToken,
636        ));
637    }
638    if let Some(userinfo_url) = &provider.userinfo_url {
639        let claims = oauth_http_client()?
640            .get(userinfo_url)
641            .header(
642                reqwest::header::USER_AGENT,
643                concat!("autumn-web/", env!("CARGO_PKG_VERSION")),
644            )
645            .bearer_auth(&token.access_token)
646            .send()
647            .await
648            .map_err(|e| {
649                crate::AutumnError::service_unavailable_msg(format!("userinfo request failed: {e}"))
650            })?
651            .error_for_status()
652            .map_err(|e| crate::AutumnError::unauthorized_msg(format!("userinfo failed: {e}")))?
653            .json()
654            .await
655            .map_err(|e| {
656                crate::AutumnError::bad_request_msg(format!("invalid userinfo payload: {e}"))
657            })?;
658        return Ok((claims, IdentitySource::UserInfo));
659    }
660    Err(crate::AutumnError::bad_request_msg(
661        "provider must return id_token or configure userinfo_url",
662    ))
663}
664
665#[cfg(feature = "oauth2")]
666async fn validate_oidc_nonce(
667    session: &crate::session::Session,
668    provider_name: &str,
669    claims: &serde_json::Value,
670    source: IdentitySource,
671) -> crate::AutumnResult<()> {
672    let nonce_key = format!("oauth2:{provider_name}:nonce");
673    let stored_nonce = session.remove(&nonce_key).await;
674    if source == IdentitySource::IdToken {
675        // The nonce MUST be present in the session for ID-token logins.
676        // A missing nonce (e.g. session was partially cleared) must be
677        // treated as an error to prevent replay/mix-up attacks.
678        let expected_nonce = stored_nonce.ok_or_else(|| {
679            crate::AutumnError::unauthorized_msg("oauth2 nonce missing from session")
680        })?;
681        let actual_nonce = claims
682            .get("nonce")
683            .and_then(serde_json::Value::as_str)
684            .ok_or_else(|| crate::AutumnError::unauthorized_msg("missing oidc nonce claim"))?;
685        if subtle::ConstantTimeEq::ct_eq(expected_nonce.as_bytes(), actual_nonce.as_bytes())
686            .unwrap_u8()
687            != 1
688        {
689            return Err(crate::AutumnError::unauthorized_msg("oidc nonce mismatch"));
690        }
691    }
692    Ok(())
693}
694
695#[cfg(feature = "oauth2")]
696async fn finalize_oauth2_session(
697    session: &crate::session::Session,
698    session_key: &str,
699    provider_name: &str,
700    subject: String,
701    claims: serde_json::Value,
702) -> crate::AutumnResult<OidcIdentity> {
703    session.insert(session_key, subject.clone()).await;
704    session.insert("auth_provider", provider_name).await;
705    session.rotate_id().await;
706    Ok(OidcIdentity {
707        subject,
708        email: claims
709            .get("email")
710            .and_then(serde_json::Value::as_str)
711            .map(str::to_owned),
712        name: claims
713            .get("name")
714            .and_then(serde_json::Value::as_str)
715            .map(str::to_owned),
716        preferred_username: claims
717            .get("preferred_username")
718            .and_then(serde_json::Value::as_str)
719            .map(str::to_owned),
720        raw_claims: claims,
721    })
722}
723
724#[cfg(feature = "oauth2")]
725fn parse_oauth2_token_response(
726    content_type: Option<&str>,
727    body: &str,
728) -> crate::AutumnResult<OAuth2TokenResponse> {
729    let looks_like_json = content_type.is_some_and(|v| v.contains("application/json"))
730        || body.trim_start().starts_with('{');
731    if looks_like_json {
732        return serde_json::from_str(body).map_err(|e| {
733            crate::AutumnError::bad_request_msg(format!("invalid json token response: {e}"))
734        });
735    }
736
737    let mut access_token = None;
738    let mut token_type = None;
739    let mut id_token = None;
740
741    for (k, v) in url::form_urlencoded::parse(body.as_bytes()) {
742        match k.as_ref() {
743            "access_token" => access_token = Some(v.into_owned()),
744            "token_type" => token_type = Some(v.into_owned()),
745            "id_token" => id_token = Some(v.into_owned()),
746            _ => {}
747        }
748    }
749
750    let access_token = access_token.ok_or_else(|| {
751        crate::AutumnError::bad_request_msg("token response missing access_token")
752    })?;
753
754    Ok(OAuth2TokenResponse {
755        access_token,
756        token_type,
757        id_token,
758    })
759}
760
761#[cfg(feature = "oauth2")]
762#[derive(Debug, Clone, Copy, PartialEq, Eq)]
763enum IdentitySource {
764    IdToken,
765    UserInfo,
766}
767
768#[cfg(feature = "oauth2")]
769fn extract_subject(
770    claims: &serde_json::Value,
771    source: IdentitySource,
772) -> crate::AutumnResult<String> {
773    if let Some(sub) = claims.get("sub").and_then(serde_json::Value::as_str) {
774        return Ok(sub.to_owned());
775    }
776
777    if source == IdentitySource::UserInfo {
778        if let Some(id) = claims.get("id").and_then(serde_json::Value::as_i64) {
779            return Ok(id.to_string());
780        }
781        if let Some(id) = claims.get("id").and_then(serde_json::Value::as_str) {
782            return Ok(id.to_owned());
783        }
784        return Err(crate::AutumnError::bad_request_msg(
785            "missing identity claim: expected sub or id from userinfo",
786        ));
787    }
788
789    Err(crate::AutumnError::bad_request_msg("missing sub claim"))
790}
791
792#[cfg(feature = "oauth2")]
793async fn validate_and_decode_id_token(
794    token: &str,
795    provider: &OAuth2ProviderConfig,
796) -> crate::AutumnResult<serde_json::Value> {
797    let issuer = provider
798        .issuer
799        .as_deref()
800        .ok_or_else(|| crate::AutumnError::bad_request_msg("provider.issuer required for oidc"))?;
801    let jwks_url = provider.jwks_url.as_deref().ok_or_else(|| {
802        crate::AutumnError::bad_request_msg("provider.jwks_url required for oidc")
803    })?;
804
805    let header = jsonwebtoken::decode_header(token).map_err(|e| {
806        crate::AutumnError::unauthorized_msg(format!("invalid id_token header: {e}"))
807    })?;
808    let kid = header
809        .kid
810        .as_deref()
811        .ok_or_else(|| crate::AutumnError::unauthorized_msg("id_token header missing kid"))?;
812    let alg = header.alg;
813
814    let jwks: JwkSet = oauth_http_client()?
815        .get(jwks_url)
816        .send()
817        .await
818        .map_err(|e| {
819            crate::AutumnError::service_unavailable_msg(format!("jwks request failed: {e}"))
820        })?
821        .error_for_status()
822        .map_err(|e| crate::AutumnError::unauthorized_msg(format!("jwks fetch failed: {e}")))?
823        .json()
824        .await
825        .map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid jwks response: {e}")))?;
826
827    let jwk = jwks
828        .keys
829        .iter()
830        .find(|k| k.common.key_id.as_deref() == Some(kid))
831        .ok_or_else(|| crate::AutumnError::unauthorized_msg("no jwk matched id_token kid"))?;
832    let decoding_key = jsonwebtoken::DecodingKey::from_jwk(jwk)
833        .map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid jwk key: {e}")))?;
834
835    let mut validation = jsonwebtoken::Validation::new(alg);
836    validation.set_issuer(&[issuer]);
837    validation.set_audience(std::slice::from_ref(&provider.client_id));
838    validation.required_spec_claims = ["exp", "iss", "aud", "sub"]
839        .into_iter()
840        .map(str::to_owned)
841        .collect();
842    validation.validate_exp = true;
843    validation.validate_nbf = true;
844
845    let claims = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
846        .map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid id_token: {e}")))?;
847    Ok(claims.claims)
848}
849
850#[cfg(feature = "oauth2")]
851fn oauth_http_client() -> crate::AutumnResult<reqwest::Client> {
852    reqwest::Client::builder()
853        .timeout(Duration::from_secs(OAUTH_HTTP_TIMEOUT_SECS))
854        .build()
855        .map_err(|e| {
856            crate::AutumnError::service_unavailable_msg(format!(
857                "failed to build oauth http client: {e}"
858            ))
859        })
860}
861
862impl Default for AuthConfig {
863    fn default() -> Self {
864        Self {
865            bcrypt_cost: default_bcrypt_cost(),
866            session_key: default_session_key(),
867            #[cfg(feature = "oauth2")]
868            oauth2: OAuth2Config::default(),
869        }
870    }
871}
872
873// ─────────────────────────────────────────────────────────────────────────────
874// API Token Authentication
875// ─────────────────────────────────────────────────────────────────────────────
876
877/// Backend trait for storing and verifying API bearer tokens.
878///
879/// Implementations persist only the token hash — the raw token is never stored
880/// at rest. The default backend for tests is [`InMemoryApiTokenStore`].
881/// Production deployments should use a database-backed implementation.
882///
883/// All methods take `&self`; use interior mutability where write access is
884/// needed.
885pub trait ApiTokenStore: Send + Sync + 'static {
886    /// Issue a new token for `principal_id` and return the raw value.
887    ///
888    /// Only the hash is persisted. The raw token must be delivered to the
889    /// caller immediately — it cannot be recovered later.
890    fn issue<'a>(
891        &'a self,
892        principal_id: &'a str,
893    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>>;
894
895    /// Verify `raw_token` and return its principal ID, or `None` for unknown
896    /// or revoked tokens.
897    fn verify<'a>(
898        &'a self,
899        raw_token: &'a str,
900    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>>;
901
902    /// Revoke a token so that subsequent requests are rejected.
903    fn revoke<'a>(
904        &'a self,
905        raw_token: &'a str,
906    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>>;
907}
908
909/// Compute the SHA-256 hash of a raw API token as a lowercase 64-char hex string.
910///
911/// The hash is deterministic: the same input always produces the same output.
912/// Only the hash is ever stored; the raw token is never persisted.
913///
914/// # Examples
915///
916/// ```rust
917/// use autumn_web::auth::hash_api_token;
918///
919/// let h = hash_api_token("my_token");
920/// assert_eq!(h.len(), 64);
921/// assert_eq!(h, hash_api_token("my_token")); // deterministic
922/// ```
923#[must_use]
924pub fn hash_api_token(raw: &str) -> String {
925    use sha2::Digest as _;
926    sha2::Sha256::digest(raw.as_bytes())
927        .iter()
928        .fold(String::with_capacity(64), |mut s, b| {
929            use std::fmt::Write as _;
930            let _ = write!(s, "{b:02x}");
931            s
932        })
933}
934
935/// Generate a 256-bit random raw API token as a lowercase hex string.
936///
937/// Uses two UUID v4 values (128 bits each) concatenated for a 64-char result.
938fn generate_raw_token() -> String {
939    let u1 = uuid::Uuid::new_v4();
940    let u2 = uuid::Uuid::new_v4();
941    format!("{}{}", u1.simple(), u2.simple())
942}
943
944/// In-memory API token store for development and testing.
945///
946/// Tokens are stored as SHA-256 hashes mapped to principal IDs inside a
947/// `RwLock`-protected `HashMap`. **Not suitable for production** — state is
948/// lost on restart and is not shared across processes.
949///
950/// # Examples
951///
952/// ```rust
953/// use std::sync::Arc;
954/// use autumn_web::auth::{ApiTokenStore, InMemoryApiTokenStore};
955///
956/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
957/// let store = Arc::new(InMemoryApiTokenStore::default());
958/// let token = store.issue("user:1").await.unwrap();
959/// assert_eq!(store.verify(&token).await.unwrap(), Some("user:1".to_owned()));
960/// store.revoke(&token).await.unwrap();
961/// assert_eq!(store.verify(&token).await.unwrap(), None);
962/// # });
963/// ```
964#[derive(Clone)]
965pub struct InMemoryApiTokenStore {
966    // hash → principal_id
967    tokens: Arc<std::sync::RwLock<std::collections::HashMap<String, String>>>,
968}
969
970impl Default for InMemoryApiTokenStore {
971    fn default() -> Self {
972        Self {
973            tokens: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
974        }
975    }
976}
977
978impl ApiTokenStore for InMemoryApiTokenStore {
979    fn issue<'a>(
980        &'a self,
981        principal_id: &'a str,
982    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>> {
983        Box::pin(async move {
984            let raw = generate_raw_token();
985            let hash = hash_api_token(&raw);
986            self.tokens
987                .write()
988                .expect("api token store lock poisoned")
989                .insert(hash, principal_id.to_owned());
990            Ok(raw)
991        })
992    }
993
994    fn verify<'a>(
995        &'a self,
996        raw_token: &'a str,
997    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>> {
998        Box::pin(async move {
999            let hash = hash_api_token(raw_token);
1000            Ok(self
1001                .tokens
1002                .read()
1003                .expect("api token store lock poisoned")
1004                .get(&hash)
1005                .cloned())
1006        })
1007    }
1008
1009    fn revoke<'a>(
1010        &'a self,
1011        raw_token: &'a str,
1012    ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>> {
1013        Box::pin(async move {
1014            let hash = hash_api_token(raw_token);
1015            self.tokens
1016                .write()
1017                .expect("api token store lock poisoned")
1018                .remove(&hash);
1019            Ok(())
1020        })
1021    }
1022}
1023
1024/// Issue a new API token for `principal_id` using `store`.
1025///
1026/// Returns the raw token string that must be transmitted to the client once.
1027///
1028/// # Errors
1029///
1030/// Propagates any error from the underlying store.
1031pub async fn issue_api_token(
1032    store: &dyn ApiTokenStore,
1033    principal_id: &str,
1034) -> crate::AutumnResult<String> {
1035    store.issue(principal_id).await
1036}
1037
1038/// Revoke a previously issued API token using `store`.
1039///
1040/// After revocation [`RequireApiToken`] rejects requests presenting this token.
1041///
1042/// # Errors
1043///
1044/// Propagates any error from the underlying store.
1045pub async fn revoke_api_token(
1046    store: &dyn ApiTokenStore,
1047    raw_token: &str,
1048) -> crate::AutumnResult<()> {
1049    store.revoke(raw_token).await
1050}
1051
1052/// Private marker inserted into request extensions by [`RequireApiToken`] after
1053/// a bearer token is successfully verified.
1054#[derive(Clone)]
1055struct ApiTokenPrincipal(String);
1056
1057/// Extractor that yields the verified principal ID from a bearer-protected route.
1058///
1059/// The principal ID is inserted by [`RequireApiToken`] after verifying the
1060/// `Authorization: Bearer <token>` header. Without `RequireApiToken` on the
1061/// route this extractor returns `401 Unauthorized`.
1062///
1063/// # Examples
1064///
1065/// ```rust,no_run
1066/// use autumn_web::prelude::*;
1067/// use autumn_web::auth::ApiToken;
1068///
1069/// #[get("/whoami")]
1070/// async fn whoami(ApiToken(principal): ApiToken) -> String {
1071///     format!("authenticated as {principal}")
1072/// }
1073/// ```
1074#[derive(Debug, Clone)]
1075pub struct ApiToken(pub String);
1076
1077impl<S> FromRequestParts<S> for ApiToken
1078where
1079    S: Send + Sync,
1080{
1081    type Rejection = AuthRejection;
1082
1083    fn from_request_parts(
1084        parts: &mut Parts,
1085        _state: &S,
1086    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
1087        let principal = parts.extensions.get::<ApiTokenPrincipal>().cloned();
1088        async move { principal.map(|p| Self(p.0)).ok_or(AuthRejection) }
1089    }
1090}
1091
1092/// Tower [`Layer`](tower::Layer) that validates `Authorization: Bearer <token>`
1093/// on every inbound request.
1094///
1095/// On success the verified principal ID is inserted into request extensions
1096/// so handlers can retrieve it via the [`ApiToken`] extractor.
1097/// Requests with a missing, malformed, or revoked token are rejected with
1098/// `401 Unauthorized` using the same Problem Details contract as
1099/// [`AuthRejection`].
1100///
1101/// Composes with [`RequireAuth`] and session middleware without conflict.
1102///
1103/// # Examples
1104///
1105/// ```rust,no_run
1106/// use std::sync::Arc;
1107/// use autumn_web::auth::{InMemoryApiTokenStore, RequireApiToken};
1108/// use autumn_web::reexports::axum::{Router, routing::get};
1109/// use autumn_web::AppState;
1110///
1111/// let store = Arc::new(InMemoryApiTokenStore::default());
1112/// let api_routes = Router::<AppState>::new()
1113///     .route("/data", get(|| async { "ok" }))
1114///     .layer(RequireApiToken::new(store));
1115/// ```
1116#[derive(Clone)]
1117pub struct RequireApiToken {
1118    store: Arc<dyn ApiTokenStore>,
1119}
1120
1121impl RequireApiToken {
1122    /// Create a new [`RequireApiToken`] layer backed by `store`.
1123    ///
1124    /// Accepts any `Arc<S>` where `S: ApiTokenStore`, so callers do not need
1125    /// to explicitly cast to `Arc<dyn ApiTokenStore>`.
1126    #[must_use]
1127    pub fn new<S: ApiTokenStore + 'static>(store: Arc<S>) -> Self {
1128        Self { store }
1129    }
1130}
1131
1132impl<S> tower::Layer<S> for RequireApiToken {
1133    type Service = RequireApiTokenService<S>;
1134
1135    fn layer(&self, inner: S) -> Self::Service {
1136        RequireApiTokenService {
1137            inner,
1138            store: Arc::clone(&self.store),
1139        }
1140    }
1141}
1142
1143/// Tower service produced by [`RequireApiToken`].
1144#[derive(Clone)]
1145pub struct RequireApiTokenService<S> {
1146    inner: S,
1147    store: Arc<dyn ApiTokenStore>,
1148}
1149
1150impl<S, ResBody> tower::Service<axum::extract::Request> for RequireApiTokenService<S>
1151where
1152    S: tower::Service<axum::extract::Request, Response = Response<ResBody>>
1153        + Clone
1154        + Send
1155        + 'static,
1156    S::Future: Send + 'static,
1157    S::Error: Send + 'static,
1158    ResBody: From<String> + Default + Send + 'static,
1159{
1160    type Response = Response<ResBody>;
1161    type Error = S::Error;
1162    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1163
1164    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1165        self.inner.poll_ready(cx)
1166    }
1167
1168    fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
1169        let store = Arc::clone(&self.store);
1170        let mut inner = self.inner.clone();
1171        std::mem::swap(&mut self.inner, &mut inner);
1172
1173        Box::pin(async move {
1174            // Parse "Authorization: Bearer <token>"
1175            let raw_token = req
1176                .headers()
1177                .get(http::header::AUTHORIZATION)
1178                .and_then(|v| v.to_str().ok())
1179                .and_then(parse_bearer_token)
1180                .map(str::to_owned);
1181
1182            let Some(raw_token) = raw_token else {
1183                let (request_id, instance) = api_token_problem_context(&req);
1184                return Ok(api_token_unauthorized_response(request_id, instance));
1185            };
1186
1187            match store.verify(&raw_token).await {
1188                Ok(Some(principal_id)) => {
1189                    req.extensions_mut().insert(ApiTokenPrincipal(principal_id));
1190                    inner.call(req).await
1191                }
1192                Ok(None) => {
1193                    let (request_id, instance) = api_token_problem_context(&req);
1194                    Ok(api_token_unauthorized_response(request_id, instance))
1195                }
1196                Err(err) => {
1197                    let (request_id, instance) = api_token_problem_context(&req);
1198                    Ok(api_token_error_response(&err, request_id, instance))
1199                }
1200            }
1201        })
1202    }
1203}
1204
1205fn parse_bearer_token(header: &str) -> Option<&str> {
1206    let (scheme, token) = header.split_once(' ')?;
1207    scheme.eq_ignore_ascii_case("Bearer").then_some(token)
1208}
1209
1210/// Build a `401 Unauthorized` response using the standard Problem Details body.
1211fn api_token_unauthorized_response<ResBody: From<String> + Default>(
1212    request_id: Option<String>,
1213    instance: Option<String>,
1214) -> Response<ResBody> {
1215    let body = crate::error::problem_details_json_string(
1216        StatusCode::UNAUTHORIZED,
1217        "authentication required",
1218        None,
1219        None,
1220        request_id,
1221        instance,
1222        true,
1223    );
1224    Response::builder()
1225        .status(StatusCode::UNAUTHORIZED)
1226        .header(http::header::CONTENT_TYPE, "application/problem+json")
1227        .body(ResBody::from(body))
1228        .unwrap_or_default()
1229}
1230
1231/// Build a Problem Details response from the API token store error.
1232fn api_token_error_response<ResBody: From<String> + Default>(
1233    err: &crate::AutumnError,
1234    request_id: Option<String>,
1235    instance: Option<String>,
1236) -> Response<ResBody> {
1237    let status = err.status();
1238    let message = err.to_string();
1239    let body = crate::error::problem_details_json_string(
1240        status,
1241        message.clone(),
1242        None,
1243        None,
1244        request_id,
1245        instance,
1246        true,
1247    );
1248    let mut response = Response::builder()
1249        .status(status)
1250        .header(http::header::CONTENT_TYPE, "application/problem+json")
1251        .body(ResBody::from(body))
1252        .unwrap_or_default();
1253    response
1254        .extensions_mut()
1255        .insert(crate::middleware::AutumnErrorInfo {
1256            status,
1257            message,
1258            details: None,
1259            problem_type: None,
1260        });
1261    response
1262}
1263
1264fn api_token_problem_context(req: &axum::extract::Request) -> (Option<String>, Option<String>) {
1265    (
1266        req.extensions()
1267            .get::<crate::middleware::RequestId>()
1268            .map(std::string::ToString::to_string),
1269        Some(req.uri().path().to_owned()),
1270    )
1271}
1272
1273// ─────────────────────────────────────────────────────────────────────────────
1274// Diesel-backed API Token Store
1275// ─────────────────────────────────────────────────────────────────────────────
1276
1277/// Embedded Diesel migrations for the `api_tokens` table.
1278///
1279/// Include this in your application's `.migrations()` call so that dev/test
1280/// startup migration checks can create and validate the `api_tokens` table
1281/// alongside your own migrations. In production, `autumn migrate` applies the
1282/// matching framework migration before token commands or `DbApiTokenStore`
1283/// need the table:
1284///
1285/// ```rust,ignore
1286/// use autumn_web::auth::API_TOKEN_MIGRATIONS;
1287///
1288/// #[autumn_web::main]
1289/// async fn main() {
1290///     autumn_web::app()
1291///         .migrations(API_TOKEN_MIGRATIONS)
1292///         .run()
1293///         .await;
1294/// }
1295/// ```
1296#[cfg(feature = "db")]
1297pub const API_TOKEN_MIGRATIONS: diesel_migrations::EmbeddedMigrations =
1298    diesel_migrations::embed_migrations!("migrations");
1299
1300#[cfg(feature = "db")]
1301mod db_store {
1302    use std::future::Future;
1303    use std::pin::Pin;
1304
1305    use diesel::OptionalExtension as _;
1306    use diesel::prelude::*;
1307    use diesel_async::AsyncPgConnection;
1308    use diesel_async::RunQueryDsl;
1309    use diesel_async::pooled_connection::deadpool::Pool;
1310
1311    use super::{ApiTokenStore, generate_raw_token, hash_api_token};
1312    use crate::error::AutumnError;
1313
1314    diesel::table! {
1315        api_tokens (id) {
1316            id -> Int8,
1317            token_hash -> Text,
1318            principal_id -> Text,
1319            created_at -> Timestamp,
1320            revoked_at -> Nullable<Timestamp>,
1321        }
1322    }
1323
1324    #[derive(Insertable)]
1325    #[diesel(table_name = api_tokens)]
1326    struct NewApiToken<'a> {
1327        token_hash: &'a str,
1328        principal_id: &'a str,
1329    }
1330
1331    /// Postgres-backed [`ApiTokenStore`].
1332    ///
1333    /// Tokens are hashed at rest (SHA-256) and never stored in plaintext.
1334    /// Suitable for production deployments where token state must survive
1335    /// process restarts and be shared across instances.
1336    ///
1337    /// # Setup
1338    ///
1339    /// Pass [`super::API_TOKEN_MIGRATIONS`] to your app builder so dev/test
1340    /// startup migration checks can create and validate the `api_tokens`
1341    /// table automatically. In production, run `autumn migrate`; the CLI
1342    /// applies the matching framework migration explicitly.
1343    ///
1344    /// ```rust,ignore
1345    /// use autumn_web::auth::{API_TOKEN_MIGRATIONS, DbApiTokenStore};
1346    /// use autumn_web::db::Pool;
1347    ///
1348    /// let store = DbApiTokenStore::new(pool.clone());
1349    /// autumn_web::app()
1350    ///     .migrations(API_TOKEN_MIGRATIONS)
1351    ///     .run()
1352    ///     .await;
1353    /// ```
1354    #[derive(Clone)]
1355    pub struct DbApiTokenStore {
1356        pool: Pool<AsyncPgConnection>,
1357    }
1358
1359    impl DbApiTokenStore {
1360        /// Create a [`DbApiTokenStore`] backed by `pool`.
1361        #[must_use]
1362        pub const fn new(pool: Pool<AsyncPgConnection>) -> Self {
1363            Self { pool }
1364        }
1365    }
1366
1367    impl ApiTokenStore for DbApiTokenStore {
1368        fn issue<'a>(
1369            &'a self,
1370            principal_id: &'a str,
1371        ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>> {
1372            Box::pin(async move {
1373                let raw = generate_raw_token();
1374                let hash = hash_api_token(&raw);
1375                let mut conn = self
1376                    .pool
1377                    .get()
1378                    .await
1379                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1380                diesel::insert_into(api_tokens::table)
1381                    .values(NewApiToken {
1382                        token_hash: &hash,
1383                        principal_id,
1384                    })
1385                    .execute(&mut conn)
1386                    .await
1387                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1388                Ok(raw)
1389            })
1390        }
1391
1392        fn verify<'a>(
1393            &'a self,
1394            raw_token: &'a str,
1395        ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>>
1396        {
1397            Box::pin(async move {
1398                let hash = hash_api_token(raw_token);
1399                let mut conn = self
1400                    .pool
1401                    .get()
1402                    .await
1403                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1404                let principal: Option<String> = api_tokens::table
1405                    .filter(api_tokens::token_hash.eq(&hash))
1406                    .filter(api_tokens::revoked_at.is_null())
1407                    .select(api_tokens::principal_id)
1408                    .first(&mut conn)
1409                    .await
1410                    .optional()
1411                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1412                Ok(principal)
1413            })
1414        }
1415
1416        fn revoke<'a>(
1417            &'a self,
1418            raw_token: &'a str,
1419        ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>> {
1420            Box::pin(async move {
1421                let hash = hash_api_token(raw_token);
1422                let mut conn = self
1423                    .pool
1424                    .get()
1425                    .await
1426                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1427                diesel::update(api_tokens::table)
1428                    .filter(api_tokens::token_hash.eq(&hash))
1429                    .set(api_tokens::revoked_at.eq(diesel::dsl::now.nullable()))
1430                    .execute(&mut conn)
1431                    .await
1432                    .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1433                Ok(())
1434            })
1435        }
1436    }
1437}
1438
1439#[cfg(feature = "db")]
1440pub use db_store::DbApiTokenStore;
1441
1442#[cfg(test)]
1443mod tests {
1444    use super::*;
1445
1446    #[tokio::test]
1447    async fn hash_and_verify_password() {
1448        let hash = hash_password("test_password").await.unwrap();
1449        assert!(hash.starts_with("$2b$"));
1450        assert!(verify_password("test_password", &hash).await.unwrap());
1451        assert!(!verify_password("wrong_password", &hash).await.unwrap());
1452    }
1453
1454    #[tokio::test]
1455    async fn verify_invalid_hash_returns_false() {
1456        let result = verify_password("test", "not-a-valid-hash").await;
1457        assert!(result.is_ok());
1458        assert!(!result.unwrap());
1459    }
1460
1461    #[tokio::test]
1462    async fn verify_password_rejects_invalid_hash_format_safely() {
1463        // Test short hash
1464        let result = verify_password("test", "short").await;
1465        assert!(result.is_ok());
1466        assert!(!result.unwrap());
1467
1468        // Test hash with correct length but not starting with $
1469        let bad_prefix = "a".repeat(60);
1470        let result = verify_password("test", &bad_prefix).await;
1471        assert!(result.is_ok());
1472        assert!(!result.unwrap());
1473
1474        // Test hash with incorrect length but starting with $
1475        let bad_length = "$2b$12$short";
1476        let result = verify_password("test", bad_length).await;
1477        assert!(result.is_ok());
1478        assert!(!result.unwrap());
1479    }
1480
1481    #[test]
1482    fn auth_config_defaults() {
1483        let config = AuthConfig::default();
1484        assert_eq!(config.bcrypt_cost, 12);
1485        assert_eq!(config.session_key, "user_id");
1486        #[cfg(feature = "oauth2")]
1487        assert!(config.oauth2.providers.is_empty());
1488    }
1489
1490    #[cfg(feature = "oauth2")]
1491    #[test]
1492    fn oauth2_config_deserializes_provider_tables() {
1493        let cfg: crate::config::AutumnConfig = toml::from_str(
1494            r#"
1495            [auth.oauth2.github]
1496            client_id = "cid"
1497            client_secret = "secret"
1498            authorize_url = "https://github.com/login/oauth/authorize"
1499            token_url = "https://github.com/login/oauth/access_token"
1500            redirect_uri = "http://localhost:3000/auth/github/callback"
1501            "#,
1502        )
1503        .unwrap();
1504        let provider = cfg.auth.oauth2.providers.get("github").unwrap();
1505        assert_eq!(provider.client_id, "cid");
1506        assert_eq!(provider.scope, "");
1507        assert!(provider.issuer.is_none());
1508        assert!(provider.jwks_url.is_none());
1509    }
1510
1511    #[cfg(feature = "oauth2")]
1512    #[tokio::test]
1513    async fn oauth2_authorize_url_sets_state_and_nonce() {
1514        let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1515        let provider = OAuth2ProviderConfig {
1516            client_id: "cid".into(),
1517            client_secret: "secret".into(),
1518            authorize_url: "https://idp.example/authorize".into(),
1519            token_url: "https://idp.example/token".into(),
1520            userinfo_url: None,
1521            redirect_uri: "http://localhost:3000/callback".into(),
1522            scope: "openid profile".into(),
1523            issuer: None,
1524            jwks_url: None,
1525        };
1526        let url = oauth2_authorize_url(&session, "github", &provider)
1527            .await
1528            .unwrap();
1529        assert!(url.contains("response_type=code"));
1530        assert!(session.get("oauth2:github:state").await.is_some());
1531        assert!(session.get("oauth2:github:nonce").await.is_some());
1532    }
1533
1534    #[cfg(feature = "oauth2")]
1535    #[tokio::test]
1536    async fn oauth2_authorize_url_omits_scope_when_empty() {
1537        let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1538        let provider = OAuth2ProviderConfig {
1539            client_id: "cid".into(),
1540            client_secret: "secret".into(),
1541            authorize_url: "https://idp.example/authorize".into(),
1542            token_url: "https://idp.example/token".into(),
1543            userinfo_url: None,
1544            redirect_uri: "http://localhost:3000/callback".into(),
1545            scope: String::new(),
1546            issuer: None,
1547            jwks_url: None,
1548        };
1549        let url = oauth2_authorize_url(&session, "github", &provider)
1550            .await
1551            .unwrap();
1552        assert!(!url.contains("scope="));
1553    }
1554
1555    #[cfg(feature = "oauth2")]
1556    #[tokio::test]
1557    async fn validate_id_token_requires_oidc_metadata() {
1558        let provider = OAuth2ProviderConfig {
1559            client_id: "cid".into(),
1560            client_secret: "secret".into(),
1561            authorize_url: "https://idp.example/authorize".into(),
1562            token_url: "https://idp.example/token".into(),
1563            userinfo_url: None,
1564            redirect_uri: "http://localhost:3000/callback".into(),
1565            scope: "openid profile".into(),
1566            issuer: None,
1567            jwks_url: None,
1568        };
1569        let err = validate_and_decode_id_token("bad.token.value", &provider)
1570            .await
1571            .unwrap_err();
1572        assert_eq!(err.to_string(), "provider.issuer required for oidc");
1573    }
1574
1575    #[cfg(feature = "oauth2")]
1576    #[test]
1577    fn parse_oauth2_token_response_supports_form_encoded_payload() {
1578        let token = parse_oauth2_token_response(
1579            Some("application/x-www-form-urlencoded"),
1580            "access_token=abc123&token_type=bearer&id_token=xyz789&extra_field=ignored",
1581        )
1582        .unwrap();
1583        assert_eq!(token.access_token, "abc123");
1584        assert_eq!(token.token_type.as_deref(), Some("bearer"));
1585        assert_eq!(token.id_token.as_deref(), Some("xyz789"));
1586    }
1587
1588    #[cfg(feature = "oauth2")]
1589    #[test]
1590    fn parse_oauth2_token_response_fails_without_access_token() {
1591        let err = parse_oauth2_token_response(
1592            Some("application/x-www-form-urlencoded"),
1593            "token_type=bearer&id_token=xyz789",
1594        )
1595        .unwrap_err();
1596        assert_eq!(err.to_string(), "token response missing access_token");
1597    }
1598
1599    #[cfg(feature = "oauth2")]
1600    #[test]
1601    fn extract_subject_allows_userinfo_id_fallback() {
1602        let claims = serde_json::json!({ "id": 42 });
1603        let subject = extract_subject(&claims, IdentitySource::UserInfo).unwrap();
1604        assert_eq!(subject, "42");
1605    }
1606
1607    #[cfg(feature = "oauth2")]
1608    #[tokio::test]
1609    async fn validate_callback_state_preserves_state_on_mismatch() {
1610        // An attacker hitting the callback with a wrong state must NOT
1611        // consume the real state stored in the session; the legitimate
1612        // provider redirect must still succeed.
1613        let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1614        session
1615            .insert("oauth2:github:state".to_owned(), "real-state".to_owned())
1616            .await;
1617        let bad_callback = OAuth2Callback {
1618            code: "c".into(),
1619            state: "wrong-state".into(),
1620        };
1621        let err = validate_callback_state(&session, "github", &bad_callback)
1622            .await
1623            .unwrap_err();
1624        assert!(err.to_string().contains("state mismatch"));
1625        // Real state must still be present after the failed attempt.
1626        assert_eq!(
1627            session.get("oauth2:github:state").await.as_deref(),
1628            Some("real-state")
1629        );
1630    }
1631
1632    #[cfg(feature = "oauth2")]
1633    #[tokio::test]
1634    async fn validate_oidc_nonce_rejects_missing_nonce_for_id_token() {
1635        // ID-token logins must fail when there is no stored nonce (e.g.,
1636        // session was partially cleared or forged).
1637        let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1638        // No nonce key inserted — simulates a cleared / missing session.
1639        let claims = serde_json::json!({ "nonce": "any" });
1640        let err = validate_oidc_nonce(&session, "github", &claims, IdentitySource::IdToken)
1641            .await
1642            .unwrap_err();
1643        assert!(err.to_string().contains("nonce missing from session"));
1644    }
1645
1646    #[cfg(feature = "oauth2")]
1647    #[test]
1648    fn extract_subject_requires_sub_for_id_token() {
1649        let claims = serde_json::json!({ "id": "abc" });
1650        let err = extract_subject(&claims, IdentitySource::IdToken).unwrap_err();
1651        assert_eq!(err.to_string(), "missing sub claim");
1652    }
1653
1654    #[test]
1655    fn auth_rejection_is_401() {
1656        let rejection = AuthRejection;
1657        let response = rejection.into_response();
1658        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1659    }
1660
1661    #[test]
1662    fn auth_rejection_display() {
1663        assert_eq!(AuthRejection.to_string(), "authentication required");
1664    }
1665
1666    #[tokio::test]
1667    async fn auth_extractor_returns_401_when_no_user() {
1668        use crate::state::AppState;
1669        use axum::Router;
1670        use axum::body::Body;
1671        use axum::routing::get;
1672        use tower::ServiceExt;
1673
1674        #[derive(Clone)]
1675        struct TestUser {
1676            name: String,
1677        }
1678
1679        async fn handler(Auth(user): Auth<TestUser>) -> String {
1680            user.name
1681        }
1682
1683        let state = AppState {
1684            extensions: std::sync::Arc::new(std::sync::RwLock::new(
1685                std::collections::HashMap::new(),
1686            )),
1687            #[cfg(feature = "db")]
1688            pool: None,
1689            #[cfg(feature = "db")]
1690            replica_pool: None,
1691            profile: None,
1692            started_at: std::time::Instant::now(),
1693            health_detailed: false,
1694            probes: crate::probe::ProbeState::ready_for_test(),
1695            metrics: crate::middleware::MetricsCollector::new(),
1696            log_levels: crate::actuator::LogLevels::new("info"),
1697            task_registry: crate::actuator::TaskRegistry::new(),
1698            job_registry: crate::actuator::JobRegistry::new(),
1699            config_props: crate::actuator::ConfigProperties::default(),
1700            #[cfg(feature = "ws")]
1701            channels: crate::channels::Channels::new(32),
1702            #[cfg(feature = "ws")]
1703            shutdown: tokio_util::sync::CancellationToken::new(),
1704            policy_registry: crate::authorization::PolicyRegistry::default(),
1705            forbidden_response: crate::authorization::ForbiddenResponse::default(),
1706            auth_session_key: "user_id".to_owned(),
1707            shared_cache: None,
1708        };
1709
1710        let app = Router::new().route("/", get(handler)).with_state(state);
1711
1712        let response = app
1713            .oneshot(
1714                http::Request::builder()
1715                    .uri("/")
1716                    .body(Body::empty())
1717                    .unwrap(),
1718            )
1719            .await
1720            .unwrap();
1721
1722        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1723    }
1724
1725    #[tokio::test]
1726    async fn auth_extractor_returns_user_when_present() {
1727        use crate::state::AppState;
1728        use axum::Router;
1729        use axum::body::Body;
1730        use axum::routing::get;
1731        use tower::ServiceExt;
1732
1733        #[derive(Clone)]
1734        struct TestUser {
1735            name: String,
1736        }
1737
1738        async fn handler(Auth(user): Auth<TestUser>) -> String {
1739            user.name
1740        }
1741
1742        let state = AppState {
1743            extensions: std::sync::Arc::new(std::sync::RwLock::new(
1744                std::collections::HashMap::new(),
1745            )),
1746            #[cfg(feature = "db")]
1747            pool: None,
1748            #[cfg(feature = "db")]
1749            replica_pool: None,
1750            profile: None,
1751            started_at: std::time::Instant::now(),
1752            health_detailed: false,
1753            probes: crate::probe::ProbeState::ready_for_test(),
1754            metrics: crate::middleware::MetricsCollector::new(),
1755            log_levels: crate::actuator::LogLevels::new("info"),
1756            task_registry: crate::actuator::TaskRegistry::new(),
1757            job_registry: crate::actuator::JobRegistry::new(),
1758            config_props: crate::actuator::ConfigProperties::default(),
1759            #[cfg(feature = "ws")]
1760            channels: crate::channels::Channels::new(32),
1761            #[cfg(feature = "ws")]
1762            shutdown: tokio_util::sync::CancellationToken::new(),
1763            policy_registry: crate::authorization::PolicyRegistry::default(),
1764            forbidden_response: crate::authorization::ForbiddenResponse::default(),
1765            auth_session_key: "user_id".to_owned(),
1766            shared_cache: None,
1767        };
1768
1769        // Middleware that inserts a user into extensions
1770        let app = Router::new()
1771            .route("/", get(handler))
1772            .layer(axum::middleware::from_fn(
1773                |mut req: axum::extract::Request, next: axum::middleware::Next| async move {
1774                    req.extensions_mut().insert(TestUser {
1775                        name: "alice".into(),
1776                    });
1777                    next.run(req).await
1778                },
1779            ))
1780            .with_state(state);
1781
1782        let response = app
1783            .oneshot(
1784                http::Request::builder()
1785                    .uri("/")
1786                    .body(Body::empty())
1787                    .unwrap(),
1788            )
1789            .await
1790            .unwrap();
1791
1792        assert_eq!(response.status(), StatusCode::OK);
1793        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
1794            .await
1795            .unwrap();
1796        assert_eq!(std::str::from_utf8(&body).unwrap(), "alice");
1797    }
1798
1799    #[tokio::test]
1800    async fn require_auth_rejects_unauthenticated() {
1801        use axum::Router;
1802        use axum::body::Body;
1803        use axum::routing::get;
1804        use tower::ServiceExt;
1805
1806        use crate::session::{MemoryStore, SessionConfig, SessionLayer};
1807        use crate::state::AppState;
1808
1809        let state = AppState {
1810            extensions: std::sync::Arc::new(std::sync::RwLock::new(
1811                std::collections::HashMap::new(),
1812            )),
1813            #[cfg(feature = "db")]
1814            pool: None,
1815            #[cfg(feature = "db")]
1816            replica_pool: None,
1817            profile: None,
1818            started_at: std::time::Instant::now(),
1819            health_detailed: false,
1820            probes: crate::probe::ProbeState::ready_for_test(),
1821            metrics: crate::middleware::MetricsCollector::new(),
1822            log_levels: crate::actuator::LogLevels::new("info"),
1823            task_registry: crate::actuator::TaskRegistry::new(),
1824            job_registry: crate::actuator::JobRegistry::new(),
1825            config_props: crate::actuator::ConfigProperties::default(),
1826            #[cfg(feature = "ws")]
1827            channels: crate::channels::Channels::new(32),
1828            #[cfg(feature = "ws")]
1829            shutdown: tokio_util::sync::CancellationToken::new(),
1830            policy_registry: crate::authorization::PolicyRegistry::default(),
1831            forbidden_response: crate::authorization::ForbiddenResponse::default(),
1832            auth_session_key: "user_id".to_owned(),
1833            shared_cache: None,
1834        };
1835
1836        let app = Router::new()
1837            .route("/protected", get(|| async { "secret" }))
1838            .layer(RequireAuth::new("user_id"))
1839            .layer(SessionLayer::new(
1840                MemoryStore::new(),
1841                SessionConfig::default(),
1842            ))
1843            .with_state(state);
1844
1845        let response = app
1846            .oneshot(
1847                http::Request::builder()
1848                    .uri("/protected")
1849                    .body(Body::empty())
1850                    .unwrap(),
1851            )
1852            .await
1853            .unwrap();
1854
1855        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1856    }
1857
1858    // ── __check_secured tests ────────────────────────────────
1859
1860    #[tokio::test]
1861    async fn check_secured_rejects_unauthenticated() {
1862        let session =
1863            crate::session::Session::new_for_test(String::new(), std::collections::HashMap::new());
1864        let result = __check_secured(&session, &[]).await;
1865        assert!(result.is_err());
1866        let err = result.unwrap_err();
1867        assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
1868        assert_eq!(err.to_string(), "authentication required");
1869    }
1870
1871    #[tokio::test]
1872    async fn check_secured_allows_authenticated() {
1873        let data = std::collections::HashMap::from([("user_id".into(), "42".into())]);
1874        let session = crate::session::Session::new_for_test("sess".into(), data);
1875        let result = __check_secured(&session, &[]).await;
1876        assert!(result.is_ok());
1877    }
1878
1879    #[tokio::test]
1880    async fn check_secured_rejects_wrong_role() {
1881        let data = std::collections::HashMap::from([
1882            ("user_id".into(), "42".into()),
1883            ("role".into(), "viewer".into()),
1884        ]);
1885        let session = crate::session::Session::new_for_test("sess".into(), data);
1886        let result = __check_secured(&session, &["admin"]).await;
1887        assert!(result.is_err());
1888        let err = result.unwrap_err();
1889        assert_eq!(err.status(), StatusCode::FORBIDDEN);
1890        assert_eq!(err.to_string(), "insufficient permissions");
1891    }
1892
1893    #[tokio::test]
1894    async fn check_secured_allows_matching_role() {
1895        let data = std::collections::HashMap::from([
1896            ("user_id".into(), "42".into()),
1897            ("role".into(), "admin".into()),
1898        ]);
1899        let session = crate::session::Session::new_for_test("sess".into(), data);
1900        let result = __check_secured(&session, &["admin"]).await;
1901        assert!(result.is_ok());
1902    }
1903
1904    #[tokio::test]
1905    async fn check_secured_allows_any_of_multiple_roles() {
1906        let data = std::collections::HashMap::from([
1907            ("user_id".into(), "42".into()),
1908            ("role".into(), "editor".into()),
1909        ]);
1910        let session = crate::session::Session::new_for_test("sess".into(), data);
1911        let result = __check_secured(&session, &["admin", "editor"]).await;
1912        assert!(result.is_ok());
1913    }
1914
1915    // ── #[secured] macro integration tests ──────────────────────
1916
1917    #[tokio::test]
1918    async fn secured_macro_rejects_unauthenticated() {
1919        use axum::Router;
1920        use axum::body::Body;
1921        use axum::routing::get;
1922        use tower::ServiceExt;
1923
1924        use crate::session::{MemoryStore, SessionConfig, SessionLayer};
1925        use crate::state::AppState;
1926
1927        #[autumn_macros::secured]
1928        async fn protected_handler() -> crate::AutumnResult<&'static str> {
1929            Ok("secret")
1930        }
1931
1932        let state = AppState {
1933            extensions: std::sync::Arc::new(std::sync::RwLock::new(
1934                std::collections::HashMap::new(),
1935            )),
1936            #[cfg(feature = "db")]
1937            pool: None,
1938            #[cfg(feature = "db")]
1939            replica_pool: None,
1940            profile: None,
1941            started_at: std::time::Instant::now(),
1942            health_detailed: false,
1943            probes: crate::probe::ProbeState::ready_for_test(),
1944            metrics: crate::middleware::MetricsCollector::new(),
1945            log_levels: crate::actuator::LogLevels::new("info"),
1946            task_registry: crate::actuator::TaskRegistry::new(),
1947            job_registry: crate::actuator::JobRegistry::new(),
1948            config_props: crate::actuator::ConfigProperties::default(),
1949            #[cfg(feature = "ws")]
1950            channels: crate::channels::Channels::new(32),
1951            #[cfg(feature = "ws")]
1952            shutdown: tokio_util::sync::CancellationToken::new(),
1953            policy_registry: crate::authorization::PolicyRegistry::default(),
1954            forbidden_response: crate::authorization::ForbiddenResponse::default(),
1955            auth_session_key: "user_id".to_owned(),
1956            shared_cache: None,
1957        };
1958
1959        let app = Router::new()
1960            .route("/", get(protected_handler))
1961            .layer(SessionLayer::new(
1962                MemoryStore::new(),
1963                SessionConfig::default(),
1964            ))
1965            .with_state(state);
1966
1967        let response = app
1968            .oneshot(
1969                http::Request::builder()
1970                    .uri("/")
1971                    .body(Body::empty())
1972                    .unwrap(),
1973            )
1974            .await
1975            .unwrap();
1976
1977        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1978    }
1979
1980    #[tokio::test]
1981    async fn secured_macro_allows_authenticated() {
1982        use axum::Router;
1983        use axum::body::Body;
1984        use axum::routing::get;
1985        use http::header::COOKIE;
1986        use tower::ServiceExt;
1987
1988        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
1989        use crate::state::AppState;
1990
1991        #[autumn_macros::secured]
1992        async fn protected_handler() -> crate::AutumnResult<&'static str> {
1993            Ok("secret")
1994        }
1995
1996        let store = MemoryStore::new();
1997        store
1998            .save(
1999                "sess1",
2000                std::collections::HashMap::from([("user_id".into(), "42".into())]),
2001            )
2002            .await
2003            .unwrap();
2004
2005        let state = AppState {
2006            extensions: std::sync::Arc::new(std::sync::RwLock::new(
2007                std::collections::HashMap::new(),
2008            )),
2009            #[cfg(feature = "db")]
2010            pool: None,
2011            #[cfg(feature = "db")]
2012            replica_pool: None,
2013            profile: None,
2014            started_at: std::time::Instant::now(),
2015            health_detailed: false,
2016            probes: crate::probe::ProbeState::ready_for_test(),
2017            metrics: crate::middleware::MetricsCollector::new(),
2018            log_levels: crate::actuator::LogLevels::new("info"),
2019            task_registry: crate::actuator::TaskRegistry::new(),
2020            job_registry: crate::actuator::JobRegistry::new(),
2021            config_props: crate::actuator::ConfigProperties::default(),
2022            #[cfg(feature = "ws")]
2023            channels: crate::channels::Channels::new(32),
2024            #[cfg(feature = "ws")]
2025            shutdown: tokio_util::sync::CancellationToken::new(),
2026            policy_registry: crate::authorization::PolicyRegistry::default(),
2027            forbidden_response: crate::authorization::ForbiddenResponse::default(),
2028            auth_session_key: "user_id".to_owned(),
2029            shared_cache: None,
2030        };
2031
2032        let app = Router::new()
2033            .route("/", get(protected_handler))
2034            .layer(SessionLayer::new(store, SessionConfig::default()))
2035            .with_state(state);
2036
2037        let response = app
2038            .oneshot(
2039                http::Request::builder()
2040                    .uri("/")
2041                    .header(COOKIE, "autumn.sid=sess1")
2042                    .body(Body::empty())
2043                    .unwrap(),
2044            )
2045            .await
2046            .unwrap();
2047
2048        assert_eq!(response.status(), StatusCode::OK);
2049        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2050            .await
2051            .unwrap();
2052        assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
2053    }
2054
2055    #[tokio::test]
2056    async fn secured_macro_honors_configured_auth_session_key() {
2057        use axum::Router;
2058        use axum::body::Body;
2059        use axum::routing::get;
2060        use http::header::COOKIE;
2061        use tower::ServiceExt;
2062
2063        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2064        use crate::state::AppState;
2065
2066        #[autumn_macros::secured]
2067        async fn account_handler() -> crate::AutumnResult<&'static str> {
2068            Ok("account")
2069        }
2070
2071        let store = MemoryStore::new();
2072        store
2073            .save(
2074                "sess1",
2075                std::collections::HashMap::from([
2076                    ("uid".into(), "42".into()),
2077                    ("account_id".into(), "42".into()),
2078                ]),
2079            )
2080            .await
2081            .unwrap();
2082
2083        let state = AppState {
2084            extensions: std::sync::Arc::new(std::sync::RwLock::new(
2085                std::collections::HashMap::new(),
2086            )),
2087            #[cfg(feature = "db")]
2088            pool: None,
2089            #[cfg(feature = "db")]
2090            replica_pool: None,
2091            profile: None,
2092            started_at: std::time::Instant::now(),
2093            health_detailed: false,
2094            probes: crate::probe::ProbeState::ready_for_test(),
2095            metrics: crate::middleware::MetricsCollector::new(),
2096            log_levels: crate::actuator::LogLevels::new("info"),
2097            task_registry: crate::actuator::TaskRegistry::new(),
2098            job_registry: crate::actuator::JobRegistry::new(),
2099            config_props: crate::actuator::ConfigProperties::default(),
2100            #[cfg(feature = "ws")]
2101            channels: crate::channels::Channels::new(32),
2102            #[cfg(feature = "ws")]
2103            shutdown: tokio_util::sync::CancellationToken::new(),
2104            policy_registry: crate::authorization::PolicyRegistry::default(),
2105            forbidden_response: crate::authorization::ForbiddenResponse::default(),
2106            auth_session_key: "uid".to_owned(),
2107            shared_cache: None,
2108        };
2109
2110        let app = Router::new()
2111            .route("/account", get(account_handler))
2112            .layer(SessionLayer::new(store, SessionConfig::default()))
2113            .with_state(state);
2114
2115        let response = app
2116            .oneshot(
2117                http::Request::builder()
2118                    .uri("/account")
2119                    .header(COOKIE, "autumn.sid=sess1")
2120                    .body(Body::empty())
2121                    .unwrap(),
2122            )
2123            .await
2124            .unwrap();
2125
2126        assert_eq!(response.status(), StatusCode::OK);
2127        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2128            .await
2129            .unwrap();
2130        assert_eq!(std::str::from_utf8(&body).unwrap(), "account");
2131    }
2132
2133    #[tokio::test]
2134    async fn secured_macro_with_role_rejects_wrong_role() {
2135        use axum::Router;
2136        use axum::body::Body;
2137        use axum::routing::get;
2138        use http::header::COOKIE;
2139        use tower::ServiceExt;
2140
2141        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2142        use crate::state::AppState;
2143
2144        #[autumn_macros::secured("admin")]
2145        async fn admin_only() -> crate::AutumnResult<&'static str> {
2146            Ok("admin area")
2147        }
2148
2149        let store = MemoryStore::new();
2150        store
2151            .save(
2152                "sess1",
2153                std::collections::HashMap::from([
2154                    ("user_id".into(), "42".into()),
2155                    ("role".into(), "viewer".into()),
2156                ]),
2157            )
2158            .await
2159            .unwrap();
2160
2161        let state = AppState {
2162            extensions: std::sync::Arc::new(std::sync::RwLock::new(
2163                std::collections::HashMap::new(),
2164            )),
2165            #[cfg(feature = "db")]
2166            pool: None,
2167            #[cfg(feature = "db")]
2168            replica_pool: None,
2169            profile: None,
2170            started_at: std::time::Instant::now(),
2171            health_detailed: false,
2172            probes: crate::probe::ProbeState::ready_for_test(),
2173            metrics: crate::middleware::MetricsCollector::new(),
2174            log_levels: crate::actuator::LogLevels::new("info"),
2175            task_registry: crate::actuator::TaskRegistry::new(),
2176            job_registry: crate::actuator::JobRegistry::new(),
2177            config_props: crate::actuator::ConfigProperties::default(),
2178            #[cfg(feature = "ws")]
2179            channels: crate::channels::Channels::new(32),
2180            #[cfg(feature = "ws")]
2181            shutdown: tokio_util::sync::CancellationToken::new(),
2182            policy_registry: crate::authorization::PolicyRegistry::default(),
2183            forbidden_response: crate::authorization::ForbiddenResponse::default(),
2184            auth_session_key: "user_id".to_owned(),
2185            shared_cache: None,
2186        };
2187
2188        let app = Router::new()
2189            .route("/", get(admin_only))
2190            .layer(SessionLayer::new(store, SessionConfig::default()))
2191            .with_state(state);
2192
2193        let response = app
2194            .oneshot(
2195                http::Request::builder()
2196                    .uri("/")
2197                    .header(COOKIE, "autumn.sid=sess1")
2198                    .body(Body::empty())
2199                    .unwrap(),
2200            )
2201            .await
2202            .unwrap();
2203
2204        assert_eq!(response.status(), StatusCode::FORBIDDEN);
2205    }
2206
2207    #[tokio::test]
2208    async fn secured_macro_with_multiple_roles_allows_match() {
2209        use axum::Router;
2210        use axum::body::Body;
2211        use axum::routing::get;
2212        use http::header::COOKIE;
2213        use tower::ServiceExt;
2214
2215        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2216        use crate::state::AppState;
2217
2218        #[autumn_macros::secured("admin", "editor")]
2219        async fn content_handler() -> crate::AutumnResult<&'static str> {
2220            Ok("content")
2221        }
2222
2223        let store = MemoryStore::new();
2224        store
2225            .save(
2226                "sess1",
2227                std::collections::HashMap::from([
2228                    ("user_id".into(), "42".into()),
2229                    ("role".into(), "editor".into()),
2230                ]),
2231            )
2232            .await
2233            .unwrap();
2234
2235        let state = AppState {
2236            extensions: std::sync::Arc::new(std::sync::RwLock::new(
2237                std::collections::HashMap::new(),
2238            )),
2239            #[cfg(feature = "db")]
2240            pool: None,
2241            #[cfg(feature = "db")]
2242            replica_pool: None,
2243            profile: None,
2244            started_at: std::time::Instant::now(),
2245            health_detailed: false,
2246            probes: crate::probe::ProbeState::ready_for_test(),
2247            metrics: crate::middleware::MetricsCollector::new(),
2248            log_levels: crate::actuator::LogLevels::new("info"),
2249            task_registry: crate::actuator::TaskRegistry::new(),
2250            job_registry: crate::actuator::JobRegistry::new(),
2251            config_props: crate::actuator::ConfigProperties::default(),
2252            #[cfg(feature = "ws")]
2253            channels: crate::channels::Channels::new(32),
2254            #[cfg(feature = "ws")]
2255            shutdown: tokio_util::sync::CancellationToken::new(),
2256            policy_registry: crate::authorization::PolicyRegistry::default(),
2257            forbidden_response: crate::authorization::ForbiddenResponse::default(),
2258            auth_session_key: "user_id".to_owned(),
2259            shared_cache: None,
2260        };
2261
2262        let app = Router::new()
2263            .route("/", get(content_handler))
2264            .layer(SessionLayer::new(store, SessionConfig::default()))
2265            .with_state(state);
2266
2267        let response = app
2268            .oneshot(
2269                http::Request::builder()
2270                    .uri("/")
2271                    .header(COOKIE, "autumn.sid=sess1")
2272                    .body(Body::empty())
2273                    .unwrap(),
2274            )
2275            .await
2276            .unwrap();
2277
2278        assert_eq!(response.status(), StatusCode::OK);
2279        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2280            .await
2281            .unwrap();
2282        assert_eq!(std::str::from_utf8(&body).unwrap(), "content");
2283    }
2284
2285    #[tokio::test]
2286    async fn require_auth_allows_authenticated() {
2287        use axum::Router;
2288        use axum::body::Body;
2289        use axum::routing::get;
2290        use http::header::COOKIE;
2291        use tower::ServiceExt;
2292
2293        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2294        use crate::state::AppState;
2295
2296        let store = MemoryStore::new();
2297        // Pre-populate a session with user_id
2298        let mut session_data = std::collections::HashMap::new();
2299        session_data.insert("user_id".into(), "42".into());
2300        store.save("valid-session", session_data).await.unwrap();
2301
2302        let state = AppState {
2303            extensions: std::sync::Arc::new(std::sync::RwLock::new(
2304                std::collections::HashMap::new(),
2305            )),
2306            #[cfg(feature = "db")]
2307            pool: None,
2308            #[cfg(feature = "db")]
2309            replica_pool: None,
2310            profile: None,
2311            started_at: std::time::Instant::now(),
2312            health_detailed: false,
2313            probes: crate::probe::ProbeState::ready_for_test(),
2314            metrics: crate::middleware::MetricsCollector::new(),
2315            log_levels: crate::actuator::LogLevels::new("info"),
2316            task_registry: crate::actuator::TaskRegistry::new(),
2317            job_registry: crate::actuator::JobRegistry::new(),
2318            config_props: crate::actuator::ConfigProperties::default(),
2319            #[cfg(feature = "ws")]
2320            channels: crate::channels::Channels::new(32),
2321            #[cfg(feature = "ws")]
2322            shutdown: tokio_util::sync::CancellationToken::new(),
2323            policy_registry: crate::authorization::PolicyRegistry::default(),
2324            forbidden_response: crate::authorization::ForbiddenResponse::default(),
2325            auth_session_key: "user_id".to_owned(),
2326            shared_cache: None,
2327        };
2328
2329        let app = Router::new()
2330            .route("/protected", get(|| async { "secret" }))
2331            .layer(RequireAuth::new("user_id"))
2332            .layer(SessionLayer::new(store, SessionConfig::default()))
2333            .with_state(state);
2334
2335        let response = app
2336            .oneshot(
2337                http::Request::builder()
2338                    .uri("/protected")
2339                    .header(COOKIE, "autumn.sid=valid-session")
2340                    .body(Body::empty())
2341                    .unwrap(),
2342            )
2343            .await
2344            .unwrap();
2345
2346        assert_eq!(response.status(), StatusCode::OK);
2347        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2348            .await
2349            .unwrap();
2350        assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
2351    }
2352
2353    #[tokio::test]
2354    async fn require_auth_poll_ready_propagates() {
2355        use std::task::{Context, Poll};
2356        use tower::{Layer, Service};
2357
2358        #[derive(Clone)]
2359        struct MockService {
2360            ready: bool,
2361            poll_count: std::sync::Arc<std::sync::atomic::AtomicUsize>,
2362        }
2363
2364        impl Service<axum::extract::Request> for MockService {
2365            type Response = axum::response::Response;
2366            type Error = std::convert::Infallible;
2367            type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
2368
2369            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
2370                self.poll_count
2371                    .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2372                if self.ready {
2373                    Poll::Ready(Ok(()))
2374                } else {
2375                    Poll::Pending
2376                }
2377            }
2378
2379            fn call(&mut self, _req: axum::extract::Request) -> Self::Future {
2380                std::future::ready(Ok(axum::response::Response::new(axum::body::Body::empty())))
2381            }
2382        }
2383
2384        let layer = RequireAuth::new("user_id");
2385        let poll_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
2386        let mock_service = MockService {
2387            ready: false,
2388            poll_count: poll_count.clone(),
2389        };
2390        let mut service = layer.layer(mock_service);
2391
2392        let waker = futures::task::noop_waker();
2393        let mut cx = Context::from_waker(&waker);
2394
2395        // When inner is not ready, RequireAuthService should not be ready
2396        let poll = service.poll_ready(&mut cx);
2397        assert!(poll.is_pending());
2398        assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 1);
2399
2400        // When inner is ready, RequireAuthService should be ready
2401        let mock_service_ready = MockService {
2402            ready: true,
2403            poll_count: poll_count.clone(),
2404        };
2405        let mut service_ready = layer.layer(mock_service_ready);
2406        let poll_ready = service_ready.poll_ready(&mut cx);
2407        assert!(poll_ready.is_ready());
2408        assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 2);
2409    }
2410
2411    #[tokio::test]
2412    async fn auth_rejection_into_response() {
2413        let rejection = AuthRejection;
2414        let response = rejection.into_response();
2415        assert_eq!(response.status(), axum::http::StatusCode::UNAUTHORIZED);
2416        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2417            .await
2418            .unwrap();
2419        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2420        assert_eq!(json["status"], 401);
2421        assert_eq!(json["detail"], "authentication required");
2422        assert_eq!(json["code"], "autumn.unauthorized");
2423    }
2424
2425    #[test]
2426    fn test_auth_config_defaults() {
2427        let config = AuthConfig::default();
2428        assert_eq!(config.bcrypt_cost, DEFAULT_BCRYPT_COST);
2429        assert_eq!(config.session_key, "user_id");
2430    }
2431
2432    #[tokio::test]
2433    async fn test_hash_password() {
2434        let test_input = uuid::Uuid::new_v4().to_string();
2435
2436        // Test hashing
2437        let hash = super::hash_password(&test_input)
2438            .await
2439            .expect("Failed to hash password");
2440        assert!(hash.starts_with("$2b$"));
2441
2442        // Test verification with correct password
2443        let is_valid = super::verify_password(&test_input, &hash)
2444            .await
2445            .expect("Failed to verify password");
2446        assert!(is_valid, "Password should be verified successfully");
2447
2448        // Test verification with incorrect password
2449        let is_invalid = super::verify_password(&uuid::Uuid::new_v4().to_string(), &hash)
2450            .await
2451            .expect("Failed to verify wrong password");
2452        assert!(!is_invalid, "Wrong password should not be verified");
2453    }
2454
2455    #[tokio::test]
2456    async fn test_hash_password_empty() {
2457        let test_input = String::new();
2458        let hash = super::hash_password(&test_input)
2459            .await
2460            .expect("Failed to hash empty password");
2461        assert!(hash.starts_with("$2b$"));
2462
2463        let is_valid = super::verify_password(&test_input, &hash)
2464            .await
2465            .expect("Failed to verify empty password");
2466        assert!(is_valid, "Empty password should be verified successfully");
2467    }
2468
2469    #[tokio::test]
2470    async fn test_hash_password_long() {
2471        // bcrypt truncates after 72 bytes. We just want to ensure it doesn't crash.
2472        let test_input = "a".repeat(100);
2473        let hash = super::hash_password(&test_input)
2474            .await
2475            .expect("Failed to hash long password");
2476        assert!(hash.starts_with("$2b$"));
2477
2478        let is_valid = super::verify_password(&test_input, &hash)
2479            .await
2480            .expect("Failed to verify long password");
2481        assert!(is_valid, "Long password should be verified successfully");
2482    }
2483
2484    #[tokio::test]
2485    async fn test_hash_password_unicode() {
2486        // Test with non-ascii characters
2487        let test_input = format!("{}🚀my_secrët_passwörd🔑", uuid::Uuid::new_v4());
2488        let hash = super::hash_password(&test_input)
2489            .await
2490            .expect("Failed to hash unicode password");
2491        assert!(hash.starts_with("$2b$"));
2492
2493        let is_valid = super::verify_password(&test_input, &hash)
2494            .await
2495            .expect("Failed to verify unicode password");
2496        assert!(is_valid, "Unicode password should be verified successfully");
2497    }
2498
2499    #[tokio::test]
2500    async fn test_verify_password_invalid_hash() {
2501        // Ensure that providing invalid hashes doesn't crash or cause issues, but returns an error/false
2502        let test_input = uuid::Uuid::new_v4().to_string();
2503
2504        // Invalid prefix
2505        let result = super::verify_password(&test_input, "invalid_hash_string").await;
2506        assert!(result.is_err() || !result.unwrap());
2507
2508        // Truncated hash
2509        let result2 = super::verify_password(&test_input, "$2b$04$").await;
2510        assert!(result2.is_err() || !result2.unwrap());
2511    }
2512}
2513
2514// ── API token tests ───────────────────────────────────────────────────────────
2515
2516#[cfg(test)]
2517mod api_token_tests {
2518    use std::sync::Arc;
2519
2520    use http::StatusCode;
2521
2522    use super::{
2523        ApiToken, ApiTokenStore, InMemoryApiTokenStore, RequireApiToken, hash_api_token,
2524        issue_api_token, revoke_api_token,
2525    };
2526
2527    struct FailingApiTokenStore;
2528
2529    impl ApiTokenStore for FailingApiTokenStore {
2530        fn issue<'a>(
2531            &'a self,
2532            _principal_id: &'a str,
2533        ) -> std::pin::Pin<
2534            Box<dyn std::future::Future<Output = crate::AutumnResult<String>> + Send + 'a>,
2535        > {
2536            Box::pin(async {
2537                Err(crate::AutumnError::service_unavailable_msg(
2538                    "api token store unavailable",
2539                ))
2540            })
2541        }
2542
2543        fn verify<'a>(
2544            &'a self,
2545            _raw_token: &'a str,
2546        ) -> std::pin::Pin<
2547            Box<dyn std::future::Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>,
2548        > {
2549            Box::pin(async {
2550                Err(crate::AutumnError::service_unavailable_msg(
2551                    "api token store unavailable",
2552                ))
2553            })
2554        }
2555
2556        fn revoke<'a>(
2557            &'a self,
2558            _raw_token: &'a str,
2559        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send + 'a>>
2560        {
2561            Box::pin(async {
2562                Err(crate::AutumnError::service_unavailable_msg(
2563                    "api token store unavailable",
2564                ))
2565            })
2566        }
2567    }
2568
2569    // ── hash_api_token ───────────────────────────────────────────────────────
2570
2571    #[test]
2572    fn hash_api_token_is_deterministic() {
2573        let h1 = hash_api_token("abc123");
2574        let h2 = hash_api_token("abc123");
2575        assert_eq!(h1, h2);
2576    }
2577
2578    #[test]
2579    fn hash_api_token_produces_64_char_hex() {
2580        let hash = hash_api_token("any_raw_token");
2581        assert_eq!(hash.len(), 64, "SHA-256 hex must be 64 chars");
2582        assert!(
2583            hash.chars().all(|c| c.is_ascii_hexdigit()),
2584            "hash must be lowercase hex digits"
2585        );
2586    }
2587
2588    #[test]
2589    fn hash_api_token_differs_from_input() {
2590        let raw = "my_raw_token";
2591        assert_ne!(hash_api_token(raw), raw);
2592    }
2593
2594    #[test]
2595    fn hash_api_token_different_inputs_produce_different_hashes() {
2596        assert_ne!(hash_api_token("token_a"), hash_api_token("token_b"));
2597    }
2598
2599    // ── InMemoryApiTokenStore ────────────────────────────────────────────────
2600
2601    #[tokio::test]
2602    async fn in_memory_store_issue_returns_unique_tokens() {
2603        let store = InMemoryApiTokenStore::default();
2604        let t1 = store.issue("user:1").await.unwrap();
2605        let t2 = store.issue("user:1").await.unwrap();
2606        assert_ne!(t1, t2, "each issued token must be unique");
2607        assert!(t1.len() >= 32, "token must have sufficient entropy");
2608    }
2609
2610    #[tokio::test]
2611    async fn in_memory_store_verify_returns_principal_for_valid_token() {
2612        let store = InMemoryApiTokenStore::default();
2613        let raw = store.issue("user:42").await.unwrap();
2614        let principal = store.verify(&raw).await.unwrap();
2615        assert_eq!(principal, Some("user:42".to_owned()));
2616    }
2617
2618    #[tokio::test]
2619    async fn in_memory_store_verify_returns_none_for_unknown_token() {
2620        let store = InMemoryApiTokenStore::default();
2621        let result = store.verify("not_a_real_token").await.unwrap();
2622        assert_eq!(result, None);
2623    }
2624
2625    #[tokio::test]
2626    async fn in_memory_store_revoke_invalidates_token() {
2627        let store = InMemoryApiTokenStore::default();
2628        let raw = store.issue("user:7").await.unwrap();
2629        assert_eq!(
2630            store.verify(&raw).await.unwrap(),
2631            Some("user:7".to_owned()),
2632            "token must be valid before revoking"
2633        );
2634        store.revoke(&raw).await.unwrap();
2635        assert_eq!(store.verify(&raw).await.unwrap(), None);
2636    }
2637
2638    #[tokio::test]
2639    async fn in_memory_store_raw_token_not_stored_verbatim() {
2640        let store = InMemoryApiTokenStore::default();
2641        let raw = store.issue("user:1").await.unwrap();
2642        // Appending a character changes the hash → lookup must return None.
2643        let tampered = format!("{raw}x");
2644        assert_eq!(store.verify(&tampered).await.unwrap(), None);
2645    }
2646
2647    #[tokio::test]
2648    async fn issue_api_token_helper_issues_verifiable_token() {
2649        let store = InMemoryApiTokenStore::default();
2650        let raw = issue_api_token(&store, "user:5").await.unwrap();
2651        assert_eq!(store.verify(&raw).await.unwrap(), Some("user:5".to_owned()));
2652    }
2653
2654    #[tokio::test]
2655    async fn revoke_api_token_helper_revokes_token() {
2656        let store = InMemoryApiTokenStore::default();
2657        let raw = store.issue("user:6").await.unwrap();
2658        revoke_api_token(&store, &raw).await.unwrap();
2659        assert_eq!(store.verify(&raw).await.unwrap(), None);
2660    }
2661
2662    // ── RequireApiToken middleware ───────────────────────────────────────────
2663
2664    #[tokio::test]
2665    async fn require_api_token_rejects_missing_authorization_header() {
2666        use axum::body::Body;
2667        use tower::ServiceExt;
2668
2669        let store = Arc::new(InMemoryApiTokenStore::default());
2670        let app = axum::Router::new()
2671            .route("/", axum::routing::get(|| async { "ok" }))
2672            .layer(RequireApiToken::new(store));
2673
2674        let response = app
2675            .oneshot(
2676                http::Request::builder()
2677                    .uri("/")
2678                    .body(Body::empty())
2679                    .unwrap(),
2680            )
2681            .await
2682            .unwrap();
2683
2684        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2685    }
2686
2687    #[tokio::test]
2688    async fn require_api_token_rejects_non_bearer_scheme() {
2689        use axum::body::Body;
2690        use tower::ServiceExt;
2691
2692        let store = Arc::new(InMemoryApiTokenStore::default());
2693        let app = axum::Router::new()
2694            .route("/", axum::routing::get(|| async { "ok" }))
2695            .layer(RequireApiToken::new(store));
2696
2697        let response = app
2698            .oneshot(
2699                http::Request::builder()
2700                    .uri("/")
2701                    .header(http::header::AUTHORIZATION, "Basic dXNlcjpwYXNz")
2702                    .body(Body::empty())
2703                    .unwrap(),
2704            )
2705            .await
2706            .unwrap();
2707
2708        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2709    }
2710
2711    #[tokio::test]
2712    async fn require_api_token_rejects_unknown_bearer_token() {
2713        use axum::body::Body;
2714        use tower::ServiceExt;
2715
2716        let store = Arc::new(InMemoryApiTokenStore::default());
2717        let app = axum::Router::new()
2718            .route("/", axum::routing::get(|| async { "ok" }))
2719            .layer(RequireApiToken::new(store));
2720
2721        let response = app
2722            .oneshot(
2723                http::Request::builder()
2724                    .uri("/")
2725                    .header(http::header::AUTHORIZATION, "Bearer unknown_token_xyz")
2726                    .body(Body::empty())
2727                    .unwrap(),
2728            )
2729            .await
2730            .unwrap();
2731
2732        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2733    }
2734
2735    #[tokio::test]
2736    async fn require_api_token_propagates_store_verify_errors() {
2737        use axum::body::Body;
2738        use tower::ServiceExt;
2739
2740        let store = Arc::new(FailingApiTokenStore);
2741        let app = axum::Router::new()
2742            .route("/", axum::routing::get(|| async { "ok" }))
2743            .layer(RequireApiToken::new(store));
2744
2745        let response = app
2746            .oneshot(
2747                http::Request::builder()
2748                    .uri("/")
2749                    .header(http::header::AUTHORIZATION, "Bearer valid_client_token")
2750                    .body(Body::empty())
2751                    .unwrap(),
2752            )
2753            .await
2754            .unwrap();
2755
2756        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
2757        assert_eq!(
2758            response
2759                .headers()
2760                .get(http::header::CONTENT_TYPE)
2761                .map(|value| value.to_str().unwrap_or_default()),
2762            Some("application/problem+json")
2763        );
2764
2765        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2766            .await
2767            .unwrap();
2768        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2769        assert_eq!(json["status"], 503);
2770        assert_eq!(json["code"], "autumn.service_unavailable");
2771        assert_eq!(json["detail"], "api token store unavailable");
2772    }
2773
2774    #[tokio::test]
2775    async fn require_api_token_allows_valid_bearer_token() {
2776        use axum::body::Body;
2777        use tower::ServiceExt;
2778
2779        let store = Arc::new(InMemoryApiTokenStore::default());
2780        let raw = store.issue("user:1").await.unwrap();
2781        let app = axum::Router::new()
2782            .route("/", axum::routing::get(|| async { "ok" }))
2783            .layer(RequireApiToken::new(Arc::clone(&store)));
2784
2785        let response = app
2786            .oneshot(
2787                http::Request::builder()
2788                    .uri("/")
2789                    .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2790                    .body(Body::empty())
2791                    .unwrap(),
2792            )
2793            .await
2794            .unwrap();
2795
2796        assert_eq!(response.status(), StatusCode::OK);
2797    }
2798
2799    #[tokio::test]
2800    async fn require_api_token_accepts_case_insensitive_bearer_scheme() {
2801        use axum::body::Body;
2802        use tower::ServiceExt;
2803
2804        let store = Arc::new(InMemoryApiTokenStore::default());
2805        let raw = store.issue("user:1").await.unwrap();
2806
2807        for scheme in ["bearer", "bEaReR"] {
2808            let app = axum::Router::new()
2809                .route("/", axum::routing::get(|| async { "ok" }))
2810                .layer(RequireApiToken::new(Arc::clone(&store)));
2811
2812            let response = app
2813                .oneshot(
2814                    http::Request::builder()
2815                        .uri("/")
2816                        .header(http::header::AUTHORIZATION, format!("{scheme} {raw}"))
2817                        .body(Body::empty())
2818                        .unwrap(),
2819                )
2820                .await
2821                .unwrap();
2822
2823            assert_eq!(response.status(), StatusCode::OK, "scheme {scheme}");
2824        }
2825    }
2826
2827    #[tokio::test]
2828    async fn require_api_token_rejects_revoked_token() {
2829        use axum::body::Body;
2830        use tower::ServiceExt;
2831
2832        let store = Arc::new(InMemoryApiTokenStore::default());
2833        let raw = store.issue("user:1").await.unwrap();
2834        store.revoke(&raw).await.unwrap();
2835        let app = axum::Router::new()
2836            .route("/", axum::routing::get(|| async { "ok" }))
2837            .layer(RequireApiToken::new(Arc::clone(&store)));
2838
2839        let response = app
2840            .oneshot(
2841                http::Request::builder()
2842                    .uri("/")
2843                    .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2844                    .body(Body::empty())
2845                    .unwrap(),
2846            )
2847            .await
2848            .unwrap();
2849
2850        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2851    }
2852
2853    #[tokio::test]
2854    async fn require_api_token_401_response_has_problem_details() {
2855        use axum::body::Body;
2856        use tower::ServiceExt;
2857
2858        let store = Arc::new(InMemoryApiTokenStore::default());
2859        let app = axum::Router::new()
2860            .route("/", axum::routing::get(|| async { "ok" }))
2861            .layer(RequireApiToken::new(store));
2862
2863        let response = app
2864            .oneshot(
2865                http::Request::builder()
2866                    .uri("/")
2867                    .body(Body::empty())
2868                    .unwrap(),
2869            )
2870            .await
2871            .unwrap();
2872
2873        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2874        assert_eq!(
2875            response
2876                .headers()
2877                .get(http::header::CONTENT_TYPE)
2878                .map(|v| v.to_str().unwrap_or_default()),
2879            Some("application/problem+json")
2880        );
2881        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2882            .await
2883            .unwrap();
2884        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2885        assert_eq!(json["status"], 401);
2886        assert_eq!(json["code"], "autumn.unauthorized");
2887        assert!(json["detail"].as_str().is_some());
2888    }
2889
2890    #[tokio::test]
2891    async fn require_api_token_401_problem_details_include_request_context() {
2892        use crate::middleware::RequestIdLayer;
2893        use axum::body::Body;
2894        use tower::ServiceExt;
2895
2896        let store = Arc::new(InMemoryApiTokenStore::default());
2897        let app = axum::Router::new()
2898            .route("/api/private", axum::routing::get(|| async { "ok" }))
2899            .layer(RequireApiToken::new(store))
2900            .layer(RequestIdLayer);
2901
2902        let response = app
2903            .oneshot(
2904                http::Request::builder()
2905                    .uri("/api/private")
2906                    .body(Body::empty())
2907                    .unwrap(),
2908            )
2909            .await
2910            .unwrap();
2911
2912        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2913        let request_id = response
2914            .headers()
2915            .get("x-request-id")
2916            .and_then(|value| value.to_str().ok())
2917            .expect("request id header should be present")
2918            .to_owned();
2919        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2920            .await
2921            .unwrap();
2922        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2923        assert_eq!(json["request_id"], request_id);
2924        assert_eq!(json["instance"], "/api/private");
2925    }
2926
2927    // ── ApiToken extractor ───────────────────────────────────────────────────
2928
2929    #[tokio::test]
2930    async fn api_token_extractor_yields_principal_id_to_handler() {
2931        use axum::body::Body;
2932        use tower::ServiceExt;
2933
2934        async fn handler(ApiToken(principal): ApiToken) -> String {
2935            principal
2936        }
2937
2938        let store = Arc::new(InMemoryApiTokenStore::default());
2939        let raw = store.issue("user:99").await.unwrap();
2940        let app = axum::Router::new()
2941            .route("/", axum::routing::get(handler))
2942            .layer(RequireApiToken::new(Arc::clone(&store)));
2943
2944        let response = app
2945            .oneshot(
2946                http::Request::builder()
2947                    .uri("/")
2948                    .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2949                    .body(Body::empty())
2950                    .unwrap(),
2951            )
2952            .await
2953            .unwrap();
2954
2955        assert_eq!(response.status(), StatusCode::OK);
2956        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2957            .await
2958            .unwrap();
2959        assert_eq!(std::str::from_utf8(&body).unwrap(), "user:99");
2960    }
2961
2962    #[tokio::test]
2963    async fn api_token_extractor_rejects_when_no_principal_in_extensions() {
2964        use axum::body::Body;
2965        use tower::ServiceExt;
2966
2967        async fn handler(ApiToken(principal): ApiToken) -> String {
2968            principal
2969        }
2970
2971        let app = axum::Router::new().route("/", axum::routing::get(handler));
2972
2973        let response = app
2974            .oneshot(
2975                http::Request::builder()
2976                    .uri("/")
2977                    .body(Body::empty())
2978                    .unwrap(),
2979            )
2980            .await
2981            .unwrap();
2982
2983        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2984    }
2985
2986    // ── Composition with session auth ────────────────────────────────────────
2987
2988    #[tokio::test]
2989    async fn api_token_and_session_auth_compose_without_conflict() {
2990        use axum::body::Body;
2991        use tower::ServiceExt;
2992
2993        use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2994
2995        async fn api_handler(ApiToken(principal): ApiToken) -> String {
2996            principal
2997        }
2998
2999        let store = Arc::new(InMemoryApiTokenStore::default());
3000        let raw = store.issue("api_user").await.unwrap();
3001
3002        let session_store = MemoryStore::new();
3003        session_store
3004            .save(
3005                "sess1",
3006                std::collections::HashMap::from([("user_id".into(), "session_user".into())]),
3007            )
3008            .await
3009            .unwrap();
3010
3011        let app = axum::Router::new()
3012            .route(
3013                "/api",
3014                axum::routing::get(api_handler).layer(RequireApiToken::new(Arc::clone(&store))),
3015            )
3016            .layer(SessionLayer::new(session_store, SessionConfig::default()));
3017
3018        let response = app
3019            .oneshot(
3020                http::Request::builder()
3021                    .uri("/api")
3022                    .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
3023                    .body(Body::empty())
3024                    .unwrap(),
3025            )
3026            .await
3027            .unwrap();
3028
3029        assert_eq!(response.status(), StatusCode::OK);
3030        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
3031            .await
3032            .unwrap();
3033        assert_eq!(std::str::from_utf8(&body).unwrap(), "api_user");
3034    }
3035
3036    // ── poll_ready propagation ───────────────────────────────────────────────
3037
3038    #[tokio::test]
3039    async fn require_api_token_poll_ready_propagates_to_inner() {
3040        use std::task::{Context, Poll};
3041        use tower::{Layer, Service};
3042
3043        #[derive(Clone)]
3044        struct MockService {
3045            ready: bool,
3046        }
3047
3048        impl tower::Service<axum::extract::Request> for MockService {
3049            type Response = axum::response::Response;
3050            type Error = std::convert::Infallible;
3051            type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
3052
3053            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
3054                if self.ready {
3055                    Poll::Ready(Ok(()))
3056                } else {
3057                    Poll::Pending
3058                }
3059            }
3060
3061            fn call(&mut self, _req: axum::extract::Request) -> Self::Future {
3062                std::future::ready(Ok(axum::response::Response::new(axum::body::Body::empty())))
3063            }
3064        }
3065
3066        let waker = futures::task::noop_waker();
3067        let mut cx = Context::from_waker(&waker);
3068
3069        let store = Arc::new(InMemoryApiTokenStore::default());
3070        let layer = RequireApiToken::new(store);
3071        let mut svc = layer.layer(MockService { ready: false });
3072        assert!(svc.poll_ready(&mut cx).is_pending());
3073
3074        let store2 = Arc::new(InMemoryApiTokenStore::default());
3075        let layer2 = RequireApiToken::new(store2);
3076        let mut svc2 = layer2.layer(MockService { ready: true });
3077        assert!(svc2.poll_ready(&mut cx).is_ready());
3078    }
3079}