axum_gate/gate/oauth2/
mod.rs

1//! OAuth2 login flow with `/login` and `/callback` routes. Cookie templates are validated when building routes to fail fast on insecure combinations.
2//!
3//! Example: insert account before JWT using a repository (ensures stable `account_id` in cookie)
4//!
5//! ```rust
6//! use axum_gate::prelude::*;
7//! use axum_gate::repositories::memory::MemoryAccountRepository;
8//! use std::sync::Arc;
9//!
10//! let jwt_codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
11//!
12//! // Build an account repository (e.g., in-memory for examples)
13//! let account_repo = Arc::new(MemoryAccountRepository::<Role, Group>::default());
14//!
15//! // Configure gate with repository-backed insertion before JWT issuance
16//! let gate = Gate::oauth2::<Role, Group>()
17//!     .auth_url("https://provider.example.com/oauth2/authorize")
18//!     .token_url("https://provider.example.com/oauth2/token")
19//!     .client_id("CLIENT_ID")
20//!     .client_secret("CLIENT_SECRET")
21//!     .redirect_url("http://localhost:3000/auth/callback")
22//!     .add_scope("openid")
23//!     // Provide JWT codec and TTL as usual
24//!     .with_jwt_codec("my-app", jwt_codec, 60 * 60 * 24)
25//!     // Persist or load the account before encoding the JWT
26//!     .with_account_repository(Arc::clone(&account_repo))
27//!     // Map provider tokens to your domain account (e.g., via userinfo)
28//!     .with_account_mapper(|_token| {
29//!         Box::pin(async move {
30//!             Ok(Account::<Role, Group>::new("user@example.com", &[Role::User], &[]))
31//!         })
32//!     });
33//! ```
34//!
35//!
36//! This module provides an OAuth2Gate builder that mounts routes to perform an
37//! Authorization Code + PKCE flow. On successful callback, it can:
38//! - Map the token response to an `Account<R, G>` via a user-supplied mapper
39//! - Mint a first-party JWT via a user-supplied codec (helper provided)
40//! - Optionally insert or load the account before issuing the JWT (via `with_account_repository` or `with_account_inserter`) so the cookie includes a stable `account_id`
41//! - Set a secure auth cookie using the crate’s cookie template
42//! - Optionally redirect to a configured post-login URL
43//!
44//! Usage (minimal):
45//!
46//! - Configure the gate (auth url, token url, client credentials, redirect url, scopes)
47//! - Provide an account mapper and JWT codec to issue a first-party session
48//! - Optionally provide an account inserter or repository to persist/load an account before JWT, ensuring a stable `account_id` in the session cookie
49//! - Mount its routes under a base path like `/auth`
50//!
51//! Example (issuing first-party cookie):
52//!
53//! ```rust
54//! use axum::{Router, routing::get};
55//! use axum_gate::prelude::*;
56//! use std::sync::Arc;
57//!
58//! let jwt_codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
59//!
60//! let gate = Gate::oauth2::<Role, Group>()
61//!     .auth_url("https://provider.example.com/oauth2/authorize")
62//!     .token_url("https://provider.example.com/oauth2/token")
63//!     .client_id("CLIENT_ID")
64//!     .client_secret("CLIENT_SECRET") // optional for public clients
65//!     .redirect_url("http://localhost:3000/auth/callback")
66//!     .add_scope("openid")
67//!     .with_post_login_redirect("/")
68//!     .with_jwt_codec("my-app", Arc::clone(&jwt_codec), 60 * 60 * 24) // 24h TTL
69//!     .with_account_mapper(|_token| {
70//!         // Map provider token response to your domain Account<R, G>.
71//!         // For plain OAuth2, you might call the provider's userinfo API here.
72//!         // Example (pseudo):
73//!         // let user = fetch_userinfo(token.access_token())?;
74//!         // Ok(Account::new(&user.email, &[Role::User], &[]))
75//!         Box::pin(async move {
76//!             Ok(Account::<Role, Group>::new("user@example.com", &[Role::User], &[]))
77//!         })
78//!     });
79//!
80//! // routes() returns Result<Router<()>, OAuth2Error>; handle or unwrap as needed
81//! let auth_router = gate.routes("/auth").expect("valid OAuth2 config");
82//! let app = Router::<()>::new().nest("/auth", auth_router);
83//! ```
84//!
85//! Security and cookie configuration:
86//! - State and PKCE cookies use secure, short-lived, HttpOnly defaults with SameSite=Lax (good for OAuth redirects).
87//! - You can fully customize state/PKCE cookie attributes (name, path, domain, SameSite, Secure, HttpOnly, Max-Age)
88//!   via `CookieTemplate` helpers on the builder.
89//! - The first-party auth cookie template remains configurable via `with_cookie_template` or `configure_cookie_template`.
90//!
91//! Example: customize state/PKCE cookies
92//! ```rust
93//! use axum_gate::prelude::*;
94//! use cookie::{SameSite, time::Duration};
95//!
96//! let gate = Gate::oauth2::<Role, Group>()
97//!     // ... provider endpoints and client config ...
98//!     // Optional: custom names (multi-provider setups)
99//!     .with_cookie_names("my-oauth-state", "my-oauth-pkce")
100//!     // Optional: fine-tune state cookie (shorter TTL, SameSite)
101//!     .configure_state_cookie_template(|t| {
102//!         t.same_site(SameSite::Lax)
103//!          .max_age(Duration::minutes(5))
104//!     })
105//!     .unwrap()
106//!     // Optional: fine-tune PKCE cookie similarly
107//!     .configure_pkce_cookie_template(|t| {
108//!         t.same_site(SameSite::Lax)
109//!          .max_age(Duration::minutes(5))
110//!     })
111//!     .unwrap();
112//! ```
113//!
114//! Note: In production, serve over HTTPS and prefer `Secure=true`. If you set `SameSite=None` you must also set `Secure=true`
115//! (browser enforcement); `CookieTemplate::validate()` guards against insecure combinations.
116
117use crate::accounts::{Account, AccountRepository};
118use crate::authz::AccessHierarchy;
119use crate::codecs::Codec;
120use crate::codecs::jwt::{JwtClaims, RegisteredClaims};
121use crate::cookie_template::CookieTemplate;
122pub mod errors;
123use self::errors::{OAuth2CookieKind, OAuth2Error, Result as OAuth2Result};
124
125use axum::{
126    Extension, Router,
127    extract::Query,
128    response::{IntoResponse, Redirect},
129    routing::get,
130};
131use axum_extra::extract::CookieJar;
132use chrono::Utc;
133use cookie::{SameSite, time::Duration};
134use http::StatusCode;
135use oauth2::{
136    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
137    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardTokenResponse, TokenResponse,
138    TokenUrl, basic::BasicClient, basic::BasicTokenType,
139};
140use serde::Deserialize;
141use std::fmt::Display;
142use std::future::Future;
143use std::marker::PhantomData;
144use std::pin::Pin;
145use std::sync::Arc;
146use tracing::{debug, error};
147
148/// Default cookie name for CSRF state during OAuth2 authorization.
149const DEFAULT_STATE_COOKIE: &str = "oauth-state";
150
151/// Default cookie name for PKCE verifier during OAuth2 authorization.
152const DEFAULT_PKCE_COOKIE: &str = "oauth-pkce";
153
154/// Type alias for an account encoding function.
155type AccountEncoderFn<R, G> = Arc<dyn Fn(Account<R, G>) -> OAuth2Result<String> + Send + Sync>;
156/// Type alias for an account mapper function.
157type AccountMapperFn<R, G> = Arc<
158    dyn for<'a> Fn(
159            &'a StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
160        )
161            -> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send + 'a>>
162        + Send
163        + Sync,
164>;
165/// Type alias for an async account persistence function invoked before JWT issuance.
166///
167/// This closure should persist or load the account (idempotently), and return the account
168/// that should be encoded into the first‑party JWT (typically with a stable `account_id`).
169type AccountPersistFn<R, G> = Arc<
170    dyn Fn(Account<R, G>) -> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send>>
171        + Send
172        + Sync,
173>;
174
175/// Public builder for configuring OAuth2 routes and session issuance.
176#[derive(Clone)]
177#[must_use]
178pub struct OAuth2Gate<R, G>
179where
180    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
181    G: Eq + Clone + Send + Sync + 'static,
182{
183    // OAuth2 endpoints and client config
184    auth_url: Option<String>,
185    token_url: Option<String>,
186    client_id: Option<String>,
187    client_secret: Option<String>,
188    redirect_url: Option<String>,
189    scopes: Vec<String>,
190
191    // CSRF/PKCE cookie templates
192    state_cookie_template: CookieTemplate,
193    pkce_cookie_template: CookieTemplate,
194
195    // First-party session issuance (optional)
196    auth_cookie_template: CookieTemplate,
197    post_login_redirect: Option<String>,
198    mapper: Option<AccountMapperFn<R, G>>,
199    account_inserter: Option<AccountPersistFn<R, G>>,
200    jwt_encoder: Option<AccountEncoderFn<R, G>>,
201
202    _phantom: PhantomData<(R, G)>,
203}
204
205impl<R, G> Default for OAuth2Gate<R, G>
206where
207    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
208    G: Eq + Clone + Send + Sync + 'static,
209{
210    fn default() -> Self {
211        Self {
212            auth_url: None,
213            token_url: None,
214            client_id: None,
215            client_secret: None,
216            redirect_url: None,
217            scopes: Vec::new(),
218            state_cookie_template: CookieTemplate::recommended()
219                .name(DEFAULT_STATE_COOKIE)
220                .same_site(SameSite::Lax)
221                .max_age(Duration::minutes(10)),
222            pkce_cookie_template: CookieTemplate::recommended()
223                .name(DEFAULT_PKCE_COOKIE)
224                .same_site(SameSite::Lax)
225                .max_age(Duration::minutes(10)),
226            auth_cookie_template: CookieTemplate::recommended(),
227            post_login_redirect: None,
228            mapper: None,
229            account_inserter: None,
230            jwt_encoder: None,
231            _phantom: PhantomData,
232        }
233    }
234}
235
236impl<R, G> OAuth2Gate<R, G>
237where
238    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
239    G: Eq + Clone + Send + Sync + 'static,
240{
241    /// Create a new, empty builder.
242    pub fn new() -> Self {
243        Self::default()
244    }
245
246    /// Set the authorization endpoint URL.
247    pub fn auth_url(mut self, url: impl Into<String>) -> Self {
248        self.auth_url = Some(url.into());
249        self
250    }
251
252    /// Set the token endpoint URL.
253    pub fn token_url(mut self, url: impl Into<String>) -> Self {
254        self.token_url = Some(url.into());
255        self
256    }
257
258    /// Set the OAuth2 client ID.
259    pub fn client_id(mut self, id: impl Into<String>) -> Self {
260        self.client_id = Some(id.into());
261        self
262    }
263
264    /// Set the OAuth2 client secret (optional for public clients).
265    pub fn client_secret(mut self, secret: impl Into<String>) -> Self {
266        self.client_secret = Some(secret.into());
267        self
268    }
269
270    /// Set the redirect URL that your provider will call after user authorization.
271    pub fn redirect_url(mut self, url: impl Into<String>) -> Self {
272        self.redirect_url = Some(url.into());
273        self
274    }
275
276    /// Add a scope to request from the provider.
277    pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
278        self.scopes.push(scope.into());
279        self
280    }
281
282    /// Set custom cookie names for state/PKCE (primarily for multi-provider setups).
283    ///
284    /// This also updates the underlying cookie templates to use the provided names.
285    pub fn with_cookie_names(
286        mut self,
287        state_cookie: impl Into<String>,
288        pkce_cookie: impl Into<String>,
289    ) -> Self {
290        let state_name: String = state_cookie.into();
291        let pkce_name: String = pkce_cookie.into();
292
293        self.state_cookie_template = self.state_cookie_template.name(state_name);
294        self.pkce_cookie_template = self.pkce_cookie_template.name(pkce_name);
295        self
296    }
297
298    /// Configure the state cookie template directly.
299    pub fn with_state_cookie_template(mut self, template: CookieTemplate) -> Self {
300        self.state_cookie_template = template;
301        self
302    }
303
304    /// Convenience to configure the state cookie template via the high-level builder.
305    pub fn configure_state_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
306    where
307        F: FnOnce(CookieTemplate) -> CookieTemplate,
308    {
309        let template = f(CookieTemplate::recommended());
310        template
311            .validate()
312            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::State, e.to_string()))?;
313
314        self.state_cookie_template = template;
315        Ok(self)
316    }
317
318    /// Configure the PKCE cookie template directly.
319    pub fn with_pkce_cookie_template(mut self, template: CookieTemplate) -> Self {
320        self.pkce_cookie_template = template;
321        self
322    }
323
324    /// Convenience to configure the PKCE cookie template via the high-level builder.
325    pub fn configure_pkce_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
326    where
327        F: FnOnce(CookieTemplate) -> CookieTemplate,
328    {
329        let template = f(CookieTemplate::recommended());
330        template
331            .validate()
332            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Pkce, e.to_string()))?;
333
334        self.pkce_cookie_template = template;
335        Ok(self)
336    }
337
338    /// Configure the auth cookie template used to store the first-party JWT.
339    pub fn with_cookie_template(mut self, template: CookieTemplate) -> Self {
340        self.auth_cookie_template = template;
341        self
342    }
343
344    /// Convenience to configure the auth cookie template via the high-level builder.
345    pub fn configure_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
346    where
347        F: FnOnce(CookieTemplate) -> CookieTemplate,
348    {
349        let template = f(CookieTemplate::recommended());
350        template
351            .validate()
352            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Auth, e.to_string()))?;
353
354        self.auth_cookie_template = template;
355        Ok(self)
356    }
357
358    /// Configure a post-login redirect URL (e.g., "/").
359    pub fn with_post_login_redirect(mut self, url: impl Into<String>) -> Self {
360        self.post_login_redirect = Some(url.into());
361        self
362    }
363
364    /// Provide an async account mapper that converts the token response to an Account<R, G>.
365    ///
366    /// This allows performing async I/O (e.g., calling a provider user info endpoint) without blocking.
367    pub fn with_account_mapper<F>(mut self, f: F) -> Self
368    where
369        F: Send + Sync + 'static,
370        for<'a> F: Fn(
371            &'a StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
372        )
373            -> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send + 'a>>,
374    {
375        let f = Arc::new(f);
376        self.mapper = Some(Arc::new(move |token_resp| (f)(token_resp)));
377        self
378    }
379
380    /// Provide an async account inserter that persists or loads an account before JWT issuance.
381    ///
382    /// The closure is called after mapping the provider token to an Account and before encoding the JWT.
383    /// It should return the persisted or loaded Account (with a stable account_id).
384    pub fn with_account_inserter<F, Fut>(mut self, f: F) -> Self
385    where
386        F: Fn(Account<R, G>) -> Fut + Send + Sync + 'static,
387        Fut: Future<Output = OAuth2Result<Account<R, G>>> + Send + 'static,
388    {
389        self.account_inserter = Some(Arc::new(move |account: Account<R, G>| Box::pin(f(account))));
390        self
391    }
392
393    /// Convenience: insert into an AccountRepository on first login (idempotent).
394    ///
395    /// Queries by user_id; if missing, stores the provided account. Returns the existing or stored account.
396    pub fn with_account_repository<AccRepo>(mut self, account_repository: Arc<AccRepo>) -> Self
397    where
398        AccRepo: AccountRepository<R, G> + Send + Sync + 'static,
399    {
400        self.account_inserter = Some(Arc::new(move |account: Account<R, G>| {
401            let repo = Arc::clone(&account_repository);
402            Box::pin(async move {
403                match repo.query_account_by_user_id(&account.user_id).await {
404                    Ok(Some(existing)) => Ok(existing),
405                    Ok(None) => match repo.store_account(account).await {
406                        Ok(Some(stored)) => Ok(stored),
407                        Ok(None) => Err(OAuth2Error::account_persistence(
408                            "account repo returned None on store",
409                        )),
410                        Err(e) => Err(OAuth2Error::account_persistence(e.to_string())),
411                    },
412                    Err(e) => Err(OAuth2Error::account_persistence(e.to_string())),
413                }
414            })
415        }));
416        self
417    }
418
419    /// Provide a JWT codec and issuer; sets up a type-erased encoder closure.
420    ///
421    /// This helper uses your provided codec to mint a first-party session JWT from an Account<R, G>. The `ttl_secs` here sets expiry and overrides `with_jwt_ttl_secs`.
422    pub fn with_jwt_codec<C>(mut self, issuer: &str, codec: Arc<C>, ttl_secs: u64) -> Self
423    where
424        C: Codec<Payload = JwtClaims<Account<R, G>>> + Send + Sync + 'static,
425    {
426        let issuer = issuer.to_string();
427        self.jwt_encoder = Some(Arc::new(move |account: Account<R, G>| {
428            let exp = Utc::now().timestamp() as u64 + ttl_secs;
429            let registered = RegisteredClaims::new(&issuer, exp);
430            let claims = JwtClaims::new(account, registered);
431            let bytes = codec
432                .encode(&claims)
433                .map_err(|e| OAuth2Error::jwt_encoding(e.to_string()))?;
434            let token = String::from_utf8(bytes).map_err(|_| OAuth2Error::JwtNotUtf8)?;
435            Ok(token)
436        }));
437        self
438    }
439
440    /// Build and return an axum Router with `/login` and `/callback` routes nested under `base_path`.
441    ///
442    /// Example:
443    /// - base_path: "/auth" → routes are "/auth/login" and "/auth/callback"
444    pub fn routes(&self, base_path: &str) -> OAuth2Result<Router<()>> {
445        // Validate presence of required config and store raw values in handler state
446        let auth_url = self
447            .auth_url
448            .clone()
449            .ok_or_else(|| OAuth2Error::missing("auth_url"))?;
450        let token_url = self
451            .token_url
452            .clone()
453            .ok_or_else(|| OAuth2Error::missing("token_url"))?;
454        let client_id = self
455            .client_id
456            .clone()
457            .ok_or_else(|| OAuth2Error::missing("client_id"))?;
458        let redirect_url = self
459            .redirect_url
460            .clone()
461            .ok_or_else(|| OAuth2Error::missing("redirect_url"))?;
462
463        // Validate cookie templates to prevent insecure SameSite=None + Secure=false, etc.
464        self.state_cookie_template
465            .validate()
466            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::State, e.to_string()))?;
467        self.pkce_cookie_template
468            .validate()
469            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Pkce, e.to_string()))?;
470        self.auth_cookie_template
471            .validate()
472            .map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Auth, e.to_string()))?;
473
474        let handler_state = Arc::new(OAuth2HandlerState::<R, G> {
475            auth_url,
476            token_url,
477            client_id,
478            client_secret: self.client_secret.clone(),
479            redirect_url,
480            scopes: self.scopes.clone(),
481            state_cookie_template: self.state_cookie_template.clone(),
482            pkce_cookie_template: self.pkce_cookie_template.clone(),
483            auth_cookie_template: self.auth_cookie_template.clone(),
484            post_login_redirect: self.post_login_redirect.clone(),
485            mapper: self.mapper.clone(),
486            account_inserter: self.account_inserter.clone(),
487            jwt_encoder: self.jwt_encoder.clone(),
488        });
489
490        let base = base_path.trim_end_matches('/');
491        let login_path = format!("{base}/login");
492        let callback_path = format!("{base}/callback");
493
494        let router = Router::<()>::new()
495            .route(&login_path, get(login_handler::<R, G>))
496            .route(&callback_path, get(callback_handler::<R, G>))
497            .layer(Extension(handler_state));
498
499        Ok(router)
500    }
501}
502
503/// Shared handler state injected via `Extension`.
504#[derive(Clone)]
505struct OAuth2HandlerState<R, G>
506where
507    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
508    G: Eq + Clone + Send + Sync + 'static,
509{
510    // Raw OAuth2 config; client is constructed in handlers
511    auth_url: String,
512    token_url: String,
513    client_id: String,
514    client_secret: Option<String>,
515    redirect_url: String,
516    scopes: Vec<String>,
517
518    state_cookie_template: CookieTemplate,
519    pkce_cookie_template: CookieTemplate,
520
521    // Session issuance
522    auth_cookie_template: CookieTemplate,
523    post_login_redirect: Option<String>,
524    mapper: Option<AccountMapperFn<R, G>>,
525    account_inserter: Option<AccountPersistFn<R, G>>,
526    jwt_encoder: Option<AccountEncoderFn<R, G>>,
527}
528
529/// Query parameters delivered by the provider to the redirect/callback endpoint.
530#[derive(Deserialize, Debug)]
531struct CallbackQuery {
532    code: Option<String>,
533    state: Option<String>,
534    error: Option<String>,
535    error_description: Option<String>,
536}
537
538/// Generates PKCE/state cookies and redirects to the provider's authorization endpoint.
539async fn login_handler<R, G>(
540    Extension(st): Extension<Arc<OAuth2HandlerState<R, G>>>,
541    jar: CookieJar,
542) -> impl IntoResponse
543where
544    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
545    G: Eq + Clone + Send + Sync + 'static,
546{
547    let auth_url = match AuthUrl::new(st.auth_url.clone()) {
548        Ok(u) => u,
549        Err(e) => {
550            {
551                let err = self::errors::OAuth2Error::invalid_url("auth_url", e.to_string());
552                error!(
553                    "{}",
554                    crate::errors::UserFriendlyError::developer_message(&err)
555                );
556            }
557            return (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured").into_response();
558        }
559    };
560    let token_url = match TokenUrl::new(st.token_url.clone()) {
561        Ok(u) => u,
562        Err(e) => {
563            {
564                let err = self::errors::OAuth2Error::invalid_url("token_url", e.to_string());
565                error!(
566                    "{}",
567                    crate::errors::UserFriendlyError::developer_message(&err)
568                );
569            }
570            return (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured").into_response();
571        }
572    };
573    let redirect_url = match RedirectUrl::new(st.redirect_url.clone()) {
574        Ok(u) => u,
575        Err(e) => {
576            {
577                let err = self::errors::OAuth2Error::invalid_url("redirect_url", e.to_string());
578                error!(
579                    "{}",
580                    crate::errors::UserFriendlyError::developer_message(&err)
581                );
582            }
583            return (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured").into_response();
584        }
585    };
586    let mut client = BasicClient::new(ClientId::new(st.client_id.clone()))
587        .set_auth_uri(auth_url)
588        .set_token_uri(token_url)
589        .set_redirect_uri(redirect_url);
590    if let Some(secret) = &st.client_secret {
591        client = client.set_client_secret(ClientSecret::new(secret.clone()));
592    }
593
594    // CSRF state
595    let csrf = CsrfToken::new_random();
596    // PKCE challenge + verifier
597    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
598
599    let mut req = client
600        .authorize_url(|| csrf.clone())
601        .set_pkce_challenge(pkce_challenge);
602
603    for s in &st.scopes {
604        req = req.add_scope(Scope::new(s.clone()));
605    }
606
607    let (auth_url, csrf_token) = req.url();
608
609    // Prepare cookies using configured templates (short-lived by default)
610    let state_cookie = st
611        .state_cookie_template
612        .build_with_value(csrf_token.secret());
613
614    let pkce_cookie = st
615        .pkce_cookie_template
616        .build_with_value(pkce_verifier.secret());
617
618    let jar = jar.add(state_cookie).add(pkce_cookie);
619
620    (jar, Redirect::to(auth_url.as_str())).into_response()
621}
622
623/// Validates state and PKCE, exchanges code for tokens, optionally mints a first-party JWT,
624/// installs auth cookie, clears ephemeral cookies, and redirects (if configured).
625async fn callback_handler<R, G>(
626    Extension(st): Extension<Arc<OAuth2HandlerState<R, G>>>,
627    jar: CookieJar,
628    Query(q): Query<CallbackQuery>,
629) -> impl IntoResponse
630where
631    R: AccessHierarchy + Eq + std::fmt::Display + Send + Sync + 'static,
632    G: Eq + Clone + Send + Sync + 'static,
633{
634    // Load state + pkce verifier from cookies
635    let state_cookie = jar.get(st.state_cookie_template.cookie_name_ref());
636    let pkce_cookie = jar.get(st.pkce_cookie_template.cookie_name_ref());
637
638    let Some(state_cookie) = state_cookie else {
639        error!("Missing state cookie");
640        let state_removal = st.state_cookie_template.build_removal();
641        let pkce_removal = st.pkce_cookie_template.build_removal();
642        let jar = jar.add(state_removal).add(pkce_removal);
643        return (jar, (StatusCode::BAD_REQUEST, "Missing state")).into_response();
644    };
645
646    let Some(pkce_cookie) = pkce_cookie else {
647        error!("Missing PKCE cookie");
648        let state_removal = st.state_cookie_template.build_removal();
649        let pkce_removal = st.pkce_cookie_template.build_removal();
650        let jar = jar.add(state_removal).add(pkce_removal);
651        return (jar, (StatusCode::BAD_REQUEST, "Missing PKCE")).into_response();
652    };
653
654    // If provider returned an error, clear cookies and return a safe error.
655    if let Some(err) = q.error.as_deref() {
656        let state_removal = st.state_cookie_template.build_removal();
657        let pkce_removal = st.pkce_cookie_template.build_removal();
658        let jar = jar.add(state_removal).add(pkce_removal);
659        error!(
660            "OAuth2 provider returned error: {err} {:?}",
661            q.error_description.as_deref()
662        );
663        return (
664            jar,
665            (StatusCode::BAD_REQUEST, "OAuth2 authorization failed"),
666        )
667            .into_response();
668    }
669
670    // Compare state from query and cookie; require state param
671    match q.state.as_deref() {
672        Some(state) if state_cookie.value() == state => {}
673        _ => {
674            error!("State missing or mismatch");
675            let state_removal = st.state_cookie_template.build_removal();
676            let pkce_removal = st.pkce_cookie_template.build_removal();
677            let jar = jar.add(state_removal).add(pkce_removal);
678            return (
679                jar,
680                (StatusCode::BAD_REQUEST, "OAuth2 authorization failed"),
681            )
682                .into_response();
683        }
684    }
685
686    let Some(code_str) = q.code.clone() else {
687        let state_removal = st.state_cookie_template.build_removal();
688        let pkce_removal = st.pkce_cookie_template.build_removal();
689        let jar = jar.add(state_removal).add(pkce_removal);
690        return (
691            jar,
692            (StatusCode::BAD_REQUEST, "OAuth2 authorization failed"),
693        )
694            .into_response();
695    };
696    let code = AuthorizationCode::new(code_str);
697    let pkce_verifier = PkceCodeVerifier::new(pkce_cookie.value().to_string());
698
699    // Exchange code for tokens
700    let auth_url = match AuthUrl::new(st.auth_url.clone()) {
701        Ok(u) => u,
702        Err(e) => {
703            {
704                let err = self::errors::OAuth2Error::invalid_url("auth_url", e.to_string());
705                error!(
706                    "{}",
707                    crate::errors::UserFriendlyError::developer_message(&err)
708                );
709            }
710            let state_removal = st.state_cookie_template.build_removal();
711            let pkce_removal = st.pkce_cookie_template.build_removal();
712            let jar = jar.add(state_removal).add(pkce_removal);
713            return (
714                jar,
715                (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured"),
716            )
717                .into_response();
718        }
719    };
720    let token_url = match TokenUrl::new(st.token_url.clone()) {
721        Ok(u) => u,
722        Err(e) => {
723            {
724                let err = self::errors::OAuth2Error::invalid_url("token_url", e.to_string());
725                error!(
726                    "{}",
727                    crate::errors::UserFriendlyError::developer_message(&err)
728                );
729            }
730            let state_removal = st.state_cookie_template.build_removal();
731            let pkce_removal = st.pkce_cookie_template.build_removal();
732            let jar = jar.add(state_removal).add(pkce_removal);
733            return (
734                jar,
735                (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured"),
736            )
737                .into_response();
738        }
739    };
740    let redirect_url = match RedirectUrl::new(st.redirect_url.clone()) {
741        Ok(u) => u,
742        Err(e) => {
743            {
744                let err = self::errors::OAuth2Error::invalid_url("redirect_url", e.to_string());
745                error!(
746                    "{}",
747                    crate::errors::UserFriendlyError::developer_message(&err)
748                );
749            }
750            let state_removal = st.state_cookie_template.build_removal();
751            let pkce_removal = st.pkce_cookie_template.build_removal();
752            let jar = jar.add(state_removal).add(pkce_removal);
753            return (
754                jar,
755                (StatusCode::INTERNAL_SERVER_ERROR, "OAuth2 misconfigured"),
756            )
757                .into_response();
758        }
759    };
760    let mut client = BasicClient::new(ClientId::new(st.client_id.clone()))
761        .set_auth_uri(auth_url)
762        .set_token_uri(token_url)
763        .set_redirect_uri(redirect_url);
764    if let Some(secret) = &st.client_secret {
765        client = client.set_client_secret(ClientSecret::new(secret.clone()));
766    }
767
768    match client
769        .exchange_code(code)
770        .set_pkce_verifier(pkce_verifier)
771        .request_async(&|req: oauth2::HttpRequest| async move {
772            let client = reqwest::Client::builder()
773                .timeout(std::time::Duration::from_secs(10))
774                .build()?;
775            let url = req.uri().to_string();
776            let builder = client.request(req.method().clone(), url);
777            let resp = builder
778                .headers(req.headers().clone())
779                .body(req.body().clone())
780                .send()
781                .await?;
782            let status = resp.status();
783            let headers = resp.headers().clone();
784            let body = resp.bytes().await?.to_vec();
785            let mut resp_out = http::Response::new(body);
786            *resp_out.status_mut() = status;
787            *resp_out.headers_mut() = headers;
788            Ok::<http::Response<Vec<u8>>, reqwest::Error>(resp_out)
789        })
790        .await
791    {
792        Ok(token_resp) => {
793            debug!(
794                "OAuth2 token response received (scopes: {:?})",
795                token_resp.scopes()
796            );
797
798            // Clear ephemeral cookies (state/pkce) using configured templates
799            let state_removal = st.state_cookie_template.build_removal();
800            let pkce_removal = st.pkce_cookie_template.build_removal();
801
802            let mut jar = jar.add(state_removal).add(pkce_removal);
803
804            // Try session issuance if configured
805            if let (Some(mapper), Some(jwt_encoder)) = (&st.mapper, &st.jwt_encoder) {
806                // 1) Map provider tokens -> Account<R, G>
807                match (mapper)(&token_resp).await {
808                    Ok(mapped_account) => {
809                        // 2) Optionally persist/load account before JWT issuance (to get stable account_id)
810                        let account_result = if let Some(inserter) = &st.account_inserter {
811                            (inserter)(mapped_account).await
812                        } else {
813                            Ok(mapped_account)
814                        };
815
816                        // 3) Encode JWT using the (possibly persisted) account
817                        match account_result.and_then(|account| jwt_encoder(account)) {
818                            Ok(token) => {
819                                // Prepare auth cookie using template flags
820                                let auth_cookie = st.auth_cookie_template.build_with_value(&token);
821
822                                jar = jar.add(auth_cookie);
823
824                                if let Some(url) = &st.post_login_redirect {
825                                    return (jar, Redirect::to(url)).into_response();
826                                } else {
827                                    return (jar, (StatusCode::OK, "OAuth2 login OK"))
828                                        .into_response();
829                                }
830                            }
831                            Err(e) => {
832                                error!(
833                                    "OAuth2 session issuance failed [{}]: {}",
834                                    crate::errors::UserFriendlyError::support_code(&e),
835                                    crate::errors::UserFriendlyError::developer_message(&e),
836                                );
837                                return (
838                                    jar,
839                                    (StatusCode::BAD_GATEWAY, "OAuth2 session issuance failed"),
840                                )
841                                    .into_response();
842                            }
843                        }
844                    }
845                    Err(e) => {
846                        error!(
847                            "OAuth2 account mapping failed [{}]: {}",
848                            crate::errors::UserFriendlyError::support_code(&e),
849                            crate::errors::UserFriendlyError::developer_message(&e),
850                        );
851                        return (
852                            jar,
853                            (StatusCode::BAD_GATEWAY, "OAuth2 account mapping failed"),
854                        )
855                            .into_response();
856                    }
857                }
858            }
859
860            // If no session issuance configured, return OK
861            (jar, (StatusCode::OK, "OAuth2 callback OK")).into_response()
862        }
863        Err(err) => {
864            let oe = self::errors::OAuth2Error::token_exchange(err.to_string());
865            error!(
866                "OAuth2 token exchange failed [{}]: {}",
867                crate::errors::UserFriendlyError::support_code(&oe),
868                crate::errors::UserFriendlyError::developer_message(&oe),
869            );
870            let state_removal = st.state_cookie_template.build_removal();
871            let pkce_removal = st.pkce_cookie_template.build_removal();
872            let jar = jar.add(state_removal).add(pkce_removal);
873            (
874                jar,
875                (StatusCode::BAD_GATEWAY, "OAuth2 token exchange failed"),
876            )
877                .into_response()
878        }
879    }
880}
881
882// Debug implementation avoids leaking secrets
883impl<R, G> std::fmt::Debug for OAuth2Gate<R, G>
884where
885    R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
886    G: Eq + Clone + Send + Sync + 'static,
887{
888    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
889        f.debug_struct("OAuth2Gate")
890            .field("auth_url", &self.auth_url)
891            .field("token_url", &self.token_url)
892            .field(
893                "client_id",
894                &self.client_id.as_ref().map(|_| "<configured>"),
895            )
896            .field(
897                "client_secret",
898                &self.client_secret.as_ref().map(|_| "<redacted>"),
899            )
900            .field("redirect_url", &self.redirect_url)
901            .field("scopes", &self.scopes)
902            .field(
903                "state_cookie_name",
904                &self.state_cookie_template.cookie_name_ref(),
905            )
906            .field(
907                "pkce_cookie_name",
908                &self.pkce_cookie_template.cookie_name_ref(),
909            )
910            .field(
911                "auth_cookie_name",
912                &self.auth_cookie_template.cookie_name_ref(),
913            )
914            .field("post_login_redirect", &self.post_login_redirect)
915            .finish()
916    }
917}
918
919#[cfg(test)]
920mod tests {
921
922    use super::OAuth2Gate;
923    use crate::cookie_template::CookieTemplate;
924    use crate::prelude::{Group, Role};
925    #[cfg(debug_assertions)]
926    use cookie::SameSite;
927
928    #[test]
929    fn cookie_template_recommended_is_valid_in_debug_defaults() {
930        // recommended() uses Secure=false, SameSite=Lax in debug builds — should be valid
931        let t = CookieTemplate::recommended();
932        assert!(t.validate().is_ok());
933    }
934
935    #[test]
936    #[cfg(debug_assertions)]
937    fn cookie_template_insecure_none_is_rejected() {
938        // SameSite=None must require Secure=true; validate() should reject this in debug defaults.
939        let t = CookieTemplate::recommended().same_site(SameSite::None);
940        assert!(t.validate().is_err());
941    }
942
943    #[test]
944    #[cfg(debug_assertions)]
945    fn routes_validation_rejects_invalid_cookie_templates() {
946        // Auth cookie using SameSite=None without Secure must be rejected when building routes.
947        let gate = OAuth2Gate::<Role, Group>::new()
948            .auth_url("https://provider.example.com/oauth2/authorize")
949            .token_url("https://provider.example.com/oauth2/token")
950            .client_id("id")
951            .redirect_url("http://localhost:3000/auth/callback")
952            .with_cookie_template(CookieTemplate::recommended().same_site(SameSite::None));
953        assert!(gate.routes("/auth").is_err());
954    }
955
956    #[test]
957    fn debug_redacts_client_secret() {
958        let gate = OAuth2Gate::<Role, Group>::new()
959            .auth_url("https://provider.example.com/oauth2/authorize")
960            .token_url("https://provider.example.com/oauth2/token")
961            .client_id("id")
962            .client_secret("super-secret")
963            .redirect_url("http://localhost:3000/auth/callback");
964        let dbg = format!("{:?}", gate);
965        assert!(dbg.contains("<redacted>"));
966        assert!(!dbg.contains("super-secret"));
967    }
968}