Skip to main content

pylon_auth/
lib.rs

1pub mod api_key;
2pub mod apple_jwt;
3pub mod captcha;
4pub mod cookie;
5pub mod email;
6pub mod jwt;
7pub mod oidc_provider;
8pub mod org;
9pub mod password;
10pub mod phone;
11pub mod provider;
12pub mod rate_limit;
13pub mod scim;
14pub mod siwe;
15pub mod stripe;
16pub mod totp;
17pub mod verification;
18pub mod webauthn;
19
20pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
21
22use serde::{Deserialize, Serialize};
23
24// ---------------------------------------------------------------------------
25// Auth context — the identity available to runtime operations
26// ---------------------------------------------------------------------------
27
28/// The auth context for a request. Represents who is making the request.
29///
30/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
31/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
32/// or add roles and bypass every policy. Identity must come from
33/// server-minted sessions (`Session::to_auth_context`) or explicit
34/// constructors, never from deserialization.
35///
36/// `Serialize` is safe because sending the resolved context BACK to the
37/// client exposes nothing the server didn't already decide.
38#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
39pub struct AuthContext {
40    /// The authenticated user ID, or None for public/anonymous access.
41    /// For guest contexts this is `Some(guest_id)` — a stable
42    /// anonymous identifier, NOT a real user.
43    pub user_id: Option<String>,
44    /// Whether this is an admin context (bypasses policies).
45    pub is_admin: bool,
46    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
47    /// for cart state and similar pre-login persistence. Routes guarded
48    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
49    /// "real signed-in user" should pass auth-required gates.
50    #[serde(default, skip_serializing_if = "is_false")]
51    pub is_guest: bool,
52    /// Roles granted to this user. Empty for anonymous.
53    pub roles: Vec<String>,
54    /// Active tenant id (for multi-tenant apps). Set when the user has
55    /// selected an organization for the current session.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub tenant_id: Option<String>,
58    /// API key id when the request was authenticated via a `pk.…`
59    /// bearer token. Set so policies + management endpoints can
60    /// distinguish "user-via-session" from "user-via-key" — e.g.
61    /// password change is forbidden via API key.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub api_key_id: Option<String>,
64    /// Comma-separated scope string from the API key. Application
65    /// policies decide what scopes mean — pylon only carries them.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub api_key_scopes: Option<String>,
68}
69
70fn is_false(b: &bool) -> bool {
71    !b
72}
73
74impl AuthContext {
75    /// Create an anonymous/public auth context.
76    pub fn anonymous() -> Self {
77        Self {
78            user_id: None,
79            is_admin: false,
80            is_guest: false,
81            roles: Vec::new(),
82            tenant_id: None,
83            api_key_id: None,
84            api_key_scopes: None,
85        }
86    }
87
88    /// Create an authenticated auth context.
89    pub fn authenticated(user_id: String) -> Self {
90        Self {
91            user_id: Some(user_id),
92            is_admin: false,
93            is_guest: false,
94            roles: Vec::new(),
95            tenant_id: None,
96            api_key_id: None,
97            api_key_scopes: None,
98        }
99    }
100
101    /// Create an authenticated context backed by an API key. Policies +
102    /// auth-management endpoints can detect this via `is_api_key_auth()`.
103    pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
104        Self {
105            user_id: Some(user_id),
106            is_admin: false,
107            is_guest: false,
108            roles: Vec::new(),
109            tenant_id: None,
110            api_key_id: Some(key_id),
111            api_key_scopes: scopes,
112        }
113    }
114
115    /// True iff this request was authenticated by an API key (not a
116    /// session cookie / bearer session token).
117    pub fn is_api_key_auth(&self) -> bool {
118        self.api_key_id.is_some()
119    }
120
121    /// Create a guest auth context with a persistent anonymous ID.
122    /// Guests carry an opaque stable id (cart/session continuity) but
123    /// are NOT considered authenticated — `is_authenticated()` returns
124    /// false and `AuthMode::User` rejects them.
125    pub fn guest(guest_id: String) -> Self {
126        Self {
127            user_id: Some(guest_id),
128            is_admin: false,
129            is_guest: true,
130            roles: Vec::new(),
131            tenant_id: None,
132            api_key_id: None,
133            api_key_scopes: None,
134        }
135    }
136
137    /// Create an admin auth context that bypasses all policies.
138    pub fn admin() -> Self {
139        Self {
140            user_id: Some("__admin__".into()),
141            is_admin: true,
142            is_guest: false,
143            roles: vec!["admin".into()],
144            tenant_id: None,
145            api_key_id: None,
146            api_key_scopes: None,
147        }
148    }
149
150    /// Convenience: build a user context from a user id.
151    pub fn user(user_id: String) -> Self {
152        Self::authenticated(user_id)
153    }
154
155    /// Active tenant id (None when the user hasn't selected an org).
156    pub fn tenant_id(&self) -> Option<&str> {
157        self.tenant_id.as_deref()
158    }
159
160    /// Attach a tenant id to the context (chainable).
161    pub fn with_tenant(mut self, tenant_id: String) -> Self {
162        self.tenant_id = Some(tenant_id);
163        self
164    }
165
166    /// Check if this context represents an authenticated user.
167    /// Guests intentionally return `false` — they have a stable anonymous
168    /// id but never gain user-level access.
169    pub fn is_authenticated(&self) -> bool {
170        self.user_id.is_some() && !self.is_guest
171    }
172
173    /// Check if the user has a specific role. Admins have every role implicitly.
174    pub fn has_role(&self, role: &str) -> bool {
175        self.is_admin || self.roles.iter().any(|r| r == role)
176    }
177
178    /// Check if the user has ANY of the given roles.
179    pub fn has_any_role(&self, roles: &[&str]) -> bool {
180        self.is_admin || roles.iter().any(|r| self.has_role(r))
181    }
182
183    /// Attach roles to the context (chainable).
184    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
185        self.roles = roles;
186        self
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Constant-time comparison
192// ---------------------------------------------------------------------------
193
194/// Constant-time byte comparison to prevent timing attacks.
195///
196/// The length check leaks whether the two slices are the same length, but the
197/// content comparison always examines every byte regardless of where (or
198/// whether) a mismatch occurs.
199pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
200    if a.len() != b.len() {
201        return false;
202    }
203    let mut result: u8 = 0;
204    for (x, y) in a.iter().zip(b.iter()) {
205        result |= x ^ y;
206    }
207    result == 0
208}
209
210// ---------------------------------------------------------------------------
211// Auth mode — matches the route "auth" field values
212// ---------------------------------------------------------------------------
213
214/// The auth mode declared on a route.
215#[derive(Debug, Clone, PartialEq, Eq)]
216pub enum AuthMode {
217    /// Anyone can access.
218    Public,
219    /// Only authenticated users can access.
220    User,
221}
222
223impl AuthMode {
224    /// Parse from the manifest auth string.
225    #[allow(clippy::should_implement_trait)]
226    pub fn from_str(s: &str) -> Option<Self> {
227        match s {
228            "public" => Some(AuthMode::Public),
229            "user" => Some(AuthMode::User),
230            _ => None,
231        }
232    }
233
234    /// Check if the given auth context satisfies this mode.
235    pub fn check(&self, ctx: &AuthContext) -> bool {
236        match self {
237            AuthMode::Public => true,
238            AuthMode::User => ctx.is_authenticated(),
239        }
240    }
241}
242
243// ---------------------------------------------------------------------------
244// Session — opaque token session
245// ---------------------------------------------------------------------------
246
247/// A session token and its associated user.
248#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
249pub struct Session {
250    pub token: String,
251    pub user_id: String,
252    /// Unix epoch seconds at which this session expires. 0 = never.
253    #[serde(default)]
254    pub expires_at: u64,
255    /// Optional user-agent / device tag recorded at session creation.
256    #[serde(default, skip_serializing_if = "Option::is_none")]
257    pub device: Option<String>,
258    /// Unix epoch seconds when the session was created.
259    #[serde(default)]
260    pub created_at: u64,
261    /// Active tenant id (selected organization). Set via
262    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
263    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
264    #[serde(default, skip_serializing_if = "Option::is_none")]
265    pub tenant_id: Option<String>,
266}
267
268impl Session {
269    /// Default session lifetime: 30 days.
270    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
271
272    /// Create a new session with a generated token and default 30-day expiry.
273    pub fn new(user_id: String) -> Self {
274        let now = now_secs();
275        Self {
276            token: generate_token(),
277            user_id,
278            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
279            device: None,
280            created_at: now,
281            tenant_id: None,
282        }
283    }
284
285    /// Create a session with a specific lifetime.
286    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
287        let now = now_secs();
288        Self {
289            token: generate_token(),
290            user_id,
291            expires_at: if lifetime_secs == 0 {
292                0
293            } else {
294                now.saturating_add(lifetime_secs)
295            },
296            device: None,
297            created_at: now,
298            tenant_id: None,
299        }
300    }
301
302    /// Convert this session to an auth context, carrying the selected
303    /// tenant so row-scoped policies see `auth.tenantId`.
304    pub fn to_auth_context(&self) -> AuthContext {
305        let ctx = AuthContext::authenticated(self.user_id.clone());
306        match &self.tenant_id {
307            Some(t) => ctx.with_tenant(t.clone()),
308            None => ctx,
309        }
310    }
311
312    /// Returns true if the session has passed its expires_at time.
313    /// Boundary is inclusive (`>=`) to match the rest of the codebase
314    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
315    pub fn is_expired(&self) -> bool {
316        self.expires_at != 0 && now_secs() >= self.expires_at
317    }
318}
319
320fn now_secs() -> u64 {
321    use std::time::{SystemTime, UNIX_EPOCH};
322    SystemTime::now()
323        .duration_since(UNIX_EPOCH)
324        .unwrap_or_default()
325        .as_secs()
326}
327
328// ---------------------------------------------------------------------------
329// OAuth provider config
330// ---------------------------------------------------------------------------
331
332#[derive(Debug, Clone, Default, Serialize, Deserialize)]
333pub struct OAuthConfig {
334    pub provider: String,
335    pub client_id: String,
336    pub client_secret: String,
337    pub redirect_uri: String,
338    /// Optional scope override — replaces the spec's default scope
339    /// when set. Use cases: requesting `repo` on GitHub for app
340    /// installation flows, requesting `https://www.googleapis.com/...`
341    /// scopes on Google for app-specific data access.
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub scopes_override: Option<String>,
344    /// Tenant id for Microsoft/Entra. Defaults to `common`. Single-
345    /// tenant apps use a directory GUID; multi-tenant work-only apps
346    /// use `organizations`.
347    #[serde(default, skip_serializing_if = "Option::is_none")]
348    pub tenant: Option<String>,
349    /// Apple-specific extras (team id, key id, ES256 PEM). Required
350    /// for Sign in with Apple — ignored for any other provider.
351    #[serde(default, skip_serializing_if = "Option::is_none")]
352    pub apple: Option<provider::AppleConfig>,
353    /// OIDC issuer URL when this config targets a generic-OIDC
354    /// provider (Auth0, Okta, Keycloak, Cognito, etc.). When set,
355    /// the runtime fetches `<issuer>/.well-known/openid-configuration`
356    /// and synthesizes a [`provider::ProviderSpec`] from the
357    /// discovered endpoints.
358    #[serde(default, skip_serializing_if = "Option::is_none")]
359    pub oidc_issuer: Option<String>,
360}
361
362impl OAuthConfig {
363    /// Resolve the [`provider::ProviderSpec`] backing this config. For
364    /// `oidc_issuer`-configured providers, falls through to the OIDC
365    /// discovery cache. Errors propagate so misconfigured providers
366    /// fail loudly at first use rather than silently 404'ing later.
367    fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
368        if let Some(issuer) = self.oidc_issuer.as_deref() {
369            return provider::oidc_cache::resolve(issuer);
370        }
371        provider::find_spec(&self.provider)
372            .map(provider::ResolvedSpec::Static)
373            .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
374    }
375
376    /// Build a [`provider::ProviderConfig`] view of `self` for the
377    /// helpers in [`provider`] that take the runtime config.
378    fn provider_cfg(&self) -> provider::ProviderConfig {
379        provider::ProviderConfig {
380            provider: self.provider.clone(),
381            client_id: self.client_id.clone(),
382            client_secret: self.client_secret.clone(),
383            redirect_uri: self.redirect_uri.clone(),
384            scopes_override: self.scopes_override.clone(),
385            tenant: self.tenant.clone(),
386            apple: self.apple.clone(),
387            oidc_issuer: self.oidc_issuer.clone(),
388        }
389    }
390
391    /// Generate the authorization URL for the provider.
392    ///
393    /// Callers MUST append a `&state=<random>` parameter and validate it in the
394    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
395    /// implementation.
396    ///
397    /// For PKCE-required providers (Twitter/X, Kick), callers should
398    /// prefer [`Self::auth_url_with_pkce`] so the `code_challenge`
399    /// pair survives to the callback.
400    pub fn auth_url(&self) -> String {
401        match self.build_auth_url(None) {
402            Ok(u) => u,
403            Err(_) => String::new(),
404        }
405    }
406
407    /// Generate the authorization URL with a CSRF state parameter attached.
408    pub fn auth_url_with_state(&self, state: &str) -> String {
409        let base = self.auth_url();
410        if base.is_empty() {
411            return base;
412        }
413        format!("{}&state={}", base, url_encode(state))
414    }
415
416    /// Generate the authorization URL with state + a freshly minted
417    /// PKCE pair when the provider requires it. Returns
418    /// `(url, code_verifier)` — the verifier MUST be persisted in
419    /// the OAuth state record and replayed in the token exchange.
420    pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
421        let spec = self.resolved_spec()?;
422        let pkce = if spec.requires_pkce() {
423            Some(generate_pkce())
424        } else {
425            None
426        };
427        let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
428        let mut url = self.build_auth_url(challenge)?;
429        if !state.is_empty() {
430            url.push_str(&format!("&state={}", url_encode(state)));
431        }
432        Ok((url, pkce.map(|p| p.code_verifier)))
433    }
434
435    fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
436        let spec = self.resolved_spec()?;
437        let cfg = self.provider_cfg();
438        let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
439        if auth.is_empty() {
440            return Err(format!(
441                "provider {} has no authorization endpoint",
442                self.provider
443            ));
444        }
445        let scopes_default = spec.scopes().to_string();
446        let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
447        // Re-join scopes with the provider's separator (TikTok uses
448        // commas, everyone else uses spaces). Splitting on whitespace
449        // first lets developers always specify scopes the human way.
450        let scopes_joined = scopes_raw
451            .split_whitespace()
452            .collect::<Vec<_>>()
453            .join(spec.scope_separator());
454
455        let mut url = format!(
456            "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
457            cid_param = spec.client_id_param(),
458            cid = url_encode(&self.client_id),
459            ruri = url_encode(&self.redirect_uri),
460            scope = url_encode(&scopes_joined),
461        );
462        if !spec.auth_query_extra().is_empty() {
463            url.push('&');
464            url.push_str(spec.auth_query_extra());
465        }
466        if let Some(challenge) = pkce_challenge {
467            url.push_str("&code_challenge=");
468            url.push_str(challenge);
469            url.push_str("&code_challenge_method=S256");
470        }
471        Ok(url)
472    }
473
474    /// Generate the token exchange URL.
475    pub fn token_url(&self) -> String {
476        match self.resolved_spec() {
477            Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
478            Err(_) => String::new(),
479        }
480    }
481
482    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
483    pub fn userinfo_url(&self) -> String {
484        match self.resolved_spec() {
485            Ok(spec) => match spec.userinfo_url() {
486                Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
487                None => String::new(),
488            },
489            Err(_) => String::new(),
490        }
491    }
492
493    /// Exchange an authorization code for the full token set
494    /// (`access_token`, optional `refresh_token`, optional `id_token`,
495    /// `expires_in`, `scope`). When the provider uses PKCE,
496    /// `code_verifier` MUST be supplied (the value previously returned
497    /// from [`Self::auth_url_with_pkce`]).
498    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
499        self.exchange_code_full_pkce(code, None)
500    }
501
502    pub fn exchange_code_full_pkce(
503        &self,
504        code: &str,
505        code_verifier: Option<&str>,
506    ) -> Result<TokenSet, String> {
507        let spec = self.resolved_spec()?;
508        let cfg = self.provider_cfg();
509        let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
510        let pkce_field = code_verifier
511            .map(|v| format!("&code_verifier={}", url_encode(v)))
512            .unwrap_or_default();
513
514        let out = match spec.token_exchange() {
515            provider::TokenExchangeShape::Standard => {
516                let body = format!(
517                    "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
518                    code = url_encode(code),
519                    cid_param = spec.client_id_param(),
520                    cid = url_encode(&self.client_id),
521                    secret = url_encode(&self.client_secret),
522                    ruri = url_encode(&self.redirect_uri),
523                    pkce = pkce_field,
524                );
525                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
526            }
527            provider::TokenExchangeShape::AppleJwt => {
528                let apple = self.apple.as_ref().ok_or(
529                    "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
530                )?;
531                let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
532                let body = format!(
533                    "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
534                    code = url_encode(code),
535                    cid = url_encode(&self.client_id),
536                    secret = url_encode(&signed_secret),
537                    ruri = url_encode(&self.redirect_uri),
538                    pkce = pkce_field,
539                );
540                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
541            }
542            provider::TokenExchangeShape::BasicAuth => {
543                let body = format!(
544                    "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
545                    code = url_encode(code),
546                    ruri = url_encode(&self.redirect_uri),
547                    pkce = pkce_field,
548                );
549                http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
550                    .map_err(sanitize_token_error)?
551            }
552            provider::TokenExchangeShape::JsonBody => {
553                let mut json = serde_json::Map::new();
554                json.insert("grant_type".into(), "authorization_code".into());
555                json.insert("code".into(), code.into());
556                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
557                json.insert("client_id".into(), self.client_id.clone().into());
558                json.insert("client_secret".into(), self.client_secret.clone().into());
559                if let Some(v) = code_verifier {
560                    json.insert("code_verifier".into(), v.to_string().into());
561                }
562                let body = serde_json::Value::Object(json).to_string();
563                http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
564            }
565            provider::TokenExchangeShape::BasicAuthJsonBody => {
566                let mut json = serde_json::Map::new();
567                json.insert("grant_type".into(), "authorization_code".into());
568                json.insert("code".into(), code.into());
569                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
570                if let Some(v) = code_verifier {
571                    json.insert("code_verifier".into(), v.to_string().into());
572                }
573                let body = serde_json::Value::Object(json).to_string();
574                http_post_json(
575                    &token_url,
576                    &body,
577                    Some((&self.client_id, &self.client_secret)),
578                )
579                .map_err(sanitize_token_error)?
580            }
581        };
582        parse_token_response(&out)
583    }
584
585    /// Exchange an authorization code for an access token. Thin wrapper
586    /// around [`OAuthConfig::exchange_code_full`] for callers that only
587    /// need the access token.
588    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
589        Ok(self.exchange_code_full(code)?.access_token)
590    }
591
592    /// Fetch the authenticated user's email + display name using an access token.
593    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
594        let info = self.fetch_userinfo_full(access_token)?;
595        Ok((info.email, info.name))
596    }
597
598    /// Fetch the authenticated user's full identity info — email + name +
599    /// the provider-stable account ID. Uses the spec's
600    /// [`provider::UserinfoParser`] so adding a new provider is a
601    /// table change, not a new branch.
602    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
603        // The id_token from the token response carries the identity
604        // for Apple and similar; route to the dedicated entry point.
605        // Apple's userinfo_url is None — this is the supported path.
606        self.fetch_userinfo_with_id_token(access_token, None)
607    }
608
609    /// Fetch userinfo, falling back to the supplied id_token JWT when
610    /// the provider has no userinfo endpoint (Apple). The id_token
611    /// is the one returned by [`Self::exchange_code_full`] in
612    /// [`TokenSet::id_token`].
613    pub fn fetch_userinfo_with_id_token(
614        &self,
615        access_token: &str,
616        id_token: Option<&str>,
617    ) -> Result<UserInfo, String> {
618        let spec = self.resolved_spec()?;
619        let cfg = self.provider_cfg();
620
621        // Apple — identity lives in the id_token, not a userinfo endpoint.
622        if matches!(spec.userinfo_parser(), provider::UserinfoParser::AppleIdToken) {
623            let token = id_token
624                .ok_or("apple login requires the id_token from the token response")?;
625            return parse_apple_id_token(token, &self.provider);
626        }
627
628        // Linear is GraphQL — the userinfo "GET" is actually a POST
629        // with a fixed query.
630        if matches!(spec.userinfo_parser(), provider::UserinfoParser::LinearGraphql) {
631            return fetch_linear_userinfo(&self.provider, access_token);
632        }
633
634        let url = match spec.userinfo_url() {
635            Some(u) => provider::resolve_endpoint(u, &cfg),
636            None => return Err(format!("provider {} has no userinfo endpoint", self.provider)),
637        };
638        let out = match spec.userinfo_method() {
639            provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
640            provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
641        }
642        .map_err(sanitize_token_error)?;
643        let parsed: serde_json::Value =
644            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
645
646        match spec.userinfo_parser() {
647            provider::UserinfoParser::Oidc => {
648                let email = parsed
649                    .get("email")
650                    .and_then(|v| v.as_str())
651                    .ok_or("no email in userinfo")?
652                    .to_string();
653                let name = parsed
654                    .get("name")
655                    .and_then(|v| v.as_str())
656                    .map(String::from);
657                let provider_account_id = parsed
658                    .get("sub")
659                    .and_then(|v| v.as_str())
660                    .ok_or("no sub in userinfo")?
661                    .to_string();
662                Ok(UserInfo {
663                    provider: self.provider.clone(),
664                    provider_account_id,
665                    email,
666                    name,
667                })
668            }
669            provider::UserinfoParser::GitHub => {
670                let name = parsed
671                    .get("name")
672                    .and_then(|v| v.as_str())
673                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
674                    .map(String::from);
675                let email = parsed
676                    .get("email")
677                    .and_then(|v| v.as_str())
678                    .map(String::from);
679                let email = email
680                    .or_else(|| fetch_github_primary_email(access_token).ok())
681                    .ok_or("no accessible email on GitHub account")?;
682                let provider_account_id = parsed
683                    .get("id")
684                    .map(|v| {
685                        v.as_i64()
686                            .map(|n| n.to_string())
687                            .or_else(|| v.as_str().map(String::from))
688                            .unwrap_or_default()
689                    })
690                    .filter(|s| !s.is_empty())
691                    .ok_or("no id in userinfo")?;
692                Ok(UserInfo {
693                    provider: self.provider.clone(),
694                    provider_account_id,
695                    email,
696                    name,
697                })
698            }
699            provider::UserinfoParser::Custom {
700                id_path,
701                email_path,
702                name_path,
703            } => {
704                let provider_account_id = json_pointer_string(&parsed, id_path)
705                    .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
706                let raw_email = json_pointer_string(&parsed, email_path)
707                    .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
708                // Twitter/Reddit don't expose real emails — they map a
709                // username into the email slot. Tag it so account
710                // policies can distinguish "real verified email" from
711                // "we made this up." `.invalid` is reserved by RFC 6761.
712                let email = if !raw_email.contains('@') {
713                    let domain = match self.provider.as_str() {
714                        "twitter" => "x.invalid",
715                        "reddit" => "reddit.invalid",
716                        other => return Err(format!(
717                            "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
718                        )),
719                    };
720                    format!("{raw_email}@{domain}")
721                } else {
722                    raw_email
723                };
724                let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
725                Ok(UserInfo {
726                    provider: self.provider.clone(),
727                    provider_account_id,
728                    email,
729                    name,
730                })
731            }
732            provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
733            provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
734        }
735    }
736}
737
738/// PKCE pair — the verifier stays server-side until token exchange,
739/// the (S256-hashed) challenge goes on the auth URL.
740struct PkcePair {
741    code_verifier: String,
742    code_challenge: String,
743}
744
745/// Generate a PKCE pair: random 43-char verifier + S256 challenge.
746/// RFC 7636 §4.1 permits 43–128 chars from `[A-Za-z0-9-._~]`. 32
747/// random bytes URL-base64-encoded comes out to exactly 43 chars.
748fn generate_pkce() -> PkcePair {
749    use rand::RngCore;
750    let mut bytes = [0u8; 32];
751    rand::thread_rng().fill_bytes(&mut bytes);
752    let code_verifier = apple_jwt::base64_url(bytes);
753    use sha2::{Digest, Sha256};
754    let mut hasher = Sha256::new();
755    hasher.update(code_verifier.as_bytes());
756    let code_challenge = apple_jwt::base64_url(hasher.finalize());
757    PkcePair {
758        code_verifier,
759        code_challenge,
760    }
761}
762
763/// Decode an Apple id_token JWT and pull the identity claims.
764///
765/// **Trust assumption:** the caller MUST have obtained this token
766/// via the back-channel `/auth/token` exchange (mutually authenticated
767/// TLS to `appleid.apple.com`). Under that assumption no third party
768/// can have substituted a forged JWT, so we skip signature
769/// verification.
770///
771/// **DO NOT call this on a JWT supplied by the client** (e.g. a
772/// "post your id_token to me" mobile-SDK flow). For those paths,
773/// implement Apple JWKS verification: fetch
774/// `https://appleid.apple.com/auth/keys`, verify the RS256
775/// signature, then check `iss == "https://appleid.apple.com"`,
776/// `aud == client_id`, and `exp > now`. Pylon doesn't ship that
777/// verifier yet — apps that need it can compose `crate::jwt::verify`
778/// against a JWKS-loaded RSA key.
779///
780/// This function is private (`fn`, not `pub fn`) precisely so it
781/// can't be misused by an external caller. The only call site is
782/// [`OAuthConfig::fetch_userinfo_with_id_token`] which is reached
783/// only via the OAuth callback handler, which only processes
784/// back-channel-exchanged tokens.
785fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
786    let mut parts = id_token.split('.');
787    let _header = parts.next().ok_or("apple id_token: missing header")?;
788    let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
789    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
790    let claims_bytes = URL_SAFE_NO_PAD
791        .decode(claims_b64)
792        .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
793    let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
794        .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
795    let provider_account_id = claims
796        .get("sub")
797        .and_then(|v| v.as_str())
798        .ok_or("apple id_token: missing sub")?
799        .to_string();
800    let email = claims
801        .get("email")
802        .and_then(|v| v.as_str())
803        .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
804        .to_string();
805    Ok(UserInfo {
806        provider: provider.to_string(),
807        provider_account_id,
808        email,
809        name: None, // Apple sends `name` as a separate form field on FIRST signup only.
810    })
811}
812
813/// Strip provider error bodies of secrets before they propagate to
814/// logs / `oauth_error_message` redirect URLs.
815///
816/// **Why:** Several token endpoints echo the request body (or pieces
817/// of it) on auth failure. Without this, a misconfigured deployment
818/// can leak `client_secret`, the Apple JWT, or even the auth `code`
819/// into the user's browser history and CDN logs.
820///
821/// Covers both shapes echoed by real providers:
822///   - form / query: `client_secret=sk_…`
823///   - JSON: `"client_secret":"sk_…"` (Notion, Atlassian)
824fn sanitize_token_error(err: String) -> String {
825    const SENSITIVE: &[&str] = &[
826        "client_secret",
827        "code_verifier",
828        "client_assertion",
829        "refresh_token",
830        "access_token",
831        "id_token",
832        // The auth `code` itself is single-use but still sensitive
833        // until the token endpoint consumes it — and many providers
834        // echo it back on a 4xx token-exchange error before the
835        // attacker has had a chance to redeem it.
836        "code",
837    ];
838    let mut out = err;
839    for key in SENSITIVE {
840        out = redact_param_form(&out, key);
841        out = redact_param_json(&out, key);
842    }
843    out
844}
845
846/// Replace the value of `key=…` (form/query string) with `***`,
847/// terminating at any of `& \n " '`. UTF-8 safe — uses `char_indices`
848/// so a stray multibyte character before a sensitive key won't panic.
849fn redact_param_form(input: &str, key: &str) -> String {
850    let needle = format!("{key}=");
851    let mut out = String::with_capacity(input.len());
852    let mut i = 0;
853    while i < input.len() {
854        if input[i..].starts_with(&needle) {
855            out.push_str(&needle);
856            out.push_str("***");
857            i += needle.len();
858            // Skip until a terminator. char_indices keeps i aligned
859            // to char boundaries.
860            while let Some((rel, ch)) = input[i..].char_indices().next() {
861                if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
862                    i += rel;
863                    break;
864                }
865                i += rel + ch.len_utf8();
866            }
867        } else {
868            // Advance by one full char to stay UTF-8 aligned.
869            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
870            out.push(ch);
871            i += ch.len_utf8();
872        }
873    }
874    out
875}
876
877/// Replace the value in `"key":"…"` with `***`. Case-sensitive,
878/// tolerant of whitespace between `:` and the value (per JSON).
879fn redact_param_json(input: &str, key: &str) -> String {
880    let needle = format!("\"{key}\"");
881    let mut out = String::with_capacity(input.len());
882    let mut i = 0;
883    while i < input.len() {
884        if !input[i..].starts_with(&needle) {
885            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
886            out.push(ch);
887            i += ch.len_utf8();
888            continue;
889        }
890        // Found `"key"`. Walk forward over `:` + optional whitespace,
891        // then `"`, then the value, then closing `"`. If anything
892        // is off (not actually a string-valued field) bail and
893        // copy verbatim.
894        let mut j = i + needle.len();
895        // optional whitespace
896        while let Some((_, ch)) = input[j..].char_indices().next() {
897            if !ch.is_whitespace() {
898                break;
899            }
900            j += ch.len_utf8();
901        }
902        if !input[j..].starts_with(':') {
903            // Not a key-value form (could be in an array, etc.).
904            out.push_str(&input[i..j]);
905            i = j;
906            continue;
907        }
908        j += 1;
909        while let Some((_, ch)) = input[j..].char_indices().next() {
910            if !ch.is_whitespace() {
911                break;
912            }
913            j += ch.len_utf8();
914        }
915        if !input[j..].starts_with('"') {
916            out.push_str(&input[i..j]);
917            i = j;
918            continue;
919        }
920        let value_start = j + 1;
921        // Find the closing `"`, honoring `\"` escapes.
922        let mut k = value_start;
923        let mut prev_backslash = false;
924        let mut closing: Option<usize> = None;
925        while k < input.len() {
926            let (_, ch) = input[k..].char_indices().next().expect("non-empty");
927            if ch == '"' && !prev_backslash {
928                closing = Some(k);
929                break;
930            }
931            prev_backslash = ch == '\\' && !prev_backslash;
932            k += ch.len_utf8();
933        }
934        match closing {
935            Some(end) => {
936                out.push_str(&input[i..value_start]);
937                out.push_str("***");
938                out.push('"');
939                i = end + 1;
940            }
941            None => {
942                // Malformed JSON, redact to end of input to be safe.
943                out.push_str(&input[i..value_start]);
944                out.push_str("***");
945                i = input.len();
946            }
947        }
948    }
949    out
950}
951
952/// Linear's userinfo lives behind a GraphQL endpoint — the bearer
953/// token is the same OAuth access token, but the request is a POST
954/// with a fixed query. Kept as a separate fn so the main fetcher
955/// stays uniform across the other parsers.
956fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
957    let body = r#"{"query":"query { viewer { id email name } }"}"#;
958    let agent = ureq_agent();
959    let resp = agent
960        .post("https://api.linear.app/graphql")
961        .set("Authorization", &format!("Bearer {access_token}"))
962        .set("Content-Type", "application/json")
963        .set("Accept", "application/json")
964        .send_string(body)
965        .map_err(|e| format!("linear graphql: {e}"))?;
966    let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
967    let parsed: serde_json::Value = serde_json::from_str(&out)
968        .map_err(|e| format!("linear graphql not JSON: {e}"))?;
969    let viewer = parsed
970        .pointer("/data/viewer")
971        .ok_or("linear graphql: no /data/viewer")?;
972    let provider_account_id = viewer
973        .get("id")
974        .and_then(|v| v.as_str())
975        .ok_or("linear graphql: no id")?
976        .to_string();
977    let email = viewer
978        .get("email")
979        .and_then(|v| v.as_str())
980        .ok_or("linear graphql: no email")?
981        .to_string();
982    let name = viewer.get("name").and_then(|v| v.as_str()).map(String::from);
983    Ok(UserInfo {
984        provider: provider.to_string(),
985        provider_account_id,
986        email,
987        name,
988    })
989}
990
991/// JSON-pointer (RFC 6901) string extraction. Returns `None` for
992/// missing paths or non-string values. Numeric ids (Discord's `id`,
993/// Roblox's `sub`) are coerced to strings.
994fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
995    let node = v.pointer(path)?;
996    if let Some(s) = node.as_str() {
997        return Some(s.to_string());
998    }
999    if let Some(n) = node.as_i64() {
1000        return Some(n.to_string());
1001    }
1002    if let Some(n) = node.as_u64() {
1003        return Some(n.to_string());
1004    }
1005    None
1006}
1007
1008/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
1009/// `provider_account_id` is the provider-stable subject id (Google `sub`,
1010/// GitHub numeric `id`) — what the account store keys on so a renamed
1011/// email doesn't orphan the pylon account.
1012#[derive(Debug, Clone, PartialEq, Eq)]
1013pub struct UserInfo {
1014    pub provider: String,
1015    pub provider_account_id: String,
1016    pub email: String,
1017    pub name: Option<String>,
1018}
1019
1020/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
1021/// on the matching `Account` row so `refresh_token` is available for
1022/// silent re-auth and `expires_at` is checked before each provider call.
1023#[derive(Debug, Clone, PartialEq, Eq)]
1024pub struct TokenSet {
1025    pub access_token: String,
1026    pub refresh_token: Option<String>,
1027    pub id_token: Option<String>,
1028    /// Unix epoch seconds at which the access token expires. `None` when
1029    /// the provider didn't return `expires_in` (GitHub's classic OAuth
1030    /// app tokens are non-expiring).
1031    pub expires_at: Option<u64>,
1032    pub scope: Option<String>,
1033}
1034
1035fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1036    // Most providers return JSON; GitHub Classic apps return form-urlencoded
1037    // unless you ask with Accept: application/json (which we do).
1038    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1039        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
1040        let mut map = serde_json::Map::new();
1041        for pair in body.split('&') {
1042            if let Some((k, v)) = pair.split_once('=') {
1043                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1044            }
1045        }
1046        serde_json::Value::Object(map)
1047    });
1048
1049    let access_token = json
1050        .get("access_token")
1051        .and_then(|v| v.as_str())
1052        .ok_or_else(|| format!("no access_token in token response: {body}"))?
1053        .to_string();
1054    let refresh_token = json
1055        .get("refresh_token")
1056        .and_then(|v| v.as_str())
1057        .map(String::from);
1058    let id_token = json
1059        .get("id_token")
1060        .and_then(|v| v.as_str())
1061        .map(String::from);
1062    let expires_at = json
1063        .get("expires_in")
1064        .and_then(|v| {
1065            v.as_u64()
1066                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1067        })
1068        .map(|secs| now_secs().saturating_add(secs));
1069    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1070    Ok(TokenSet {
1071        access_token,
1072        refresh_token,
1073        id_token,
1074        expires_at,
1075        scope,
1076    })
1077}
1078
1079fn url_encode(s: &str) -> String {
1080    let mut out = String::with_capacity(s.len());
1081    for b in s.bytes() {
1082        match b {
1083            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1084                out.push(b as char)
1085            }
1086            _ => out.push_str(&format!("%{b:02X}")),
1087        }
1088    }
1089    out
1090}
1091
1092/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
1093/// provider doesn't block a login indefinitely; long enough to absorb
1094/// typical internet latency.
1095const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1096
1097fn ureq_agent() -> ureq::Agent {
1098    ureq::AgentBuilder::new()
1099        .timeout_connect(HTTP_TIMEOUT)
1100        .timeout_read(HTTP_TIMEOUT)
1101        .timeout_write(HTTP_TIMEOUT)
1102        .user_agent("pylon/0.1")
1103        .build()
1104}
1105
1106fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1107    let agent = ureq_agent();
1108    let mut req = agent
1109        .post(url)
1110        .set("Content-Type", "application/x-www-form-urlencoded");
1111    if accept_json {
1112        req = req.set("Accept", "application/json");
1113    }
1114    match req.send_string(body) {
1115        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1116        Err(ureq::Error::Status(code, resp)) => {
1117            let body = resp.into_string().unwrap_or_default();
1118            Err(format!("HTTP {code}: {body}"))
1119        }
1120        Err(e) => Err(format!("HTTP error: {e}")),
1121    }
1122}
1123
1124/// POST a form body using HTTP Basic auth for the client credentials.
1125/// Used by Spotify, Reddit, Figma, Zoom, PayPal — providers that
1126/// mandate Basic auth on the token endpoint.
1127fn http_post_form_basic(
1128    url: &str,
1129    body: &str,
1130    client_id: &str,
1131    client_secret: &str,
1132) -> Result<String, String> {
1133    use base64::{engine::general_purpose::STANDARD, Engine};
1134    let creds = format!("{client_id}:{client_secret}");
1135    let basic = STANDARD.encode(creds.as_bytes());
1136    let agent = ureq_agent();
1137    match agent
1138        .post(url)
1139        .set("Content-Type", "application/x-www-form-urlencoded")
1140        .set("Accept", "application/json")
1141        .set("Authorization", &format!("Basic {basic}"))
1142        .send_string(body)
1143    {
1144        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1145        Err(ureq::Error::Status(code, resp)) => {
1146            let body = resp.into_string().unwrap_or_default();
1147            Err(format!("HTTP {code}: {body}"))
1148        }
1149        Err(e) => Err(format!("HTTP error: {e}")),
1150    }
1151}
1152
1153/// POST a JSON body, optionally with HTTP Basic auth. Used by
1154/// Notion (Basic + JSON) and Atlassian (JSON only) — both reject
1155/// form-encoded bodies on their token endpoints.
1156fn http_post_json(
1157    url: &str,
1158    body: &str,
1159    basic_creds: Option<(&str, &str)>,
1160) -> Result<String, String> {
1161    let agent = ureq_agent();
1162    let mut req = agent
1163        .post(url)
1164        .set("Content-Type", "application/json")
1165        .set("Accept", "application/json");
1166    if let Some((id, secret)) = basic_creds {
1167        use base64::{engine::general_purpose::STANDARD, Engine};
1168        let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1169        req = req.set("Authorization", &format!("Basic {creds}"));
1170    }
1171    // Notion requires the API version header on every call, even the
1172    // token exchange. Using a recent stable version.
1173    req = req.set("Notion-Version", "2022-06-28");
1174    match req.send_string(body) {
1175        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1176        Err(ureq::Error::Status(code, resp)) => {
1177            let body = resp.into_string().unwrap_or_default();
1178            Err(format!("HTTP {code}: {body}"))
1179        }
1180        Err(e) => Err(format!("HTTP error: {e}")),
1181    }
1182}
1183
1184/// POST with empty body + bearer auth. Used for Dropbox userinfo
1185/// (an RPC-style endpoint that requires POST instead of GET).
1186fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1187    let agent = ureq_agent();
1188    match agent
1189        .post(url)
1190        .set("Authorization", &format!("Bearer {token}"))
1191        .set("Accept", "application/json")
1192        .call()
1193    {
1194        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1195        Err(ureq::Error::Status(code, resp)) => {
1196            let body = resp.into_string().unwrap_or_default();
1197            Err(format!("HTTP {code}: {body}"))
1198        }
1199        Err(e) => Err(format!("HTTP error: {e}")),
1200    }
1201}
1202
1203fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1204    let agent = ureq_agent();
1205    match agent
1206        .get(url)
1207        .set("Authorization", &format!("Bearer {token}"))
1208        .set("Accept", "application/json")
1209        .call()
1210    {
1211        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1212        Err(ureq::Error::Status(code, resp)) => {
1213            let body = resp.into_string().unwrap_or_default();
1214            Err(format!("HTTP {code}: {body}"))
1215        }
1216        Err(e) => Err(format!("HTTP error: {e}")),
1217    }
1218}
1219
1220fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1221    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1222    let emails: serde_json::Value =
1223        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1224    emails
1225        .as_array()
1226        .and_then(|arr| {
1227            arr.iter()
1228                .find(|e| {
1229                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1230                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1231                })
1232                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1233        })
1234        .ok_or_else(|| "no primary verified email on GitHub".into())
1235}
1236
1237/// OAuth provider registry.
1238pub struct OAuthRegistry {
1239    providers: std::collections::HashMap<String, OAuthConfig>,
1240}
1241
1242impl Default for OAuthRegistry {
1243    fn default() -> Self {
1244        Self::new()
1245    }
1246}
1247
1248impl OAuthRegistry {
1249    pub fn new() -> Self {
1250        Self {
1251            providers: std::collections::HashMap::new(),
1252        }
1253    }
1254
1255    pub fn register(&mut self, config: OAuthConfig) {
1256        self.providers.insert(config.provider.clone(), config);
1257    }
1258
1259    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1260        self.providers.get(provider)
1261    }
1262
1263    /// Build from environment variables.
1264    ///
1265    /// For each builtin provider (and any `oidc_issuer`-configured
1266    /// IdP), looks for `PYLON_OAUTH_<PROVIDER>_CLIENT_ID` /
1267    /// `_CLIENT_SECRET` / `_REDIRECT`. Apple additionally requires
1268    /// `_TEAM_ID`, `_KEY_ID`, `_PRIVATE_KEY` (PEM contents or path).
1269    /// Microsoft accepts an optional `_TENANT`.
1270    ///
1271    /// Generic OIDC: any env var matching
1272    /// `PYLON_OAUTH_<NAME>_OIDC_ISSUER` registers a provider with id
1273    /// `<name>` (lowercased) using the discovered endpoints. Useful
1274    /// for Auth0, Okta, Keycloak, Cognito, Logto, Authentik, etc.
1275    pub fn from_env() -> Self {
1276        let mut reg = Self::new();
1277
1278        for spec in provider::builtin::all() {
1279            let upper = spec.id.to_ascii_uppercase();
1280            let prefix = format!("PYLON_OAUTH_{upper}");
1281            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1282                Ok(v) => v,
1283                Err(_) => continue,
1284            };
1285            let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1286                Ok(v) => v,
1287                // Apple's "client_secret" is synthesized — allow blank.
1288                Err(_) if spec.id == "apple" => String::new(),
1289                Err(_) => continue,
1290            };
1291            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT")).unwrap_or_else(|_| {
1292                format!("http://localhost:3000/api/auth/callback/{}", spec.id)
1293            });
1294            let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1295            let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1296
1297            let apple = if spec.id == "apple" {
1298                match (
1299                    std::env::var(format!("{prefix}_TEAM_ID")),
1300                    std::env::var(format!("{prefix}_KEY_ID")),
1301                    std::env::var(format!("{prefix}_PRIVATE_KEY")),
1302                ) {
1303                    (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1304                        team_id,
1305                        key_id,
1306                        private_key_pem,
1307                    }),
1308                    _ => continue, // Apple requires the JWT material to function.
1309                }
1310            } else {
1311                None
1312            };
1313
1314            reg.register(OAuthConfig {
1315                provider: spec.id.to_string(),
1316                client_id: id,
1317                client_secret: secret,
1318                redirect_uri,
1319                scopes_override,
1320                tenant,
1321                apple,
1322                oidc_issuer: None,
1323            });
1324        }
1325
1326        // Generic OIDC providers — scan PYLON_OAUTH_<NAME>_OIDC_ISSUER.
1327        for (key, issuer) in std::env::vars() {
1328            let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1329                continue;
1330            };
1331            let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1332                continue;
1333            };
1334            let name = name_upper.to_ascii_lowercase();
1335            if provider::find_spec(&name).is_some() {
1336                continue; // already handled as a builtin
1337            }
1338            let prefix = format!("PYLON_OAUTH_{name_upper}");
1339            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1340                Ok(v) => v,
1341                Err(_) => continue,
1342            };
1343            let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1344            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1345                .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1346            reg.register(OAuthConfig {
1347                provider: name,
1348                client_id: id,
1349                client_secret: secret,
1350                redirect_uri,
1351                scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1352                tenant: None,
1353                apple: None,
1354                oidc_issuer: Some(issuer),
1355            });
1356        }
1357
1358        reg
1359    }
1360
1361    /// Iterate over registered provider ids — used by routes/auth.rs
1362    /// to expose `/api/auth/providers` and to validate
1363    /// `/api/auth/login/<id>` paths against the configured set.
1364    pub fn ids(&self) -> impl Iterator<Item = &str> {
1365        self.providers.keys().map(|s| s.as_str())
1366    }
1367
1368    /// Process-wide cached registry. Built once on first use from
1369    /// `from_env`; subsequent calls are zero-cost. Routes use this
1370    /// to avoid the ~150 syscalls `from_env` does per call.
1371    ///
1372    /// **Trade-off:** env changes after server start aren't picked up
1373    /// without a restart — same as every other Pylon env-var path.
1374    pub fn shared() -> &'static OAuthRegistry {
1375        static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1376        CELL.get_or_init(Self::from_env)
1377    }
1378}
1379
1380// ---------------------------------------------------------------------------
1381// OAuth state store — CSRF protection for OAuth flows
1382// ---------------------------------------------------------------------------
1383
1384/// One stored OAuth state record. Carries the post-callback redirect
1385/// URLs alongside the provider so the callback handler doesn't need to
1386/// consult an env var to know where to send the user. Both URLs are
1387/// validated against `PYLON_TRUSTED_ORIGINS` at create time, so the
1388/// callback can trust them without re-checking.
1389#[derive(Debug, Clone, PartialEq, Eq)]
1390pub struct OAuthState {
1391    pub provider: String,
1392    /// URL the callback redirects to on success. The frontend supplies
1393    /// this via `?callback=` on the start request.
1394    pub callback_url: String,
1395    /// URL the callback redirects to on failure. Defaults to
1396    /// `callback_url` when the frontend doesn't pass an explicit
1397    /// `?error_callback=`. The error code + message ride along as
1398    /// query params (`?oauth_error=X&oauth_error_message=Y`).
1399    pub error_callback_url: String,
1400    /// PKCE code_verifier when the provider requires PKCE. Set by the
1401    /// `/api/auth/login/<provider>` start route via
1402    /// [`OAuthConfig::auth_url_with_pkce`]; replayed on token exchange
1403    /// in the callback. `None` for non-PKCE providers.
1404    pub pkce_verifier: Option<String>,
1405    pub expires_at: u64,
1406}
1407
1408/// Backing store for OAuth state records. Default impl keeps them in
1409/// memory (fine for tests + dev); the runtime swaps in a SQLite or
1410/// Postgres backend so a restart in the middle of an OAuth handshake
1411/// doesn't leave the user with "invalid state" on the callback.
1412pub trait OAuthStateBackend: Send + Sync {
1413    /// Persist a state record under `token`.
1414    fn put(&self, token: &str, state: &OAuthState);
1415    /// Atomic compare-and-consume: returns the stored record if the
1416    /// token exists and hasn't expired, then removes it. Returning
1417    /// `None` means either the token never existed or it has already
1418    /// been used / expired.
1419    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1420}
1421
1422/// In-memory backend (default). Lost on restart.
1423pub struct InMemoryOAuthBackend {
1424    states: Mutex<HashMap<String, OAuthState>>,
1425}
1426
1427impl InMemoryOAuthBackend {
1428    pub fn new() -> Self {
1429        Self {
1430            states: Mutex::new(HashMap::new()),
1431        }
1432    }
1433}
1434
1435impl Default for InMemoryOAuthBackend {
1436    fn default() -> Self {
1437        Self::new()
1438    }
1439}
1440
1441impl OAuthStateBackend for InMemoryOAuthBackend {
1442    fn put(&self, token: &str, state: &OAuthState) {
1443        self.states
1444            .lock()
1445            .unwrap()
1446            .insert(token.to_string(), state.clone());
1447    }
1448    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1449        let mut s = self.states.lock().unwrap();
1450        let entry = s.remove(token)?;
1451        if entry.expires_at <= now_unix_secs {
1452            return None;
1453        }
1454        Some(entry)
1455    }
1456}
1457
1458/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
1459///
1460/// State tokens are short-lived (10 minutes) and single-use. Backed by an
1461/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
1462/// to SQLite (or Postgres when `DATABASE_URL` is set) so they survive a
1463/// restart that happens mid-OAuth-handshake.
1464pub struct OAuthStateStore {
1465    backend: Box<dyn OAuthStateBackend>,
1466}
1467
1468impl Default for OAuthStateStore {
1469    fn default() -> Self {
1470        Self::new()
1471    }
1472}
1473
1474impl OAuthStateStore {
1475    pub fn new() -> Self {
1476        Self {
1477            backend: Box::new(InMemoryOAuthBackend::new()),
1478        }
1479    }
1480
1481    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1482        Self { backend }
1483    }
1484
1485    /// Generate and store a new state record. Returns the random
1486    /// state token (the value the OAuth provider echoes back as
1487    /// `?state=…` on the callback).
1488    ///
1489    /// Caller is responsible for validating `callback_url` and
1490    /// `error_callback_url` against the trusted-origins allowlist
1491    /// BEFORE calling this — the store trusts what it's given.
1492    pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1493        self.create_with_pkce(provider, callback_url, error_callback_url, None)
1494    }
1495
1496    /// Same as [`Self::create`] but accepts a PKCE verifier to stash
1497    /// alongside the state record. The callback handler reads it back
1498    /// out and replays it in the token exchange.
1499    pub fn create_with_pkce(
1500        &self,
1501        provider: &str,
1502        callback_url: &str,
1503        error_callback_url: &str,
1504        pkce_verifier: Option<String>,
1505    ) -> String {
1506        use std::time::{SystemTime, UNIX_EPOCH};
1507        let token = generate_token();
1508        let now = SystemTime::now()
1509            .duration_since(UNIX_EPOCH)
1510            .unwrap_or_default()
1511            .as_secs();
1512        let state = OAuthState {
1513            provider: provider.to_string(),
1514            callback_url: callback_url.to_string(),
1515            error_callback_url: error_callback_url.to_string(),
1516            pkce_verifier,
1517            expires_at: now + 600,
1518        };
1519        self.backend.put(&token, &state);
1520        token
1521    }
1522
1523    /// Validate and consume a state token. Returns the stored record
1524    /// iff the token existed, has not expired, AND matches
1525    /// `expected_provider`. The token is removed either way to make
1526    /// replay impossible.
1527    pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1528        use std::time::{SystemTime, UNIX_EPOCH};
1529        let now = SystemTime::now()
1530            .duration_since(UNIX_EPOCH)
1531            .unwrap_or_default()
1532            .as_secs();
1533        let entry = self.backend.take(state, now)?;
1534        if entry.provider != expected_provider {
1535            return None;
1536        }
1537        Some(entry)
1538    }
1539}
1540
1541/// Validate that `url` has an origin (scheme://host[:port]) listed in
1542/// `trusted_origins`. Returns `Ok(url)` when trusted (echoes input for
1543/// chaining), `Err` with a code/message when not. Used by the OAuth
1544/// start endpoint to gate `?callback=` + `?error_callback=` values
1545/// before storing them in the state record.
1546///
1547/// `trusted_origins` entries are origin strings like
1548/// `"https://app.example.com"` or `"http://localhost:3000"` — no
1549/// trailing slash, no path. A `url` like
1550/// `"http://localhost:3000/dashboard?x=1"` matches the
1551/// `"http://localhost:3000"` entry.
1552///
1553/// Borrowed wholesale from better-auth's `trustedOrigins` model:
1554/// explicit allowlist, no implicit "same-origin trust," no env-var
1555/// magic. An open-redirect via OAuth is one of the easier auth bugs
1556/// to ship by accident.
1557pub fn validate_trusted_redirect(
1558    url: &str,
1559    trusted_origins: &[String],
1560) -> Result<(), TrustedOriginError> {
1561    if url.is_empty() {
1562        return Err(TrustedOriginError::Empty);
1563    }
1564    // Must be absolute http(s) URL — no relative paths, no schemes
1565    // like javascript:, file:, data:.
1566    if !url.starts_with("http://") && !url.starts_with("https://") {
1567        return Err(TrustedOriginError::NotHttp);
1568    }
1569    let url_origin = origin_of(url);
1570    if trusted_origins.iter().any(|t| t == &url_origin) {
1571        Ok(())
1572    } else {
1573        Err(TrustedOriginError::NotTrusted { origin: url_origin })
1574    }
1575}
1576
1577/// Reasons a redirect URL might be rejected by [`validate_trusted_redirect`].
1578#[derive(Debug, Clone, PartialEq, Eq)]
1579pub enum TrustedOriginError {
1580    Empty,
1581    NotHttp,
1582    NotTrusted { origin: String },
1583}
1584
1585impl std::fmt::Display for TrustedOriginError {
1586    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1587        match self {
1588            TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1589            TrustedOriginError::NotHttp => {
1590                write!(f, "redirect URL must use http:// or https:// scheme")
1591            }
1592            TrustedOriginError::NotTrusted { origin } => write!(
1593                f,
1594                "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1595            ),
1596        }
1597    }
1598}
1599
1600/// Extract the origin (`scheme://host[:port]`) from a URL string,
1601/// stripping any path/query/fragment. Best-effort string slicing —
1602/// no full URL parser dep. Public so router crates can reuse the same
1603/// logic when comparing redirect URLs against the trusted-origins list.
1604pub fn origin_of(url: &str) -> String {
1605    let after_scheme = match url.find("://") {
1606        Some(i) => i + 3,
1607        None => return url.trim_end_matches('/').to_string(),
1608    };
1609    let rest = &url[after_scheme..];
1610    let cut = rest
1611        .find(|c: char| c == '/' || c == '?' || c == '#')
1612        .unwrap_or(rest.len());
1613    url[..after_scheme + cut].to_string()
1614}
1615
1616// ---------------------------------------------------------------------------
1617// Magic code auth — email verification codes
1618// ---------------------------------------------------------------------------
1619
1620/// Pluggable storage for magic-code records. In-memory is the default
1621/// (fine for dev); persistent backends (SQLite, Postgres) live in
1622/// `pylon-runtime` so a server restart between "send code" and "verify
1623/// code" doesn't invalidate the user's pending login.
1624///
1625/// All methods are infallible from the caller's perspective — durability
1626/// is best-effort. A backend that fails to write should log; the
1627/// in-memory cache remains authoritative for the current process.
1628pub trait MagicCodeBackend: Send + Sync {
1629    /// Replace any existing code for `email` with `code`.
1630    fn put(&self, email: &str, code: &MagicCode);
1631    /// Look up the current code for `email`. Returns `None` if absent.
1632    fn get(&self, email: &str) -> Option<MagicCode>;
1633    /// Remove the code for `email` (called on successful verify or
1634    /// expiry). Idempotent — missing key is not an error.
1635    fn remove(&self, email: &str);
1636    /// Persist an attempts++ on the existing record without touching
1637    /// other fields. Used by the verify-failed path to enforce
1638    /// `MAX_ATTEMPTS` across restarts.
1639    fn bump_attempts(&self, email: &str);
1640    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
1641    /// hydrate the in-memory cache from durable storage on startup.
1642    fn load_all(&self) -> Vec<MagicCode>;
1643}
1644
1645/// In-memory backend for magic codes. The default — also used as the
1646/// authoritative cache by `MagicCodeStore`.
1647pub struct InMemoryMagicCodeBackend {
1648    codes: Mutex<HashMap<String, MagicCode>>,
1649}
1650
1651impl InMemoryMagicCodeBackend {
1652    pub fn new() -> Self {
1653        Self {
1654            codes: Mutex::new(HashMap::new()),
1655        }
1656    }
1657}
1658
1659impl Default for InMemoryMagicCodeBackend {
1660    fn default() -> Self {
1661        Self::new()
1662    }
1663}
1664
1665impl MagicCodeBackend for InMemoryMagicCodeBackend {
1666    fn put(&self, email: &str, code: &MagicCode) {
1667        self.codes
1668            .lock()
1669            .unwrap()
1670            .insert(email.to_string(), code.clone());
1671    }
1672    fn get(&self, email: &str) -> Option<MagicCode> {
1673        self.codes.lock().unwrap().get(email).cloned()
1674    }
1675    fn remove(&self, email: &str) {
1676        self.codes.lock().unwrap().remove(email);
1677    }
1678    fn bump_attempts(&self, email: &str) {
1679        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1680            c.attempts = c.attempts.saturating_add(1);
1681        }
1682    }
1683    fn load_all(&self) -> Vec<MagicCode> {
1684        self.codes.lock().unwrap().values().cloned().collect()
1685    }
1686}
1687
1688/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
1689/// and applies the verify/cooldown semantics. Hydrates the in-memory
1690/// cache from the backend on construction so durable backends survive
1691/// restart without losing in-flight codes.
1692pub struct MagicCodeStore {
1693    cache: Mutex<HashMap<String, MagicCode>>,
1694    backend: Box<dyn MagicCodeBackend>,
1695}
1696
1697#[derive(Debug, Clone)]
1698pub struct MagicCode {
1699    pub email: String,
1700    pub code: String,
1701    pub expires_at: u64,
1702    /// Failed verify attempts against this code. Once it reaches
1703    /// `MAX_ATTEMPTS` the code is invalidated.
1704    pub attempts: u32,
1705}
1706
1707/// Maximum verify attempts per code before it's burned. 5 is a common bound —
1708/// lets the user fix typos without enabling realistic brute-force against a
1709/// 6-digit code space.
1710const MAX_ATTEMPTS: u32 = 5;
1711
1712/// Minimum seconds between successive `create()` calls for the same email.
1713/// Throttles magic-code spam (user can't be flooded with login codes).
1714const CREATE_COOLDOWN_SECS: u64 = 60;
1715
1716#[derive(Debug, Clone, PartialEq, Eq)]
1717pub enum MagicCodeError {
1718    /// There is no active code for this email, or it expired.
1719    NotFound,
1720    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
1721    TooManyAttempts,
1722    /// The code did not match.
1723    BadCode,
1724    /// The code expired since it was created.
1725    Expired,
1726    /// Another code was requested too recently. Wait and try again.
1727    Throttled { retry_after_secs: u64 },
1728}
1729
1730impl Default for MagicCodeStore {
1731    fn default() -> Self {
1732        Self::new()
1733    }
1734}
1735
1736impl MagicCodeStore {
1737    pub fn new() -> Self {
1738        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1739    }
1740
1741    /// Build a magic-code store backed by a persistent backend. Existing
1742    /// live codes are hydrated into the in-memory cache on construction
1743    /// so a server restart between "send" and "verify" doesn't kill the
1744    /// user's pending login.
1745    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1746        let now = now_secs();
1747        let mut cache = HashMap::new();
1748        for c in backend.load_all() {
1749            if c.expires_at > now {
1750                cache.insert(c.email.clone(), c);
1751            }
1752        }
1753        Self {
1754            cache: Mutex::new(cache),
1755            backend,
1756        }
1757    }
1758
1759    /// Generate a 6-digit code for an email and return it. Subject to a
1760    /// per-email cooldown — returns the error-shape via `try_create`.
1761    pub fn create(&self, email: &str) -> String {
1762        // Back-compat wrapper: same signature as before, but we still burn
1763        // the cooldown if one is active. Use `try_create` for a Result shape.
1764        self.try_create(email).unwrap_or_else(|_| String::new())
1765    }
1766
1767    /// Create a magic code, enforcing per-email cooldown. Returns the code
1768    /// or an error describing why one couldn't be issued.
1769    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1770        let now = now_secs();
1771
1772        let mut codes = self.cache.lock().unwrap();
1773
1774        // Cooldown check: if a live code exists and was created less than
1775        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
1776        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
1777        if let Some(existing) = codes.get(email) {
1778            if existing.expires_at > now {
1779                let created_at = existing.expires_at.saturating_sub(600);
1780                let age = now.saturating_sub(created_at);
1781                if age < CREATE_COOLDOWN_SECS {
1782                    return Err(MagicCodeError::Throttled {
1783                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
1784                    });
1785                }
1786            }
1787        }
1788
1789        let code = generate_magic_code();
1790        let mc = MagicCode {
1791            email: email.to_string(),
1792            code: code.clone(),
1793            expires_at: now + 600, // 10 minutes
1794            attempts: 0,
1795        };
1796        codes.insert(email.to_string(), mc.clone());
1797        // Persist after the cache mutation lands. Backend write is
1798        // best-effort — if it fails the code still works for this
1799        // process; only a restart in the next 10 minutes would lose it.
1800        self.backend.put(email, &mc);
1801        Ok(code)
1802    }
1803
1804    /// Verify a code for an email. Returns true if valid and not expired.
1805    /// Uses constant-time comparison to prevent timing attacks.
1806    /// Back-compat wrapper around [`try_verify`].
1807    pub fn verify(&self, email: &str, code: &str) -> bool {
1808        matches!(self.try_verify(email, code), Ok(()))
1809    }
1810
1811    /// Verify a code. Returns a typed error so callers can surface specific
1812    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1813    /// correct subsequent attempts return `TooManyAttempts`.
1814    /// Every magic code currently in the cache. Powers the Studio
1815    /// "Auth tables" view; not for app use. Includes expired codes —
1816    /// the cache only drops them on next verify attempt for that email.
1817    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1818        self.cache
1819            .lock()
1820            .map(|m| m.values().cloned().collect())
1821            .unwrap_or_default()
1822    }
1823
1824    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1825        let now = now_secs();
1826        let mut codes = self.cache.lock().unwrap();
1827
1828        let mc = match codes.get_mut(email) {
1829            Some(m) => m,
1830            None => return Err(MagicCodeError::NotFound),
1831        };
1832
1833        if mc.attempts >= MAX_ATTEMPTS {
1834            return Err(MagicCodeError::TooManyAttempts);
1835        }
1836        if mc.expires_at <= now {
1837            codes.remove(email);
1838            self.backend.remove(email);
1839            return Err(MagicCodeError::Expired);
1840        }
1841
1842        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1843        if !ok {
1844            mc.attempts += 1;
1845            self.backend.bump_attempts(email);
1846            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1847            if mc.attempts >= MAX_ATTEMPTS {
1848                return Err(MagicCodeError::TooManyAttempts);
1849            }
1850            return Err(MagicCodeError::BadCode);
1851        }
1852
1853        // Correct code — consume it.
1854        codes.remove(email);
1855        self.backend.remove(email);
1856        Ok(())
1857    }
1858}
1859
1860// ---------------------------------------------------------------------------
1861// Cryptographic helpers — CSPRNG-based token and code generation
1862// ---------------------------------------------------------------------------
1863
1864fn hex_encode(bytes: &[u8]) -> String {
1865    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1866}
1867
1868/// Generate a 6-digit magic code using a CSPRNG.
1869fn generate_magic_code() -> String {
1870    use rand::Rng;
1871    let mut rng = rand::thread_rng();
1872    let code: u32 = rng.gen_range(0..1_000_000);
1873    format!("{:06}", code)
1874}
1875
1876/// Generate a session token with 256 bits of entropy from a CSPRNG.
1877fn generate_token() -> String {
1878    use rand::Rng;
1879    let mut rng = rand::thread_rng();
1880    let bytes: [u8; 32] = rng.gen();
1881    format!("pylon_{}", hex_encode(&bytes))
1882}
1883
1884// ---------------------------------------------------------------------------
1885// Session store — in-memory for dev
1886// ---------------------------------------------------------------------------
1887
1888use std::collections::HashMap;
1889use std::sync::Mutex;
1890
1891/// Pluggable storage backend for sessions. The default is in-memory; apps
1892/// deploying for real should supply a persistent backend (e.g. SQLite or
1893/// Redis) so users don't log out on server restart.
1894pub trait SessionBackend: Send + Sync {
1895    fn load_all(&self) -> Vec<Session>;
1896    fn save(&self, session: &Session);
1897    fn remove(&self, token: &str);
1898}
1899
1900/// A session store. In-memory by default; optionally backed by a
1901/// persistent [`SessionBackend`].
1902///
1903/// The in-memory map is always authoritative — reads don't touch the
1904/// backend. The backend receives every `save`/`remove`, making it a
1905/// write-through cache. On construction via [`SessionStore::with_backend`],
1906/// the store hydrates from the backend so sessions survive restart.
1907pub struct SessionStore {
1908    sessions: Mutex<HashMap<String, Session>>,
1909    backend: Option<Box<dyn SessionBackend>>,
1910    /// Default lifetime for new sessions (seconds). Sourced from the
1911    /// manifest's `auth.session.expires_in` config at server boot;
1912    /// falls back to `Session::DEFAULT_LIFETIME_SECS` (30 days).
1913    default_lifetime_secs: u64,
1914}
1915
1916impl Default for SessionStore {
1917    fn default() -> Self {
1918        Self::new()
1919    }
1920}
1921
1922impl SessionStore {
1923    pub fn new() -> Self {
1924        Self {
1925            sessions: Mutex::new(HashMap::new()),
1926            backend: None,
1927            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1928        }
1929    }
1930
1931    /// Override the default session lifetime. Used by `pylon-runtime`'s
1932    /// server bootstrap to apply the manifest's `auth.session.expires_in`.
1933    pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1934        self.default_lifetime_secs = lifetime_secs;
1935        self
1936    }
1937
1938    /// Build a session store backed by a persistent store. Existing sessions
1939    /// are loaded from the backend on construction; every future mutation
1940    /// writes through.
1941    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1942        let mut map = HashMap::new();
1943        for s in backend.load_all() {
1944            if !s.is_expired() {
1945                map.insert(s.token.clone(), s);
1946            }
1947        }
1948        Self {
1949            sessions: Mutex::new(map),
1950            backend: Some(backend),
1951            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1952        }
1953    }
1954
1955    /// Create a session for a user and return it. Uses the store's
1956    /// configured `default_lifetime_secs` (from the manifest's
1957    /// `auth.session.expires_in`, default 30 days).
1958    pub fn create(&self, user_id: String) -> Session {
1959        let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1960        let mut sessions = self.sessions.lock().unwrap();
1961        sessions.insert(session.token.clone(), session.clone());
1962        if let Some(b) = &self.backend {
1963            b.save(&session);
1964        }
1965        session
1966    }
1967
1968    /// Look up a session by token. Returns None if the session is expired.
1969    pub fn get(&self, token: &str) -> Option<Session> {
1970        let mut sessions = self.sessions.lock().unwrap();
1971        match sessions.get(token) {
1972            Some(s) if s.is_expired() => {
1973                sessions.remove(token);
1974                None
1975            }
1976            Some(s) => Some(s.clone()),
1977            None => None,
1978        }
1979    }
1980
1981    /// Resolve a token to an auth context.
1982    /// Returns anonymous context if the token is invalid, missing, or expired.
1983    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1984        match token {
1985            Some(t) => match self.get(t) {
1986                Some(session) => session.to_auth_context(),
1987                None => AuthContext::anonymous(),
1988            },
1989            None => AuthContext::anonymous(),
1990        }
1991    }
1992
1993    /// Refresh a session — issues a new token, copies user/device, extends expiry.
1994    /// The old token is revoked. Returns the new session or None if the old
1995    /// token is missing/expired.
1996    pub fn refresh(&self, old_token: &str) -> Option<Session> {
1997        let mut sessions = self.sessions.lock().unwrap();
1998        let old = sessions.remove(old_token)?;
1999        if let Some(b) = &self.backend {
2000            b.remove(old_token);
2001        }
2002        if old.is_expired() {
2003            return None;
2004        }
2005        // Use the store's configured lifetime so a manifest-set
2006        // `auth.session.expires_in` survives session refresh. Previous
2007        // bug: `Session::new(...)` baked in 30 days regardless of
2008        // config — apps with a custom lifetime got the right value on
2009        // first sign-in and lost it on the next refresh.
2010        let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2011        new.device = old.device.clone();
2012        sessions.insert(new.token.clone(), new.clone());
2013        if let Some(b) = &self.backend {
2014            b.save(&new);
2015        }
2016        Some(new)
2017    }
2018
2019    /// Every session in the store, including expired ones, with no
2020    /// filtering. Powers the Studio "Auth tables" view so operators
2021    /// can see orphaned sessions / debug stuck logins. Don't use for
2022    /// app code — `list_for_user` is the right surface there.
2023    pub fn list_all_unfiltered(&self) -> Vec<Session> {
2024        self.sessions
2025            .lock()
2026            .map(|m| m.values().cloned().collect())
2027            .unwrap_or_default()
2028    }
2029
2030    /// List all active sessions for a user.
2031    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2032        let sessions = self.sessions.lock().unwrap();
2033        sessions
2034            .values()
2035            .filter(|s| s.user_id == user_id && !s.is_expired())
2036            .cloned()
2037            .collect()
2038    }
2039
2040    /// Revoke all sessions for a user. Returns the count removed.
2041    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2042        let mut sessions = self.sessions.lock().unwrap();
2043        let tokens: Vec<String> = sessions
2044            .iter()
2045            .filter_map(|(t, s)| {
2046                if s.user_id == user_id {
2047                    Some(t.clone())
2048                } else {
2049                    None
2050                }
2051            })
2052            .collect();
2053        let n = tokens.len();
2054        for t in &tokens {
2055            sessions.remove(t);
2056            if let Some(b) = &self.backend {
2057                b.remove(t);
2058            }
2059        }
2060        n
2061    }
2062
2063    /// Sweep expired sessions. Returns the count removed.
2064    pub fn sweep_expired(&self) -> usize {
2065        let mut sessions = self.sessions.lock().unwrap();
2066        let expired: Vec<String> = sessions
2067            .iter()
2068            .filter_map(|(t, s)| {
2069                if s.is_expired() {
2070                    Some(t.clone())
2071                } else {
2072                    None
2073                }
2074            })
2075            .collect();
2076        let n = expired.len();
2077        for t in &expired {
2078            sessions.remove(t);
2079            if let Some(b) = &self.backend {
2080                b.remove(t);
2081            }
2082        }
2083        n
2084    }
2085
2086    /// Attach a device label to a session (typically on login from a browser).
2087    pub fn set_device(&self, token: &str, device: String) -> bool {
2088        let mut sessions = self.sessions.lock().unwrap();
2089        if let Some(s) = sessions.get_mut(token) {
2090            s.device = Some(device);
2091            if let Some(b) = &self.backend {
2092                b.save(s);
2093            }
2094            true
2095        } else {
2096            false
2097        }
2098    }
2099
2100    /// Create a guest session with a generated anonymous ID.
2101    pub fn create_guest(&self) -> Session {
2102        use rand::Rng;
2103        let mut rng = rand::thread_rng();
2104        let bytes: [u8; 16] = rng.gen();
2105        let guest_id = format!("guest_{}", hex_encode(&bytes));
2106        self.create(guest_id)
2107    }
2108
2109    /// Upgrade a guest session to a real user. Replaces the user_id.
2110    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2111        let mut sessions = self.sessions.lock().unwrap();
2112        if let Some(session) = sessions.get_mut(token) {
2113            session.user_id = real_user_id;
2114            if let Some(b) = &self.backend {
2115                b.save(session);
2116            }
2117            true
2118        } else {
2119            false
2120        }
2121    }
2122
2123    /// Switch the session's active tenant (organization). `None` clears it.
2124    /// Callers should verify the user actually has membership in the target
2125    /// tenant BEFORE invoking this — the session store takes the value on
2126    /// trust. Returns true if the session exists, false otherwise.
2127    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2128        let mut sessions = self.sessions.lock().unwrap();
2129        if let Some(session) = sessions.get_mut(token) {
2130            session.tenant_id = tenant_id;
2131            if let Some(b) = &self.backend {
2132                b.save(session);
2133            }
2134            true
2135        } else {
2136            false
2137        }
2138    }
2139
2140    /// Remove a session.
2141    pub fn revoke(&self, token: &str) -> bool {
2142        let mut sessions = self.sessions.lock().unwrap();
2143        let removed = sessions.remove(token).is_some();
2144        if removed {
2145            if let Some(b) = &self.backend {
2146                b.remove(token);
2147            }
2148        }
2149        removed
2150    }
2151}
2152
2153// ---------------------------------------------------------------------------
2154// OAuth account links — better-auth's `account` table equivalent
2155// ---------------------------------------------------------------------------
2156
2157/// A persisted account link. Schema-aligned with better-auth's `account`
2158/// table (verified against https://www.better-auth.com/docs/concepts/database
2159/// at the time of writing) so users migrating from better-auth see the
2160/// same field names + meanings:
2161///
2162/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
2163///   `"credential"` once email/password auth lands). Matches
2164///   better-auth's `providerId`.
2165/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
2166///   GitHub numeric `id`, or for email/password the user's own id).
2167///   Matches better-auth's `accountId`. NOT the row PK.
2168/// - `id` — the row PK, generated. Lets the row be referenced
2169///   independently of the (provider_id, account_id) natural key.
2170/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
2171///   rows; `None` for OAuth links. Reserves the column so adding
2172///   email/password auth doesn't need a schema migration.
2173///
2174/// Account vs. user: a single User row can have many Account rows
2175/// (Google + GitHub + a password — all linked to one pylon user).
2176/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
2177/// a user changing their Google address keeps the same pylon account.
2178#[derive(Debug, Clone, PartialEq, Eq)]
2179pub struct Account {
2180    pub id: String,
2181    pub user_id: String,
2182    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
2183    /// (better-auth: `providerId`)
2184    pub provider_id: String,
2185    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
2186    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
2187    pub account_id: String,
2188    pub access_token: Option<String>,
2189    pub refresh_token: Option<String>,
2190    pub id_token: Option<String>,
2191    /// Unix epoch seconds at which `access_token` expires. `None` for
2192    /// non-expiring tokens (GitHub Classic apps) or for password rows.
2193    pub access_token_expires_at: Option<u64>,
2194    /// Unix epoch seconds at which `refresh_token` expires. `None` when
2195    /// the provider doesn't expire refresh tokens (most don't, but
2196    /// Microsoft Identity Platform does after 90 days of inactivity).
2197    pub refresh_token_expires_at: Option<u64>,
2198    pub scope: Option<String>,
2199    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
2200    /// Always `None` today — present so adding password auth later
2201    /// doesn't require a schema migration.
2202    pub password: Option<String>,
2203    /// Unix epoch seconds when this account was first linked.
2204    pub created_at: u64,
2205    /// Unix epoch seconds when the token bundle was last refreshed.
2206    pub updated_at: u64,
2207}
2208
2209impl Account {
2210    /// Build a new account link from a freshly-completed OAuth handshake.
2211    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
2212    /// what later lookups key on.
2213    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2214        let now = now_secs();
2215        Self {
2216            id: generate_token(),
2217            user_id,
2218            provider_id: info.provider.clone(),
2219            account_id: info.provider_account_id.clone(),
2220            access_token: Some(tokens.access_token.clone()),
2221            refresh_token: tokens.refresh_token.clone(),
2222            id_token: tokens.id_token.clone(),
2223            access_token_expires_at: tokens.expires_at,
2224            refresh_token_expires_at: None,
2225            scope: tokens.scope.clone(),
2226            password: None,
2227            created_at: now,
2228            updated_at: now,
2229        }
2230    }
2231
2232    /// True if `access_token_expires_at` is set and has passed.
2233    /// Non-expiring tokens (GitHub Classic) report `false` — caller
2234    /// should treat them as "valid until proven otherwise" and refresh
2235    /// on 401.
2236    pub fn access_token_expired(&self) -> bool {
2237        match self.access_token_expires_at {
2238            Some(ts) => now_secs() >= ts,
2239            None => false,
2240        }
2241    }
2242}
2243
2244/// Pluggable storage for account links. In-memory default ships with
2245/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
2246pub trait AccountBackend: Send + Sync {
2247    /// Insert or refresh an account link. The `(provider_id, account_id)`
2248    /// pair is the natural key — repeated calls for the same pair
2249    /// update the token bundle and `updated_at` on the existing row.
2250    fn upsert(&self, account: &Account);
2251    /// Find an account by provider identity. Returns `None` if the user
2252    /// hasn't linked this provider yet.
2253    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2254    /// Every account linked to a user. The `/api/auth/me` endpoint uses
2255    /// this to render "you're connected via Google + GitHub" in the UI
2256    /// and to gate "unlink" affordances behind "user has another way to
2257    /// sign in" checks.
2258    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2259    /// Remove a single provider link. Returns `true` if a row was removed.
2260    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2261    /// Remove every account link for a user. Used during account
2262    /// deletion to ensure no OAuth references survive past a user row
2263    /// delete. Default implementation walks `find_for_user` + `unlink`;
2264    /// SQL backends can override with a single DELETE.
2265    fn delete_for_user(&self, user_id: &str) -> usize {
2266        let accounts = self.find_for_user(user_id);
2267        let n = accounts.len();
2268        for a in accounts {
2269            self.unlink(&a.provider_id, &a.account_id);
2270        }
2271        n
2272    }
2273    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
2274    /// to power the Studio admin inspector. Backends that can stream
2275    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
2276    /// returns its full map.
2277    fn list_all(&self) -> Vec<Account>;
2278}
2279
2280/// In-memory account backend (default). Lost on restart — production
2281/// deployments should swap in a persistent backend so refresh tokens
2282/// survive a redeploy.
2283pub struct InMemoryAccountBackend {
2284    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
2285    /// user_id would speed up `find_for_user` but at framework scale
2286    /// the linear scan of (typically ≤ 5) accounts per user is fine.
2287    accounts: Mutex<HashMap<(String, String), Account>>,
2288}
2289
2290impl InMemoryAccountBackend {
2291    pub fn new() -> Self {
2292        Self {
2293            accounts: Mutex::new(HashMap::new()),
2294        }
2295    }
2296}
2297
2298impl Default for InMemoryAccountBackend {
2299    fn default() -> Self {
2300        Self::new()
2301    }
2302}
2303
2304impl AccountBackend for InMemoryAccountBackend {
2305    fn upsert(&self, account: &Account) {
2306        let key = (account.provider_id.clone(), account.account_id.clone());
2307        self.accounts.lock().unwrap().insert(key, account.clone());
2308    }
2309    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2310        self.accounts
2311            .lock()
2312            .unwrap()
2313            .get(&(provider_id.to_string(), account_id.to_string()))
2314            .cloned()
2315    }
2316    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2317        self.accounts
2318            .lock()
2319            .unwrap()
2320            .values()
2321            .filter(|a| a.user_id == user_id)
2322            .cloned()
2323            .collect()
2324    }
2325    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2326        self.accounts
2327            .lock()
2328            .unwrap()
2329            .remove(&(provider_id.to_string(), account_id.to_string()))
2330            .is_some()
2331    }
2332    fn list_all(&self) -> Vec<Account> {
2333        self.accounts.lock().unwrap().values().cloned().collect()
2334    }
2335}
2336
2337/// Account store. Wraps an `AccountBackend` and provides the methods the
2338/// OAuth callback / API endpoints actually call.
2339pub struct AccountStore {
2340    backend: Box<dyn AccountBackend>,
2341}
2342
2343impl Default for AccountStore {
2344    fn default() -> Self {
2345        Self::new()
2346    }
2347}
2348
2349impl AccountStore {
2350    pub fn new() -> Self {
2351        Self {
2352            backend: Box::new(InMemoryAccountBackend::new()),
2353        }
2354    }
2355    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2356        Self { backend }
2357    }
2358    pub fn upsert(&self, account: &Account) {
2359        self.backend.upsert(account);
2360    }
2361    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2362        self.backend.find_by_provider(provider_id, account_id)
2363    }
2364    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2365        self.backend.find_for_user(user_id)
2366    }
2367    pub fn delete_for_user(&self, user_id: &str) -> usize {
2368        self.backend.delete_for_user(user_id)
2369    }
2370
2371    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2372        self.backend.unlink(provider_id, account_id)
2373    }
2374
2375    /// Every account in the store. Powers the Studio "Auth tables"
2376    /// view; not for app use. Implemented by walking the backend's
2377    /// per-user index — doable because account counts per user are
2378    /// small (typically ≤ 5) and total account count tracks user
2379    /// count.
2380    ///
2381    /// We don't add a `list_all` method to the `AccountBackend` trait
2382    /// because the in-memory + sqlite + postgres impls would each
2383    /// need a separate implementation, and the operational use case
2384    /// (Studio inspector) is narrow enough to live behind a wrapper
2385    /// that walks the underlying store directly. For PG/SQLite that
2386    /// means a `SELECT * FROM _pylon_accounts` — which the backends
2387    /// can grow if we ever need this at scale.
2388    pub fn list_all_unfiltered(&self) -> Vec<Account> {
2389        self.backend.list_all()
2390    }
2391}
2392
2393// ---------------------------------------------------------------------------
2394// Tests
2395// ---------------------------------------------------------------------------
2396
2397#[cfg(test)]
2398mod tests {
2399    use super::*;
2400
2401    #[test]
2402    fn anonymous_context() {
2403        let ctx = AuthContext::anonymous();
2404        assert!(!ctx.is_authenticated());
2405        assert!(ctx.user_id.is_none());
2406    }
2407
2408    #[test]
2409    fn authenticated_context() {
2410        let ctx = AuthContext::authenticated("user-1".into());
2411        assert!(ctx.is_authenticated());
2412        assert_eq!(ctx.user_id, Some("user-1".into()));
2413    }
2414
2415    #[test]
2416    fn from_api_key_carries_scope_metadata() {
2417        let ctx = AuthContext::from_api_key(
2418            "user-1".into(),
2419            "key_abc".into(),
2420            Some("read,write".into()),
2421        );
2422        assert!(ctx.is_authenticated());
2423        assert!(ctx.is_api_key_auth());
2424        assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2425        assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2426        assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2427    }
2428
2429    #[test]
2430    fn session_auth_is_not_api_key_auth() {
2431        let ctx = AuthContext::authenticated("user-1".into());
2432        assert!(!ctx.is_api_key_auth());
2433        assert!(ctx.api_key_id.is_none());
2434    }
2435
2436    #[test]
2437    fn auth_mode_public_allows_anonymous() {
2438        let mode = AuthMode::Public;
2439        assert!(mode.check(&AuthContext::anonymous()));
2440        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2441    }
2442
2443    #[test]
2444    fn auth_mode_user_requires_authenticated() {
2445        let mode = AuthMode::User;
2446        assert!(!mode.check(&AuthContext::anonymous()));
2447        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2448    }
2449
2450    #[test]
2451    fn auth_mode_from_str() {
2452        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2453        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2454        assert_eq!(AuthMode::from_str("admin"), None);
2455    }
2456
2457    #[test]
2458    fn session_store_create_and_get() {
2459        let store = SessionStore::new();
2460        let session = store.create("user-1".into());
2461        assert!(!session.token.is_empty());
2462        assert!(session.token.starts_with("pylon_"));
2463
2464        let retrieved = store.get(&session.token).unwrap();
2465        assert_eq!(retrieved.user_id, "user-1");
2466    }
2467
2468    #[test]
2469    fn session_store_resolve() {
2470        let store = SessionStore::new();
2471        let session = store.create("user-1".into());
2472
2473        let ctx = store.resolve(Some(&session.token));
2474        assert!(ctx.is_authenticated());
2475        assert_eq!(ctx.user_id, Some("user-1".into()));
2476
2477        let anon = store.resolve(None);
2478        assert!(!anon.is_authenticated());
2479
2480        let bad = store.resolve(Some("invalid-token"));
2481        assert!(!bad.is_authenticated());
2482    }
2483
2484    #[test]
2485    fn session_store_revoke() {
2486        let store = SessionStore::new();
2487        let session = store.create("user-1".into());
2488
2489        assert!(store.revoke(&session.token));
2490        assert!(store.get(&session.token).is_none());
2491        assert!(!store.revoke(&session.token)); // already revoked
2492    }
2493
2494    #[test]
2495    fn session_to_auth_context() {
2496        let session = Session::new("user-42".into());
2497        let ctx = session.to_auth_context();
2498        assert_eq!(ctx.user_id, Some("user-42".into()));
2499    }
2500
2501    // -- Admin context --
2502
2503    #[test]
2504    fn admin_context() {
2505        let ctx = AuthContext::admin();
2506        assert!(ctx.is_admin);
2507        assert!(ctx.is_authenticated());
2508    }
2509
2510    #[test]
2511    fn anonymous_not_admin() {
2512        let ctx = AuthContext::anonymous();
2513        assert!(!ctx.is_admin);
2514    }
2515
2516    #[test]
2517    fn authenticated_not_admin() {
2518        let ctx = AuthContext::authenticated("user-1".into());
2519        assert!(!ctx.is_admin);
2520    }
2521
2522    // -- Magic codes --
2523
2524    #[test]
2525    fn magic_code_create_and_verify() {
2526        let store = MagicCodeStore::new();
2527        let code = store.create("test@example.com");
2528        assert_eq!(code.len(), 6);
2529        assert!(store.verify("test@example.com", &code));
2530    }
2531
2532    #[test]
2533    fn magic_code_wrong_code_rejected() {
2534        let store = MagicCodeStore::new();
2535        store.create("test@example.com");
2536        assert!(!store.verify("test@example.com", "000000"));
2537    }
2538
2539    #[test]
2540    fn magic_code_wrong_email_rejected() {
2541        let store = MagicCodeStore::new();
2542        let code = store.create("test@example.com");
2543        assert!(!store.verify("other@example.com", &code));
2544    }
2545
2546    #[test]
2547    fn magic_code_consumed_after_verify() {
2548        let store = MagicCodeStore::new();
2549        let code = store.create("test@example.com");
2550        assert!(store.verify("test@example.com", &code));
2551        // Second verify should fail — code consumed.
2552        assert!(!store.verify("test@example.com", &code));
2553    }
2554
2555    #[test]
2556    fn magic_code_different_emails_independent() {
2557        let store = MagicCodeStore::new();
2558        let code1 = store.create("alice@example.com");
2559        let code2 = store.create("bob@example.com");
2560        // Each email has its own code.
2561        assert!(store.verify("alice@example.com", &code1));
2562        assert!(store.verify("bob@example.com", &code2));
2563    }
2564
2565    // -- Constant-time comparison --
2566
2567    #[test]
2568    fn constant_time_eq_equal() {
2569        assert!(constant_time_eq(b"hello", b"hello"));
2570        assert!(constant_time_eq(b"", b""));
2571    }
2572
2573    #[test]
2574    fn constant_time_eq_not_equal() {
2575        assert!(!constant_time_eq(b"hello", b"world"));
2576        assert!(!constant_time_eq(b"hello", b"hell"));
2577        assert!(!constant_time_eq(b"a", b"b"));
2578    }
2579
2580    // -- Token generation --
2581
2582    #[test]
2583    fn generated_tokens_are_unique() {
2584        let t1 = generate_token();
2585        let t2 = generate_token();
2586        assert_ne!(t1, t2);
2587        assert!(t1.starts_with("pylon_"));
2588        assert!(t2.starts_with("pylon_"));
2589        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
2590        assert_eq!(t1.len(), 6 + 64);
2591    }
2592
2593    // -- OAuth registry --
2594
2595    #[test]
2596    fn oauth_registry_empty() {
2597        let reg = OAuthRegistry::new();
2598        assert!(reg.get("google").is_none());
2599    }
2600
2601    #[test]
2602    fn oauth_registry_register_and_get() {
2603        let mut reg = OAuthRegistry::new();
2604        reg.register(OAuthConfig {
2605            provider: "google".into(),
2606            client_id: "test-id".into(),
2607            client_secret: "test-secret".into(),
2608            redirect_uri: "http://localhost/callback".into(),
2609            ..Default::default()
2610        });
2611        let config = reg.get("google").unwrap();
2612        assert_eq!(config.client_id, "test-id");
2613        assert!(config.auth_url().contains("accounts.google.com"));
2614    }
2615
2616    // -- Spec-driven provider routing --
2617
2618    /// Every builtin provider must produce a non-empty auth_url +
2619    /// token_url when wired with placeholder credentials. This is the
2620    /// regression test for the table-driven refactor: a typo in any
2621    /// `ProviderSpec` field that breaks URL formatting will trip here
2622    /// before it reaches a user.
2623    #[test]
2624    fn every_builtin_provider_routes_through_oauth_config() {
2625        for spec in provider::builtin::all() {
2626            let cfg = OAuthConfig {
2627                provider: spec.id.into(),
2628                client_id: "cid".into(),
2629                client_secret: "csecret".into(),
2630                redirect_uri: "https://app/cb".into(),
2631                tenant: if spec.id == "microsoft" {
2632                    Some("contoso".into())
2633                } else {
2634                    None
2635                },
2636                apple: if spec.id == "apple" {
2637                    Some(provider::AppleConfig {
2638                        team_id: "T".into(),
2639                        key_id: "K".into(),
2640                        private_key_pem: "no".into(),
2641                    })
2642                } else {
2643                    None
2644                },
2645                ..Default::default()
2646            };
2647            let auth = cfg.auth_url();
2648            assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2649            // TikTok uses `client_key`; everyone else uses `client_id`.
2650            let expected_param = format!("{}=cid", spec.client_id_param);
2651            assert!(
2652                auth.contains(&expected_param),
2653                "{}: missing {}; got auth_url: {}",
2654                spec.id,
2655                expected_param,
2656                auth,
2657            );
2658            assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2659            // Apple requires response_mode=form_post in the auth URL.
2660            if spec.id == "apple" {
2661                assert!(
2662                    auth.contains("response_mode=form_post"),
2663                    "apple auth_url must include response_mode=form_post; got {auth}"
2664                );
2665            }
2666        }
2667    }
2668
2669    /// Microsoft uses `{tenant}` placeholder substitution — the
2670    /// configured tenant must end up in both auth + token URLs.
2671    #[test]
2672    fn microsoft_tenant_placeholder_resolves() {
2673        let cfg = OAuthConfig {
2674            provider: "microsoft".into(),
2675            client_id: "id".into(),
2676            client_secret: "secret".into(),
2677            redirect_uri: "https://app/cb".into(),
2678            tenant: Some("contoso.onmicrosoft.com".into()),
2679            ..Default::default()
2680        };
2681        assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2682        assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2683    }
2684
2685    /// Microsoft without a tenant defaults to `common` (any account).
2686    #[test]
2687    fn microsoft_default_tenant_common() {
2688        let cfg = OAuthConfig {
2689            provider: "microsoft".into(),
2690            client_id: "id".into(),
2691            client_secret: "secret".into(),
2692            redirect_uri: "https://app/cb".into(),
2693            ..Default::default()
2694        };
2695        assert!(cfg.auth_url().contains("/common/"));
2696        assert!(cfg.token_url().contains("/common/"));
2697    }
2698
2699    /// `scopes_override` replaces the spec default — used for GitHub
2700    /// `repo` scope or Google calendar scopes.
2701    #[test]
2702    fn scopes_override_replaces_spec_default() {
2703        let cfg = OAuthConfig {
2704            provider: "github".into(),
2705            client_id: "id".into(),
2706            client_secret: "secret".into(),
2707            redirect_uri: "https://app/cb".into(),
2708            scopes_override: Some("repo user:email".into()),
2709            ..Default::default()
2710        };
2711        let auth = cfg.auth_url();
2712        // url-encoded "repo user:email" → "repo%20user%3Aemail"
2713        assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2714    }
2715
2716    /// Apple's `client_secret` is minted as a JWT — passing a bad PEM
2717    /// must surface the signing error, not silently send the literal
2718    /// string. The mint path is tested in `apple_jwt::tests`; this
2719    /// asserts the wiring delegates to it.
2720    #[test]
2721    fn apple_exchange_requires_apple_config() {
2722        let cfg = OAuthConfig {
2723            provider: "apple".into(),
2724            client_id: "com.example.app".into(),
2725            client_secret: String::new(),
2726            redirect_uri: "https://app/cb".into(),
2727            apple: None, // missing!
2728            ..Default::default()
2729        };
2730        let err = cfg.exchange_code_full("x").unwrap_err();
2731        assert!(err.contains("apple provider requires"), "got: {err}");
2732    }
2733
2734    /// OIDC discovery cache: priming with a synthetic spec lets us
2735    /// route an issuer-configured provider without touching the
2736    /// network. Validates that `oidc_issuer` short-circuits the
2737    /// builtin lookup.
2738    #[test]
2739    fn oidc_issuer_uses_discovered_endpoints() {
2740        let issuer = "https://acme.test.invalid";
2741        provider::oidc_cache::insert_for_test(
2742            issuer,
2743            provider::DiscoveredSpec {
2744                auth_url: "https://acme.test.invalid/authorize".into(),
2745                token_url: "https://acme.test.invalid/oauth/token".into(),
2746                userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2747                scopes: "openid email profile".into(),
2748                userinfo_parser: provider::UserinfoParser::Oidc,
2749                token_exchange: provider::TokenExchangeShape::Standard,
2750            },
2751        );
2752        let cfg = OAuthConfig {
2753            provider: "auth0".into(), // not a builtin id
2754            client_id: "id".into(),
2755            client_secret: "secret".into(),
2756            redirect_uri: "https://app/cb".into(),
2757            oidc_issuer: Some(issuer.into()),
2758            ..Default::default()
2759        };
2760        assert!(cfg.auth_url().starts_with("https://acme.test.invalid/authorize?"));
2761        assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2762        assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2763    }
2764
2765    // -- Codex review regression tests (P1/P2 from Wave 1 review) --
2766
2767    /// P1: Apple's auth URL MUST include response_mode=form_post when
2768    /// requesting name/email scopes, otherwise Apple rejects with
2769    /// "invalid_request".
2770    #[test]
2771    fn apple_auth_url_includes_form_post() {
2772        let cfg = OAuthConfig {
2773            provider: "apple".into(),
2774            client_id: "com.example.app".into(),
2775            client_secret: String::new(),
2776            redirect_uri: "https://app/cb".into(),
2777            apple: Some(provider::AppleConfig {
2778                team_id: "T".into(),
2779                key_id: "K".into(),
2780                private_key_pem: "no".into(),
2781            }),
2782            ..Default::default()
2783        };
2784        let auth = cfg.auth_url();
2785        assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2786        // Apple identity comes from id_token, so userinfo_url is empty.
2787        assert_eq!(cfg.userinfo_url(), "");
2788    }
2789
2790    /// P1: Apple identity is extracted from the id_token JWT
2791    /// (Apple has no userinfo endpoint). `fetch_userinfo_with_id_token`
2792    /// must decode the claims; `fetch_userinfo_full` (no id_token)
2793    /// must surface a clear error.
2794    #[test]
2795    fn apple_id_token_decode_extracts_identity() {
2796        // Synthesize an unsigned JWT with realistic Apple claims.
2797        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2798        use base64::Engine;
2799        let claims = serde_json::json!({
2800            "iss": "https://appleid.apple.com",
2801            "sub": "001234.abc.def",
2802            "aud": "com.example.app",
2803            "email": "user@privaterelay.appleid.com",
2804            "email_verified": "true",
2805        });
2806        let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
2807            .encode(claims.to_string().as_bytes());
2808        let id_token = format!("{header}.{claims_b64}.signature_ignored");
2809
2810        let cfg = OAuthConfig {
2811            provider: "apple".into(),
2812            client_id: "com.example.app".into(),
2813            client_secret: String::new(),
2814            redirect_uri: "https://app/cb".into(),
2815            apple: Some(provider::AppleConfig {
2816                team_id: "T".into(),
2817                key_id: "K".into(),
2818                private_key_pem: "no".into(),
2819            }),
2820            ..Default::default()
2821        };
2822        let info = cfg
2823            .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2824            .expect("apple id_token decode");
2825        assert_eq!(info.provider_account_id, "001234.abc.def");
2826        assert_eq!(info.email, "user@privaterelay.appleid.com");
2827
2828        // Without an id_token the call must fail loud, not silently
2829        // try to hit a non-existent userinfo endpoint.
2830        let err = cfg.fetch_userinfo_full("token").unwrap_err();
2831        assert!(err.contains("apple login requires"), "got: {err}");
2832    }
2833
2834    /// P1: Twitter/X requires PKCE — `auth_url_with_pkce` must mint a
2835    /// verifier, embed the SHA-256 challenge in the auth URL, and
2836    /// return the verifier for the callback to replay.
2837    #[test]
2838    fn twitter_auth_url_includes_pkce() {
2839        let cfg = OAuthConfig {
2840            provider: "twitter".into(),
2841            client_id: "tw_client".into(),
2842            client_secret: "tw_secret".into(),
2843            redirect_uri: "https://app/cb".into(),
2844            ..Default::default()
2845        };
2846        let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2847        let v = verifier.expect("twitter must produce verifier");
2848        assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2849        assert!(url.contains("code_challenge="), "got: {url}");
2850        assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2851
2852        // Non-PKCE provider must NOT add a code_challenge.
2853        let google = OAuthConfig {
2854            provider: "google".into(),
2855            client_id: "g".into(),
2856            client_secret: "g".into(),
2857            redirect_uri: "https://app/cb".into(),
2858            ..Default::default()
2859        };
2860        let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2861        assert!(gverifier.is_none(), "google should not add PKCE");
2862        assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2863    }
2864
2865    /// P2: TikTok uses `client_key` (not `client_id`) and joins
2866    /// scopes with commas (not spaces).
2867    #[test]
2868    fn tiktok_uses_client_key_and_comma_scopes() {
2869        let cfg = OAuthConfig {
2870            provider: "tiktok".into(),
2871            client_id: "tk_client".into(),
2872            client_secret: "tk_secret".into(),
2873            redirect_uri: "https://app/cb".into(),
2874            scopes_override: Some("user.info.basic video.list".into()),
2875            ..Default::default()
2876        };
2877        let auth = cfg.auth_url();
2878        assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2879        // Comma-separated, url-encoded → "user.info.basic%2Cvideo.list"
2880        assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2881        // Should NOT use the standard space separator.
2882        assert!(!auth.contains("user.info.basic%20video.list"), "got: {auth}");
2883    }
2884
2885    /// P2: `code` MUST be url-encoded in the token-exchange body.
2886    /// Auth codes can contain reserved characters (`+`, `=`, `/`) that
2887    /// would otherwise corrupt the form body.
2888    #[test]
2889    fn token_exchange_url_encodes_code() {
2890        // We can't hit the network in a unit test, so this asserts
2891        // via the `apple_exchange_requires_apple_config` shape — if
2892        // we DID have a working apple config, encoding would happen
2893        // before the network call. Instead, verify by calling the
2894        // helper used internally:
2895        let raw = "code+with/special=chars";
2896        let encoded = url_encode(raw);
2897        assert!(!encoded.contains('+'));
2898        assert!(!encoded.contains('/'));
2899        assert!(!encoded.contains('='));
2900        assert!(encoded.contains("%2B"));
2901        assert!(encoded.contains("%2F"));
2902        assert!(encoded.contains("%3D"));
2903    }
2904
2905    /// P1: Token-endpoint error bodies must NOT propagate
2906    /// `client_secret`, `code_verifier`, or other sensitive form
2907    /// fields that providers sometimes echo back on auth failure.
2908    #[test]
2909    fn sanitize_token_error_redacts_secrets() {
2910        let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2911        let scrubbed = sanitize_token_error(raw.into());
2912        assert!(!scrubbed.contains("sk_real_secret_value"));
2913        assert!(!scrubbed.contains("verifierxyz"));
2914        assert!(scrubbed.contains("client_secret=***"));
2915        assert!(scrubbed.contains("code_verifier=***"));
2916        // Non-sensitive context preserved.
2917        assert!(scrubbed.contains("invalid_grant"));
2918        assert!(scrubbed.contains("hint=check%20your%20code"));
2919    }
2920
2921    /// P1 (codex round-2): JSON-shaped error bodies (Notion,
2922    /// Atlassian) must also have their secret fields redacted.
2923    #[test]
2924    fn sanitize_token_error_redacts_json_secrets() {
2925        let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2926        let scrubbed = sanitize_token_error(raw.into());
2927        assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2928        assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2929        assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2930        assert!(scrubbed.contains(r#""client_secret":"***""#), "got: {scrubbed}");
2931        assert!(scrubbed.contains(r#""refresh_token":"***""#), "got: {scrubbed}");
2932        assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2933        assert!(scrubbed.contains("invalid_grant"));
2934    }
2935
2936    /// P2 (codex round-2): redact_param_form must NOT panic on
2937    /// multibyte chars before the sensitive key. Earlier byte-index
2938    /// implementation hit `panicked at byte index N is not a char
2939    /// boundary` on bodies with emoji or non-ASCII text.
2940    #[test]
2941    fn sanitize_token_error_handles_utf8() {
2942        let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2943        let scrubbed = sanitize_token_error(raw.into());
2944        assert!(scrubbed.contains("⚠️"), "non-ASCII chars must survive: {scrubbed}");
2945        assert!(!scrubbed.contains("sk_x"));
2946        assert!(scrubbed.contains("client_secret=***"));
2947    }
2948
2949    /// P2: OIDC discovery must respect
2950    /// `token_endpoint_auth_methods_supported`. When the IdP
2951    /// publishes `client_secret_post`, use Standard form bodies.
2952    /// When omitted (the spec default), use BasicAuth.
2953    #[test]
2954    fn oidc_discovery_picks_token_auth_method() {
2955        let json_post = r#"{
2956            "issuer": "https://acme.test/",
2957            "authorization_endpoint": "https://acme.test/auth",
2958            "token_endpoint": "https://acme.test/token",
2959            "token_endpoint_auth_methods_supported": ["client_secret_post"]
2960        }"#;
2961        let spec = provider::OidcDiscoveryDoc::parse(json_post).unwrap().into_spec();
2962        assert!(matches!(
2963            spec.token_exchange,
2964            provider::TokenExchangeShape::Standard
2965        ));
2966
2967        // Default (omitted) → BasicAuth.
2968        let json_default = r#"{
2969            "issuer": "https://acme.test/",
2970            "authorization_endpoint": "https://acme.test/auth",
2971            "token_endpoint": "https://acme.test/token"
2972        }"#;
2973        let spec = provider::OidcDiscoveryDoc::parse(json_default)
2974            .unwrap()
2975            .into_spec();
2976        assert!(matches!(
2977            spec.token_exchange,
2978            provider::TokenExchangeShape::BasicAuth
2979        ));
2980    }
2981
2982    /// P2: OIDC discovery missing required endpoints must fail loud,
2983    /// not silently produce empty URLs that would 404 every login.
2984    #[test]
2985    fn oidc_discovery_rejects_incomplete_doc() {
2986        // Missing token_endpoint.
2987        let json = r#"{
2988            "issuer": "https://acme.test/",
2989            "authorization_endpoint": "https://acme.test/auth"
2990        }"#;
2991        let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
2992        assert!(err.contains("token_endpoint"), "got: {err}");
2993    }
2994
2995    /// `OAuthRegistry::from_env` must auto-discover every provider
2996    /// whose env vars are set — not just google/github. Smoke-test
2997    /// with Discord since it covers the simple-builtin path.
2998    #[test]
2999    fn from_env_picks_up_discord() {
3000        // Use a unique prefix so this doesn't collide with a real
3001        // dev environment variable. Set+restore in scope.
3002        let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3003        let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3004        // SAFETY: tests run single-threaded for env mutation isn't
3005        // strictly true, but this provider is unique enough that
3006        // contention is unlikely. Cleanup happens at end.
3007        std::env::set_var(key_id, "discord-test-id");
3008        std::env::set_var(key_secret, "discord-test-secret");
3009
3010        let reg = OAuthRegistry::from_env();
3011        let discord = reg.get("discord").expect("discord registered");
3012        assert_eq!(discord.client_id, "discord-test-id");
3013        assert!(discord.auth_url().contains("discord.com"));
3014
3015        std::env::remove_var(key_id);
3016        std::env::remove_var(key_secret);
3017    }
3018
3019    // -- Guest auth --
3020
3021    #[test]
3022    fn guest_session() {
3023        let store = SessionStore::new();
3024        let session = store.create_guest();
3025        assert!(session.user_id.starts_with("guest_"));
3026        assert!(!session.token.is_empty());
3027
3028        let ctx = store.resolve(Some(&session.token));
3029        assert!(ctx.is_authenticated());
3030        assert!(ctx.user_id.unwrap().starts_with("guest_"));
3031    }
3032
3033    #[test]
3034    fn upgrade_guest_to_real_user() {
3035        let store = SessionStore::new();
3036        let session = store.create_guest();
3037        assert!(session.user_id.starts_with("guest_"));
3038
3039        let upgraded = store.upgrade(&session.token, "real-user-123".into());
3040        assert!(upgraded);
3041
3042        let ctx = store.resolve(Some(&session.token));
3043        assert_eq!(ctx.user_id, Some("real-user-123".into()));
3044    }
3045
3046    #[test]
3047    fn upgrade_invalid_token_fails() {
3048        let store = SessionStore::new();
3049        let upgraded = store.upgrade("nonexistent-token", "user".into());
3050        assert!(!upgraded);
3051    }
3052
3053    #[test]
3054    fn guest_context() {
3055        let ctx = AuthContext::guest("guest_123".into());
3056        // Guests carry a stable id but are NOT authenticated — routes
3057        // guarded by AuthMode::User must reject them.
3058        assert!(!ctx.is_authenticated());
3059        assert!(ctx.is_guest);
3060        assert!(!ctx.is_admin);
3061        assert_eq!(ctx.user_id, Some("guest_123".into()));
3062        assert!(!AuthMode::User.check(&ctx));
3063        assert!(AuthMode::Public.check(&ctx));
3064    }
3065
3066    #[test]
3067    fn oauth_token_urls() {
3068        let google = OAuthConfig {
3069            provider: "google".into(),
3070            client_id: "x".into(),
3071            client_secret: "x".into(),
3072            redirect_uri: "x".into(),
3073            ..Default::default()
3074        };
3075        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3076        let github = OAuthConfig {
3077            provider: "github".into(),
3078            client_id: "x".into(),
3079            client_secret: "x".into(),
3080            redirect_uri: "x".into(),
3081            ..Default::default()
3082        };
3083        assert_eq!(
3084            github.token_url(),
3085            "https://github.com/login/oauth/access_token"
3086        );
3087        let unknown = OAuthConfig {
3088            provider: "unknown".into(),
3089            client_id: "x".into(),
3090            client_secret: "x".into(),
3091            redirect_uri: "x".into(),
3092            ..Default::default()
3093        };
3094        assert_eq!(unknown.token_url(), "");
3095        assert!(unknown.auth_url().is_empty());
3096    }
3097
3098    #[test]
3099    fn oauth_auth_url_github() {
3100        let config = OAuthConfig {
3101            provider: "github".into(),
3102            client_id: "gh-id".into(),
3103            client_secret: "gh-secret".into(),
3104            redirect_uri: "http://localhost/cb".into(),
3105            ..Default::default()
3106        };
3107        assert!(config.auth_url().contains("github.com"));
3108        assert!(config.auth_url().contains("gh-id"));
3109    }
3110
3111    #[test]
3112    fn oauth_auth_url_with_state() {
3113        let config = OAuthConfig {
3114            provider: "google".into(),
3115            client_id: "test-id".into(),
3116            client_secret: "test-secret".into(),
3117            redirect_uri: "http://localhost/cb".into(),
3118            ..Default::default()
3119        };
3120        let url = config.auth_url_with_state("random_state_123");
3121        assert!(url.contains("&state=random_state_123"));
3122    }
3123
3124    #[test]
3125    fn oauth_state_store_create_and_validate() {
3126        let store = OAuthStateStore::new();
3127        let token = store.create("google", "https://app/cb", "https://app/login");
3128        let rec = store.validate(&token, "google").expect("valid first time");
3129        assert_eq!(rec.callback_url, "https://app/cb");
3130        assert_eq!(rec.error_callback_url, "https://app/login");
3131        // Second validation should fail — single-use.
3132        assert!(store.validate(&token, "google").is_none());
3133    }
3134
3135    #[test]
3136    fn oauth_state_store_wrong_provider_rejected() {
3137        let store = OAuthStateStore::new();
3138        let token = store.create("google", "https://app/cb", "https://app/cb");
3139        assert!(store.validate(&token, "github").is_none());
3140    }
3141
3142    #[test]
3143    fn oauth_state_store_invalid_state_rejected() {
3144        let store = OAuthStateStore::new();
3145        assert!(store.validate("nonexistent", "google").is_none());
3146    }
3147
3148    #[test]
3149    fn validate_trusted_redirect_basics() {
3150        let trusted = vec!["http://localhost:3000".to_string()];
3151        assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3152        assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3153        assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3154
3155        // Wrong port → wrong origin.
3156        assert!(matches!(
3157            validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3158            Err(TrustedOriginError::NotTrusted { .. })
3159        ));
3160        // Non-http scheme rejected even before trusted check (defense
3161        // against javascript:, file:, data:).
3162        assert!(matches!(
3163            validate_trusted_redirect("javascript:alert(1)", &trusted),
3164            Err(TrustedOriginError::NotHttp)
3165        ));
3166        assert!(matches!(
3167            validate_trusted_redirect("", &trusted),
3168            Err(TrustedOriginError::Empty)
3169        ));
3170    }
3171}