Skip to main content

pylon_auth/
lib.rs

1pub mod api_key;
2pub mod apple_jwt;
3pub mod captcha;
4pub mod cookie;
5pub mod email;
6pub mod jwt;
7pub mod oidc_provider;
8pub mod org;
9pub mod password;
10pub mod phone;
11pub mod provider;
12pub mod scim;
13pub mod siwe;
14pub mod stripe;
15pub mod totp;
16pub mod webauthn;
17
18pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
19
20use serde::{Deserialize, Serialize};
21
22// ---------------------------------------------------------------------------
23// Auth context — the identity available to runtime operations
24// ---------------------------------------------------------------------------
25
26/// The auth context for a request. Represents who is making the request.
27///
28/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
29/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
30/// or add roles and bypass every policy. Identity must come from
31/// server-minted sessions (`Session::to_auth_context`) or explicit
32/// constructors, never from deserialization.
33///
34/// `Serialize` is safe because sending the resolved context BACK to the
35/// client exposes nothing the server didn't already decide.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub struct AuthContext {
38    /// The authenticated user ID, or None for public/anonymous access.
39    /// For guest contexts this is `Some(guest_id)` — a stable
40    /// anonymous identifier, NOT a real user.
41    pub user_id: Option<String>,
42    /// Whether this is an admin context (bypasses policies).
43    pub is_admin: bool,
44    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
45    /// for cart state and similar pre-login persistence. Routes guarded
46    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
47    /// "real signed-in user" should pass auth-required gates.
48    #[serde(default, skip_serializing_if = "is_false")]
49    pub is_guest: bool,
50    /// Roles granted to this user. Empty for anonymous.
51    pub roles: Vec<String>,
52    /// Active tenant id (for multi-tenant apps). Set when the user has
53    /// selected an organization for the current session.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub tenant_id: Option<String>,
56    /// API key id when the request was authenticated via a `pk.…`
57    /// bearer token. Set so policies + management endpoints can
58    /// distinguish "user-via-session" from "user-via-key" — e.g.
59    /// password change is forbidden via API key.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub api_key_id: Option<String>,
62    /// Comma-separated scope string from the API key. Application
63    /// policies decide what scopes mean — pylon only carries them.
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub api_key_scopes: Option<String>,
66}
67
68fn is_false(b: &bool) -> bool {
69    !b
70}
71
72impl AuthContext {
73    /// Create an anonymous/public auth context.
74    pub fn anonymous() -> Self {
75        Self {
76            user_id: None,
77            is_admin: false,
78            is_guest: false,
79            roles: Vec::new(),
80            tenant_id: None,
81            api_key_id: None,
82            api_key_scopes: None,
83        }
84    }
85
86    /// Create an authenticated auth context.
87    pub fn authenticated(user_id: String) -> Self {
88        Self {
89            user_id: Some(user_id),
90            is_admin: false,
91            is_guest: false,
92            roles: Vec::new(),
93            tenant_id: None,
94            api_key_id: None,
95            api_key_scopes: None,
96        }
97    }
98
99    /// Create an authenticated context backed by an API key. Policies +
100    /// auth-management endpoints can detect this via `is_api_key_auth()`.
101    pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
102        Self {
103            user_id: Some(user_id),
104            is_admin: false,
105            is_guest: false,
106            roles: Vec::new(),
107            tenant_id: None,
108            api_key_id: Some(key_id),
109            api_key_scopes: scopes,
110        }
111    }
112
113    /// True iff this request was authenticated by an API key (not a
114    /// session cookie / bearer session token).
115    pub fn is_api_key_auth(&self) -> bool {
116        self.api_key_id.is_some()
117    }
118
119    /// Create a guest auth context with a persistent anonymous ID.
120    /// Guests carry an opaque stable id (cart/session continuity) but
121    /// are NOT considered authenticated — `is_authenticated()` returns
122    /// false and `AuthMode::User` rejects them.
123    pub fn guest(guest_id: String) -> Self {
124        Self {
125            user_id: Some(guest_id),
126            is_admin: false,
127            is_guest: true,
128            roles: Vec::new(),
129            tenant_id: None,
130            api_key_id: None,
131            api_key_scopes: None,
132        }
133    }
134
135    /// Create an admin auth context that bypasses all policies.
136    pub fn admin() -> Self {
137        Self {
138            user_id: Some("__admin__".into()),
139            is_admin: true,
140            is_guest: false,
141            roles: vec!["admin".into()],
142            tenant_id: None,
143            api_key_id: None,
144            api_key_scopes: None,
145        }
146    }
147
148    /// Convenience: build a user context from a user id.
149    pub fn user(user_id: String) -> Self {
150        Self::authenticated(user_id)
151    }
152
153    /// Active tenant id (None when the user hasn't selected an org).
154    pub fn tenant_id(&self) -> Option<&str> {
155        self.tenant_id.as_deref()
156    }
157
158    /// Attach a tenant id to the context (chainable).
159    pub fn with_tenant(mut self, tenant_id: String) -> Self {
160        self.tenant_id = Some(tenant_id);
161        self
162    }
163
164    /// Check if this context represents an authenticated user.
165    /// Guests intentionally return `false` — they have a stable anonymous
166    /// id but never gain user-level access.
167    pub fn is_authenticated(&self) -> bool {
168        self.user_id.is_some() && !self.is_guest
169    }
170
171    /// Check if the user has a specific role. Admins have every role implicitly.
172    pub fn has_role(&self, role: &str) -> bool {
173        self.is_admin || self.roles.iter().any(|r| r == role)
174    }
175
176    /// Check if the user has ANY of the given roles.
177    pub fn has_any_role(&self, roles: &[&str]) -> bool {
178        self.is_admin || roles.iter().any(|r| self.has_role(r))
179    }
180
181    /// Attach roles to the context (chainable).
182    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
183        self.roles = roles;
184        self
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Constant-time comparison
190// ---------------------------------------------------------------------------
191
192/// Constant-time byte comparison to prevent timing attacks.
193///
194/// The length check leaks whether the two slices are the same length, but the
195/// content comparison always examines every byte regardless of where (or
196/// whether) a mismatch occurs.
197pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
198    if a.len() != b.len() {
199        return false;
200    }
201    let mut result: u8 = 0;
202    for (x, y) in a.iter().zip(b.iter()) {
203        result |= x ^ y;
204    }
205    result == 0
206}
207
208// ---------------------------------------------------------------------------
209// Auth mode — matches the route "auth" field values
210// ---------------------------------------------------------------------------
211
212/// The auth mode declared on a route.
213#[derive(Debug, Clone, PartialEq, Eq)]
214pub enum AuthMode {
215    /// Anyone can access.
216    Public,
217    /// Only authenticated users can access.
218    User,
219}
220
221impl AuthMode {
222    /// Parse from the manifest auth string.
223    #[allow(clippy::should_implement_trait)]
224    pub fn from_str(s: &str) -> Option<Self> {
225        match s {
226            "public" => Some(AuthMode::Public),
227            "user" => Some(AuthMode::User),
228            _ => None,
229        }
230    }
231
232    /// Check if the given auth context satisfies this mode.
233    pub fn check(&self, ctx: &AuthContext) -> bool {
234        match self {
235            AuthMode::Public => true,
236            AuthMode::User => ctx.is_authenticated(),
237        }
238    }
239}
240
241// ---------------------------------------------------------------------------
242// Session — opaque token session
243// ---------------------------------------------------------------------------
244
245/// A session token and its associated user.
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
247pub struct Session {
248    pub token: String,
249    pub user_id: String,
250    /// Unix epoch seconds at which this session expires. 0 = never.
251    #[serde(default)]
252    pub expires_at: u64,
253    /// Optional user-agent / device tag recorded at session creation.
254    #[serde(default, skip_serializing_if = "Option::is_none")]
255    pub device: Option<String>,
256    /// Unix epoch seconds when the session was created.
257    #[serde(default)]
258    pub created_at: u64,
259    /// Active tenant id (selected organization). Set via
260    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
261    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
262    #[serde(default, skip_serializing_if = "Option::is_none")]
263    pub tenant_id: Option<String>,
264}
265
266impl Session {
267    /// Default session lifetime: 30 days.
268    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
269
270    /// Create a new session with a generated token and default 30-day expiry.
271    pub fn new(user_id: String) -> Self {
272        let now = now_secs();
273        Self {
274            token: generate_token(),
275            user_id,
276            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
277            device: None,
278            created_at: now,
279            tenant_id: None,
280        }
281    }
282
283    /// Create a session with a specific lifetime.
284    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
285        let now = now_secs();
286        Self {
287            token: generate_token(),
288            user_id,
289            expires_at: if lifetime_secs == 0 {
290                0
291            } else {
292                now.saturating_add(lifetime_secs)
293            },
294            device: None,
295            created_at: now,
296            tenant_id: None,
297        }
298    }
299
300    /// Convert this session to an auth context, carrying the selected
301    /// tenant so row-scoped policies see `auth.tenantId`.
302    pub fn to_auth_context(&self) -> AuthContext {
303        let ctx = AuthContext::authenticated(self.user_id.clone());
304        match &self.tenant_id {
305            Some(t) => ctx.with_tenant(t.clone()),
306            None => ctx,
307        }
308    }
309
310    /// Returns true if the session has passed its expires_at time.
311    /// Boundary is inclusive (`>=`) to match the rest of the codebase
312    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
313    pub fn is_expired(&self) -> bool {
314        self.expires_at != 0 && now_secs() >= self.expires_at
315    }
316}
317
318fn now_secs() -> u64 {
319    use std::time::{SystemTime, UNIX_EPOCH};
320    SystemTime::now()
321        .duration_since(UNIX_EPOCH)
322        .unwrap_or_default()
323        .as_secs()
324}
325
326// ---------------------------------------------------------------------------
327// OAuth provider config
328// ---------------------------------------------------------------------------
329
330#[derive(Debug, Clone, Default, Serialize, Deserialize)]
331pub struct OAuthConfig {
332    pub provider: String,
333    pub client_id: String,
334    pub client_secret: String,
335    pub redirect_uri: String,
336    /// Optional scope override — replaces the spec's default scope
337    /// when set. Use cases: requesting `repo` on GitHub for app
338    /// installation flows, requesting `https://www.googleapis.com/...`
339    /// scopes on Google for app-specific data access.
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub scopes_override: Option<String>,
342    /// Tenant id for Microsoft/Entra. Defaults to `common`. Single-
343    /// tenant apps use a directory GUID; multi-tenant work-only apps
344    /// use `organizations`.
345    #[serde(default, skip_serializing_if = "Option::is_none")]
346    pub tenant: Option<String>,
347    /// Apple-specific extras (team id, key id, ES256 PEM). Required
348    /// for Sign in with Apple — ignored for any other provider.
349    #[serde(default, skip_serializing_if = "Option::is_none")]
350    pub apple: Option<provider::AppleConfig>,
351    /// OIDC issuer URL when this config targets a generic-OIDC
352    /// provider (Auth0, Okta, Keycloak, Cognito, etc.). When set,
353    /// the runtime fetches `<issuer>/.well-known/openid-configuration`
354    /// and synthesizes a [`provider::ProviderSpec`] from the
355    /// discovered endpoints.
356    #[serde(default, skip_serializing_if = "Option::is_none")]
357    pub oidc_issuer: Option<String>,
358}
359
360impl OAuthConfig {
361    /// Resolve the [`provider::ProviderSpec`] backing this config. For
362    /// `oidc_issuer`-configured providers, falls through to the OIDC
363    /// discovery cache. Errors propagate so misconfigured providers
364    /// fail loudly at first use rather than silently 404'ing later.
365    fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
366        if let Some(issuer) = self.oidc_issuer.as_deref() {
367            return provider::oidc_cache::resolve(issuer);
368        }
369        provider::find_spec(&self.provider)
370            .map(provider::ResolvedSpec::Static)
371            .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
372    }
373
374    /// Build a [`provider::ProviderConfig`] view of `self` for the
375    /// helpers in [`provider`] that take the runtime config.
376    fn provider_cfg(&self) -> provider::ProviderConfig {
377        provider::ProviderConfig {
378            provider: self.provider.clone(),
379            client_id: self.client_id.clone(),
380            client_secret: self.client_secret.clone(),
381            redirect_uri: self.redirect_uri.clone(),
382            scopes_override: self.scopes_override.clone(),
383            tenant: self.tenant.clone(),
384            apple: self.apple.clone(),
385            oidc_issuer: self.oidc_issuer.clone(),
386        }
387    }
388
389    /// Generate the authorization URL for the provider.
390    ///
391    /// Callers MUST append a `&state=<random>` parameter and validate it in the
392    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
393    /// implementation.
394    ///
395    /// For PKCE-required providers (Twitter/X, Kick), callers should
396    /// prefer [`Self::auth_url_with_pkce`] so the `code_challenge`
397    /// pair survives to the callback.
398    pub fn auth_url(&self) -> String {
399        match self.build_auth_url(None) {
400            Ok(u) => u,
401            Err(_) => String::new(),
402        }
403    }
404
405    /// Generate the authorization URL with a CSRF state parameter attached.
406    pub fn auth_url_with_state(&self, state: &str) -> String {
407        let base = self.auth_url();
408        if base.is_empty() {
409            return base;
410        }
411        format!("{}&state={}", base, url_encode(state))
412    }
413
414    /// Generate the authorization URL with state + a freshly minted
415    /// PKCE pair when the provider requires it. Returns
416    /// `(url, code_verifier)` — the verifier MUST be persisted in
417    /// the OAuth state record and replayed in the token exchange.
418    pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
419        let spec = self.resolved_spec()?;
420        let pkce = if spec.requires_pkce() {
421            Some(generate_pkce())
422        } else {
423            None
424        };
425        let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
426        let mut url = self.build_auth_url(challenge)?;
427        if !state.is_empty() {
428            url.push_str(&format!("&state={}", url_encode(state)));
429        }
430        Ok((url, pkce.map(|p| p.code_verifier)))
431    }
432
433    fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
434        let spec = self.resolved_spec()?;
435        let cfg = self.provider_cfg();
436        let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
437        if auth.is_empty() {
438            return Err(format!(
439                "provider {} has no authorization endpoint",
440                self.provider
441            ));
442        }
443        let scopes_default = spec.scopes().to_string();
444        let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
445        // Re-join scopes with the provider's separator (TikTok uses
446        // commas, everyone else uses spaces). Splitting on whitespace
447        // first lets developers always specify scopes the human way.
448        let scopes_joined = scopes_raw
449            .split_whitespace()
450            .collect::<Vec<_>>()
451            .join(spec.scope_separator());
452
453        let mut url = format!(
454            "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
455            cid_param = spec.client_id_param(),
456            cid = url_encode(&self.client_id),
457            ruri = url_encode(&self.redirect_uri),
458            scope = url_encode(&scopes_joined),
459        );
460        if !spec.auth_query_extra().is_empty() {
461            url.push('&');
462            url.push_str(spec.auth_query_extra());
463        }
464        if let Some(challenge) = pkce_challenge {
465            url.push_str("&code_challenge=");
466            url.push_str(challenge);
467            url.push_str("&code_challenge_method=S256");
468        }
469        Ok(url)
470    }
471
472    /// Generate the token exchange URL.
473    pub fn token_url(&self) -> String {
474        match self.resolved_spec() {
475            Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
476            Err(_) => String::new(),
477        }
478    }
479
480    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
481    pub fn userinfo_url(&self) -> String {
482        match self.resolved_spec() {
483            Ok(spec) => match spec.userinfo_url() {
484                Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
485                None => String::new(),
486            },
487            Err(_) => String::new(),
488        }
489    }
490
491    /// Exchange an authorization code for the full token set
492    /// (`access_token`, optional `refresh_token`, optional `id_token`,
493    /// `expires_in`, `scope`). When the provider uses PKCE,
494    /// `code_verifier` MUST be supplied (the value previously returned
495    /// from [`Self::auth_url_with_pkce`]).
496    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
497        self.exchange_code_full_pkce(code, None)
498    }
499
500    pub fn exchange_code_full_pkce(
501        &self,
502        code: &str,
503        code_verifier: Option<&str>,
504    ) -> Result<TokenSet, String> {
505        let spec = self.resolved_spec()?;
506        let cfg = self.provider_cfg();
507        let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
508        let pkce_field = code_verifier
509            .map(|v| format!("&code_verifier={}", url_encode(v)))
510            .unwrap_or_default();
511
512        let out = match spec.token_exchange() {
513            provider::TokenExchangeShape::Standard => {
514                let body = format!(
515                    "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
516                    code = url_encode(code),
517                    cid_param = spec.client_id_param(),
518                    cid = url_encode(&self.client_id),
519                    secret = url_encode(&self.client_secret),
520                    ruri = url_encode(&self.redirect_uri),
521                    pkce = pkce_field,
522                );
523                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
524            }
525            provider::TokenExchangeShape::AppleJwt => {
526                let apple = self.apple.as_ref().ok_or(
527                    "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
528                )?;
529                let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
530                let body = format!(
531                    "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
532                    code = url_encode(code),
533                    cid = url_encode(&self.client_id),
534                    secret = url_encode(&signed_secret),
535                    ruri = url_encode(&self.redirect_uri),
536                    pkce = pkce_field,
537                );
538                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
539            }
540            provider::TokenExchangeShape::BasicAuth => {
541                let body = format!(
542                    "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
543                    code = url_encode(code),
544                    ruri = url_encode(&self.redirect_uri),
545                    pkce = pkce_field,
546                );
547                http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
548                    .map_err(sanitize_token_error)?
549            }
550            provider::TokenExchangeShape::JsonBody => {
551                let mut json = serde_json::Map::new();
552                json.insert("grant_type".into(), "authorization_code".into());
553                json.insert("code".into(), code.into());
554                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
555                json.insert("client_id".into(), self.client_id.clone().into());
556                json.insert("client_secret".into(), self.client_secret.clone().into());
557                if let Some(v) = code_verifier {
558                    json.insert("code_verifier".into(), v.to_string().into());
559                }
560                let body = serde_json::Value::Object(json).to_string();
561                http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
562            }
563            provider::TokenExchangeShape::BasicAuthJsonBody => {
564                let mut json = serde_json::Map::new();
565                json.insert("grant_type".into(), "authorization_code".into());
566                json.insert("code".into(), code.into());
567                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
568                if let Some(v) = code_verifier {
569                    json.insert("code_verifier".into(), v.to_string().into());
570                }
571                let body = serde_json::Value::Object(json).to_string();
572                http_post_json(
573                    &token_url,
574                    &body,
575                    Some((&self.client_id, &self.client_secret)),
576                )
577                .map_err(sanitize_token_error)?
578            }
579        };
580        parse_token_response(&out)
581    }
582
583    /// Exchange an authorization code for an access token. Thin wrapper
584    /// around [`OAuthConfig::exchange_code_full`] for callers that only
585    /// need the access token.
586    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
587        Ok(self.exchange_code_full(code)?.access_token)
588    }
589
590    /// Fetch the authenticated user's email + display name using an access token.
591    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
592        let info = self.fetch_userinfo_full(access_token)?;
593        Ok((info.email, info.name))
594    }
595
596    /// Fetch the authenticated user's full identity info — email + name +
597    /// the provider-stable account ID. Uses the spec's
598    /// [`provider::UserinfoParser`] so adding a new provider is a
599    /// table change, not a new branch.
600    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
601        // The id_token from the token response carries the identity
602        // for Apple and similar; route to the dedicated entry point.
603        // Apple's userinfo_url is None — this is the supported path.
604        self.fetch_userinfo_with_id_token(access_token, None)
605    }
606
607    /// Fetch userinfo, falling back to the supplied id_token JWT when
608    /// the provider has no userinfo endpoint (Apple). The id_token
609    /// is the one returned by [`Self::exchange_code_full`] in
610    /// [`TokenSet::id_token`].
611    pub fn fetch_userinfo_with_id_token(
612        &self,
613        access_token: &str,
614        id_token: Option<&str>,
615    ) -> Result<UserInfo, String> {
616        let spec = self.resolved_spec()?;
617        let cfg = self.provider_cfg();
618
619        // Apple — identity lives in the id_token, not a userinfo endpoint.
620        if matches!(spec.userinfo_parser(), provider::UserinfoParser::AppleIdToken) {
621            let token = id_token
622                .ok_or("apple login requires the id_token from the token response")?;
623            return parse_apple_id_token(token, &self.provider);
624        }
625
626        // Linear is GraphQL — the userinfo "GET" is actually a POST
627        // with a fixed query.
628        if matches!(spec.userinfo_parser(), provider::UserinfoParser::LinearGraphql) {
629            return fetch_linear_userinfo(&self.provider, access_token);
630        }
631
632        let url = match spec.userinfo_url() {
633            Some(u) => provider::resolve_endpoint(u, &cfg),
634            None => return Err(format!("provider {} has no userinfo endpoint", self.provider)),
635        };
636        let out = match spec.userinfo_method() {
637            provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
638            provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
639        }
640        .map_err(sanitize_token_error)?;
641        let parsed: serde_json::Value =
642            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
643
644        match spec.userinfo_parser() {
645            provider::UserinfoParser::Oidc => {
646                let email = parsed
647                    .get("email")
648                    .and_then(|v| v.as_str())
649                    .ok_or("no email in userinfo")?
650                    .to_string();
651                let name = parsed
652                    .get("name")
653                    .and_then(|v| v.as_str())
654                    .map(String::from);
655                let provider_account_id = parsed
656                    .get("sub")
657                    .and_then(|v| v.as_str())
658                    .ok_or("no sub in userinfo")?
659                    .to_string();
660                Ok(UserInfo {
661                    provider: self.provider.clone(),
662                    provider_account_id,
663                    email,
664                    name,
665                })
666            }
667            provider::UserinfoParser::GitHub => {
668                let name = parsed
669                    .get("name")
670                    .and_then(|v| v.as_str())
671                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
672                    .map(String::from);
673                let email = parsed
674                    .get("email")
675                    .and_then(|v| v.as_str())
676                    .map(String::from);
677                let email = email
678                    .or_else(|| fetch_github_primary_email(access_token).ok())
679                    .ok_or("no accessible email on GitHub account")?;
680                let provider_account_id = parsed
681                    .get("id")
682                    .map(|v| {
683                        v.as_i64()
684                            .map(|n| n.to_string())
685                            .or_else(|| v.as_str().map(String::from))
686                            .unwrap_or_default()
687                    })
688                    .filter(|s| !s.is_empty())
689                    .ok_or("no id in userinfo")?;
690                Ok(UserInfo {
691                    provider: self.provider.clone(),
692                    provider_account_id,
693                    email,
694                    name,
695                })
696            }
697            provider::UserinfoParser::Custom {
698                id_path,
699                email_path,
700                name_path,
701            } => {
702                let provider_account_id = json_pointer_string(&parsed, id_path)
703                    .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
704                let raw_email = json_pointer_string(&parsed, email_path)
705                    .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
706                // Twitter/Reddit don't expose real emails — they map a
707                // username into the email slot. Tag it so account
708                // policies can distinguish "real verified email" from
709                // "we made this up." `.invalid` is reserved by RFC 6761.
710                let email = if !raw_email.contains('@') {
711                    let domain = match self.provider.as_str() {
712                        "twitter" => "x.invalid",
713                        "reddit" => "reddit.invalid",
714                        other => return Err(format!(
715                            "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
716                        )),
717                    };
718                    format!("{raw_email}@{domain}")
719                } else {
720                    raw_email
721                };
722                let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
723                Ok(UserInfo {
724                    provider: self.provider.clone(),
725                    provider_account_id,
726                    email,
727                    name,
728                })
729            }
730            provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
731            provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
732        }
733    }
734}
735
736/// PKCE pair — the verifier stays server-side until token exchange,
737/// the (S256-hashed) challenge goes on the auth URL.
738struct PkcePair {
739    code_verifier: String,
740    code_challenge: String,
741}
742
743/// Generate a PKCE pair: random 43-char verifier + S256 challenge.
744/// RFC 7636 §4.1 permits 43–128 chars from `[A-Za-z0-9-._~]`. 32
745/// random bytes URL-base64-encoded comes out to exactly 43 chars.
746fn generate_pkce() -> PkcePair {
747    use rand::RngCore;
748    let mut bytes = [0u8; 32];
749    rand::thread_rng().fill_bytes(&mut bytes);
750    let code_verifier = apple_jwt::base64_url(bytes);
751    use sha2::{Digest, Sha256};
752    let mut hasher = Sha256::new();
753    hasher.update(code_verifier.as_bytes());
754    let code_challenge = apple_jwt::base64_url(hasher.finalize());
755    PkcePair {
756        code_verifier,
757        code_challenge,
758    }
759}
760
761/// Decode an Apple id_token JWT and pull the identity claims.
762///
763/// **Trust assumption:** the caller MUST have obtained this token
764/// via the back-channel `/auth/token` exchange (mutually authenticated
765/// TLS to `appleid.apple.com`). Under that assumption no third party
766/// can have substituted a forged JWT, so we skip signature
767/// verification.
768///
769/// **DO NOT call this on a JWT supplied by the client** (e.g. a
770/// "post your id_token to me" mobile-SDK flow). For those paths,
771/// implement Apple JWKS verification: fetch
772/// `https://appleid.apple.com/auth/keys`, verify the RS256
773/// signature, then check `iss == "https://appleid.apple.com"`,
774/// `aud == client_id`, and `exp > now`. Pylon doesn't ship that
775/// verifier yet — apps that need it can compose `crate::jwt::verify`
776/// against a JWKS-loaded RSA key.
777///
778/// This function is private (`fn`, not `pub fn`) precisely so it
779/// can't be misused by an external caller. The only call site is
780/// [`OAuthConfig::fetch_userinfo_with_id_token`] which is reached
781/// only via the OAuth callback handler, which only processes
782/// back-channel-exchanged tokens.
783fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
784    let mut parts = id_token.split('.');
785    let _header = parts.next().ok_or("apple id_token: missing header")?;
786    let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
787    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
788    let claims_bytes = URL_SAFE_NO_PAD
789        .decode(claims_b64)
790        .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
791    let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
792        .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
793    let provider_account_id = claims
794        .get("sub")
795        .and_then(|v| v.as_str())
796        .ok_or("apple id_token: missing sub")?
797        .to_string();
798    let email = claims
799        .get("email")
800        .and_then(|v| v.as_str())
801        .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
802        .to_string();
803    Ok(UserInfo {
804        provider: provider.to_string(),
805        provider_account_id,
806        email,
807        name: None, // Apple sends `name` as a separate form field on FIRST signup only.
808    })
809}
810
811/// Strip provider error bodies of secrets before they propagate to
812/// logs / `oauth_error_message` redirect URLs.
813///
814/// **Why:** Several token endpoints echo the request body (or pieces
815/// of it) on auth failure. Without this, a misconfigured deployment
816/// can leak `client_secret`, the Apple JWT, or even the auth `code`
817/// into the user's browser history and CDN logs.
818///
819/// Covers both shapes echoed by real providers:
820///   - form / query: `client_secret=sk_…`
821///   - JSON: `"client_secret":"sk_…"` (Notion, Atlassian)
822fn sanitize_token_error(err: String) -> String {
823    const SENSITIVE: &[&str] = &[
824        "client_secret",
825        "code_verifier",
826        "client_assertion",
827        "refresh_token",
828        "access_token",
829        "id_token",
830        // The auth `code` itself is single-use but still sensitive
831        // until the token endpoint consumes it — and many providers
832        // echo it back on a 4xx token-exchange error before the
833        // attacker has had a chance to redeem it.
834        "code",
835    ];
836    let mut out = err;
837    for key in SENSITIVE {
838        out = redact_param_form(&out, key);
839        out = redact_param_json(&out, key);
840    }
841    out
842}
843
844/// Replace the value of `key=…` (form/query string) with `***`,
845/// terminating at any of `& \n " '`. UTF-8 safe — uses `char_indices`
846/// so a stray multibyte character before a sensitive key won't panic.
847fn redact_param_form(input: &str, key: &str) -> String {
848    let needle = format!("{key}=");
849    let mut out = String::with_capacity(input.len());
850    let mut i = 0;
851    while i < input.len() {
852        if input[i..].starts_with(&needle) {
853            out.push_str(&needle);
854            out.push_str("***");
855            i += needle.len();
856            // Skip until a terminator. char_indices keeps i aligned
857            // to char boundaries.
858            while let Some((rel, ch)) = input[i..].char_indices().next() {
859                if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
860                    i += rel;
861                    break;
862                }
863                i += rel + ch.len_utf8();
864            }
865        } else {
866            // Advance by one full char to stay UTF-8 aligned.
867            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
868            out.push(ch);
869            i += ch.len_utf8();
870        }
871    }
872    out
873}
874
875/// Replace the value in `"key":"…"` with `***`. Case-sensitive,
876/// tolerant of whitespace between `:` and the value (per JSON).
877fn redact_param_json(input: &str, key: &str) -> String {
878    let needle = format!("\"{key}\"");
879    let mut out = String::with_capacity(input.len());
880    let mut i = 0;
881    while i < input.len() {
882        if !input[i..].starts_with(&needle) {
883            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
884            out.push(ch);
885            i += ch.len_utf8();
886            continue;
887        }
888        // Found `"key"`. Walk forward over `:` + optional whitespace,
889        // then `"`, then the value, then closing `"`. If anything
890        // is off (not actually a string-valued field) bail and
891        // copy verbatim.
892        let mut j = i + needle.len();
893        // optional whitespace
894        while let Some((_, ch)) = input[j..].char_indices().next() {
895            if !ch.is_whitespace() {
896                break;
897            }
898            j += ch.len_utf8();
899        }
900        if !input[j..].starts_with(':') {
901            // Not a key-value form (could be in an array, etc.).
902            out.push_str(&input[i..j]);
903            i = j;
904            continue;
905        }
906        j += 1;
907        while let Some((_, ch)) = input[j..].char_indices().next() {
908            if !ch.is_whitespace() {
909                break;
910            }
911            j += ch.len_utf8();
912        }
913        if !input[j..].starts_with('"') {
914            out.push_str(&input[i..j]);
915            i = j;
916            continue;
917        }
918        let value_start = j + 1;
919        // Find the closing `"`, honoring `\"` escapes.
920        let mut k = value_start;
921        let mut prev_backslash = false;
922        let mut closing: Option<usize> = None;
923        while k < input.len() {
924            let (_, ch) = input[k..].char_indices().next().expect("non-empty");
925            if ch == '"' && !prev_backslash {
926                closing = Some(k);
927                break;
928            }
929            prev_backslash = ch == '\\' && !prev_backslash;
930            k += ch.len_utf8();
931        }
932        match closing {
933            Some(end) => {
934                out.push_str(&input[i..value_start]);
935                out.push_str("***");
936                out.push('"');
937                i = end + 1;
938            }
939            None => {
940                // Malformed JSON, redact to end of input to be safe.
941                out.push_str(&input[i..value_start]);
942                out.push_str("***");
943                i = input.len();
944            }
945        }
946    }
947    out
948}
949
950/// Linear's userinfo lives behind a GraphQL endpoint — the bearer
951/// token is the same OAuth access token, but the request is a POST
952/// with a fixed query. Kept as a separate fn so the main fetcher
953/// stays uniform across the other parsers.
954fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
955    let body = r#"{"query":"query { viewer { id email name } }"}"#;
956    let agent = ureq_agent();
957    let resp = agent
958        .post("https://api.linear.app/graphql")
959        .set("Authorization", &format!("Bearer {access_token}"))
960        .set("Content-Type", "application/json")
961        .set("Accept", "application/json")
962        .send_string(body)
963        .map_err(|e| format!("linear graphql: {e}"))?;
964    let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
965    let parsed: serde_json::Value = serde_json::from_str(&out)
966        .map_err(|e| format!("linear graphql not JSON: {e}"))?;
967    let viewer = parsed
968        .pointer("/data/viewer")
969        .ok_or("linear graphql: no /data/viewer")?;
970    let provider_account_id = viewer
971        .get("id")
972        .and_then(|v| v.as_str())
973        .ok_or("linear graphql: no id")?
974        .to_string();
975    let email = viewer
976        .get("email")
977        .and_then(|v| v.as_str())
978        .ok_or("linear graphql: no email")?
979        .to_string();
980    let name = viewer.get("name").and_then(|v| v.as_str()).map(String::from);
981    Ok(UserInfo {
982        provider: provider.to_string(),
983        provider_account_id,
984        email,
985        name,
986    })
987}
988
989/// JSON-pointer (RFC 6901) string extraction. Returns `None` for
990/// missing paths or non-string values. Numeric ids (Discord's `id`,
991/// Roblox's `sub`) are coerced to strings.
992fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
993    let node = v.pointer(path)?;
994    if let Some(s) = node.as_str() {
995        return Some(s.to_string());
996    }
997    if let Some(n) = node.as_i64() {
998        return Some(n.to_string());
999    }
1000    if let Some(n) = node.as_u64() {
1001        return Some(n.to_string());
1002    }
1003    None
1004}
1005
1006/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
1007/// `provider_account_id` is the provider-stable subject id (Google `sub`,
1008/// GitHub numeric `id`) — what the account store keys on so a renamed
1009/// email doesn't orphan the pylon account.
1010#[derive(Debug, Clone, PartialEq, Eq)]
1011pub struct UserInfo {
1012    pub provider: String,
1013    pub provider_account_id: String,
1014    pub email: String,
1015    pub name: Option<String>,
1016}
1017
1018/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
1019/// on the matching `Account` row so `refresh_token` is available for
1020/// silent re-auth and `expires_at` is checked before each provider call.
1021#[derive(Debug, Clone, PartialEq, Eq)]
1022pub struct TokenSet {
1023    pub access_token: String,
1024    pub refresh_token: Option<String>,
1025    pub id_token: Option<String>,
1026    /// Unix epoch seconds at which the access token expires. `None` when
1027    /// the provider didn't return `expires_in` (GitHub's classic OAuth
1028    /// app tokens are non-expiring).
1029    pub expires_at: Option<u64>,
1030    pub scope: Option<String>,
1031}
1032
1033fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1034    // Most providers return JSON; GitHub Classic apps return form-urlencoded
1035    // unless you ask with Accept: application/json (which we do).
1036    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1037        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
1038        let mut map = serde_json::Map::new();
1039        for pair in body.split('&') {
1040            if let Some((k, v)) = pair.split_once('=') {
1041                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1042            }
1043        }
1044        serde_json::Value::Object(map)
1045    });
1046
1047    let access_token = json
1048        .get("access_token")
1049        .and_then(|v| v.as_str())
1050        .ok_or_else(|| format!("no access_token in token response: {body}"))?
1051        .to_string();
1052    let refresh_token = json
1053        .get("refresh_token")
1054        .and_then(|v| v.as_str())
1055        .map(String::from);
1056    let id_token = json
1057        .get("id_token")
1058        .and_then(|v| v.as_str())
1059        .map(String::from);
1060    let expires_at = json
1061        .get("expires_in")
1062        .and_then(|v| {
1063            v.as_u64()
1064                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1065        })
1066        .map(|secs| now_secs().saturating_add(secs));
1067    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1068    Ok(TokenSet {
1069        access_token,
1070        refresh_token,
1071        id_token,
1072        expires_at,
1073        scope,
1074    })
1075}
1076
1077fn url_encode(s: &str) -> String {
1078    let mut out = String::with_capacity(s.len());
1079    for b in s.bytes() {
1080        match b {
1081            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1082                out.push(b as char)
1083            }
1084            _ => out.push_str(&format!("%{b:02X}")),
1085        }
1086    }
1087    out
1088}
1089
1090/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
1091/// provider doesn't block a login indefinitely; long enough to absorb
1092/// typical internet latency.
1093const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1094
1095fn ureq_agent() -> ureq::Agent {
1096    ureq::AgentBuilder::new()
1097        .timeout_connect(HTTP_TIMEOUT)
1098        .timeout_read(HTTP_TIMEOUT)
1099        .timeout_write(HTTP_TIMEOUT)
1100        .user_agent("pylon/0.1")
1101        .build()
1102}
1103
1104fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1105    let agent = ureq_agent();
1106    let mut req = agent
1107        .post(url)
1108        .set("Content-Type", "application/x-www-form-urlencoded");
1109    if accept_json {
1110        req = req.set("Accept", "application/json");
1111    }
1112    match req.send_string(body) {
1113        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1114        Err(ureq::Error::Status(code, resp)) => {
1115            let body = resp.into_string().unwrap_or_default();
1116            Err(format!("HTTP {code}: {body}"))
1117        }
1118        Err(e) => Err(format!("HTTP error: {e}")),
1119    }
1120}
1121
1122/// POST a form body using HTTP Basic auth for the client credentials.
1123/// Used by Spotify, Reddit, Figma, Zoom, PayPal — providers that
1124/// mandate Basic auth on the token endpoint.
1125fn http_post_form_basic(
1126    url: &str,
1127    body: &str,
1128    client_id: &str,
1129    client_secret: &str,
1130) -> Result<String, String> {
1131    use base64::{engine::general_purpose::STANDARD, Engine};
1132    let creds = format!("{client_id}:{client_secret}");
1133    let basic = STANDARD.encode(creds.as_bytes());
1134    let agent = ureq_agent();
1135    match agent
1136        .post(url)
1137        .set("Content-Type", "application/x-www-form-urlencoded")
1138        .set("Accept", "application/json")
1139        .set("Authorization", &format!("Basic {basic}"))
1140        .send_string(body)
1141    {
1142        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1143        Err(ureq::Error::Status(code, resp)) => {
1144            let body = resp.into_string().unwrap_or_default();
1145            Err(format!("HTTP {code}: {body}"))
1146        }
1147        Err(e) => Err(format!("HTTP error: {e}")),
1148    }
1149}
1150
1151/// POST a JSON body, optionally with HTTP Basic auth. Used by
1152/// Notion (Basic + JSON) and Atlassian (JSON only) — both reject
1153/// form-encoded bodies on their token endpoints.
1154fn http_post_json(
1155    url: &str,
1156    body: &str,
1157    basic_creds: Option<(&str, &str)>,
1158) -> Result<String, String> {
1159    let agent = ureq_agent();
1160    let mut req = agent
1161        .post(url)
1162        .set("Content-Type", "application/json")
1163        .set("Accept", "application/json");
1164    if let Some((id, secret)) = basic_creds {
1165        use base64::{engine::general_purpose::STANDARD, Engine};
1166        let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1167        req = req.set("Authorization", &format!("Basic {creds}"));
1168    }
1169    // Notion requires the API version header on every call, even the
1170    // token exchange. Using a recent stable version.
1171    req = req.set("Notion-Version", "2022-06-28");
1172    match req.send_string(body) {
1173        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1174        Err(ureq::Error::Status(code, resp)) => {
1175            let body = resp.into_string().unwrap_or_default();
1176            Err(format!("HTTP {code}: {body}"))
1177        }
1178        Err(e) => Err(format!("HTTP error: {e}")),
1179    }
1180}
1181
1182/// POST with empty body + bearer auth. Used for Dropbox userinfo
1183/// (an RPC-style endpoint that requires POST instead of GET).
1184fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1185    let agent = ureq_agent();
1186    match agent
1187        .post(url)
1188        .set("Authorization", &format!("Bearer {token}"))
1189        .set("Accept", "application/json")
1190        .call()
1191    {
1192        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1193        Err(ureq::Error::Status(code, resp)) => {
1194            let body = resp.into_string().unwrap_or_default();
1195            Err(format!("HTTP {code}: {body}"))
1196        }
1197        Err(e) => Err(format!("HTTP error: {e}")),
1198    }
1199}
1200
1201fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1202    let agent = ureq_agent();
1203    match agent
1204        .get(url)
1205        .set("Authorization", &format!("Bearer {token}"))
1206        .set("Accept", "application/json")
1207        .call()
1208    {
1209        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1210        Err(ureq::Error::Status(code, resp)) => {
1211            let body = resp.into_string().unwrap_or_default();
1212            Err(format!("HTTP {code}: {body}"))
1213        }
1214        Err(e) => Err(format!("HTTP error: {e}")),
1215    }
1216}
1217
1218fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1219    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1220    let emails: serde_json::Value =
1221        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1222    emails
1223        .as_array()
1224        .and_then(|arr| {
1225            arr.iter()
1226                .find(|e| {
1227                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1228                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1229                })
1230                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1231        })
1232        .ok_or_else(|| "no primary verified email on GitHub".into())
1233}
1234
1235/// OAuth provider registry.
1236pub struct OAuthRegistry {
1237    providers: std::collections::HashMap<String, OAuthConfig>,
1238}
1239
1240impl Default for OAuthRegistry {
1241    fn default() -> Self {
1242        Self::new()
1243    }
1244}
1245
1246impl OAuthRegistry {
1247    pub fn new() -> Self {
1248        Self {
1249            providers: std::collections::HashMap::new(),
1250        }
1251    }
1252
1253    pub fn register(&mut self, config: OAuthConfig) {
1254        self.providers.insert(config.provider.clone(), config);
1255    }
1256
1257    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1258        self.providers.get(provider)
1259    }
1260
1261    /// Build from environment variables.
1262    ///
1263    /// For each builtin provider (and any `oidc_issuer`-configured
1264    /// IdP), looks for `PYLON_OAUTH_<PROVIDER>_CLIENT_ID` /
1265    /// `_CLIENT_SECRET` / `_REDIRECT`. Apple additionally requires
1266    /// `_TEAM_ID`, `_KEY_ID`, `_PRIVATE_KEY` (PEM contents or path).
1267    /// Microsoft accepts an optional `_TENANT`.
1268    ///
1269    /// Generic OIDC: any env var matching
1270    /// `PYLON_OAUTH_<NAME>_OIDC_ISSUER` registers a provider with id
1271    /// `<name>` (lowercased) using the discovered endpoints. Useful
1272    /// for Auth0, Okta, Keycloak, Cognito, Logto, Authentik, etc.
1273    pub fn from_env() -> Self {
1274        let mut reg = Self::new();
1275
1276        for spec in provider::builtin::all() {
1277            let upper = spec.id.to_ascii_uppercase();
1278            let prefix = format!("PYLON_OAUTH_{upper}");
1279            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1280                Ok(v) => v,
1281                Err(_) => continue,
1282            };
1283            let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1284                Ok(v) => v,
1285                // Apple's "client_secret" is synthesized — allow blank.
1286                Err(_) if spec.id == "apple" => String::new(),
1287                Err(_) => continue,
1288            };
1289            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT")).unwrap_or_else(|_| {
1290                format!("http://localhost:3000/api/auth/callback/{}", spec.id)
1291            });
1292            let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1293            let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1294
1295            let apple = if spec.id == "apple" {
1296                match (
1297                    std::env::var(format!("{prefix}_TEAM_ID")),
1298                    std::env::var(format!("{prefix}_KEY_ID")),
1299                    std::env::var(format!("{prefix}_PRIVATE_KEY")),
1300                ) {
1301                    (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1302                        team_id,
1303                        key_id,
1304                        private_key_pem,
1305                    }),
1306                    _ => continue, // Apple requires the JWT material to function.
1307                }
1308            } else {
1309                None
1310            };
1311
1312            reg.register(OAuthConfig {
1313                provider: spec.id.to_string(),
1314                client_id: id,
1315                client_secret: secret,
1316                redirect_uri,
1317                scopes_override,
1318                tenant,
1319                apple,
1320                oidc_issuer: None,
1321            });
1322        }
1323
1324        // Generic OIDC providers — scan PYLON_OAUTH_<NAME>_OIDC_ISSUER.
1325        for (key, issuer) in std::env::vars() {
1326            let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1327                continue;
1328            };
1329            let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1330                continue;
1331            };
1332            let name = name_upper.to_ascii_lowercase();
1333            if provider::find_spec(&name).is_some() {
1334                continue; // already handled as a builtin
1335            }
1336            let prefix = format!("PYLON_OAUTH_{name_upper}");
1337            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1338                Ok(v) => v,
1339                Err(_) => continue,
1340            };
1341            let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1342            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1343                .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1344            reg.register(OAuthConfig {
1345                provider: name,
1346                client_id: id,
1347                client_secret: secret,
1348                redirect_uri,
1349                scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1350                tenant: None,
1351                apple: None,
1352                oidc_issuer: Some(issuer),
1353            });
1354        }
1355
1356        reg
1357    }
1358
1359    /// Iterate over registered provider ids — used by routes/auth.rs
1360    /// to expose `/api/auth/providers` and to validate
1361    /// `/api/auth/login/<id>` paths against the configured set.
1362    pub fn ids(&self) -> impl Iterator<Item = &str> {
1363        self.providers.keys().map(|s| s.as_str())
1364    }
1365
1366    /// Process-wide cached registry. Built once on first use from
1367    /// `from_env`; subsequent calls are zero-cost. Routes use this
1368    /// to avoid the ~150 syscalls `from_env` does per call.
1369    ///
1370    /// **Trade-off:** env changes after server start aren't picked up
1371    /// without a restart — same as every other Pylon env-var path.
1372    pub fn shared() -> &'static OAuthRegistry {
1373        static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1374        CELL.get_or_init(Self::from_env)
1375    }
1376}
1377
1378// ---------------------------------------------------------------------------
1379// OAuth state store — CSRF protection for OAuth flows
1380// ---------------------------------------------------------------------------
1381
1382/// One stored OAuth state record. Carries the post-callback redirect
1383/// URLs alongside the provider so the callback handler doesn't need to
1384/// consult an env var to know where to send the user. Both URLs are
1385/// validated against `PYLON_TRUSTED_ORIGINS` at create time, so the
1386/// callback can trust them without re-checking.
1387#[derive(Debug, Clone, PartialEq, Eq)]
1388pub struct OAuthState {
1389    pub provider: String,
1390    /// URL the callback redirects to on success. The frontend supplies
1391    /// this via `?callback=` on the start request.
1392    pub callback_url: String,
1393    /// URL the callback redirects to on failure. Defaults to
1394    /// `callback_url` when the frontend doesn't pass an explicit
1395    /// `?error_callback=`. The error code + message ride along as
1396    /// query params (`?oauth_error=X&oauth_error_message=Y`).
1397    pub error_callback_url: String,
1398    /// PKCE code_verifier when the provider requires PKCE. Set by the
1399    /// `/api/auth/login/<provider>` start route via
1400    /// [`OAuthConfig::auth_url_with_pkce`]; replayed on token exchange
1401    /// in the callback. `None` for non-PKCE providers.
1402    pub pkce_verifier: Option<String>,
1403    pub expires_at: u64,
1404}
1405
1406/// Backing store for OAuth state records. Default impl keeps them in
1407/// memory (fine for tests + dev); the runtime swaps in a SQLite or
1408/// Postgres backend so a restart in the middle of an OAuth handshake
1409/// doesn't leave the user with "invalid state" on the callback.
1410pub trait OAuthStateBackend: Send + Sync {
1411    /// Persist a state record under `token`.
1412    fn put(&self, token: &str, state: &OAuthState);
1413    /// Atomic compare-and-consume: returns the stored record if the
1414    /// token exists and hasn't expired, then removes it. Returning
1415    /// `None` means either the token never existed or it has already
1416    /// been used / expired.
1417    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1418}
1419
1420/// In-memory backend (default). Lost on restart.
1421pub struct InMemoryOAuthBackend {
1422    states: Mutex<HashMap<String, OAuthState>>,
1423}
1424
1425impl InMemoryOAuthBackend {
1426    pub fn new() -> Self {
1427        Self {
1428            states: Mutex::new(HashMap::new()),
1429        }
1430    }
1431}
1432
1433impl Default for InMemoryOAuthBackend {
1434    fn default() -> Self {
1435        Self::new()
1436    }
1437}
1438
1439impl OAuthStateBackend for InMemoryOAuthBackend {
1440    fn put(&self, token: &str, state: &OAuthState) {
1441        self.states
1442            .lock()
1443            .unwrap()
1444            .insert(token.to_string(), state.clone());
1445    }
1446    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1447        let mut s = self.states.lock().unwrap();
1448        let entry = s.remove(token)?;
1449        if entry.expires_at <= now_unix_secs {
1450            return None;
1451        }
1452        Some(entry)
1453    }
1454}
1455
1456/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
1457///
1458/// State tokens are short-lived (10 minutes) and single-use. Backed by an
1459/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
1460/// to SQLite (or Postgres when `DATABASE_URL` is set) so they survive a
1461/// restart that happens mid-OAuth-handshake.
1462pub struct OAuthStateStore {
1463    backend: Box<dyn OAuthStateBackend>,
1464}
1465
1466impl Default for OAuthStateStore {
1467    fn default() -> Self {
1468        Self::new()
1469    }
1470}
1471
1472impl OAuthStateStore {
1473    pub fn new() -> Self {
1474        Self {
1475            backend: Box::new(InMemoryOAuthBackend::new()),
1476        }
1477    }
1478
1479    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1480        Self { backend }
1481    }
1482
1483    /// Generate and store a new state record. Returns the random
1484    /// state token (the value the OAuth provider echoes back as
1485    /// `?state=…` on the callback).
1486    ///
1487    /// Caller is responsible for validating `callback_url` and
1488    /// `error_callback_url` against the trusted-origins allowlist
1489    /// BEFORE calling this — the store trusts what it's given.
1490    pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1491        self.create_with_pkce(provider, callback_url, error_callback_url, None)
1492    }
1493
1494    /// Same as [`Self::create`] but accepts a PKCE verifier to stash
1495    /// alongside the state record. The callback handler reads it back
1496    /// out and replays it in the token exchange.
1497    pub fn create_with_pkce(
1498        &self,
1499        provider: &str,
1500        callback_url: &str,
1501        error_callback_url: &str,
1502        pkce_verifier: Option<String>,
1503    ) -> String {
1504        use std::time::{SystemTime, UNIX_EPOCH};
1505        let token = generate_token();
1506        let now = SystemTime::now()
1507            .duration_since(UNIX_EPOCH)
1508            .unwrap_or_default()
1509            .as_secs();
1510        let state = OAuthState {
1511            provider: provider.to_string(),
1512            callback_url: callback_url.to_string(),
1513            error_callback_url: error_callback_url.to_string(),
1514            pkce_verifier,
1515            expires_at: now + 600,
1516        };
1517        self.backend.put(&token, &state);
1518        token
1519    }
1520
1521    /// Validate and consume a state token. Returns the stored record
1522    /// iff the token existed, has not expired, AND matches
1523    /// `expected_provider`. The token is removed either way to make
1524    /// replay impossible.
1525    pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1526        use std::time::{SystemTime, UNIX_EPOCH};
1527        let now = SystemTime::now()
1528            .duration_since(UNIX_EPOCH)
1529            .unwrap_or_default()
1530            .as_secs();
1531        let entry = self.backend.take(state, now)?;
1532        if entry.provider != expected_provider {
1533            return None;
1534        }
1535        Some(entry)
1536    }
1537}
1538
1539/// Validate that `url` has an origin (scheme://host[:port]) listed in
1540/// `trusted_origins`. Returns `Ok(url)` when trusted (echoes input for
1541/// chaining), `Err` with a code/message when not. Used by the OAuth
1542/// start endpoint to gate `?callback=` + `?error_callback=` values
1543/// before storing them in the state record.
1544///
1545/// `trusted_origins` entries are origin strings like
1546/// `"https://app.example.com"` or `"http://localhost:3000"` — no
1547/// trailing slash, no path. A `url` like
1548/// `"http://localhost:3000/dashboard?x=1"` matches the
1549/// `"http://localhost:3000"` entry.
1550///
1551/// Borrowed wholesale from better-auth's `trustedOrigins` model:
1552/// explicit allowlist, no implicit "same-origin trust," no env-var
1553/// magic. An open-redirect via OAuth is one of the easier auth bugs
1554/// to ship by accident.
1555pub fn validate_trusted_redirect(
1556    url: &str,
1557    trusted_origins: &[String],
1558) -> Result<(), TrustedOriginError> {
1559    if url.is_empty() {
1560        return Err(TrustedOriginError::Empty);
1561    }
1562    // Must be absolute http(s) URL — no relative paths, no schemes
1563    // like javascript:, file:, data:.
1564    if !url.starts_with("http://") && !url.starts_with("https://") {
1565        return Err(TrustedOriginError::NotHttp);
1566    }
1567    let url_origin = origin_of(url);
1568    if trusted_origins.iter().any(|t| t == &url_origin) {
1569        Ok(())
1570    } else {
1571        Err(TrustedOriginError::NotTrusted { origin: url_origin })
1572    }
1573}
1574
1575/// Reasons a redirect URL might be rejected by [`validate_trusted_redirect`].
1576#[derive(Debug, Clone, PartialEq, Eq)]
1577pub enum TrustedOriginError {
1578    Empty,
1579    NotHttp,
1580    NotTrusted { origin: String },
1581}
1582
1583impl std::fmt::Display for TrustedOriginError {
1584    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1585        match self {
1586            TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1587            TrustedOriginError::NotHttp => {
1588                write!(f, "redirect URL must use http:// or https:// scheme")
1589            }
1590            TrustedOriginError::NotTrusted { origin } => write!(
1591                f,
1592                "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1593            ),
1594        }
1595    }
1596}
1597
1598/// Extract the origin (`scheme://host[:port]`) from a URL string,
1599/// stripping any path/query/fragment. Best-effort string slicing —
1600/// no full URL parser dep. Public so router crates can reuse the same
1601/// logic when comparing redirect URLs against the trusted-origins list.
1602pub fn origin_of(url: &str) -> String {
1603    let after_scheme = match url.find("://") {
1604        Some(i) => i + 3,
1605        None => return url.trim_end_matches('/').to_string(),
1606    };
1607    let rest = &url[after_scheme..];
1608    let cut = rest
1609        .find(|c: char| c == '/' || c == '?' || c == '#')
1610        .unwrap_or(rest.len());
1611    url[..after_scheme + cut].to_string()
1612}
1613
1614// ---------------------------------------------------------------------------
1615// Magic code auth — email verification codes
1616// ---------------------------------------------------------------------------
1617
1618/// Pluggable storage for magic-code records. In-memory is the default
1619/// (fine for dev); persistent backends (SQLite, Postgres) live in
1620/// `pylon-runtime` so a server restart between "send code" and "verify
1621/// code" doesn't invalidate the user's pending login.
1622///
1623/// All methods are infallible from the caller's perspective — durability
1624/// is best-effort. A backend that fails to write should log; the
1625/// in-memory cache remains authoritative for the current process.
1626pub trait MagicCodeBackend: Send + Sync {
1627    /// Replace any existing code for `email` with `code`.
1628    fn put(&self, email: &str, code: &MagicCode);
1629    /// Look up the current code for `email`. Returns `None` if absent.
1630    fn get(&self, email: &str) -> Option<MagicCode>;
1631    /// Remove the code for `email` (called on successful verify or
1632    /// expiry). Idempotent — missing key is not an error.
1633    fn remove(&self, email: &str);
1634    /// Persist an attempts++ on the existing record without touching
1635    /// other fields. Used by the verify-failed path to enforce
1636    /// `MAX_ATTEMPTS` across restarts.
1637    fn bump_attempts(&self, email: &str);
1638    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
1639    /// hydrate the in-memory cache from durable storage on startup.
1640    fn load_all(&self) -> Vec<MagicCode>;
1641}
1642
1643/// In-memory backend for magic codes. The default — also used as the
1644/// authoritative cache by `MagicCodeStore`.
1645pub struct InMemoryMagicCodeBackend {
1646    codes: Mutex<HashMap<String, MagicCode>>,
1647}
1648
1649impl InMemoryMagicCodeBackend {
1650    pub fn new() -> Self {
1651        Self {
1652            codes: Mutex::new(HashMap::new()),
1653        }
1654    }
1655}
1656
1657impl Default for InMemoryMagicCodeBackend {
1658    fn default() -> Self {
1659        Self::new()
1660    }
1661}
1662
1663impl MagicCodeBackend for InMemoryMagicCodeBackend {
1664    fn put(&self, email: &str, code: &MagicCode) {
1665        self.codes
1666            .lock()
1667            .unwrap()
1668            .insert(email.to_string(), code.clone());
1669    }
1670    fn get(&self, email: &str) -> Option<MagicCode> {
1671        self.codes.lock().unwrap().get(email).cloned()
1672    }
1673    fn remove(&self, email: &str) {
1674        self.codes.lock().unwrap().remove(email);
1675    }
1676    fn bump_attempts(&self, email: &str) {
1677        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1678            c.attempts = c.attempts.saturating_add(1);
1679        }
1680    }
1681    fn load_all(&self) -> Vec<MagicCode> {
1682        self.codes.lock().unwrap().values().cloned().collect()
1683    }
1684}
1685
1686/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
1687/// and applies the verify/cooldown semantics. Hydrates the in-memory
1688/// cache from the backend on construction so durable backends survive
1689/// restart without losing in-flight codes.
1690pub struct MagicCodeStore {
1691    cache: Mutex<HashMap<String, MagicCode>>,
1692    backend: Box<dyn MagicCodeBackend>,
1693}
1694
1695#[derive(Debug, Clone)]
1696pub struct MagicCode {
1697    pub email: String,
1698    pub code: String,
1699    pub expires_at: u64,
1700    /// Failed verify attempts against this code. Once it reaches
1701    /// `MAX_ATTEMPTS` the code is invalidated.
1702    pub attempts: u32,
1703}
1704
1705/// Maximum verify attempts per code before it's burned. 5 is a common bound —
1706/// lets the user fix typos without enabling realistic brute-force against a
1707/// 6-digit code space.
1708const MAX_ATTEMPTS: u32 = 5;
1709
1710/// Minimum seconds between successive `create()` calls for the same email.
1711/// Throttles magic-code spam (user can't be flooded with login codes).
1712const CREATE_COOLDOWN_SECS: u64 = 60;
1713
1714#[derive(Debug, Clone, PartialEq, Eq)]
1715pub enum MagicCodeError {
1716    /// There is no active code for this email, or it expired.
1717    NotFound,
1718    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
1719    TooManyAttempts,
1720    /// The code did not match.
1721    BadCode,
1722    /// The code expired since it was created.
1723    Expired,
1724    /// Another code was requested too recently. Wait and try again.
1725    Throttled { retry_after_secs: u64 },
1726}
1727
1728impl Default for MagicCodeStore {
1729    fn default() -> Self {
1730        Self::new()
1731    }
1732}
1733
1734impl MagicCodeStore {
1735    pub fn new() -> Self {
1736        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1737    }
1738
1739    /// Build a magic-code store backed by a persistent backend. Existing
1740    /// live codes are hydrated into the in-memory cache on construction
1741    /// so a server restart between "send" and "verify" doesn't kill the
1742    /// user's pending login.
1743    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1744        let now = now_secs();
1745        let mut cache = HashMap::new();
1746        for c in backend.load_all() {
1747            if c.expires_at > now {
1748                cache.insert(c.email.clone(), c);
1749            }
1750        }
1751        Self {
1752            cache: Mutex::new(cache),
1753            backend,
1754        }
1755    }
1756
1757    /// Generate a 6-digit code for an email and return it. Subject to a
1758    /// per-email cooldown — returns the error-shape via `try_create`.
1759    pub fn create(&self, email: &str) -> String {
1760        // Back-compat wrapper: same signature as before, but we still burn
1761        // the cooldown if one is active. Use `try_create` for a Result shape.
1762        self.try_create(email).unwrap_or_else(|_| String::new())
1763    }
1764
1765    /// Create a magic code, enforcing per-email cooldown. Returns the code
1766    /// or an error describing why one couldn't be issued.
1767    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1768        let now = now_secs();
1769
1770        let mut codes = self.cache.lock().unwrap();
1771
1772        // Cooldown check: if a live code exists and was created less than
1773        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
1774        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
1775        if let Some(existing) = codes.get(email) {
1776            if existing.expires_at > now {
1777                let created_at = existing.expires_at.saturating_sub(600);
1778                let age = now.saturating_sub(created_at);
1779                if age < CREATE_COOLDOWN_SECS {
1780                    return Err(MagicCodeError::Throttled {
1781                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
1782                    });
1783                }
1784            }
1785        }
1786
1787        let code = generate_magic_code();
1788        let mc = MagicCode {
1789            email: email.to_string(),
1790            code: code.clone(),
1791            expires_at: now + 600, // 10 minutes
1792            attempts: 0,
1793        };
1794        codes.insert(email.to_string(), mc.clone());
1795        // Persist after the cache mutation lands. Backend write is
1796        // best-effort — if it fails the code still works for this
1797        // process; only a restart in the next 10 minutes would lose it.
1798        self.backend.put(email, &mc);
1799        Ok(code)
1800    }
1801
1802    /// Verify a code for an email. Returns true if valid and not expired.
1803    /// Uses constant-time comparison to prevent timing attacks.
1804    /// Back-compat wrapper around [`try_verify`].
1805    pub fn verify(&self, email: &str, code: &str) -> bool {
1806        matches!(self.try_verify(email, code), Ok(()))
1807    }
1808
1809    /// Verify a code. Returns a typed error so callers can surface specific
1810    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1811    /// correct subsequent attempts return `TooManyAttempts`.
1812    /// Every magic code currently in the cache. Powers the Studio
1813    /// "Auth tables" view; not for app use. Includes expired codes —
1814    /// the cache only drops them on next verify attempt for that email.
1815    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1816        self.cache
1817            .lock()
1818            .map(|m| m.values().cloned().collect())
1819            .unwrap_or_default()
1820    }
1821
1822    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1823        let now = now_secs();
1824        let mut codes = self.cache.lock().unwrap();
1825
1826        let mc = match codes.get_mut(email) {
1827            Some(m) => m,
1828            None => return Err(MagicCodeError::NotFound),
1829        };
1830
1831        if mc.attempts >= MAX_ATTEMPTS {
1832            return Err(MagicCodeError::TooManyAttempts);
1833        }
1834        if mc.expires_at <= now {
1835            codes.remove(email);
1836            self.backend.remove(email);
1837            return Err(MagicCodeError::Expired);
1838        }
1839
1840        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1841        if !ok {
1842            mc.attempts += 1;
1843            self.backend.bump_attempts(email);
1844            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1845            if mc.attempts >= MAX_ATTEMPTS {
1846                return Err(MagicCodeError::TooManyAttempts);
1847            }
1848            return Err(MagicCodeError::BadCode);
1849        }
1850
1851        // Correct code — consume it.
1852        codes.remove(email);
1853        self.backend.remove(email);
1854        Ok(())
1855    }
1856}
1857
1858// ---------------------------------------------------------------------------
1859// Cryptographic helpers — CSPRNG-based token and code generation
1860// ---------------------------------------------------------------------------
1861
1862fn hex_encode(bytes: &[u8]) -> String {
1863    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1864}
1865
1866/// Generate a 6-digit magic code using a CSPRNG.
1867fn generate_magic_code() -> String {
1868    use rand::Rng;
1869    let mut rng = rand::thread_rng();
1870    let code: u32 = rng.gen_range(0..1_000_000);
1871    format!("{:06}", code)
1872}
1873
1874/// Generate a session token with 256 bits of entropy from a CSPRNG.
1875fn generate_token() -> String {
1876    use rand::Rng;
1877    let mut rng = rand::thread_rng();
1878    let bytes: [u8; 32] = rng.gen();
1879    format!("pylon_{}", hex_encode(&bytes))
1880}
1881
1882// ---------------------------------------------------------------------------
1883// Session store — in-memory for dev
1884// ---------------------------------------------------------------------------
1885
1886use std::collections::HashMap;
1887use std::sync::Mutex;
1888
1889/// Pluggable storage backend for sessions. The default is in-memory; apps
1890/// deploying for real should supply a persistent backend (e.g. SQLite or
1891/// Redis) so users don't log out on server restart.
1892pub trait SessionBackend: Send + Sync {
1893    fn load_all(&self) -> Vec<Session>;
1894    fn save(&self, session: &Session);
1895    fn remove(&self, token: &str);
1896}
1897
1898/// A session store. In-memory by default; optionally backed by a
1899/// persistent [`SessionBackend`].
1900///
1901/// The in-memory map is always authoritative — reads don't touch the
1902/// backend. The backend receives every `save`/`remove`, making it a
1903/// write-through cache. On construction via [`SessionStore::with_backend`],
1904/// the store hydrates from the backend so sessions survive restart.
1905pub struct SessionStore {
1906    sessions: Mutex<HashMap<String, Session>>,
1907    backend: Option<Box<dyn SessionBackend>>,
1908    /// Default lifetime for new sessions (seconds). Sourced from the
1909    /// manifest's `auth.session.expires_in` config at server boot;
1910    /// falls back to `Session::DEFAULT_LIFETIME_SECS` (30 days).
1911    default_lifetime_secs: u64,
1912}
1913
1914impl Default for SessionStore {
1915    fn default() -> Self {
1916        Self::new()
1917    }
1918}
1919
1920impl SessionStore {
1921    pub fn new() -> Self {
1922        Self {
1923            sessions: Mutex::new(HashMap::new()),
1924            backend: None,
1925            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1926        }
1927    }
1928
1929    /// Override the default session lifetime. Used by `pylon-runtime`'s
1930    /// server bootstrap to apply the manifest's `auth.session.expires_in`.
1931    pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1932        self.default_lifetime_secs = lifetime_secs;
1933        self
1934    }
1935
1936    /// Build a session store backed by a persistent store. Existing sessions
1937    /// are loaded from the backend on construction; every future mutation
1938    /// writes through.
1939    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1940        let mut map = HashMap::new();
1941        for s in backend.load_all() {
1942            if !s.is_expired() {
1943                map.insert(s.token.clone(), s);
1944            }
1945        }
1946        Self {
1947            sessions: Mutex::new(map),
1948            backend: Some(backend),
1949            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1950        }
1951    }
1952
1953    /// Create a session for a user and return it. Uses the store's
1954    /// configured `default_lifetime_secs` (from the manifest's
1955    /// `auth.session.expires_in`, default 30 days).
1956    pub fn create(&self, user_id: String) -> Session {
1957        let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1958        let mut sessions = self.sessions.lock().unwrap();
1959        sessions.insert(session.token.clone(), session.clone());
1960        if let Some(b) = &self.backend {
1961            b.save(&session);
1962        }
1963        session
1964    }
1965
1966    /// Look up a session by token. Returns None if the session is expired.
1967    pub fn get(&self, token: &str) -> Option<Session> {
1968        let mut sessions = self.sessions.lock().unwrap();
1969        match sessions.get(token) {
1970            Some(s) if s.is_expired() => {
1971                sessions.remove(token);
1972                None
1973            }
1974            Some(s) => Some(s.clone()),
1975            None => None,
1976        }
1977    }
1978
1979    /// Resolve a token to an auth context.
1980    /// Returns anonymous context if the token is invalid, missing, or expired.
1981    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1982        match token {
1983            Some(t) => match self.get(t) {
1984                Some(session) => session.to_auth_context(),
1985                None => AuthContext::anonymous(),
1986            },
1987            None => AuthContext::anonymous(),
1988        }
1989    }
1990
1991    /// Refresh a session — issues a new token, copies user/device, extends expiry.
1992    /// The old token is revoked. Returns the new session or None if the old
1993    /// token is missing/expired.
1994    pub fn refresh(&self, old_token: &str) -> Option<Session> {
1995        let mut sessions = self.sessions.lock().unwrap();
1996        let old = sessions.remove(old_token)?;
1997        if let Some(b) = &self.backend {
1998            b.remove(old_token);
1999        }
2000        if old.is_expired() {
2001            return None;
2002        }
2003        // Use the store's configured lifetime so a manifest-set
2004        // `auth.session.expires_in` survives session refresh. Previous
2005        // bug: `Session::new(...)` baked in 30 days regardless of
2006        // config — apps with a custom lifetime got the right value on
2007        // first sign-in and lost it on the next refresh.
2008        let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2009        new.device = old.device.clone();
2010        sessions.insert(new.token.clone(), new.clone());
2011        if let Some(b) = &self.backend {
2012            b.save(&new);
2013        }
2014        Some(new)
2015    }
2016
2017    /// Every session in the store, including expired ones, with no
2018    /// filtering. Powers the Studio "Auth tables" view so operators
2019    /// can see orphaned sessions / debug stuck logins. Don't use for
2020    /// app code — `list_for_user` is the right surface there.
2021    pub fn list_all_unfiltered(&self) -> Vec<Session> {
2022        self.sessions
2023            .lock()
2024            .map(|m| m.values().cloned().collect())
2025            .unwrap_or_default()
2026    }
2027
2028    /// List all active sessions for a user.
2029    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2030        let sessions = self.sessions.lock().unwrap();
2031        sessions
2032            .values()
2033            .filter(|s| s.user_id == user_id && !s.is_expired())
2034            .cloned()
2035            .collect()
2036    }
2037
2038    /// Revoke all sessions for a user. Returns the count removed.
2039    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2040        let mut sessions = self.sessions.lock().unwrap();
2041        let tokens: Vec<String> = sessions
2042            .iter()
2043            .filter_map(|(t, s)| {
2044                if s.user_id == user_id {
2045                    Some(t.clone())
2046                } else {
2047                    None
2048                }
2049            })
2050            .collect();
2051        let n = tokens.len();
2052        for t in &tokens {
2053            sessions.remove(t);
2054            if let Some(b) = &self.backend {
2055                b.remove(t);
2056            }
2057        }
2058        n
2059    }
2060
2061    /// Sweep expired sessions. Returns the count removed.
2062    pub fn sweep_expired(&self) -> usize {
2063        let mut sessions = self.sessions.lock().unwrap();
2064        let expired: Vec<String> = sessions
2065            .iter()
2066            .filter_map(|(t, s)| {
2067                if s.is_expired() {
2068                    Some(t.clone())
2069                } else {
2070                    None
2071                }
2072            })
2073            .collect();
2074        let n = expired.len();
2075        for t in &expired {
2076            sessions.remove(t);
2077            if let Some(b) = &self.backend {
2078                b.remove(t);
2079            }
2080        }
2081        n
2082    }
2083
2084    /// Attach a device label to a session (typically on login from a browser).
2085    pub fn set_device(&self, token: &str, device: String) -> bool {
2086        let mut sessions = self.sessions.lock().unwrap();
2087        if let Some(s) = sessions.get_mut(token) {
2088            s.device = Some(device);
2089            if let Some(b) = &self.backend {
2090                b.save(s);
2091            }
2092            true
2093        } else {
2094            false
2095        }
2096    }
2097
2098    /// Create a guest session with a generated anonymous ID.
2099    pub fn create_guest(&self) -> Session {
2100        use rand::Rng;
2101        let mut rng = rand::thread_rng();
2102        let bytes: [u8; 16] = rng.gen();
2103        let guest_id = format!("guest_{}", hex_encode(&bytes));
2104        self.create(guest_id)
2105    }
2106
2107    /// Upgrade a guest session to a real user. Replaces the user_id.
2108    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2109        let mut sessions = self.sessions.lock().unwrap();
2110        if let Some(session) = sessions.get_mut(token) {
2111            session.user_id = real_user_id;
2112            if let Some(b) = &self.backend {
2113                b.save(session);
2114            }
2115            true
2116        } else {
2117            false
2118        }
2119    }
2120
2121    /// Switch the session's active tenant (organization). `None` clears it.
2122    /// Callers should verify the user actually has membership in the target
2123    /// tenant BEFORE invoking this — the session store takes the value on
2124    /// trust. Returns true if the session exists, false otherwise.
2125    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2126        let mut sessions = self.sessions.lock().unwrap();
2127        if let Some(session) = sessions.get_mut(token) {
2128            session.tenant_id = tenant_id;
2129            if let Some(b) = &self.backend {
2130                b.save(session);
2131            }
2132            true
2133        } else {
2134            false
2135        }
2136    }
2137
2138    /// Remove a session.
2139    pub fn revoke(&self, token: &str) -> bool {
2140        let mut sessions = self.sessions.lock().unwrap();
2141        let removed = sessions.remove(token).is_some();
2142        if removed {
2143            if let Some(b) = &self.backend {
2144                b.remove(token);
2145            }
2146        }
2147        removed
2148    }
2149}
2150
2151// ---------------------------------------------------------------------------
2152// OAuth account links — better-auth's `account` table equivalent
2153// ---------------------------------------------------------------------------
2154
2155/// A persisted account link. Schema-aligned with better-auth's `account`
2156/// table (verified against https://www.better-auth.com/docs/concepts/database
2157/// at the time of writing) so users migrating from better-auth see the
2158/// same field names + meanings:
2159///
2160/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
2161///   `"credential"` once email/password auth lands). Matches
2162///   better-auth's `providerId`.
2163/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
2164///   GitHub numeric `id`, or for email/password the user's own id).
2165///   Matches better-auth's `accountId`. NOT the row PK.
2166/// - `id` — the row PK, generated. Lets the row be referenced
2167///   independently of the (provider_id, account_id) natural key.
2168/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
2169///   rows; `None` for OAuth links. Reserves the column so adding
2170///   email/password auth doesn't need a schema migration.
2171///
2172/// Account vs. user: a single User row can have many Account rows
2173/// (Google + GitHub + a password — all linked to one pylon user).
2174/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
2175/// a user changing their Google address keeps the same pylon account.
2176#[derive(Debug, Clone, PartialEq, Eq)]
2177pub struct Account {
2178    pub id: String,
2179    pub user_id: String,
2180    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
2181    /// (better-auth: `providerId`)
2182    pub provider_id: String,
2183    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
2184    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
2185    pub account_id: String,
2186    pub access_token: Option<String>,
2187    pub refresh_token: Option<String>,
2188    pub id_token: Option<String>,
2189    /// Unix epoch seconds at which `access_token` expires. `None` for
2190    /// non-expiring tokens (GitHub Classic apps) or for password rows.
2191    pub access_token_expires_at: Option<u64>,
2192    /// Unix epoch seconds at which `refresh_token` expires. `None` when
2193    /// the provider doesn't expire refresh tokens (most don't, but
2194    /// Microsoft Identity Platform does after 90 days of inactivity).
2195    pub refresh_token_expires_at: Option<u64>,
2196    pub scope: Option<String>,
2197    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
2198    /// Always `None` today — present so adding password auth later
2199    /// doesn't require a schema migration.
2200    pub password: Option<String>,
2201    /// Unix epoch seconds when this account was first linked.
2202    pub created_at: u64,
2203    /// Unix epoch seconds when the token bundle was last refreshed.
2204    pub updated_at: u64,
2205}
2206
2207impl Account {
2208    /// Build a new account link from a freshly-completed OAuth handshake.
2209    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
2210    /// what later lookups key on.
2211    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2212        let now = now_secs();
2213        Self {
2214            id: generate_token(),
2215            user_id,
2216            provider_id: info.provider.clone(),
2217            account_id: info.provider_account_id.clone(),
2218            access_token: Some(tokens.access_token.clone()),
2219            refresh_token: tokens.refresh_token.clone(),
2220            id_token: tokens.id_token.clone(),
2221            access_token_expires_at: tokens.expires_at,
2222            refresh_token_expires_at: None,
2223            scope: tokens.scope.clone(),
2224            password: None,
2225            created_at: now,
2226            updated_at: now,
2227        }
2228    }
2229
2230    /// True if `access_token_expires_at` is set and has passed.
2231    /// Non-expiring tokens (GitHub Classic) report `false` — caller
2232    /// should treat them as "valid until proven otherwise" and refresh
2233    /// on 401.
2234    pub fn access_token_expired(&self) -> bool {
2235        match self.access_token_expires_at {
2236            Some(ts) => now_secs() >= ts,
2237            None => false,
2238        }
2239    }
2240}
2241
2242/// Pluggable storage for account links. In-memory default ships with
2243/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
2244pub trait AccountBackend: Send + Sync {
2245    /// Insert or refresh an account link. The `(provider_id, account_id)`
2246    /// pair is the natural key — repeated calls for the same pair
2247    /// update the token bundle and `updated_at` on the existing row.
2248    fn upsert(&self, account: &Account);
2249    /// Find an account by provider identity. Returns `None` if the user
2250    /// hasn't linked this provider yet.
2251    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2252    /// Every account linked to a user. The `/api/auth/me` endpoint uses
2253    /// this to render "you're connected via Google + GitHub" in the UI
2254    /// and to gate "unlink" affordances behind "user has another way to
2255    /// sign in" checks.
2256    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2257    /// Remove a single provider link. Returns `true` if a row was removed.
2258    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2259    /// Remove every account link for a user. Used during account
2260    /// deletion to ensure no OAuth references survive past a user row
2261    /// delete. Default implementation walks `find_for_user` + `unlink`;
2262    /// SQL backends can override with a single DELETE.
2263    fn delete_for_user(&self, user_id: &str) -> usize {
2264        let accounts = self.find_for_user(user_id);
2265        let n = accounts.len();
2266        for a in accounts {
2267            self.unlink(&a.provider_id, &a.account_id);
2268        }
2269        n
2270    }
2271    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
2272    /// to power the Studio admin inspector. Backends that can stream
2273    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
2274    /// returns its full map.
2275    fn list_all(&self) -> Vec<Account>;
2276}
2277
2278/// In-memory account backend (default). Lost on restart — production
2279/// deployments should swap in a persistent backend so refresh tokens
2280/// survive a redeploy.
2281pub struct InMemoryAccountBackend {
2282    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
2283    /// user_id would speed up `find_for_user` but at framework scale
2284    /// the linear scan of (typically ≤ 5) accounts per user is fine.
2285    accounts: Mutex<HashMap<(String, String), Account>>,
2286}
2287
2288impl InMemoryAccountBackend {
2289    pub fn new() -> Self {
2290        Self {
2291            accounts: Mutex::new(HashMap::new()),
2292        }
2293    }
2294}
2295
2296impl Default for InMemoryAccountBackend {
2297    fn default() -> Self {
2298        Self::new()
2299    }
2300}
2301
2302impl AccountBackend for InMemoryAccountBackend {
2303    fn upsert(&self, account: &Account) {
2304        let key = (account.provider_id.clone(), account.account_id.clone());
2305        self.accounts.lock().unwrap().insert(key, account.clone());
2306    }
2307    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2308        self.accounts
2309            .lock()
2310            .unwrap()
2311            .get(&(provider_id.to_string(), account_id.to_string()))
2312            .cloned()
2313    }
2314    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2315        self.accounts
2316            .lock()
2317            .unwrap()
2318            .values()
2319            .filter(|a| a.user_id == user_id)
2320            .cloned()
2321            .collect()
2322    }
2323    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2324        self.accounts
2325            .lock()
2326            .unwrap()
2327            .remove(&(provider_id.to_string(), account_id.to_string()))
2328            .is_some()
2329    }
2330    fn list_all(&self) -> Vec<Account> {
2331        self.accounts.lock().unwrap().values().cloned().collect()
2332    }
2333}
2334
2335/// Account store. Wraps an `AccountBackend` and provides the methods the
2336/// OAuth callback / API endpoints actually call.
2337pub struct AccountStore {
2338    backend: Box<dyn AccountBackend>,
2339}
2340
2341impl Default for AccountStore {
2342    fn default() -> Self {
2343        Self::new()
2344    }
2345}
2346
2347impl AccountStore {
2348    pub fn new() -> Self {
2349        Self {
2350            backend: Box::new(InMemoryAccountBackend::new()),
2351        }
2352    }
2353    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2354        Self { backend }
2355    }
2356    pub fn upsert(&self, account: &Account) {
2357        self.backend.upsert(account);
2358    }
2359    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2360        self.backend.find_by_provider(provider_id, account_id)
2361    }
2362    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2363        self.backend.find_for_user(user_id)
2364    }
2365    pub fn delete_for_user(&self, user_id: &str) -> usize {
2366        self.backend.delete_for_user(user_id)
2367    }
2368
2369    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2370        self.backend.unlink(provider_id, account_id)
2371    }
2372
2373    /// Every account in the store. Powers the Studio "Auth tables"
2374    /// view; not for app use. Implemented by walking the backend's
2375    /// per-user index — doable because account counts per user are
2376    /// small (typically ≤ 5) and total account count tracks user
2377    /// count.
2378    ///
2379    /// We don't add a `list_all` method to the `AccountBackend` trait
2380    /// because the in-memory + sqlite + postgres impls would each
2381    /// need a separate implementation, and the operational use case
2382    /// (Studio inspector) is narrow enough to live behind a wrapper
2383    /// that walks the underlying store directly. For PG/SQLite that
2384    /// means a `SELECT * FROM _pylon_accounts` — which the backends
2385    /// can grow if we ever need this at scale.
2386    pub fn list_all_unfiltered(&self) -> Vec<Account> {
2387        self.backend.list_all()
2388    }
2389}
2390
2391// ---------------------------------------------------------------------------
2392// Tests
2393// ---------------------------------------------------------------------------
2394
2395#[cfg(test)]
2396mod tests {
2397    use super::*;
2398
2399    #[test]
2400    fn anonymous_context() {
2401        let ctx = AuthContext::anonymous();
2402        assert!(!ctx.is_authenticated());
2403        assert!(ctx.user_id.is_none());
2404    }
2405
2406    #[test]
2407    fn authenticated_context() {
2408        let ctx = AuthContext::authenticated("user-1".into());
2409        assert!(ctx.is_authenticated());
2410        assert_eq!(ctx.user_id, Some("user-1".into()));
2411    }
2412
2413    #[test]
2414    fn from_api_key_carries_scope_metadata() {
2415        let ctx = AuthContext::from_api_key(
2416            "user-1".into(),
2417            "key_abc".into(),
2418            Some("read,write".into()),
2419        );
2420        assert!(ctx.is_authenticated());
2421        assert!(ctx.is_api_key_auth());
2422        assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2423        assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2424        assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2425    }
2426
2427    #[test]
2428    fn session_auth_is_not_api_key_auth() {
2429        let ctx = AuthContext::authenticated("user-1".into());
2430        assert!(!ctx.is_api_key_auth());
2431        assert!(ctx.api_key_id.is_none());
2432    }
2433
2434    #[test]
2435    fn auth_mode_public_allows_anonymous() {
2436        let mode = AuthMode::Public;
2437        assert!(mode.check(&AuthContext::anonymous()));
2438        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2439    }
2440
2441    #[test]
2442    fn auth_mode_user_requires_authenticated() {
2443        let mode = AuthMode::User;
2444        assert!(!mode.check(&AuthContext::anonymous()));
2445        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2446    }
2447
2448    #[test]
2449    fn auth_mode_from_str() {
2450        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2451        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2452        assert_eq!(AuthMode::from_str("admin"), None);
2453    }
2454
2455    #[test]
2456    fn session_store_create_and_get() {
2457        let store = SessionStore::new();
2458        let session = store.create("user-1".into());
2459        assert!(!session.token.is_empty());
2460        assert!(session.token.starts_with("pylon_"));
2461
2462        let retrieved = store.get(&session.token).unwrap();
2463        assert_eq!(retrieved.user_id, "user-1");
2464    }
2465
2466    #[test]
2467    fn session_store_resolve() {
2468        let store = SessionStore::new();
2469        let session = store.create("user-1".into());
2470
2471        let ctx = store.resolve(Some(&session.token));
2472        assert!(ctx.is_authenticated());
2473        assert_eq!(ctx.user_id, Some("user-1".into()));
2474
2475        let anon = store.resolve(None);
2476        assert!(!anon.is_authenticated());
2477
2478        let bad = store.resolve(Some("invalid-token"));
2479        assert!(!bad.is_authenticated());
2480    }
2481
2482    #[test]
2483    fn session_store_revoke() {
2484        let store = SessionStore::new();
2485        let session = store.create("user-1".into());
2486
2487        assert!(store.revoke(&session.token));
2488        assert!(store.get(&session.token).is_none());
2489        assert!(!store.revoke(&session.token)); // already revoked
2490    }
2491
2492    #[test]
2493    fn session_to_auth_context() {
2494        let session = Session::new("user-42".into());
2495        let ctx = session.to_auth_context();
2496        assert_eq!(ctx.user_id, Some("user-42".into()));
2497    }
2498
2499    // -- Admin context --
2500
2501    #[test]
2502    fn admin_context() {
2503        let ctx = AuthContext::admin();
2504        assert!(ctx.is_admin);
2505        assert!(ctx.is_authenticated());
2506    }
2507
2508    #[test]
2509    fn anonymous_not_admin() {
2510        let ctx = AuthContext::anonymous();
2511        assert!(!ctx.is_admin);
2512    }
2513
2514    #[test]
2515    fn authenticated_not_admin() {
2516        let ctx = AuthContext::authenticated("user-1".into());
2517        assert!(!ctx.is_admin);
2518    }
2519
2520    // -- Magic codes --
2521
2522    #[test]
2523    fn magic_code_create_and_verify() {
2524        let store = MagicCodeStore::new();
2525        let code = store.create("test@example.com");
2526        assert_eq!(code.len(), 6);
2527        assert!(store.verify("test@example.com", &code));
2528    }
2529
2530    #[test]
2531    fn magic_code_wrong_code_rejected() {
2532        let store = MagicCodeStore::new();
2533        store.create("test@example.com");
2534        assert!(!store.verify("test@example.com", "000000"));
2535    }
2536
2537    #[test]
2538    fn magic_code_wrong_email_rejected() {
2539        let store = MagicCodeStore::new();
2540        let code = store.create("test@example.com");
2541        assert!(!store.verify("other@example.com", &code));
2542    }
2543
2544    #[test]
2545    fn magic_code_consumed_after_verify() {
2546        let store = MagicCodeStore::new();
2547        let code = store.create("test@example.com");
2548        assert!(store.verify("test@example.com", &code));
2549        // Second verify should fail — code consumed.
2550        assert!(!store.verify("test@example.com", &code));
2551    }
2552
2553    #[test]
2554    fn magic_code_different_emails_independent() {
2555        let store = MagicCodeStore::new();
2556        let code1 = store.create("alice@example.com");
2557        let code2 = store.create("bob@example.com");
2558        // Each email has its own code.
2559        assert!(store.verify("alice@example.com", &code1));
2560        assert!(store.verify("bob@example.com", &code2));
2561    }
2562
2563    // -- Constant-time comparison --
2564
2565    #[test]
2566    fn constant_time_eq_equal() {
2567        assert!(constant_time_eq(b"hello", b"hello"));
2568        assert!(constant_time_eq(b"", b""));
2569    }
2570
2571    #[test]
2572    fn constant_time_eq_not_equal() {
2573        assert!(!constant_time_eq(b"hello", b"world"));
2574        assert!(!constant_time_eq(b"hello", b"hell"));
2575        assert!(!constant_time_eq(b"a", b"b"));
2576    }
2577
2578    // -- Token generation --
2579
2580    #[test]
2581    fn generated_tokens_are_unique() {
2582        let t1 = generate_token();
2583        let t2 = generate_token();
2584        assert_ne!(t1, t2);
2585        assert!(t1.starts_with("pylon_"));
2586        assert!(t2.starts_with("pylon_"));
2587        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
2588        assert_eq!(t1.len(), 6 + 64);
2589    }
2590
2591    // -- OAuth registry --
2592
2593    #[test]
2594    fn oauth_registry_empty() {
2595        let reg = OAuthRegistry::new();
2596        assert!(reg.get("google").is_none());
2597    }
2598
2599    #[test]
2600    fn oauth_registry_register_and_get() {
2601        let mut reg = OAuthRegistry::new();
2602        reg.register(OAuthConfig {
2603            provider: "google".into(),
2604            client_id: "test-id".into(),
2605            client_secret: "test-secret".into(),
2606            redirect_uri: "http://localhost/callback".into(),
2607            ..Default::default()
2608        });
2609        let config = reg.get("google").unwrap();
2610        assert_eq!(config.client_id, "test-id");
2611        assert!(config.auth_url().contains("accounts.google.com"));
2612    }
2613
2614    // -- Spec-driven provider routing --
2615
2616    /// Every builtin provider must produce a non-empty auth_url +
2617    /// token_url when wired with placeholder credentials. This is the
2618    /// regression test for the table-driven refactor: a typo in any
2619    /// `ProviderSpec` field that breaks URL formatting will trip here
2620    /// before it reaches a user.
2621    #[test]
2622    fn every_builtin_provider_routes_through_oauth_config() {
2623        for spec in provider::builtin::all() {
2624            let cfg = OAuthConfig {
2625                provider: spec.id.into(),
2626                client_id: "cid".into(),
2627                client_secret: "csecret".into(),
2628                redirect_uri: "https://app/cb".into(),
2629                tenant: if spec.id == "microsoft" {
2630                    Some("contoso".into())
2631                } else {
2632                    None
2633                },
2634                apple: if spec.id == "apple" {
2635                    Some(provider::AppleConfig {
2636                        team_id: "T".into(),
2637                        key_id: "K".into(),
2638                        private_key_pem: "no".into(),
2639                    })
2640                } else {
2641                    None
2642                },
2643                ..Default::default()
2644            };
2645            let auth = cfg.auth_url();
2646            assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2647            // TikTok uses `client_key`; everyone else uses `client_id`.
2648            let expected_param = format!("{}=cid", spec.client_id_param);
2649            assert!(
2650                auth.contains(&expected_param),
2651                "{}: missing {}; got auth_url: {}",
2652                spec.id,
2653                expected_param,
2654                auth,
2655            );
2656            assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2657            // Apple requires response_mode=form_post in the auth URL.
2658            if spec.id == "apple" {
2659                assert!(
2660                    auth.contains("response_mode=form_post"),
2661                    "apple auth_url must include response_mode=form_post; got {auth}"
2662                );
2663            }
2664        }
2665    }
2666
2667    /// Microsoft uses `{tenant}` placeholder substitution — the
2668    /// configured tenant must end up in both auth + token URLs.
2669    #[test]
2670    fn microsoft_tenant_placeholder_resolves() {
2671        let cfg = OAuthConfig {
2672            provider: "microsoft".into(),
2673            client_id: "id".into(),
2674            client_secret: "secret".into(),
2675            redirect_uri: "https://app/cb".into(),
2676            tenant: Some("contoso.onmicrosoft.com".into()),
2677            ..Default::default()
2678        };
2679        assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2680        assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2681    }
2682
2683    /// Microsoft without a tenant defaults to `common` (any account).
2684    #[test]
2685    fn microsoft_default_tenant_common() {
2686        let cfg = OAuthConfig {
2687            provider: "microsoft".into(),
2688            client_id: "id".into(),
2689            client_secret: "secret".into(),
2690            redirect_uri: "https://app/cb".into(),
2691            ..Default::default()
2692        };
2693        assert!(cfg.auth_url().contains("/common/"));
2694        assert!(cfg.token_url().contains("/common/"));
2695    }
2696
2697    /// `scopes_override` replaces the spec default — used for GitHub
2698    /// `repo` scope or Google calendar scopes.
2699    #[test]
2700    fn scopes_override_replaces_spec_default() {
2701        let cfg = OAuthConfig {
2702            provider: "github".into(),
2703            client_id: "id".into(),
2704            client_secret: "secret".into(),
2705            redirect_uri: "https://app/cb".into(),
2706            scopes_override: Some("repo user:email".into()),
2707            ..Default::default()
2708        };
2709        let auth = cfg.auth_url();
2710        // url-encoded "repo user:email" → "repo%20user%3Aemail"
2711        assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2712    }
2713
2714    /// Apple's `client_secret` is minted as a JWT — passing a bad PEM
2715    /// must surface the signing error, not silently send the literal
2716    /// string. The mint path is tested in `apple_jwt::tests`; this
2717    /// asserts the wiring delegates to it.
2718    #[test]
2719    fn apple_exchange_requires_apple_config() {
2720        let cfg = OAuthConfig {
2721            provider: "apple".into(),
2722            client_id: "com.example.app".into(),
2723            client_secret: String::new(),
2724            redirect_uri: "https://app/cb".into(),
2725            apple: None, // missing!
2726            ..Default::default()
2727        };
2728        let err = cfg.exchange_code_full("x").unwrap_err();
2729        assert!(err.contains("apple provider requires"), "got: {err}");
2730    }
2731
2732    /// OIDC discovery cache: priming with a synthetic spec lets us
2733    /// route an issuer-configured provider without touching the
2734    /// network. Validates that `oidc_issuer` short-circuits the
2735    /// builtin lookup.
2736    #[test]
2737    fn oidc_issuer_uses_discovered_endpoints() {
2738        let issuer = "https://acme.test.invalid";
2739        provider::oidc_cache::insert_for_test(
2740            issuer,
2741            provider::DiscoveredSpec {
2742                auth_url: "https://acme.test.invalid/authorize".into(),
2743                token_url: "https://acme.test.invalid/oauth/token".into(),
2744                userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2745                scopes: "openid email profile".into(),
2746                userinfo_parser: provider::UserinfoParser::Oidc,
2747                token_exchange: provider::TokenExchangeShape::Standard,
2748            },
2749        );
2750        let cfg = OAuthConfig {
2751            provider: "auth0".into(), // not a builtin id
2752            client_id: "id".into(),
2753            client_secret: "secret".into(),
2754            redirect_uri: "https://app/cb".into(),
2755            oidc_issuer: Some(issuer.into()),
2756            ..Default::default()
2757        };
2758        assert!(cfg.auth_url().starts_with("https://acme.test.invalid/authorize?"));
2759        assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2760        assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2761    }
2762
2763    // -- Codex review regression tests (P1/P2 from Wave 1 review) --
2764
2765    /// P1: Apple's auth URL MUST include response_mode=form_post when
2766    /// requesting name/email scopes, otherwise Apple rejects with
2767    /// "invalid_request".
2768    #[test]
2769    fn apple_auth_url_includes_form_post() {
2770        let cfg = OAuthConfig {
2771            provider: "apple".into(),
2772            client_id: "com.example.app".into(),
2773            client_secret: String::new(),
2774            redirect_uri: "https://app/cb".into(),
2775            apple: Some(provider::AppleConfig {
2776                team_id: "T".into(),
2777                key_id: "K".into(),
2778                private_key_pem: "no".into(),
2779            }),
2780            ..Default::default()
2781        };
2782        let auth = cfg.auth_url();
2783        assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2784        // Apple identity comes from id_token, so userinfo_url is empty.
2785        assert_eq!(cfg.userinfo_url(), "");
2786    }
2787
2788    /// P1: Apple identity is extracted from the id_token JWT
2789    /// (Apple has no userinfo endpoint). `fetch_userinfo_with_id_token`
2790    /// must decode the claims; `fetch_userinfo_full` (no id_token)
2791    /// must surface a clear error.
2792    #[test]
2793    fn apple_id_token_decode_extracts_identity() {
2794        // Synthesize an unsigned JWT with realistic Apple claims.
2795        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2796        use base64::Engine;
2797        let claims = serde_json::json!({
2798            "iss": "https://appleid.apple.com",
2799            "sub": "001234.abc.def",
2800            "aud": "com.example.app",
2801            "email": "user@privaterelay.appleid.com",
2802            "email_verified": "true",
2803        });
2804        let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
2805            .encode(claims.to_string().as_bytes());
2806        let id_token = format!("{header}.{claims_b64}.signature_ignored");
2807
2808        let cfg = OAuthConfig {
2809            provider: "apple".into(),
2810            client_id: "com.example.app".into(),
2811            client_secret: String::new(),
2812            redirect_uri: "https://app/cb".into(),
2813            apple: Some(provider::AppleConfig {
2814                team_id: "T".into(),
2815                key_id: "K".into(),
2816                private_key_pem: "no".into(),
2817            }),
2818            ..Default::default()
2819        };
2820        let info = cfg
2821            .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2822            .expect("apple id_token decode");
2823        assert_eq!(info.provider_account_id, "001234.abc.def");
2824        assert_eq!(info.email, "user@privaterelay.appleid.com");
2825
2826        // Without an id_token the call must fail loud, not silently
2827        // try to hit a non-existent userinfo endpoint.
2828        let err = cfg.fetch_userinfo_full("token").unwrap_err();
2829        assert!(err.contains("apple login requires"), "got: {err}");
2830    }
2831
2832    /// P1: Twitter/X requires PKCE — `auth_url_with_pkce` must mint a
2833    /// verifier, embed the SHA-256 challenge in the auth URL, and
2834    /// return the verifier for the callback to replay.
2835    #[test]
2836    fn twitter_auth_url_includes_pkce() {
2837        let cfg = OAuthConfig {
2838            provider: "twitter".into(),
2839            client_id: "tw_client".into(),
2840            client_secret: "tw_secret".into(),
2841            redirect_uri: "https://app/cb".into(),
2842            ..Default::default()
2843        };
2844        let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2845        let v = verifier.expect("twitter must produce verifier");
2846        assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2847        assert!(url.contains("code_challenge="), "got: {url}");
2848        assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2849
2850        // Non-PKCE provider must NOT add a code_challenge.
2851        let google = OAuthConfig {
2852            provider: "google".into(),
2853            client_id: "g".into(),
2854            client_secret: "g".into(),
2855            redirect_uri: "https://app/cb".into(),
2856            ..Default::default()
2857        };
2858        let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2859        assert!(gverifier.is_none(), "google should not add PKCE");
2860        assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2861    }
2862
2863    /// P2: TikTok uses `client_key` (not `client_id`) and joins
2864    /// scopes with commas (not spaces).
2865    #[test]
2866    fn tiktok_uses_client_key_and_comma_scopes() {
2867        let cfg = OAuthConfig {
2868            provider: "tiktok".into(),
2869            client_id: "tk_client".into(),
2870            client_secret: "tk_secret".into(),
2871            redirect_uri: "https://app/cb".into(),
2872            scopes_override: Some("user.info.basic video.list".into()),
2873            ..Default::default()
2874        };
2875        let auth = cfg.auth_url();
2876        assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2877        // Comma-separated, url-encoded → "user.info.basic%2Cvideo.list"
2878        assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2879        // Should NOT use the standard space separator.
2880        assert!(!auth.contains("user.info.basic%20video.list"), "got: {auth}");
2881    }
2882
2883    /// P2: `code` MUST be url-encoded in the token-exchange body.
2884    /// Auth codes can contain reserved characters (`+`, `=`, `/`) that
2885    /// would otherwise corrupt the form body.
2886    #[test]
2887    fn token_exchange_url_encodes_code() {
2888        // We can't hit the network in a unit test, so this asserts
2889        // via the `apple_exchange_requires_apple_config` shape — if
2890        // we DID have a working apple config, encoding would happen
2891        // before the network call. Instead, verify by calling the
2892        // helper used internally:
2893        let raw = "code+with/special=chars";
2894        let encoded = url_encode(raw);
2895        assert!(!encoded.contains('+'));
2896        assert!(!encoded.contains('/'));
2897        assert!(!encoded.contains('='));
2898        assert!(encoded.contains("%2B"));
2899        assert!(encoded.contains("%2F"));
2900        assert!(encoded.contains("%3D"));
2901    }
2902
2903    /// P1: Token-endpoint error bodies must NOT propagate
2904    /// `client_secret`, `code_verifier`, or other sensitive form
2905    /// fields that providers sometimes echo back on auth failure.
2906    #[test]
2907    fn sanitize_token_error_redacts_secrets() {
2908        let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2909        let scrubbed = sanitize_token_error(raw.into());
2910        assert!(!scrubbed.contains("sk_real_secret_value"));
2911        assert!(!scrubbed.contains("verifierxyz"));
2912        assert!(scrubbed.contains("client_secret=***"));
2913        assert!(scrubbed.contains("code_verifier=***"));
2914        // Non-sensitive context preserved.
2915        assert!(scrubbed.contains("invalid_grant"));
2916        assert!(scrubbed.contains("hint=check%20your%20code"));
2917    }
2918
2919    /// P1 (codex round-2): JSON-shaped error bodies (Notion,
2920    /// Atlassian) must also have their secret fields redacted.
2921    #[test]
2922    fn sanitize_token_error_redacts_json_secrets() {
2923        let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2924        let scrubbed = sanitize_token_error(raw.into());
2925        assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2926        assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2927        assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2928        assert!(scrubbed.contains(r#""client_secret":"***""#), "got: {scrubbed}");
2929        assert!(scrubbed.contains(r#""refresh_token":"***""#), "got: {scrubbed}");
2930        assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2931        assert!(scrubbed.contains("invalid_grant"));
2932    }
2933
2934    /// P2 (codex round-2): redact_param_form must NOT panic on
2935    /// multibyte chars before the sensitive key. Earlier byte-index
2936    /// implementation hit `panicked at byte index N is not a char
2937    /// boundary` on bodies with emoji or non-ASCII text.
2938    #[test]
2939    fn sanitize_token_error_handles_utf8() {
2940        let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2941        let scrubbed = sanitize_token_error(raw.into());
2942        assert!(scrubbed.contains("⚠️"), "non-ASCII chars must survive: {scrubbed}");
2943        assert!(!scrubbed.contains("sk_x"));
2944        assert!(scrubbed.contains("client_secret=***"));
2945    }
2946
2947    /// P2: OIDC discovery must respect
2948    /// `token_endpoint_auth_methods_supported`. When the IdP
2949    /// publishes `client_secret_post`, use Standard form bodies.
2950    /// When omitted (the spec default), use BasicAuth.
2951    #[test]
2952    fn oidc_discovery_picks_token_auth_method() {
2953        let json_post = r#"{
2954            "issuer": "https://acme.test/",
2955            "authorization_endpoint": "https://acme.test/auth",
2956            "token_endpoint": "https://acme.test/token",
2957            "token_endpoint_auth_methods_supported": ["client_secret_post"]
2958        }"#;
2959        let spec = provider::OidcDiscoveryDoc::parse(json_post).unwrap().into_spec();
2960        assert!(matches!(
2961            spec.token_exchange,
2962            provider::TokenExchangeShape::Standard
2963        ));
2964
2965        // Default (omitted) → BasicAuth.
2966        let json_default = r#"{
2967            "issuer": "https://acme.test/",
2968            "authorization_endpoint": "https://acme.test/auth",
2969            "token_endpoint": "https://acme.test/token"
2970        }"#;
2971        let spec = provider::OidcDiscoveryDoc::parse(json_default)
2972            .unwrap()
2973            .into_spec();
2974        assert!(matches!(
2975            spec.token_exchange,
2976            provider::TokenExchangeShape::BasicAuth
2977        ));
2978    }
2979
2980    /// P2: OIDC discovery missing required endpoints must fail loud,
2981    /// not silently produce empty URLs that would 404 every login.
2982    #[test]
2983    fn oidc_discovery_rejects_incomplete_doc() {
2984        // Missing token_endpoint.
2985        let json = r#"{
2986            "issuer": "https://acme.test/",
2987            "authorization_endpoint": "https://acme.test/auth"
2988        }"#;
2989        let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
2990        assert!(err.contains("token_endpoint"), "got: {err}");
2991    }
2992
2993    /// `OAuthRegistry::from_env` must auto-discover every provider
2994    /// whose env vars are set — not just google/github. Smoke-test
2995    /// with Discord since it covers the simple-builtin path.
2996    #[test]
2997    fn from_env_picks_up_discord() {
2998        // Use a unique prefix so this doesn't collide with a real
2999        // dev environment variable. Set+restore in scope.
3000        let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3001        let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3002        // SAFETY: tests run single-threaded for env mutation isn't
3003        // strictly true, but this provider is unique enough that
3004        // contention is unlikely. Cleanup happens at end.
3005        std::env::set_var(key_id, "discord-test-id");
3006        std::env::set_var(key_secret, "discord-test-secret");
3007
3008        let reg = OAuthRegistry::from_env();
3009        let discord = reg.get("discord").expect("discord registered");
3010        assert_eq!(discord.client_id, "discord-test-id");
3011        assert!(discord.auth_url().contains("discord.com"));
3012
3013        std::env::remove_var(key_id);
3014        std::env::remove_var(key_secret);
3015    }
3016
3017    // -- Guest auth --
3018
3019    #[test]
3020    fn guest_session() {
3021        let store = SessionStore::new();
3022        let session = store.create_guest();
3023        assert!(session.user_id.starts_with("guest_"));
3024        assert!(!session.token.is_empty());
3025
3026        let ctx = store.resolve(Some(&session.token));
3027        assert!(ctx.is_authenticated());
3028        assert!(ctx.user_id.unwrap().starts_with("guest_"));
3029    }
3030
3031    #[test]
3032    fn upgrade_guest_to_real_user() {
3033        let store = SessionStore::new();
3034        let session = store.create_guest();
3035        assert!(session.user_id.starts_with("guest_"));
3036
3037        let upgraded = store.upgrade(&session.token, "real-user-123".into());
3038        assert!(upgraded);
3039
3040        let ctx = store.resolve(Some(&session.token));
3041        assert_eq!(ctx.user_id, Some("real-user-123".into()));
3042    }
3043
3044    #[test]
3045    fn upgrade_invalid_token_fails() {
3046        let store = SessionStore::new();
3047        let upgraded = store.upgrade("nonexistent-token", "user".into());
3048        assert!(!upgraded);
3049    }
3050
3051    #[test]
3052    fn guest_context() {
3053        let ctx = AuthContext::guest("guest_123".into());
3054        // Guests carry a stable id but are NOT authenticated — routes
3055        // guarded by AuthMode::User must reject them.
3056        assert!(!ctx.is_authenticated());
3057        assert!(ctx.is_guest);
3058        assert!(!ctx.is_admin);
3059        assert_eq!(ctx.user_id, Some("guest_123".into()));
3060        assert!(!AuthMode::User.check(&ctx));
3061        assert!(AuthMode::Public.check(&ctx));
3062    }
3063
3064    #[test]
3065    fn oauth_token_urls() {
3066        let google = OAuthConfig {
3067            provider: "google".into(),
3068            client_id: "x".into(),
3069            client_secret: "x".into(),
3070            redirect_uri: "x".into(),
3071            ..Default::default()
3072        };
3073        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3074        let github = OAuthConfig {
3075            provider: "github".into(),
3076            client_id: "x".into(),
3077            client_secret: "x".into(),
3078            redirect_uri: "x".into(),
3079            ..Default::default()
3080        };
3081        assert_eq!(
3082            github.token_url(),
3083            "https://github.com/login/oauth/access_token"
3084        );
3085        let unknown = OAuthConfig {
3086            provider: "unknown".into(),
3087            client_id: "x".into(),
3088            client_secret: "x".into(),
3089            redirect_uri: "x".into(),
3090            ..Default::default()
3091        };
3092        assert_eq!(unknown.token_url(), "");
3093        assert!(unknown.auth_url().is_empty());
3094    }
3095
3096    #[test]
3097    fn oauth_auth_url_github() {
3098        let config = OAuthConfig {
3099            provider: "github".into(),
3100            client_id: "gh-id".into(),
3101            client_secret: "gh-secret".into(),
3102            redirect_uri: "http://localhost/cb".into(),
3103            ..Default::default()
3104        };
3105        assert!(config.auth_url().contains("github.com"));
3106        assert!(config.auth_url().contains("gh-id"));
3107    }
3108
3109    #[test]
3110    fn oauth_auth_url_with_state() {
3111        let config = OAuthConfig {
3112            provider: "google".into(),
3113            client_id: "test-id".into(),
3114            client_secret: "test-secret".into(),
3115            redirect_uri: "http://localhost/cb".into(),
3116            ..Default::default()
3117        };
3118        let url = config.auth_url_with_state("random_state_123");
3119        assert!(url.contains("&state=random_state_123"));
3120    }
3121
3122    #[test]
3123    fn oauth_state_store_create_and_validate() {
3124        let store = OAuthStateStore::new();
3125        let token = store.create("google", "https://app/cb", "https://app/login");
3126        let rec = store.validate(&token, "google").expect("valid first time");
3127        assert_eq!(rec.callback_url, "https://app/cb");
3128        assert_eq!(rec.error_callback_url, "https://app/login");
3129        // Second validation should fail — single-use.
3130        assert!(store.validate(&token, "google").is_none());
3131    }
3132
3133    #[test]
3134    fn oauth_state_store_wrong_provider_rejected() {
3135        let store = OAuthStateStore::new();
3136        let token = store.create("google", "https://app/cb", "https://app/cb");
3137        assert!(store.validate(&token, "github").is_none());
3138    }
3139
3140    #[test]
3141    fn oauth_state_store_invalid_state_rejected() {
3142        let store = OAuthStateStore::new();
3143        assert!(store.validate("nonexistent", "google").is_none());
3144    }
3145
3146    #[test]
3147    fn validate_trusted_redirect_basics() {
3148        let trusted = vec!["http://localhost:3000".to_string()];
3149        assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3150        assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3151        assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3152
3153        // Wrong port → wrong origin.
3154        assert!(matches!(
3155            validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3156            Err(TrustedOriginError::NotTrusted { .. })
3157        ));
3158        // Non-http scheme rejected even before trusted check (defense
3159        // against javascript:, file:, data:).
3160        assert!(matches!(
3161            validate_trusted_redirect("javascript:alert(1)", &trusted),
3162            Err(TrustedOriginError::NotHttp)
3163        ));
3164        assert!(matches!(
3165            validate_trusted_redirect("", &trusted),
3166            Err(TrustedOriginError::Empty)
3167        ));
3168    }
3169}