Skip to main content

pylon_auth/
lib.rs

1pub mod api_key;
2pub mod apple_jwt;
3pub mod audit;
4pub mod captcha;
5pub mod cookie;
6pub mod device;
7pub mod email;
8pub mod email_templates;
9pub mod jwt;
10pub mod oidc_provider;
11pub mod org;
12pub mod password;
13pub mod phone;
14pub mod provider;
15pub mod rate_limit;
16pub mod scim;
17pub mod siwe;
18pub mod stripe;
19pub mod totp;
20pub mod verification;
21pub mod webauthn;
22
23pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
24
25use serde::{Deserialize, Serialize};
26
27// ---------------------------------------------------------------------------
28// Auth context — the identity available to runtime operations
29// ---------------------------------------------------------------------------
30
31/// The auth context for a request. Represents who is making the request.
32///
33/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
34/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
35/// or add roles and bypass every policy. Identity must come from
36/// server-minted sessions (`Session::to_auth_context`) or explicit
37/// constructors, never from deserialization.
38///
39/// `Serialize` is safe because sending the resolved context BACK to the
40/// client exposes nothing the server didn't already decide.
41#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
42pub struct AuthContext {
43    /// The authenticated user ID, or None for public/anonymous access.
44    /// For guest contexts this is `Some(guest_id)` — a stable
45    /// anonymous identifier, NOT a real user.
46    pub user_id: Option<String>,
47    /// Whether this is an admin context (bypasses policies).
48    pub is_admin: bool,
49    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
50    /// for cart state and similar pre-login persistence. Routes guarded
51    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
52    /// "real signed-in user" should pass auth-required gates.
53    #[serde(default, skip_serializing_if = "is_false")]
54    pub is_guest: bool,
55    /// Roles granted to this user. Empty for anonymous.
56    pub roles: Vec<String>,
57    /// Active tenant id (for multi-tenant apps). Set when the user has
58    /// selected an organization for the current session.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub tenant_id: Option<String>,
61    /// API key id when the request was authenticated via a `pk.…`
62    /// bearer token. Set so policies + management endpoints can
63    /// distinguish "user-via-session" from "user-via-key" — e.g.
64    /// password change is forbidden via API key.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub api_key_id: Option<String>,
67    /// Comma-separated scope string from the API key. Application
68    /// policies decide what scopes mean — pylon only carries them.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub api_key_scopes: Option<String>,
71}
72
73fn is_false(b: &bool) -> bool {
74    !b
75}
76
77impl AuthContext {
78    /// Create an anonymous/public auth context.
79    pub fn anonymous() -> Self {
80        Self {
81            user_id: None,
82            is_admin: false,
83            is_guest: false,
84            roles: Vec::new(),
85            tenant_id: None,
86            api_key_id: None,
87            api_key_scopes: None,
88        }
89    }
90
91    /// Create an authenticated auth context.
92    pub fn authenticated(user_id: String) -> Self {
93        Self {
94            user_id: Some(user_id),
95            is_admin: false,
96            is_guest: false,
97            roles: Vec::new(),
98            tenant_id: None,
99            api_key_id: None,
100            api_key_scopes: None,
101        }
102    }
103
104    /// Create an authenticated context backed by an API key. Policies +
105    /// auth-management endpoints can detect this via `is_api_key_auth()`.
106    pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
107        Self {
108            user_id: Some(user_id),
109            is_admin: false,
110            is_guest: false,
111            roles: Vec::new(),
112            tenant_id: None,
113            api_key_id: Some(key_id),
114            api_key_scopes: scopes,
115        }
116    }
117
118    /// True iff this request was authenticated by an API key (not a
119    /// session cookie / bearer session token).
120    pub fn is_api_key_auth(&self) -> bool {
121        self.api_key_id.is_some()
122    }
123
124    /// Create a guest auth context with a persistent anonymous ID.
125    /// Guests carry an opaque stable id (cart/session continuity) but
126    /// are NOT considered authenticated — `is_authenticated()` returns
127    /// false and `AuthMode::User` rejects them.
128    pub fn guest(guest_id: String) -> Self {
129        Self {
130            user_id: Some(guest_id),
131            is_admin: false,
132            is_guest: true,
133            roles: Vec::new(),
134            tenant_id: None,
135            api_key_id: None,
136            api_key_scopes: None,
137        }
138    }
139
140    /// Create an admin auth context that bypasses all policies.
141    pub fn admin() -> Self {
142        Self {
143            user_id: Some("__admin__".into()),
144            is_admin: true,
145            is_guest: false,
146            roles: vec!["admin".into()],
147            tenant_id: None,
148            api_key_id: None,
149            api_key_scopes: None,
150        }
151    }
152
153    /// Convenience: build a user context from a user id.
154    pub fn user(user_id: String) -> Self {
155        Self::authenticated(user_id)
156    }
157
158    /// Active tenant id (None when the user hasn't selected an org).
159    pub fn tenant_id(&self) -> Option<&str> {
160        self.tenant_id.as_deref()
161    }
162
163    /// Attach a tenant id to the context (chainable).
164    pub fn with_tenant(mut self, tenant_id: String) -> Self {
165        self.tenant_id = Some(tenant_id);
166        self
167    }
168
169    /// Check if this context represents an authenticated user.
170    /// Guests intentionally return `false` — they have a stable anonymous
171    /// id but never gain user-level access.
172    pub fn is_authenticated(&self) -> bool {
173        self.user_id.is_some() && !self.is_guest
174    }
175
176    /// Check if the user has a specific role. Admins have every role implicitly.
177    pub fn has_role(&self, role: &str) -> bool {
178        self.is_admin || self.roles.iter().any(|r| r == role)
179    }
180
181    /// Check if the user has ANY of the given roles.
182    pub fn has_any_role(&self, roles: &[&str]) -> bool {
183        self.is_admin || roles.iter().any(|r| self.has_role(r))
184    }
185
186    /// Attach roles to the context (chainable).
187    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
188        self.roles = roles;
189        self
190    }
191}
192
193// ---------------------------------------------------------------------------
194// Constant-time comparison
195// ---------------------------------------------------------------------------
196
197/// Constant-time byte comparison to prevent timing attacks.
198///
199/// The length check leaks whether the two slices are the same length, but the
200/// content comparison always examines every byte regardless of where (or
201/// whether) a mismatch occurs.
202pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
203    if a.len() != b.len() {
204        return false;
205    }
206    let mut result: u8 = 0;
207    for (x, y) in a.iter().zip(b.iter()) {
208        result |= x ^ y;
209    }
210    result == 0
211}
212
213// ---------------------------------------------------------------------------
214// Auth mode — matches the route "auth" field values
215// ---------------------------------------------------------------------------
216
217/// The auth mode declared on a route.
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum AuthMode {
220    /// Anyone can access.
221    Public,
222    /// Only authenticated users can access.
223    User,
224}
225
226impl AuthMode {
227    /// Parse from the manifest auth string.
228    #[allow(clippy::should_implement_trait)]
229    pub fn from_str(s: &str) -> Option<Self> {
230        match s {
231            "public" => Some(AuthMode::Public),
232            "user" => Some(AuthMode::User),
233            _ => None,
234        }
235    }
236
237    /// Check if the given auth context satisfies this mode.
238    pub fn check(&self, ctx: &AuthContext) -> bool {
239        match self {
240            AuthMode::Public => true,
241            AuthMode::User => ctx.is_authenticated(),
242        }
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Session — opaque token session
248// ---------------------------------------------------------------------------
249
250/// A session token and its associated user.
251#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct Session {
253    pub token: String,
254    pub user_id: String,
255    /// Unix epoch seconds at which this session expires. 0 = never.
256    #[serde(default)]
257    pub expires_at: u64,
258    /// Optional user-agent / device tag recorded at session creation.
259    #[serde(default, skip_serializing_if = "Option::is_none")]
260    pub device: Option<String>,
261    /// Unix epoch seconds when the session was created.
262    #[serde(default)]
263    pub created_at: u64,
264    /// Active tenant id (selected organization). Set via
265    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
266    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub tenant_id: Option<String>,
269}
270
271impl Session {
272    /// Default session lifetime: 30 days.
273    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
274
275    /// Create a new session with a generated token and default 30-day expiry.
276    pub fn new(user_id: String) -> Self {
277        let now = now_secs();
278        Self {
279            token: generate_token(),
280            user_id,
281            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
282            device: None,
283            created_at: now,
284            tenant_id: None,
285        }
286    }
287
288    /// Create a session with a specific lifetime.
289    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
290        let now = now_secs();
291        Self {
292            token: generate_token(),
293            user_id,
294            expires_at: if lifetime_secs == 0 {
295                0
296            } else {
297                now.saturating_add(lifetime_secs)
298            },
299            device: None,
300            created_at: now,
301            tenant_id: None,
302        }
303    }
304
305    /// Convert this session to an auth context, carrying the selected
306    /// tenant so row-scoped policies see `auth.tenantId`.
307    pub fn to_auth_context(&self) -> AuthContext {
308        let ctx = AuthContext::authenticated(self.user_id.clone());
309        match &self.tenant_id {
310            Some(t) => ctx.with_tenant(t.clone()),
311            None => ctx,
312        }
313    }
314
315    /// Returns true if the session has passed its expires_at time.
316    /// Boundary is inclusive (`>=`) to match the rest of the codebase
317    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
318    pub fn is_expired(&self) -> bool {
319        self.expires_at != 0 && now_secs() >= self.expires_at
320    }
321}
322
323fn now_secs() -> u64 {
324    use std::time::{SystemTime, UNIX_EPOCH};
325    SystemTime::now()
326        .duration_since(UNIX_EPOCH)
327        .unwrap_or_default()
328        .as_secs()
329}
330
331// ---------------------------------------------------------------------------
332// OAuth provider config
333// ---------------------------------------------------------------------------
334
335#[derive(Debug, Clone, Default, Serialize, Deserialize)]
336pub struct OAuthConfig {
337    pub provider: String,
338    pub client_id: String,
339    pub client_secret: String,
340    pub redirect_uri: String,
341    /// Optional scope override — replaces the spec's default scope
342    /// when set. Use cases: requesting `repo` on GitHub for app
343    /// installation flows, requesting `https://www.googleapis.com/...`
344    /// scopes on Google for app-specific data access.
345    #[serde(default, skip_serializing_if = "Option::is_none")]
346    pub scopes_override: Option<String>,
347    /// Tenant id for Microsoft/Entra. Defaults to `common`. Single-
348    /// tenant apps use a directory GUID; multi-tenant work-only apps
349    /// use `organizations`.
350    #[serde(default, skip_serializing_if = "Option::is_none")]
351    pub tenant: Option<String>,
352    /// Apple-specific extras (team id, key id, ES256 PEM). Required
353    /// for Sign in with Apple — ignored for any other provider.
354    #[serde(default, skip_serializing_if = "Option::is_none")]
355    pub apple: Option<provider::AppleConfig>,
356    /// OIDC issuer URL when this config targets a generic-OIDC
357    /// provider (Auth0, Okta, Keycloak, Cognito, etc.). When set,
358    /// the runtime fetches `<issuer>/.well-known/openid-configuration`
359    /// and synthesizes a [`provider::ProviderSpec`] from the
360    /// discovered endpoints.
361    #[serde(default, skip_serializing_if = "Option::is_none")]
362    pub oidc_issuer: Option<String>,
363}
364
365impl OAuthConfig {
366    /// Resolve the [`provider::ProviderSpec`] backing this config. For
367    /// `oidc_issuer`-configured providers, falls through to the OIDC
368    /// discovery cache. Errors propagate so misconfigured providers
369    /// fail loudly at first use rather than silently 404'ing later.
370    fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
371        if let Some(issuer) = self.oidc_issuer.as_deref() {
372            return provider::oidc_cache::resolve(issuer);
373        }
374        provider::find_spec(&self.provider)
375            .map(provider::ResolvedSpec::Static)
376            .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
377    }
378
379    /// Build a [`provider::ProviderConfig`] view of `self` for the
380    /// helpers in [`provider`] that take the runtime config.
381    fn provider_cfg(&self) -> provider::ProviderConfig {
382        provider::ProviderConfig {
383            provider: self.provider.clone(),
384            client_id: self.client_id.clone(),
385            client_secret: self.client_secret.clone(),
386            redirect_uri: self.redirect_uri.clone(),
387            scopes_override: self.scopes_override.clone(),
388            tenant: self.tenant.clone(),
389            apple: self.apple.clone(),
390            oidc_issuer: self.oidc_issuer.clone(),
391        }
392    }
393
394    /// Generate the authorization URL for the provider.
395    ///
396    /// Callers MUST append a `&state=<random>` parameter and validate it in the
397    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
398    /// implementation.
399    ///
400    /// For PKCE-required providers (Twitter/X, Kick), callers should
401    /// prefer [`Self::auth_url_with_pkce`] so the `code_challenge`
402    /// pair survives to the callback.
403    pub fn auth_url(&self) -> String {
404        match self.build_auth_url(None) {
405            Ok(u) => u,
406            Err(_) => String::new(),
407        }
408    }
409
410    /// Generate the authorization URL with a CSRF state parameter attached.
411    pub fn auth_url_with_state(&self, state: &str) -> String {
412        let base = self.auth_url();
413        if base.is_empty() {
414            return base;
415        }
416        format!("{}&state={}", base, url_encode(state))
417    }
418
419    /// Generate the authorization URL with state + a freshly minted
420    /// PKCE pair when the provider requires it. Returns
421    /// `(url, code_verifier)` — the verifier MUST be persisted in
422    /// the OAuth state record and replayed in the token exchange.
423    pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
424        let spec = self.resolved_spec()?;
425        let pkce = if spec.requires_pkce() {
426            Some(generate_pkce())
427        } else {
428            None
429        };
430        let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
431        let mut url = self.build_auth_url(challenge)?;
432        if !state.is_empty() {
433            url.push_str(&format!("&state={}", url_encode(state)));
434        }
435        Ok((url, pkce.map(|p| p.code_verifier)))
436    }
437
438    fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
439        let spec = self.resolved_spec()?;
440        let cfg = self.provider_cfg();
441        let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
442        if auth.is_empty() {
443            return Err(format!(
444                "provider {} has no authorization endpoint",
445                self.provider
446            ));
447        }
448        let scopes_default = spec.scopes().to_string();
449        let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
450        // Re-join scopes with the provider's separator (TikTok uses
451        // commas, everyone else uses spaces). Splitting on whitespace
452        // first lets developers always specify scopes the human way.
453        let scopes_joined = scopes_raw
454            .split_whitespace()
455            .collect::<Vec<_>>()
456            .join(spec.scope_separator());
457
458        let mut url = format!(
459            "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
460            cid_param = spec.client_id_param(),
461            cid = url_encode(&self.client_id),
462            ruri = url_encode(&self.redirect_uri),
463            scope = url_encode(&scopes_joined),
464        );
465        if !spec.auth_query_extra().is_empty() {
466            url.push('&');
467            url.push_str(spec.auth_query_extra());
468        }
469        if let Some(challenge) = pkce_challenge {
470            url.push_str("&code_challenge=");
471            url.push_str(challenge);
472            url.push_str("&code_challenge_method=S256");
473        }
474        Ok(url)
475    }
476
477    /// Generate the token exchange URL.
478    pub fn token_url(&self) -> String {
479        match self.resolved_spec() {
480            Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
481            Err(_) => String::new(),
482        }
483    }
484
485    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
486    pub fn userinfo_url(&self) -> String {
487        match self.resolved_spec() {
488            Ok(spec) => match spec.userinfo_url() {
489                Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
490                None => String::new(),
491            },
492            Err(_) => String::new(),
493        }
494    }
495
496    /// Exchange an authorization code for the full token set
497    /// (`access_token`, optional `refresh_token`, optional `id_token`,
498    /// `expires_in`, `scope`). When the provider uses PKCE,
499    /// `code_verifier` MUST be supplied (the value previously returned
500    /// from [`Self::auth_url_with_pkce`]).
501    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
502        self.exchange_code_full_pkce(code, None)
503    }
504
505    pub fn exchange_code_full_pkce(
506        &self,
507        code: &str,
508        code_verifier: Option<&str>,
509    ) -> Result<TokenSet, String> {
510        let spec = self.resolved_spec()?;
511        let cfg = self.provider_cfg();
512        let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
513        let pkce_field = code_verifier
514            .map(|v| format!("&code_verifier={}", url_encode(v)))
515            .unwrap_or_default();
516
517        let out = match spec.token_exchange() {
518            provider::TokenExchangeShape::Standard => {
519                let body = format!(
520                    "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
521                    code = url_encode(code),
522                    cid_param = spec.client_id_param(),
523                    cid = url_encode(&self.client_id),
524                    secret = url_encode(&self.client_secret),
525                    ruri = url_encode(&self.redirect_uri),
526                    pkce = pkce_field,
527                );
528                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
529            }
530            provider::TokenExchangeShape::AppleJwt => {
531                let apple = self.apple.as_ref().ok_or(
532                    "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
533                )?;
534                let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
535                let body = format!(
536                    "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
537                    code = url_encode(code),
538                    cid = url_encode(&self.client_id),
539                    secret = url_encode(&signed_secret),
540                    ruri = url_encode(&self.redirect_uri),
541                    pkce = pkce_field,
542                );
543                http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
544            }
545            provider::TokenExchangeShape::BasicAuth => {
546                let body = format!(
547                    "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
548                    code = url_encode(code),
549                    ruri = url_encode(&self.redirect_uri),
550                    pkce = pkce_field,
551                );
552                http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
553                    .map_err(sanitize_token_error)?
554            }
555            provider::TokenExchangeShape::JsonBody => {
556                let mut json = serde_json::Map::new();
557                json.insert("grant_type".into(), "authorization_code".into());
558                json.insert("code".into(), code.into());
559                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
560                json.insert("client_id".into(), self.client_id.clone().into());
561                json.insert("client_secret".into(), self.client_secret.clone().into());
562                if let Some(v) = code_verifier {
563                    json.insert("code_verifier".into(), v.to_string().into());
564                }
565                let body = serde_json::Value::Object(json).to_string();
566                http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
567            }
568            provider::TokenExchangeShape::BasicAuthJsonBody => {
569                let mut json = serde_json::Map::new();
570                json.insert("grant_type".into(), "authorization_code".into());
571                json.insert("code".into(), code.into());
572                json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
573                if let Some(v) = code_verifier {
574                    json.insert("code_verifier".into(), v.to_string().into());
575                }
576                let body = serde_json::Value::Object(json).to_string();
577                http_post_json(
578                    &token_url,
579                    &body,
580                    Some((&self.client_id, &self.client_secret)),
581                )
582                .map_err(sanitize_token_error)?
583            }
584        };
585        parse_token_response(&out)
586    }
587
588    /// Exchange an authorization code for an access token. Thin wrapper
589    /// around [`OAuthConfig::exchange_code_full`] for callers that only
590    /// need the access token.
591    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
592        Ok(self.exchange_code_full(code)?.access_token)
593    }
594
595    /// Fetch the authenticated user's email + display name using an access token.
596    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
597        let info = self.fetch_userinfo_full(access_token)?;
598        Ok((info.email, info.name))
599    }
600
601    /// Fetch the authenticated user's full identity info — email + name +
602    /// the provider-stable account ID. Uses the spec's
603    /// [`provider::UserinfoParser`] so adding a new provider is a
604    /// table change, not a new branch.
605    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
606        // The id_token from the token response carries the identity
607        // for Apple and similar; route to the dedicated entry point.
608        // Apple's userinfo_url is None — this is the supported path.
609        self.fetch_userinfo_with_id_token(access_token, None)
610    }
611
612    /// Fetch userinfo, falling back to the supplied id_token JWT when
613    /// the provider has no userinfo endpoint (Apple). The id_token
614    /// is the one returned by [`Self::exchange_code_full`] in
615    /// [`TokenSet::id_token`].
616    pub fn fetch_userinfo_with_id_token(
617        &self,
618        access_token: &str,
619        id_token: Option<&str>,
620    ) -> Result<UserInfo, String> {
621        let spec = self.resolved_spec()?;
622        let cfg = self.provider_cfg();
623
624        // Apple — identity lives in the id_token, not a userinfo endpoint.
625        if matches!(
626            spec.userinfo_parser(),
627            provider::UserinfoParser::AppleIdToken
628        ) {
629            let token =
630                id_token.ok_or("apple login requires the id_token from the token response")?;
631            return parse_apple_id_token(token, &self.provider);
632        }
633
634        // Linear is GraphQL — the userinfo "GET" is actually a POST
635        // with a fixed query.
636        if matches!(
637            spec.userinfo_parser(),
638            provider::UserinfoParser::LinearGraphql
639        ) {
640            return fetch_linear_userinfo(&self.provider, access_token);
641        }
642
643        let url = match spec.userinfo_url() {
644            Some(u) => provider::resolve_endpoint(u, &cfg),
645            None => {
646                return Err(format!(
647                    "provider {} has no userinfo endpoint",
648                    self.provider
649                ))
650            }
651        };
652        let out = match spec.userinfo_method() {
653            provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
654            provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
655        }
656        .map_err(sanitize_token_error)?;
657        let parsed: serde_json::Value =
658            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
659
660        match spec.userinfo_parser() {
661            provider::UserinfoParser::Oidc => {
662                let email = parsed
663                    .get("email")
664                    .and_then(|v| v.as_str())
665                    .ok_or("no email in userinfo")?
666                    .to_string();
667                let name = parsed
668                    .get("name")
669                    .and_then(|v| v.as_str())
670                    .map(String::from);
671                let provider_account_id = parsed
672                    .get("sub")
673                    .and_then(|v| v.as_str())
674                    .ok_or("no sub in userinfo")?
675                    .to_string();
676                Ok(UserInfo {
677                    provider: self.provider.clone(),
678                    provider_account_id,
679                    email,
680                    name,
681                })
682            }
683            provider::UserinfoParser::GitHub => {
684                let name = parsed
685                    .get("name")
686                    .and_then(|v| v.as_str())
687                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
688                    .map(String::from);
689                let email = parsed
690                    .get("email")
691                    .and_then(|v| v.as_str())
692                    .map(String::from);
693                let email = email
694                    .or_else(|| fetch_github_primary_email(access_token).ok())
695                    .ok_or("no accessible email on GitHub account")?;
696                let provider_account_id = parsed
697                    .get("id")
698                    .map(|v| {
699                        v.as_i64()
700                            .map(|n| n.to_string())
701                            .or_else(|| v.as_str().map(String::from))
702                            .unwrap_or_default()
703                    })
704                    .filter(|s| !s.is_empty())
705                    .ok_or("no id in userinfo")?;
706                Ok(UserInfo {
707                    provider: self.provider.clone(),
708                    provider_account_id,
709                    email,
710                    name,
711                })
712            }
713            provider::UserinfoParser::Custom {
714                id_path,
715                email_path,
716                name_path,
717            } => {
718                let provider_account_id = json_pointer_string(&parsed, id_path)
719                    .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
720                let raw_email = json_pointer_string(&parsed, email_path)
721                    .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
722                // Twitter/Reddit don't expose real emails — they map a
723                // username into the email slot. Tag it so account
724                // policies can distinguish "real verified email" from
725                // "we made this up." `.invalid` is reserved by RFC 6761.
726                let email = if !raw_email.contains('@') {
727                    let domain = match self.provider.as_str() {
728                        "twitter" => "x.invalid",
729                        "reddit" => "reddit.invalid",
730                        other => return Err(format!(
731                            "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
732                        )),
733                    };
734                    format!("{raw_email}@{domain}")
735                } else {
736                    raw_email
737                };
738                let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
739                Ok(UserInfo {
740                    provider: self.provider.clone(),
741                    provider_account_id,
742                    email,
743                    name,
744                })
745            }
746            provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
747            provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
748        }
749    }
750}
751
752/// PKCE pair — the verifier stays server-side until token exchange,
753/// the (S256-hashed) challenge goes on the auth URL.
754struct PkcePair {
755    code_verifier: String,
756    code_challenge: String,
757}
758
759/// Generate a PKCE pair: random 43-char verifier + S256 challenge.
760/// RFC 7636 §4.1 permits 43–128 chars from `[A-Za-z0-9-._~]`. 32
761/// random bytes URL-base64-encoded comes out to exactly 43 chars.
762fn generate_pkce() -> PkcePair {
763    use rand::RngCore;
764    let mut bytes = [0u8; 32];
765    rand::thread_rng().fill_bytes(&mut bytes);
766    let code_verifier = apple_jwt::base64_url(bytes);
767    use sha2::{Digest, Sha256};
768    let mut hasher = Sha256::new();
769    hasher.update(code_verifier.as_bytes());
770    let code_challenge = apple_jwt::base64_url(hasher.finalize());
771    PkcePair {
772        code_verifier,
773        code_challenge,
774    }
775}
776
777/// Decode an Apple id_token JWT and pull the identity claims.
778///
779/// **Trust assumption:** the caller MUST have obtained this token
780/// via the back-channel `/auth/token` exchange (mutually authenticated
781/// TLS to `appleid.apple.com`). Under that assumption no third party
782/// can have substituted a forged JWT, so we skip signature
783/// verification.
784///
785/// **DO NOT call this on a JWT supplied by the client** (e.g. a
786/// "post your id_token to me" mobile-SDK flow). For those paths,
787/// implement Apple JWKS verification: fetch
788/// `https://appleid.apple.com/auth/keys`, verify the RS256
789/// signature, then check `iss == "https://appleid.apple.com"`,
790/// `aud == client_id`, and `exp > now`. Pylon doesn't ship that
791/// verifier yet — apps that need it can compose `crate::jwt::verify`
792/// against a JWKS-loaded RSA key.
793///
794/// This function is private (`fn`, not `pub fn`) precisely so it
795/// can't be misused by an external caller. The only call site is
796/// [`OAuthConfig::fetch_userinfo_with_id_token`] which is reached
797/// only via the OAuth callback handler, which only processes
798/// back-channel-exchanged tokens.
799fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
800    let mut parts = id_token.split('.');
801    let _header = parts.next().ok_or("apple id_token: missing header")?;
802    let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
803    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
804    let claims_bytes = URL_SAFE_NO_PAD
805        .decode(claims_b64)
806        .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
807    let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
808        .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
809    let provider_account_id = claims
810        .get("sub")
811        .and_then(|v| v.as_str())
812        .ok_or("apple id_token: missing sub")?
813        .to_string();
814    let email = claims
815        .get("email")
816        .and_then(|v| v.as_str())
817        .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
818        .to_string();
819    Ok(UserInfo {
820        provider: provider.to_string(),
821        provider_account_id,
822        email,
823        name: None, // Apple sends `name` as a separate form field on FIRST signup only.
824    })
825}
826
827/// Strip provider error bodies of secrets before they propagate to
828/// logs / `oauth_error_message` redirect URLs.
829///
830/// **Why:** Several token endpoints echo the request body (or pieces
831/// of it) on auth failure. Without this, a misconfigured deployment
832/// can leak `client_secret`, the Apple JWT, or even the auth `code`
833/// into the user's browser history and CDN logs.
834///
835/// Covers both shapes echoed by real providers:
836///   - form / query: `client_secret=sk_…`
837///   - JSON: `"client_secret":"sk_…"` (Notion, Atlassian)
838fn sanitize_token_error(err: String) -> String {
839    const SENSITIVE: &[&str] = &[
840        "client_secret",
841        "code_verifier",
842        "client_assertion",
843        "refresh_token",
844        "access_token",
845        "id_token",
846        // The auth `code` itself is single-use but still sensitive
847        // until the token endpoint consumes it — and many providers
848        // echo it back on a 4xx token-exchange error before the
849        // attacker has had a chance to redeem it.
850        "code",
851    ];
852    let mut out = err;
853    for key in SENSITIVE {
854        out = redact_param_form(&out, key);
855        out = redact_param_json(&out, key);
856    }
857    out
858}
859
860/// Replace the value of `key=…` (form/query string) with `***`,
861/// terminating at any of `& \n " '`. UTF-8 safe — uses `char_indices`
862/// so a stray multibyte character before a sensitive key won't panic.
863fn redact_param_form(input: &str, key: &str) -> String {
864    let needle = format!("{key}=");
865    let mut out = String::with_capacity(input.len());
866    let mut i = 0;
867    while i < input.len() {
868        if input[i..].starts_with(&needle) {
869            out.push_str(&needle);
870            out.push_str("***");
871            i += needle.len();
872            // Skip until a terminator. char_indices keeps i aligned
873            // to char boundaries.
874            while let Some((rel, ch)) = input[i..].char_indices().next() {
875                if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
876                    i += rel;
877                    break;
878                }
879                i += rel + ch.len_utf8();
880            }
881        } else {
882            // Advance by one full char to stay UTF-8 aligned.
883            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
884            out.push(ch);
885            i += ch.len_utf8();
886        }
887    }
888    out
889}
890
891/// Replace the value in `"key":"…"` with `***`. Case-sensitive,
892/// tolerant of whitespace between `:` and the value (per JSON).
893fn redact_param_json(input: &str, key: &str) -> String {
894    let needle = format!("\"{key}\"");
895    let mut out = String::with_capacity(input.len());
896    let mut i = 0;
897    while i < input.len() {
898        if !input[i..].starts_with(&needle) {
899            let (_, ch) = input[i..].char_indices().next().expect("non-empty");
900            out.push(ch);
901            i += ch.len_utf8();
902            continue;
903        }
904        // Found `"key"`. Walk forward over `:` + optional whitespace,
905        // then `"`, then the value, then closing `"`. If anything
906        // is off (not actually a string-valued field) bail and
907        // copy verbatim.
908        let mut j = i + needle.len();
909        // optional whitespace
910        while let Some((_, ch)) = input[j..].char_indices().next() {
911            if !ch.is_whitespace() {
912                break;
913            }
914            j += ch.len_utf8();
915        }
916        if !input[j..].starts_with(':') {
917            // Not a key-value form (could be in an array, etc.).
918            out.push_str(&input[i..j]);
919            i = j;
920            continue;
921        }
922        j += 1;
923        while let Some((_, ch)) = input[j..].char_indices().next() {
924            if !ch.is_whitespace() {
925                break;
926            }
927            j += ch.len_utf8();
928        }
929        if !input[j..].starts_with('"') {
930            out.push_str(&input[i..j]);
931            i = j;
932            continue;
933        }
934        let value_start = j + 1;
935        // Find the closing `"`, honoring `\"` escapes.
936        let mut k = value_start;
937        let mut prev_backslash = false;
938        let mut closing: Option<usize> = None;
939        while k < input.len() {
940            let (_, ch) = input[k..].char_indices().next().expect("non-empty");
941            if ch == '"' && !prev_backslash {
942                closing = Some(k);
943                break;
944            }
945            prev_backslash = ch == '\\' && !prev_backslash;
946            k += ch.len_utf8();
947        }
948        match closing {
949            Some(end) => {
950                out.push_str(&input[i..value_start]);
951                out.push_str("***");
952                out.push('"');
953                i = end + 1;
954            }
955            None => {
956                // Malformed JSON, redact to end of input to be safe.
957                out.push_str(&input[i..value_start]);
958                out.push_str("***");
959                i = input.len();
960            }
961        }
962    }
963    out
964}
965
966/// Linear's userinfo lives behind a GraphQL endpoint — the bearer
967/// token is the same OAuth access token, but the request is a POST
968/// with a fixed query. Kept as a separate fn so the main fetcher
969/// stays uniform across the other parsers.
970fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
971    let body = r#"{"query":"query { viewer { id email name } }"}"#;
972    let agent = ureq_agent();
973    let resp = agent
974        .post("https://api.linear.app/graphql")
975        .set("Authorization", &format!("Bearer {access_token}"))
976        .set("Content-Type", "application/json")
977        .set("Accept", "application/json")
978        .send_string(body)
979        .map_err(|e| format!("linear graphql: {e}"))?;
980    let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
981    let parsed: serde_json::Value =
982        serde_json::from_str(&out).map_err(|e| format!("linear graphql not JSON: {e}"))?;
983    let viewer = parsed
984        .pointer("/data/viewer")
985        .ok_or("linear graphql: no /data/viewer")?;
986    let provider_account_id = viewer
987        .get("id")
988        .and_then(|v| v.as_str())
989        .ok_or("linear graphql: no id")?
990        .to_string();
991    let email = viewer
992        .get("email")
993        .and_then(|v| v.as_str())
994        .ok_or("linear graphql: no email")?
995        .to_string();
996    let name = viewer
997        .get("name")
998        .and_then(|v| v.as_str())
999        .map(String::from);
1000    Ok(UserInfo {
1001        provider: provider.to_string(),
1002        provider_account_id,
1003        email,
1004        name,
1005    })
1006}
1007
1008/// JSON-pointer (RFC 6901) string extraction. Returns `None` for
1009/// missing paths or non-string values. Numeric ids (Discord's `id`,
1010/// Roblox's `sub`) are coerced to strings.
1011fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
1012    let node = v.pointer(path)?;
1013    if let Some(s) = node.as_str() {
1014        return Some(s.to_string());
1015    }
1016    if let Some(n) = node.as_i64() {
1017        return Some(n.to_string());
1018    }
1019    if let Some(n) = node.as_u64() {
1020        return Some(n.to_string());
1021    }
1022    None
1023}
1024
1025/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
1026/// `provider_account_id` is the provider-stable subject id (Google `sub`,
1027/// GitHub numeric `id`) — what the account store keys on so a renamed
1028/// email doesn't orphan the pylon account.
1029#[derive(Debug, Clone, PartialEq, Eq)]
1030pub struct UserInfo {
1031    pub provider: String,
1032    pub provider_account_id: String,
1033    pub email: String,
1034    pub name: Option<String>,
1035}
1036
1037/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
1038/// on the matching `Account` row so `refresh_token` is available for
1039/// silent re-auth and `expires_at` is checked before each provider call.
1040#[derive(Debug, Clone, PartialEq, Eq)]
1041pub struct TokenSet {
1042    pub access_token: String,
1043    pub refresh_token: Option<String>,
1044    pub id_token: Option<String>,
1045    /// Unix epoch seconds at which the access token expires. `None` when
1046    /// the provider didn't return `expires_in` (GitHub's classic OAuth
1047    /// app tokens are non-expiring).
1048    pub expires_at: Option<u64>,
1049    pub scope: Option<String>,
1050}
1051
1052fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1053    // Most providers return JSON; GitHub Classic apps return form-urlencoded
1054    // unless you ask with Accept: application/json (which we do).
1055    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1056        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
1057        let mut map = serde_json::Map::new();
1058        for pair in body.split('&') {
1059            if let Some((k, v)) = pair.split_once('=') {
1060                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1061            }
1062        }
1063        serde_json::Value::Object(map)
1064    });
1065
1066    let access_token = json
1067        .get("access_token")
1068        .and_then(|v| v.as_str())
1069        .ok_or_else(|| format!("no access_token in token response: {body}"))?
1070        .to_string();
1071    let refresh_token = json
1072        .get("refresh_token")
1073        .and_then(|v| v.as_str())
1074        .map(String::from);
1075    let id_token = json
1076        .get("id_token")
1077        .and_then(|v| v.as_str())
1078        .map(String::from);
1079    let expires_at = json
1080        .get("expires_in")
1081        .and_then(|v| {
1082            v.as_u64()
1083                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1084        })
1085        .map(|secs| now_secs().saturating_add(secs));
1086    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1087    Ok(TokenSet {
1088        access_token,
1089        refresh_token,
1090        id_token,
1091        expires_at,
1092        scope,
1093    })
1094}
1095
1096fn url_encode(s: &str) -> String {
1097    let mut out = String::with_capacity(s.len());
1098    for b in s.bytes() {
1099        match b {
1100            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1101                out.push(b as char)
1102            }
1103            _ => out.push_str(&format!("%{b:02X}")),
1104        }
1105    }
1106    out
1107}
1108
1109/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
1110/// provider doesn't block a login indefinitely; long enough to absorb
1111/// typical internet latency.
1112const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1113
1114fn ureq_agent() -> ureq::Agent {
1115    ureq::AgentBuilder::new()
1116        .timeout_connect(HTTP_TIMEOUT)
1117        .timeout_read(HTTP_TIMEOUT)
1118        .timeout_write(HTTP_TIMEOUT)
1119        .user_agent("pylon/0.1")
1120        .build()
1121}
1122
1123fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1124    let agent = ureq_agent();
1125    let mut req = agent
1126        .post(url)
1127        .set("Content-Type", "application/x-www-form-urlencoded");
1128    if accept_json {
1129        req = req.set("Accept", "application/json");
1130    }
1131    match req.send_string(body) {
1132        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1133        Err(ureq::Error::Status(code, resp)) => {
1134            let body = resp.into_string().unwrap_or_default();
1135            Err(format!("HTTP {code}: {body}"))
1136        }
1137        Err(e) => Err(format!("HTTP error: {e}")),
1138    }
1139}
1140
1141/// POST a form body using HTTP Basic auth for the client credentials.
1142/// Used by Spotify, Reddit, Figma, Zoom, PayPal — providers that
1143/// mandate Basic auth on the token endpoint.
1144fn http_post_form_basic(
1145    url: &str,
1146    body: &str,
1147    client_id: &str,
1148    client_secret: &str,
1149) -> Result<String, String> {
1150    use base64::{engine::general_purpose::STANDARD, Engine};
1151    let creds = format!("{client_id}:{client_secret}");
1152    let basic = STANDARD.encode(creds.as_bytes());
1153    let agent = ureq_agent();
1154    match agent
1155        .post(url)
1156        .set("Content-Type", "application/x-www-form-urlencoded")
1157        .set("Accept", "application/json")
1158        .set("Authorization", &format!("Basic {basic}"))
1159        .send_string(body)
1160    {
1161        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1162        Err(ureq::Error::Status(code, resp)) => {
1163            let body = resp.into_string().unwrap_or_default();
1164            Err(format!("HTTP {code}: {body}"))
1165        }
1166        Err(e) => Err(format!("HTTP error: {e}")),
1167    }
1168}
1169
1170/// POST a JSON body, optionally with HTTP Basic auth. Used by
1171/// Notion (Basic + JSON) and Atlassian (JSON only) — both reject
1172/// form-encoded bodies on their token endpoints.
1173fn http_post_json(
1174    url: &str,
1175    body: &str,
1176    basic_creds: Option<(&str, &str)>,
1177) -> Result<String, String> {
1178    let agent = ureq_agent();
1179    let mut req = agent
1180        .post(url)
1181        .set("Content-Type", "application/json")
1182        .set("Accept", "application/json");
1183    if let Some((id, secret)) = basic_creds {
1184        use base64::{engine::general_purpose::STANDARD, Engine};
1185        let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1186        req = req.set("Authorization", &format!("Basic {creds}"));
1187    }
1188    // Notion requires the API version header on every call, even the
1189    // token exchange. Using a recent stable version.
1190    req = req.set("Notion-Version", "2022-06-28");
1191    match req.send_string(body) {
1192        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1193        Err(ureq::Error::Status(code, resp)) => {
1194            let body = resp.into_string().unwrap_or_default();
1195            Err(format!("HTTP {code}: {body}"))
1196        }
1197        Err(e) => Err(format!("HTTP error: {e}")),
1198    }
1199}
1200
1201/// POST with empty body + bearer auth. Used for Dropbox userinfo
1202/// (an RPC-style endpoint that requires POST instead of GET).
1203fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1204    let agent = ureq_agent();
1205    match agent
1206        .post(url)
1207        .set("Authorization", &format!("Bearer {token}"))
1208        .set("Accept", "application/json")
1209        .call()
1210    {
1211        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1212        Err(ureq::Error::Status(code, resp)) => {
1213            let body = resp.into_string().unwrap_or_default();
1214            Err(format!("HTTP {code}: {body}"))
1215        }
1216        Err(e) => Err(format!("HTTP error: {e}")),
1217    }
1218}
1219
1220fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1221    let agent = ureq_agent();
1222    match agent
1223        .get(url)
1224        .set("Authorization", &format!("Bearer {token}"))
1225        .set("Accept", "application/json")
1226        .call()
1227    {
1228        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1229        Err(ureq::Error::Status(code, resp)) => {
1230            let body = resp.into_string().unwrap_or_default();
1231            Err(format!("HTTP {code}: {body}"))
1232        }
1233        Err(e) => Err(format!("HTTP error: {e}")),
1234    }
1235}
1236
1237fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1238    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1239    let emails: serde_json::Value =
1240        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1241    emails
1242        .as_array()
1243        .and_then(|arr| {
1244            arr.iter()
1245                .find(|e| {
1246                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1247                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1248                })
1249                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1250        })
1251        .ok_or_else(|| "no primary verified email on GitHub".into())
1252}
1253
1254/// OAuth provider registry.
1255pub struct OAuthRegistry {
1256    providers: std::collections::HashMap<String, OAuthConfig>,
1257}
1258
1259impl Default for OAuthRegistry {
1260    fn default() -> Self {
1261        Self::new()
1262    }
1263}
1264
1265impl OAuthRegistry {
1266    pub fn new() -> Self {
1267        Self {
1268            providers: std::collections::HashMap::new(),
1269        }
1270    }
1271
1272    pub fn register(&mut self, config: OAuthConfig) {
1273        self.providers.insert(config.provider.clone(), config);
1274    }
1275
1276    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1277        self.providers.get(provider)
1278    }
1279
1280    /// Build from environment variables.
1281    ///
1282    /// For each builtin provider (and any `oidc_issuer`-configured
1283    /// IdP), looks for `PYLON_OAUTH_<PROVIDER>_CLIENT_ID` /
1284    /// `_CLIENT_SECRET` / `_REDIRECT`. Apple additionally requires
1285    /// `_TEAM_ID`, `_KEY_ID`, `_PRIVATE_KEY` (PEM contents or path).
1286    /// Microsoft accepts an optional `_TENANT`.
1287    ///
1288    /// Generic OIDC: any env var matching
1289    /// `PYLON_OAUTH_<NAME>_OIDC_ISSUER` registers a provider with id
1290    /// `<name>` (lowercased) using the discovered endpoints. Useful
1291    /// for Auth0, Okta, Keycloak, Cognito, Logto, Authentik, etc.
1292    pub fn from_env() -> Self {
1293        let mut reg = Self::new();
1294
1295        for spec in provider::builtin::all() {
1296            let upper = spec.id.to_ascii_uppercase();
1297            let prefix = format!("PYLON_OAUTH_{upper}");
1298            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1299                Ok(v) => v,
1300                Err(_) => continue,
1301            };
1302            let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1303                Ok(v) => v,
1304                // Apple's "client_secret" is synthesized — allow blank.
1305                Err(_) if spec.id == "apple" => String::new(),
1306                Err(_) => continue,
1307            };
1308            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1309                .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{}", spec.id));
1310            let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1311            let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1312
1313            let apple = if spec.id == "apple" {
1314                match (
1315                    std::env::var(format!("{prefix}_TEAM_ID")),
1316                    std::env::var(format!("{prefix}_KEY_ID")),
1317                    std::env::var(format!("{prefix}_PRIVATE_KEY")),
1318                ) {
1319                    (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1320                        team_id,
1321                        key_id,
1322                        private_key_pem,
1323                    }),
1324                    _ => continue, // Apple requires the JWT material to function.
1325                }
1326            } else {
1327                None
1328            };
1329
1330            reg.register(OAuthConfig {
1331                provider: spec.id.to_string(),
1332                client_id: id,
1333                client_secret: secret,
1334                redirect_uri,
1335                scopes_override,
1336                tenant,
1337                apple,
1338                oidc_issuer: None,
1339            });
1340        }
1341
1342        // Generic OIDC providers — scan PYLON_OAUTH_<NAME>_OIDC_ISSUER.
1343        for (key, issuer) in std::env::vars() {
1344            let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1345                continue;
1346            };
1347            let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1348                continue;
1349            };
1350            let name = name_upper.to_ascii_lowercase();
1351            if provider::find_spec(&name).is_some() {
1352                continue; // already handled as a builtin
1353            }
1354            let prefix = format!("PYLON_OAUTH_{name_upper}");
1355            let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1356                Ok(v) => v,
1357                Err(_) => continue,
1358            };
1359            let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1360            let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1361                .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1362            reg.register(OAuthConfig {
1363                provider: name,
1364                client_id: id,
1365                client_secret: secret,
1366                redirect_uri,
1367                scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1368                tenant: None,
1369                apple: None,
1370                oidc_issuer: Some(issuer),
1371            });
1372        }
1373
1374        reg
1375    }
1376
1377    /// Iterate over registered provider ids — used by routes/auth.rs
1378    /// to expose `/api/auth/providers` and to validate
1379    /// `/api/auth/login/<id>` paths against the configured set.
1380    pub fn ids(&self) -> impl Iterator<Item = &str> {
1381        self.providers.keys().map(|s| s.as_str())
1382    }
1383
1384    /// Process-wide cached registry. Built once on first use from
1385    /// `from_env`; subsequent calls are zero-cost. Routes use this
1386    /// to avoid the ~150 syscalls `from_env` does per call.
1387    ///
1388    /// **Trade-off:** env changes after server start aren't picked up
1389    /// without a restart — same as every other Pylon env-var path.
1390    pub fn shared() -> &'static OAuthRegistry {
1391        static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1392        CELL.get_or_init(Self::from_env)
1393    }
1394}
1395
1396// ---------------------------------------------------------------------------
1397// OAuth state store — CSRF protection for OAuth flows
1398// ---------------------------------------------------------------------------
1399
1400/// One stored OAuth state record. Carries the post-callback redirect
1401/// URLs alongside the provider so the callback handler doesn't need to
1402/// consult an env var to know where to send the user. Both URLs are
1403/// validated against `PYLON_TRUSTED_ORIGINS` at create time, so the
1404/// callback can trust them without re-checking.
1405#[derive(Debug, Clone, PartialEq, Eq)]
1406pub struct OAuthState {
1407    pub provider: String,
1408    /// URL the callback redirects to on success. The frontend supplies
1409    /// this via `?callback=` on the start request.
1410    pub callback_url: String,
1411    /// URL the callback redirects to on failure. Defaults to
1412    /// `callback_url` when the frontend doesn't pass an explicit
1413    /// `?error_callback=`. The error code + message ride along as
1414    /// query params (`?oauth_error=X&oauth_error_message=Y`).
1415    pub error_callback_url: String,
1416    /// PKCE code_verifier when the provider requires PKCE. Set by the
1417    /// `/api/auth/login/<provider>` start route via
1418    /// [`OAuthConfig::auth_url_with_pkce`]; replayed on token exchange
1419    /// in the callback. `None` for non-PKCE providers.
1420    pub pkce_verifier: Option<String>,
1421    pub expires_at: u64,
1422}
1423
1424/// Backing store for OAuth state records. Default impl keeps them in
1425/// memory (fine for tests + dev); the runtime swaps in a SQLite or
1426/// Postgres backend so a restart in the middle of an OAuth handshake
1427/// doesn't leave the user with "invalid state" on the callback.
1428pub trait OAuthStateBackend: Send + Sync {
1429    /// Persist a state record under `token`.
1430    fn put(&self, token: &str, state: &OAuthState);
1431    /// Atomic compare-and-consume: returns the stored record if the
1432    /// token exists and hasn't expired, then removes it. Returning
1433    /// `None` means either the token never existed or it has already
1434    /// been used / expired.
1435    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1436}
1437
1438/// In-memory backend (default). Lost on restart.
1439pub struct InMemoryOAuthBackend {
1440    states: Mutex<HashMap<String, OAuthState>>,
1441}
1442
1443impl InMemoryOAuthBackend {
1444    pub fn new() -> Self {
1445        Self {
1446            states: Mutex::new(HashMap::new()),
1447        }
1448    }
1449}
1450
1451impl Default for InMemoryOAuthBackend {
1452    fn default() -> Self {
1453        Self::new()
1454    }
1455}
1456
1457impl OAuthStateBackend for InMemoryOAuthBackend {
1458    fn put(&self, token: &str, state: &OAuthState) {
1459        self.states
1460            .lock()
1461            .unwrap()
1462            .insert(token.to_string(), state.clone());
1463    }
1464    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1465        let mut s = self.states.lock().unwrap();
1466        let entry = s.remove(token)?;
1467        if entry.expires_at <= now_unix_secs {
1468            return None;
1469        }
1470        Some(entry)
1471    }
1472}
1473
1474/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
1475///
1476/// State tokens are short-lived (10 minutes) and single-use. Backed by an
1477/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
1478/// to SQLite (or Postgres when `DATABASE_URL` is set) so they survive a
1479/// restart that happens mid-OAuth-handshake.
1480pub struct OAuthStateStore {
1481    backend: Box<dyn OAuthStateBackend>,
1482}
1483
1484impl Default for OAuthStateStore {
1485    fn default() -> Self {
1486        Self::new()
1487    }
1488}
1489
1490impl OAuthStateStore {
1491    pub fn new() -> Self {
1492        Self {
1493            backend: Box::new(InMemoryOAuthBackend::new()),
1494        }
1495    }
1496
1497    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1498        Self { backend }
1499    }
1500
1501    /// Generate and store a new state record. Returns the random
1502    /// state token (the value the OAuth provider echoes back as
1503    /// `?state=…` on the callback).
1504    ///
1505    /// Caller is responsible for validating `callback_url` and
1506    /// `error_callback_url` against the trusted-origins allowlist
1507    /// BEFORE calling this — the store trusts what it's given.
1508    pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1509        self.create_with_pkce(provider, callback_url, error_callback_url, None)
1510    }
1511
1512    /// Same as [`Self::create`] but accepts a PKCE verifier to stash
1513    /// alongside the state record. The callback handler reads it back
1514    /// out and replays it in the token exchange.
1515    pub fn create_with_pkce(
1516        &self,
1517        provider: &str,
1518        callback_url: &str,
1519        error_callback_url: &str,
1520        pkce_verifier: Option<String>,
1521    ) -> String {
1522        use std::time::{SystemTime, UNIX_EPOCH};
1523        let token = generate_token();
1524        let now = SystemTime::now()
1525            .duration_since(UNIX_EPOCH)
1526            .unwrap_or_default()
1527            .as_secs();
1528        let state = OAuthState {
1529            provider: provider.to_string(),
1530            callback_url: callback_url.to_string(),
1531            error_callback_url: error_callback_url.to_string(),
1532            pkce_verifier,
1533            expires_at: now + 600,
1534        };
1535        self.backend.put(&token, &state);
1536        token
1537    }
1538
1539    /// Validate and consume a state token. Returns the stored record
1540    /// iff the token existed, has not expired, AND matches
1541    /// `expected_provider`. The token is removed either way to make
1542    /// replay impossible.
1543    pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1544        use std::time::{SystemTime, UNIX_EPOCH};
1545        let now = SystemTime::now()
1546            .duration_since(UNIX_EPOCH)
1547            .unwrap_or_default()
1548            .as_secs();
1549        let entry = self.backend.take(state, now)?;
1550        if entry.provider != expected_provider {
1551            return None;
1552        }
1553        Some(entry)
1554    }
1555}
1556
1557/// Validate that `url` has an origin (scheme://host[:port]) listed in
1558/// `trusted_origins`. Returns `Ok(url)` when trusted (echoes input for
1559/// chaining), `Err` with a code/message when not. Used by the OAuth
1560/// start endpoint to gate `?callback=` + `?error_callback=` values
1561/// before storing them in the state record.
1562///
1563/// `trusted_origins` entries are origin strings like
1564/// `"https://app.example.com"` or `"http://localhost:3000"` — no
1565/// trailing slash, no path. A `url` like
1566/// `"http://localhost:3000/dashboard?x=1"` matches the
1567/// `"http://localhost:3000"` entry.
1568///
1569/// Borrowed wholesale from better-auth's `trustedOrigins` model:
1570/// explicit allowlist, no implicit "same-origin trust," no env-var
1571/// magic. An open-redirect via OAuth is one of the easier auth bugs
1572/// to ship by accident.
1573pub fn validate_trusted_redirect(
1574    url: &str,
1575    trusted_origins: &[String],
1576) -> Result<(), TrustedOriginError> {
1577    if url.is_empty() {
1578        return Err(TrustedOriginError::Empty);
1579    }
1580    // Must be absolute http(s) URL — no relative paths, no schemes
1581    // like javascript:, file:, data:.
1582    if !url.starts_with("http://") && !url.starts_with("https://") {
1583        return Err(TrustedOriginError::NotHttp);
1584    }
1585    let url_origin = origin_of(url);
1586    if trusted_origins.iter().any(|t| t == &url_origin) {
1587        Ok(())
1588    } else {
1589        Err(TrustedOriginError::NotTrusted { origin: url_origin })
1590    }
1591}
1592
1593/// Reasons a redirect URL might be rejected by [`validate_trusted_redirect`].
1594#[derive(Debug, Clone, PartialEq, Eq)]
1595pub enum TrustedOriginError {
1596    Empty,
1597    NotHttp,
1598    NotTrusted { origin: String },
1599}
1600
1601impl std::fmt::Display for TrustedOriginError {
1602    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1603        match self {
1604            TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1605            TrustedOriginError::NotHttp => {
1606                write!(f, "redirect URL must use http:// or https:// scheme")
1607            }
1608            TrustedOriginError::NotTrusted { origin } => write!(
1609                f,
1610                "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1611            ),
1612        }
1613    }
1614}
1615
1616/// Extract the origin (`scheme://host[:port]`) from a URL string,
1617/// stripping any path/query/fragment. Best-effort string slicing —
1618/// no full URL parser dep. Public so router crates can reuse the same
1619/// logic when comparing redirect URLs against the trusted-origins list.
1620pub fn origin_of(url: &str) -> String {
1621    let after_scheme = match url.find("://") {
1622        Some(i) => i + 3,
1623        None => return url.trim_end_matches('/').to_string(),
1624    };
1625    let rest = &url[after_scheme..];
1626    let cut = rest
1627        .find(|c: char| c == '/' || c == '?' || c == '#')
1628        .unwrap_or(rest.len());
1629    url[..after_scheme + cut].to_string()
1630}
1631
1632// ---------------------------------------------------------------------------
1633// Magic code auth — email verification codes
1634// ---------------------------------------------------------------------------
1635
1636/// Pluggable storage for magic-code records. In-memory is the default
1637/// (fine for dev); persistent backends (SQLite, Postgres) live in
1638/// `pylon-runtime` so a server restart between "send code" and "verify
1639/// code" doesn't invalidate the user's pending login.
1640///
1641/// All methods are infallible from the caller's perspective — durability
1642/// is best-effort. A backend that fails to write should log; the
1643/// in-memory cache remains authoritative for the current process.
1644pub trait MagicCodeBackend: Send + Sync {
1645    /// Replace any existing code for `email` with `code`.
1646    fn put(&self, email: &str, code: &MagicCode);
1647    /// Look up the current code for `email`. Returns `None` if absent.
1648    fn get(&self, email: &str) -> Option<MagicCode>;
1649    /// Remove the code for `email` (called on successful verify or
1650    /// expiry). Idempotent — missing key is not an error.
1651    fn remove(&self, email: &str);
1652    /// Persist an attempts++ on the existing record without touching
1653    /// other fields. Used by the verify-failed path to enforce
1654    /// `MAX_ATTEMPTS` across restarts.
1655    fn bump_attempts(&self, email: &str);
1656    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
1657    /// hydrate the in-memory cache from durable storage on startup.
1658    fn load_all(&self) -> Vec<MagicCode>;
1659}
1660
1661/// In-memory backend for magic codes. The default — also used as the
1662/// authoritative cache by `MagicCodeStore`.
1663pub struct InMemoryMagicCodeBackend {
1664    codes: Mutex<HashMap<String, MagicCode>>,
1665}
1666
1667impl InMemoryMagicCodeBackend {
1668    pub fn new() -> Self {
1669        Self {
1670            codes: Mutex::new(HashMap::new()),
1671        }
1672    }
1673}
1674
1675impl Default for InMemoryMagicCodeBackend {
1676    fn default() -> Self {
1677        Self::new()
1678    }
1679}
1680
1681impl MagicCodeBackend for InMemoryMagicCodeBackend {
1682    fn put(&self, email: &str, code: &MagicCode) {
1683        self.codes
1684            .lock()
1685            .unwrap()
1686            .insert(email.to_string(), code.clone());
1687    }
1688    fn get(&self, email: &str) -> Option<MagicCode> {
1689        self.codes.lock().unwrap().get(email).cloned()
1690    }
1691    fn remove(&self, email: &str) {
1692        self.codes.lock().unwrap().remove(email);
1693    }
1694    fn bump_attempts(&self, email: &str) {
1695        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1696            c.attempts = c.attempts.saturating_add(1);
1697        }
1698    }
1699    fn load_all(&self) -> Vec<MagicCode> {
1700        self.codes.lock().unwrap().values().cloned().collect()
1701    }
1702}
1703
1704/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
1705/// and applies the verify/cooldown semantics. Hydrates the in-memory
1706/// cache from the backend on construction so durable backends survive
1707/// restart without losing in-flight codes.
1708pub struct MagicCodeStore {
1709    cache: Mutex<HashMap<String, MagicCode>>,
1710    backend: Box<dyn MagicCodeBackend>,
1711}
1712
1713#[derive(Debug, Clone)]
1714pub struct MagicCode {
1715    pub email: String,
1716    pub code: String,
1717    pub expires_at: u64,
1718    /// Failed verify attempts against this code. Once it reaches
1719    /// `MAX_ATTEMPTS` the code is invalidated.
1720    pub attempts: u32,
1721}
1722
1723/// Maximum verify attempts per code before it's burned. 5 is a common bound —
1724/// lets the user fix typos without enabling realistic brute-force against a
1725/// 6-digit code space.
1726const MAX_ATTEMPTS: u32 = 5;
1727
1728/// Minimum seconds between successive `create()` calls for the same email.
1729/// Throttles magic-code spam (user can't be flooded with login codes).
1730const CREATE_COOLDOWN_SECS: u64 = 60;
1731
1732#[derive(Debug, Clone, PartialEq, Eq)]
1733pub enum MagicCodeError {
1734    /// There is no active code for this email, or it expired.
1735    NotFound,
1736    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
1737    TooManyAttempts,
1738    /// The code did not match.
1739    BadCode,
1740    /// The code expired since it was created.
1741    Expired,
1742    /// Another code was requested too recently. Wait and try again.
1743    Throttled { retry_after_secs: u64 },
1744}
1745
1746impl Default for MagicCodeStore {
1747    fn default() -> Self {
1748        Self::new()
1749    }
1750}
1751
1752impl MagicCodeStore {
1753    pub fn new() -> Self {
1754        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1755    }
1756
1757    /// Build a magic-code store backed by a persistent backend. Existing
1758    /// live codes are hydrated into the in-memory cache on construction
1759    /// so a server restart between "send" and "verify" doesn't kill the
1760    /// user's pending login.
1761    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1762        let now = now_secs();
1763        let mut cache = HashMap::new();
1764        for c in backend.load_all() {
1765            if c.expires_at > now {
1766                cache.insert(c.email.clone(), c);
1767            }
1768        }
1769        Self {
1770            cache: Mutex::new(cache),
1771            backend,
1772        }
1773    }
1774
1775    /// Generate a 6-digit code for an email and return it. Subject to a
1776    /// per-email cooldown — returns the error-shape via `try_create`.
1777    pub fn create(&self, email: &str) -> String {
1778        // Back-compat wrapper: same signature as before, but we still burn
1779        // the cooldown if one is active. Use `try_create` for a Result shape.
1780        self.try_create(email).unwrap_or_else(|_| String::new())
1781    }
1782
1783    /// Create a magic code, enforcing per-email cooldown. Returns the code
1784    /// or an error describing why one couldn't be issued.
1785    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1786        let now = now_secs();
1787
1788        let mut codes = self.cache.lock().unwrap();
1789
1790        // Cooldown check: if a live code exists and was created less than
1791        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
1792        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
1793        if let Some(existing) = codes.get(email) {
1794            if existing.expires_at > now {
1795                let created_at = existing.expires_at.saturating_sub(600);
1796                let age = now.saturating_sub(created_at);
1797                if age < CREATE_COOLDOWN_SECS {
1798                    return Err(MagicCodeError::Throttled {
1799                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
1800                    });
1801                }
1802            }
1803        }
1804
1805        let code = generate_magic_code();
1806        let mc = MagicCode {
1807            email: email.to_string(),
1808            code: code.clone(),
1809            expires_at: now + 600, // 10 minutes
1810            attempts: 0,
1811        };
1812        codes.insert(email.to_string(), mc.clone());
1813        // Persist after the cache mutation lands. Backend write is
1814        // best-effort — if it fails the code still works for this
1815        // process; only a restart in the next 10 minutes would lose it.
1816        self.backend.put(email, &mc);
1817        Ok(code)
1818    }
1819
1820    /// Verify a code for an email. Returns true if valid and not expired.
1821    /// Uses constant-time comparison to prevent timing attacks.
1822    /// Back-compat wrapper around [`try_verify`].
1823    pub fn verify(&self, email: &str, code: &str) -> bool {
1824        matches!(self.try_verify(email, code), Ok(()))
1825    }
1826
1827    /// Verify a code. Returns a typed error so callers can surface specific
1828    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1829    /// correct subsequent attempts return `TooManyAttempts`.
1830    /// Every magic code currently in the cache. Powers the Studio
1831    /// "Auth tables" view; not for app use. Includes expired codes —
1832    /// the cache only drops them on next verify attempt for that email.
1833    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1834        self.cache
1835            .lock()
1836            .map(|m| m.values().cloned().collect())
1837            .unwrap_or_default()
1838    }
1839
1840    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1841        let now = now_secs();
1842        let mut codes = self.cache.lock().unwrap();
1843
1844        let mc = match codes.get_mut(email) {
1845            Some(m) => m,
1846            None => return Err(MagicCodeError::NotFound),
1847        };
1848
1849        if mc.attempts >= MAX_ATTEMPTS {
1850            return Err(MagicCodeError::TooManyAttempts);
1851        }
1852        if mc.expires_at <= now {
1853            codes.remove(email);
1854            self.backend.remove(email);
1855            return Err(MagicCodeError::Expired);
1856        }
1857
1858        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1859        if !ok {
1860            mc.attempts += 1;
1861            self.backend.bump_attempts(email);
1862            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1863            if mc.attempts >= MAX_ATTEMPTS {
1864                return Err(MagicCodeError::TooManyAttempts);
1865            }
1866            return Err(MagicCodeError::BadCode);
1867        }
1868
1869        // Correct code — consume it.
1870        codes.remove(email);
1871        self.backend.remove(email);
1872        Ok(())
1873    }
1874}
1875
1876// ---------------------------------------------------------------------------
1877// Cryptographic helpers — CSPRNG-based token and code generation
1878// ---------------------------------------------------------------------------
1879
1880fn hex_encode(bytes: &[u8]) -> String {
1881    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1882}
1883
1884/// Generate a 6-digit magic code using a CSPRNG.
1885fn generate_magic_code() -> String {
1886    use rand::Rng;
1887    let mut rng = rand::thread_rng();
1888    let code: u32 = rng.gen_range(0..1_000_000);
1889    format!("{:06}", code)
1890}
1891
1892/// Generate a session token with 256 bits of entropy from a CSPRNG.
1893fn generate_token() -> String {
1894    use rand::Rng;
1895    let mut rng = rand::thread_rng();
1896    let bytes: [u8; 32] = rng.gen();
1897    format!("pylon_{}", hex_encode(&bytes))
1898}
1899
1900// ---------------------------------------------------------------------------
1901// Session store — in-memory for dev
1902// ---------------------------------------------------------------------------
1903
1904use std::collections::HashMap;
1905use std::sync::Mutex;
1906
1907/// Pluggable storage backend for sessions. The default is in-memory; apps
1908/// deploying for real should supply a persistent backend (e.g. SQLite or
1909/// Redis) so users don't log out on server restart.
1910pub trait SessionBackend: Send + Sync {
1911    fn load_all(&self) -> Vec<Session>;
1912    fn save(&self, session: &Session);
1913    fn remove(&self, token: &str);
1914}
1915
1916/// A session store. In-memory by default; optionally backed by a
1917/// persistent [`SessionBackend`].
1918///
1919/// The in-memory map is always authoritative — reads don't touch the
1920/// backend. The backend receives every `save`/`remove`, making it a
1921/// write-through cache. On construction via [`SessionStore::with_backend`],
1922/// the store hydrates from the backend so sessions survive restart.
1923pub struct SessionStore {
1924    sessions: Mutex<HashMap<String, Session>>,
1925    backend: Option<Box<dyn SessionBackend>>,
1926    /// Default lifetime for new sessions (seconds). Sourced from the
1927    /// manifest's `auth.session.expires_in` config at server boot;
1928    /// falls back to `Session::DEFAULT_LIFETIME_SECS` (30 days).
1929    default_lifetime_secs: u64,
1930}
1931
1932impl Default for SessionStore {
1933    fn default() -> Self {
1934        Self::new()
1935    }
1936}
1937
1938impl SessionStore {
1939    pub fn new() -> Self {
1940        Self {
1941            sessions: Mutex::new(HashMap::new()),
1942            backend: None,
1943            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1944        }
1945    }
1946
1947    /// Override the default session lifetime. Used by `pylon-runtime`'s
1948    /// server bootstrap to apply the manifest's `auth.session.expires_in`.
1949    pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1950        self.default_lifetime_secs = lifetime_secs;
1951        self
1952    }
1953
1954    /// Build a session store backed by a persistent store. Existing sessions
1955    /// are loaded from the backend on construction; every future mutation
1956    /// writes through.
1957    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1958        let mut map = HashMap::new();
1959        for s in backend.load_all() {
1960            if !s.is_expired() {
1961                map.insert(s.token.clone(), s);
1962            }
1963        }
1964        Self {
1965            sessions: Mutex::new(map),
1966            backend: Some(backend),
1967            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1968        }
1969    }
1970
1971    /// Create a session for a user and return it. Uses the store's
1972    /// configured `default_lifetime_secs` (from the manifest's
1973    /// `auth.session.expires_in`, default 30 days).
1974    pub fn create(&self, user_id: String) -> Session {
1975        self.create_with_device(user_id, None)
1976    }
1977
1978    /// Create a session with an attached device label. The label is
1979    /// what `/api/auth/sessions` shows to the user — typically the
1980    /// parsed User-Agent (see [`crate::device::parse_user_agent`]).
1981    /// Pass `None` (or use `create()`) for non-browser flows where
1982    /// no UA is available.
1983    pub fn create_with_device(&self, user_id: String, device: Option<String>) -> Session {
1984        let mut session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1985        session.device = device;
1986        let mut sessions = self.sessions.lock().unwrap();
1987        sessions.insert(session.token.clone(), session.clone());
1988        if let Some(b) = &self.backend {
1989            b.save(&session);
1990        }
1991        session
1992    }
1993
1994    /// Look up a session by token. Returns None if the session is expired.
1995    pub fn get(&self, token: &str) -> Option<Session> {
1996        let mut sessions = self.sessions.lock().unwrap();
1997        match sessions.get(token) {
1998            Some(s) if s.is_expired() => {
1999                sessions.remove(token);
2000                None
2001            }
2002            Some(s) => Some(s.clone()),
2003            None => None,
2004        }
2005    }
2006
2007    /// Resolve a token to an auth context.
2008    /// Returns anonymous context if the token is invalid, missing, or expired.
2009    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
2010        match token {
2011            Some(t) => match self.get(t) {
2012                Some(session) => session.to_auth_context(),
2013                None => AuthContext::anonymous(),
2014            },
2015            None => AuthContext::anonymous(),
2016        }
2017    }
2018
2019    /// Refresh a session — issues a new token, copies user/device, extends expiry.
2020    /// The old token is revoked. Returns the new session or None if the old
2021    /// token is missing/expired.
2022    pub fn refresh(&self, old_token: &str) -> Option<Session> {
2023        let mut sessions = self.sessions.lock().unwrap();
2024        let old = sessions.remove(old_token)?;
2025        if let Some(b) = &self.backend {
2026            b.remove(old_token);
2027        }
2028        if old.is_expired() {
2029            return None;
2030        }
2031        // Use the store's configured lifetime so a manifest-set
2032        // `auth.session.expires_in` survives session refresh. Previous
2033        // bug: `Session::new(...)` baked in 30 days regardless of
2034        // config — apps with a custom lifetime got the right value on
2035        // first sign-in and lost it on the next refresh.
2036        let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2037        new.device = old.device.clone();
2038        sessions.insert(new.token.clone(), new.clone());
2039        if let Some(b) = &self.backend {
2040            b.save(&new);
2041        }
2042        Some(new)
2043    }
2044
2045    /// Every session in the store, including expired ones, with no
2046    /// filtering. Powers the Studio "Auth tables" view so operators
2047    /// can see orphaned sessions / debug stuck logins. Don't use for
2048    /// app code — `list_for_user` is the right surface there.
2049    pub fn list_all_unfiltered(&self) -> Vec<Session> {
2050        self.sessions
2051            .lock()
2052            .map(|m| m.values().cloned().collect())
2053            .unwrap_or_default()
2054    }
2055
2056    /// List all active sessions for a user.
2057    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2058        let sessions = self.sessions.lock().unwrap();
2059        sessions
2060            .values()
2061            .filter(|s| s.user_id == user_id && !s.is_expired())
2062            .cloned()
2063            .collect()
2064    }
2065
2066    /// Revoke all sessions for a user. Returns the count removed.
2067    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2068        let mut sessions = self.sessions.lock().unwrap();
2069        let tokens: Vec<String> = sessions
2070            .iter()
2071            .filter_map(|(t, s)| {
2072                if s.user_id == user_id {
2073                    Some(t.clone())
2074                } else {
2075                    None
2076                }
2077            })
2078            .collect();
2079        let n = tokens.len();
2080        for t in &tokens {
2081            sessions.remove(t);
2082            if let Some(b) = &self.backend {
2083                b.remove(t);
2084            }
2085        }
2086        n
2087    }
2088
2089    /// Sweep expired sessions. Returns the count removed.
2090    pub fn sweep_expired(&self) -> usize {
2091        let mut sessions = self.sessions.lock().unwrap();
2092        let expired: Vec<String> = sessions
2093            .iter()
2094            .filter_map(|(t, s)| {
2095                if s.is_expired() {
2096                    Some(t.clone())
2097                } else {
2098                    None
2099                }
2100            })
2101            .collect();
2102        let n = expired.len();
2103        for t in &expired {
2104            sessions.remove(t);
2105            if let Some(b) = &self.backend {
2106                b.remove(t);
2107            }
2108        }
2109        n
2110    }
2111
2112    /// Attach a device label to a session (typically on login from a browser).
2113    pub fn set_device(&self, token: &str, device: String) -> bool {
2114        let mut sessions = self.sessions.lock().unwrap();
2115        if let Some(s) = sessions.get_mut(token) {
2116            s.device = Some(device);
2117            if let Some(b) = &self.backend {
2118                b.save(s);
2119            }
2120            true
2121        } else {
2122            false
2123        }
2124    }
2125
2126    /// Create a guest session with a generated anonymous ID.
2127    pub fn create_guest(&self) -> Session {
2128        use rand::Rng;
2129        let mut rng = rand::thread_rng();
2130        let bytes: [u8; 16] = rng.gen();
2131        let guest_id = format!("guest_{}", hex_encode(&bytes));
2132        self.create(guest_id)
2133    }
2134
2135    /// Upgrade a guest session to a real user. Replaces the user_id.
2136    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2137        let mut sessions = self.sessions.lock().unwrap();
2138        if let Some(session) = sessions.get_mut(token) {
2139            session.user_id = real_user_id;
2140            if let Some(b) = &self.backend {
2141                b.save(session);
2142            }
2143            true
2144        } else {
2145            false
2146        }
2147    }
2148
2149    /// Switch the session's active tenant (organization). `None` clears it.
2150    /// Callers should verify the user actually has membership in the target
2151    /// tenant BEFORE invoking this — the session store takes the value on
2152    /// trust. Returns true if the session exists, false otherwise.
2153    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2154        let mut sessions = self.sessions.lock().unwrap();
2155        if let Some(session) = sessions.get_mut(token) {
2156            session.tenant_id = tenant_id;
2157            if let Some(b) = &self.backend {
2158                b.save(session);
2159            }
2160            true
2161        } else {
2162            false
2163        }
2164    }
2165
2166    /// Remove a session.
2167    pub fn revoke(&self, token: &str) -> bool {
2168        let mut sessions = self.sessions.lock().unwrap();
2169        let removed = sessions.remove(token).is_some();
2170        if removed {
2171            if let Some(b) = &self.backend {
2172                b.remove(token);
2173            }
2174        }
2175        removed
2176    }
2177}
2178
2179// ---------------------------------------------------------------------------
2180// OAuth account links — better-auth's `account` table equivalent
2181// ---------------------------------------------------------------------------
2182
2183/// A persisted account link. Schema-aligned with better-auth's `account`
2184/// table (verified against https://www.better-auth.com/docs/concepts/database
2185/// at the time of writing) so users migrating from better-auth see the
2186/// same field names + meanings:
2187///
2188/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
2189///   `"credential"` once email/password auth lands). Matches
2190///   better-auth's `providerId`.
2191/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
2192///   GitHub numeric `id`, or for email/password the user's own id).
2193///   Matches better-auth's `accountId`. NOT the row PK.
2194/// - `id` — the row PK, generated. Lets the row be referenced
2195///   independently of the (provider_id, account_id) natural key.
2196/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
2197///   rows; `None` for OAuth links. Reserves the column so adding
2198///   email/password auth doesn't need a schema migration.
2199///
2200/// Account vs. user: a single User row can have many Account rows
2201/// (Google + GitHub + a password — all linked to one pylon user).
2202/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
2203/// a user changing their Google address keeps the same pylon account.
2204#[derive(Debug, Clone, PartialEq, Eq)]
2205pub struct Account {
2206    pub id: String,
2207    pub user_id: String,
2208    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
2209    /// (better-auth: `providerId`)
2210    pub provider_id: String,
2211    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
2212    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
2213    pub account_id: String,
2214    pub access_token: Option<String>,
2215    pub refresh_token: Option<String>,
2216    pub id_token: Option<String>,
2217    /// Unix epoch seconds at which `access_token` expires. `None` for
2218    /// non-expiring tokens (GitHub Classic apps) or for password rows.
2219    pub access_token_expires_at: Option<u64>,
2220    /// Unix epoch seconds at which `refresh_token` expires. `None` when
2221    /// the provider doesn't expire refresh tokens (most don't, but
2222    /// Microsoft Identity Platform does after 90 days of inactivity).
2223    pub refresh_token_expires_at: Option<u64>,
2224    pub scope: Option<String>,
2225    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
2226    /// Always `None` today — present so adding password auth later
2227    /// doesn't require a schema migration.
2228    pub password: Option<String>,
2229    /// Unix epoch seconds when this account was first linked.
2230    pub created_at: u64,
2231    /// Unix epoch seconds when the token bundle was last refreshed.
2232    pub updated_at: u64,
2233}
2234
2235impl Account {
2236    /// Build a new account link from a freshly-completed OAuth handshake.
2237    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
2238    /// what later lookups key on.
2239    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2240        let now = now_secs();
2241        Self {
2242            id: generate_token(),
2243            user_id,
2244            provider_id: info.provider.clone(),
2245            account_id: info.provider_account_id.clone(),
2246            access_token: Some(tokens.access_token.clone()),
2247            refresh_token: tokens.refresh_token.clone(),
2248            id_token: tokens.id_token.clone(),
2249            access_token_expires_at: tokens.expires_at,
2250            refresh_token_expires_at: None,
2251            scope: tokens.scope.clone(),
2252            password: None,
2253            created_at: now,
2254            updated_at: now,
2255        }
2256    }
2257
2258    /// True if `access_token_expires_at` is set and has passed.
2259    /// Non-expiring tokens (GitHub Classic) report `false` — caller
2260    /// should treat them as "valid until proven otherwise" and refresh
2261    /// on 401.
2262    pub fn access_token_expired(&self) -> bool {
2263        match self.access_token_expires_at {
2264            Some(ts) => now_secs() >= ts,
2265            None => false,
2266        }
2267    }
2268}
2269
2270/// Pluggable storage for account links. In-memory default ships with
2271/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
2272pub trait AccountBackend: Send + Sync {
2273    /// Insert or refresh an account link. The `(provider_id, account_id)`
2274    /// pair is the natural key — repeated calls for the same pair
2275    /// update the token bundle and `updated_at` on the existing row.
2276    fn upsert(&self, account: &Account);
2277    /// Find an account by provider identity. Returns `None` if the user
2278    /// hasn't linked this provider yet.
2279    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2280    /// Every account linked to a user. The `/api/auth/me` endpoint uses
2281    /// this to render "you're connected via Google + GitHub" in the UI
2282    /// and to gate "unlink" affordances behind "user has another way to
2283    /// sign in" checks.
2284    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2285    /// Remove a single provider link. Returns `true` if a row was removed.
2286    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2287    /// Remove every account link for a user. Used during account
2288    /// deletion to ensure no OAuth references survive past a user row
2289    /// delete. Default implementation walks `find_for_user` + `unlink`;
2290    /// SQL backends can override with a single DELETE.
2291    fn delete_for_user(&self, user_id: &str) -> usize {
2292        let accounts = self.find_for_user(user_id);
2293        let n = accounts.len();
2294        for a in accounts {
2295            self.unlink(&a.provider_id, &a.account_id);
2296        }
2297        n
2298    }
2299    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
2300    /// to power the Studio admin inspector. Backends that can stream
2301    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
2302    /// returns its full map.
2303    fn list_all(&self) -> Vec<Account>;
2304}
2305
2306/// In-memory account backend (default). Lost on restart — production
2307/// deployments should swap in a persistent backend so refresh tokens
2308/// survive a redeploy.
2309pub struct InMemoryAccountBackend {
2310    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
2311    /// user_id would speed up `find_for_user` but at framework scale
2312    /// the linear scan of (typically ≤ 5) accounts per user is fine.
2313    accounts: Mutex<HashMap<(String, String), Account>>,
2314}
2315
2316impl InMemoryAccountBackend {
2317    pub fn new() -> Self {
2318        Self {
2319            accounts: Mutex::new(HashMap::new()),
2320        }
2321    }
2322}
2323
2324impl Default for InMemoryAccountBackend {
2325    fn default() -> Self {
2326        Self::new()
2327    }
2328}
2329
2330impl AccountBackend for InMemoryAccountBackend {
2331    fn upsert(&self, account: &Account) {
2332        let key = (account.provider_id.clone(), account.account_id.clone());
2333        self.accounts.lock().unwrap().insert(key, account.clone());
2334    }
2335    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2336        self.accounts
2337            .lock()
2338            .unwrap()
2339            .get(&(provider_id.to_string(), account_id.to_string()))
2340            .cloned()
2341    }
2342    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2343        self.accounts
2344            .lock()
2345            .unwrap()
2346            .values()
2347            .filter(|a| a.user_id == user_id)
2348            .cloned()
2349            .collect()
2350    }
2351    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2352        self.accounts
2353            .lock()
2354            .unwrap()
2355            .remove(&(provider_id.to_string(), account_id.to_string()))
2356            .is_some()
2357    }
2358    fn list_all(&self) -> Vec<Account> {
2359        self.accounts.lock().unwrap().values().cloned().collect()
2360    }
2361}
2362
2363/// Account store. Wraps an `AccountBackend` and provides the methods the
2364/// OAuth callback / API endpoints actually call.
2365pub struct AccountStore {
2366    backend: Box<dyn AccountBackend>,
2367}
2368
2369impl Default for AccountStore {
2370    fn default() -> Self {
2371        Self::new()
2372    }
2373}
2374
2375impl AccountStore {
2376    pub fn new() -> Self {
2377        Self {
2378            backend: Box::new(InMemoryAccountBackend::new()),
2379        }
2380    }
2381    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2382        Self { backend }
2383    }
2384    pub fn upsert(&self, account: &Account) {
2385        self.backend.upsert(account);
2386    }
2387    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2388        self.backend.find_by_provider(provider_id, account_id)
2389    }
2390    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2391        self.backend.find_for_user(user_id)
2392    }
2393    pub fn delete_for_user(&self, user_id: &str) -> usize {
2394        self.backend.delete_for_user(user_id)
2395    }
2396
2397    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2398        self.backend.unlink(provider_id, account_id)
2399    }
2400
2401    /// Every account in the store. Powers the Studio "Auth tables"
2402    /// view; not for app use. Implemented by walking the backend's
2403    /// per-user index — doable because account counts per user are
2404    /// small (typically ≤ 5) and total account count tracks user
2405    /// count.
2406    ///
2407    /// We don't add a `list_all` method to the `AccountBackend` trait
2408    /// because the in-memory + sqlite + postgres impls would each
2409    /// need a separate implementation, and the operational use case
2410    /// (Studio inspector) is narrow enough to live behind a wrapper
2411    /// that walks the underlying store directly. For PG/SQLite that
2412    /// means a `SELECT * FROM _pylon_accounts` — which the backends
2413    /// can grow if we ever need this at scale.
2414    pub fn list_all_unfiltered(&self) -> Vec<Account> {
2415        self.backend.list_all()
2416    }
2417}
2418
2419// ---------------------------------------------------------------------------
2420// Tests
2421// ---------------------------------------------------------------------------
2422
2423#[cfg(test)]
2424mod tests {
2425    use super::*;
2426
2427    #[test]
2428    fn anonymous_context() {
2429        let ctx = AuthContext::anonymous();
2430        assert!(!ctx.is_authenticated());
2431        assert!(ctx.user_id.is_none());
2432    }
2433
2434    #[test]
2435    fn authenticated_context() {
2436        let ctx = AuthContext::authenticated("user-1".into());
2437        assert!(ctx.is_authenticated());
2438        assert_eq!(ctx.user_id, Some("user-1".into()));
2439    }
2440
2441    #[test]
2442    fn from_api_key_carries_scope_metadata() {
2443        let ctx =
2444            AuthContext::from_api_key("user-1".into(), "key_abc".into(), Some("read,write".into()));
2445        assert!(ctx.is_authenticated());
2446        assert!(ctx.is_api_key_auth());
2447        assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2448        assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2449        assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2450    }
2451
2452    #[test]
2453    fn session_auth_is_not_api_key_auth() {
2454        let ctx = AuthContext::authenticated("user-1".into());
2455        assert!(!ctx.is_api_key_auth());
2456        assert!(ctx.api_key_id.is_none());
2457    }
2458
2459    #[test]
2460    fn auth_mode_public_allows_anonymous() {
2461        let mode = AuthMode::Public;
2462        assert!(mode.check(&AuthContext::anonymous()));
2463        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2464    }
2465
2466    #[test]
2467    fn auth_mode_user_requires_authenticated() {
2468        let mode = AuthMode::User;
2469        assert!(!mode.check(&AuthContext::anonymous()));
2470        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2471    }
2472
2473    #[test]
2474    fn auth_mode_from_str() {
2475        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2476        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2477        assert_eq!(AuthMode::from_str("admin"), None);
2478    }
2479
2480    #[test]
2481    fn session_store_create_and_get() {
2482        let store = SessionStore::new();
2483        let session = store.create("user-1".into());
2484        assert!(!session.token.is_empty());
2485        assert!(session.token.starts_with("pylon_"));
2486
2487        let retrieved = store.get(&session.token).unwrap();
2488        assert_eq!(retrieved.user_id, "user-1");
2489    }
2490
2491    #[test]
2492    fn session_store_resolve() {
2493        let store = SessionStore::new();
2494        let session = store.create("user-1".into());
2495
2496        let ctx = store.resolve(Some(&session.token));
2497        assert!(ctx.is_authenticated());
2498        assert_eq!(ctx.user_id, Some("user-1".into()));
2499
2500        let anon = store.resolve(None);
2501        assert!(!anon.is_authenticated());
2502
2503        let bad = store.resolve(Some("invalid-token"));
2504        assert!(!bad.is_authenticated());
2505    }
2506
2507    #[test]
2508    fn session_store_revoke() {
2509        let store = SessionStore::new();
2510        let session = store.create("user-1".into());
2511
2512        assert!(store.revoke(&session.token));
2513        assert!(store.get(&session.token).is_none());
2514        assert!(!store.revoke(&session.token)); // already revoked
2515    }
2516
2517    #[test]
2518    fn session_to_auth_context() {
2519        let session = Session::new("user-42".into());
2520        let ctx = session.to_auth_context();
2521        assert_eq!(ctx.user_id, Some("user-42".into()));
2522    }
2523
2524    // -- Admin context --
2525
2526    #[test]
2527    fn admin_context() {
2528        let ctx = AuthContext::admin();
2529        assert!(ctx.is_admin);
2530        assert!(ctx.is_authenticated());
2531    }
2532
2533    #[test]
2534    fn anonymous_not_admin() {
2535        let ctx = AuthContext::anonymous();
2536        assert!(!ctx.is_admin);
2537    }
2538
2539    #[test]
2540    fn authenticated_not_admin() {
2541        let ctx = AuthContext::authenticated("user-1".into());
2542        assert!(!ctx.is_admin);
2543    }
2544
2545    // -- Magic codes --
2546
2547    #[test]
2548    fn magic_code_create_and_verify() {
2549        let store = MagicCodeStore::new();
2550        let code = store.create("test@example.com");
2551        assert_eq!(code.len(), 6);
2552        assert!(store.verify("test@example.com", &code));
2553    }
2554
2555    #[test]
2556    fn magic_code_wrong_code_rejected() {
2557        let store = MagicCodeStore::new();
2558        store.create("test@example.com");
2559        assert!(!store.verify("test@example.com", "000000"));
2560    }
2561
2562    #[test]
2563    fn magic_code_wrong_email_rejected() {
2564        let store = MagicCodeStore::new();
2565        let code = store.create("test@example.com");
2566        assert!(!store.verify("other@example.com", &code));
2567    }
2568
2569    #[test]
2570    fn magic_code_consumed_after_verify() {
2571        let store = MagicCodeStore::new();
2572        let code = store.create("test@example.com");
2573        assert!(store.verify("test@example.com", &code));
2574        // Second verify should fail — code consumed.
2575        assert!(!store.verify("test@example.com", &code));
2576    }
2577
2578    #[test]
2579    fn magic_code_different_emails_independent() {
2580        let store = MagicCodeStore::new();
2581        let code1 = store.create("alice@example.com");
2582        let code2 = store.create("bob@example.com");
2583        // Each email has its own code.
2584        assert!(store.verify("alice@example.com", &code1));
2585        assert!(store.verify("bob@example.com", &code2));
2586    }
2587
2588    // -- Constant-time comparison --
2589
2590    #[test]
2591    fn constant_time_eq_equal() {
2592        assert!(constant_time_eq(b"hello", b"hello"));
2593        assert!(constant_time_eq(b"", b""));
2594    }
2595
2596    #[test]
2597    fn constant_time_eq_not_equal() {
2598        assert!(!constant_time_eq(b"hello", b"world"));
2599        assert!(!constant_time_eq(b"hello", b"hell"));
2600        assert!(!constant_time_eq(b"a", b"b"));
2601    }
2602
2603    // -- Token generation --
2604
2605    #[test]
2606    fn generated_tokens_are_unique() {
2607        let t1 = generate_token();
2608        let t2 = generate_token();
2609        assert_ne!(t1, t2);
2610        assert!(t1.starts_with("pylon_"));
2611        assert!(t2.starts_with("pylon_"));
2612        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
2613        assert_eq!(t1.len(), 6 + 64);
2614    }
2615
2616    // -- OAuth registry --
2617
2618    #[test]
2619    fn oauth_registry_empty() {
2620        let reg = OAuthRegistry::new();
2621        assert!(reg.get("google").is_none());
2622    }
2623
2624    #[test]
2625    fn oauth_registry_register_and_get() {
2626        let mut reg = OAuthRegistry::new();
2627        reg.register(OAuthConfig {
2628            provider: "google".into(),
2629            client_id: "test-id".into(),
2630            client_secret: "test-secret".into(),
2631            redirect_uri: "http://localhost/callback".into(),
2632            ..Default::default()
2633        });
2634        let config = reg.get("google").unwrap();
2635        assert_eq!(config.client_id, "test-id");
2636        assert!(config.auth_url().contains("accounts.google.com"));
2637    }
2638
2639    // -- Spec-driven provider routing --
2640
2641    /// Every builtin provider must produce a non-empty auth_url +
2642    /// token_url when wired with placeholder credentials. This is the
2643    /// regression test for the table-driven refactor: a typo in any
2644    /// `ProviderSpec` field that breaks URL formatting will trip here
2645    /// before it reaches a user.
2646    #[test]
2647    fn every_builtin_provider_routes_through_oauth_config() {
2648        for spec in provider::builtin::all() {
2649            let cfg = OAuthConfig {
2650                provider: spec.id.into(),
2651                client_id: "cid".into(),
2652                client_secret: "csecret".into(),
2653                redirect_uri: "https://app/cb".into(),
2654                tenant: if spec.id == "microsoft" {
2655                    Some("contoso".into())
2656                } else {
2657                    None
2658                },
2659                apple: if spec.id == "apple" {
2660                    Some(provider::AppleConfig {
2661                        team_id: "T".into(),
2662                        key_id: "K".into(),
2663                        private_key_pem: "no".into(),
2664                    })
2665                } else {
2666                    None
2667                },
2668                ..Default::default()
2669            };
2670            let auth = cfg.auth_url();
2671            assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2672            // TikTok uses `client_key`; everyone else uses `client_id`.
2673            let expected_param = format!("{}=cid", spec.client_id_param);
2674            assert!(
2675                auth.contains(&expected_param),
2676                "{}: missing {}; got auth_url: {}",
2677                spec.id,
2678                expected_param,
2679                auth,
2680            );
2681            assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2682            // Apple requires response_mode=form_post in the auth URL.
2683            if spec.id == "apple" {
2684                assert!(
2685                    auth.contains("response_mode=form_post"),
2686                    "apple auth_url must include response_mode=form_post; got {auth}"
2687                );
2688            }
2689        }
2690    }
2691
2692    /// Microsoft uses `{tenant}` placeholder substitution — the
2693    /// configured tenant must end up in both auth + token URLs.
2694    #[test]
2695    fn microsoft_tenant_placeholder_resolves() {
2696        let cfg = OAuthConfig {
2697            provider: "microsoft".into(),
2698            client_id: "id".into(),
2699            client_secret: "secret".into(),
2700            redirect_uri: "https://app/cb".into(),
2701            tenant: Some("contoso.onmicrosoft.com".into()),
2702            ..Default::default()
2703        };
2704        assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2705        assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2706    }
2707
2708    /// Microsoft without a tenant defaults to `common` (any account).
2709    #[test]
2710    fn microsoft_default_tenant_common() {
2711        let cfg = OAuthConfig {
2712            provider: "microsoft".into(),
2713            client_id: "id".into(),
2714            client_secret: "secret".into(),
2715            redirect_uri: "https://app/cb".into(),
2716            ..Default::default()
2717        };
2718        assert!(cfg.auth_url().contains("/common/"));
2719        assert!(cfg.token_url().contains("/common/"));
2720    }
2721
2722    /// `scopes_override` replaces the spec default — used for GitHub
2723    /// `repo` scope or Google calendar scopes.
2724    #[test]
2725    fn scopes_override_replaces_spec_default() {
2726        let cfg = OAuthConfig {
2727            provider: "github".into(),
2728            client_id: "id".into(),
2729            client_secret: "secret".into(),
2730            redirect_uri: "https://app/cb".into(),
2731            scopes_override: Some("repo user:email".into()),
2732            ..Default::default()
2733        };
2734        let auth = cfg.auth_url();
2735        // url-encoded "repo user:email" → "repo%20user%3Aemail"
2736        assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2737    }
2738
2739    /// Apple's `client_secret` is minted as a JWT — passing a bad PEM
2740    /// must surface the signing error, not silently send the literal
2741    /// string. The mint path is tested in `apple_jwt::tests`; this
2742    /// asserts the wiring delegates to it.
2743    #[test]
2744    fn apple_exchange_requires_apple_config() {
2745        let cfg = OAuthConfig {
2746            provider: "apple".into(),
2747            client_id: "com.example.app".into(),
2748            client_secret: String::new(),
2749            redirect_uri: "https://app/cb".into(),
2750            apple: None, // missing!
2751            ..Default::default()
2752        };
2753        let err = cfg.exchange_code_full("x").unwrap_err();
2754        assert!(err.contains("apple provider requires"), "got: {err}");
2755    }
2756
2757    /// OIDC discovery cache: priming with a synthetic spec lets us
2758    /// route an issuer-configured provider without touching the
2759    /// network. Validates that `oidc_issuer` short-circuits the
2760    /// builtin lookup.
2761    #[test]
2762    fn oidc_issuer_uses_discovered_endpoints() {
2763        let issuer = "https://acme.test.invalid";
2764        provider::oidc_cache::insert_for_test(
2765            issuer,
2766            provider::DiscoveredSpec {
2767                auth_url: "https://acme.test.invalid/authorize".into(),
2768                token_url: "https://acme.test.invalid/oauth/token".into(),
2769                userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2770                scopes: "openid email profile".into(),
2771                userinfo_parser: provider::UserinfoParser::Oidc,
2772                token_exchange: provider::TokenExchangeShape::Standard,
2773            },
2774        );
2775        let cfg = OAuthConfig {
2776            provider: "auth0".into(), // not a builtin id
2777            client_id: "id".into(),
2778            client_secret: "secret".into(),
2779            redirect_uri: "https://app/cb".into(),
2780            oidc_issuer: Some(issuer.into()),
2781            ..Default::default()
2782        };
2783        assert!(cfg
2784            .auth_url()
2785            .starts_with("https://acme.test.invalid/authorize?"));
2786        assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2787        assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2788    }
2789
2790    // -- Codex review regression tests (P1/P2 from Wave 1 review) --
2791
2792    /// P1: Apple's auth URL MUST include response_mode=form_post when
2793    /// requesting name/email scopes, otherwise Apple rejects with
2794    /// "invalid_request".
2795    #[test]
2796    fn apple_auth_url_includes_form_post() {
2797        let cfg = OAuthConfig {
2798            provider: "apple".into(),
2799            client_id: "com.example.app".into(),
2800            client_secret: String::new(),
2801            redirect_uri: "https://app/cb".into(),
2802            apple: Some(provider::AppleConfig {
2803                team_id: "T".into(),
2804                key_id: "K".into(),
2805                private_key_pem: "no".into(),
2806            }),
2807            ..Default::default()
2808        };
2809        let auth = cfg.auth_url();
2810        assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2811        // Apple identity comes from id_token, so userinfo_url is empty.
2812        assert_eq!(cfg.userinfo_url(), "");
2813    }
2814
2815    /// P1: Apple identity is extracted from the id_token JWT
2816    /// (Apple has no userinfo endpoint). `fetch_userinfo_with_id_token`
2817    /// must decode the claims; `fetch_userinfo_full` (no id_token)
2818    /// must surface a clear error.
2819    #[test]
2820    fn apple_id_token_decode_extracts_identity() {
2821        // Synthesize an unsigned JWT with realistic Apple claims.
2822        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2823        use base64::Engine;
2824        let claims = serde_json::json!({
2825            "iss": "https://appleid.apple.com",
2826            "sub": "001234.abc.def",
2827            "aud": "com.example.app",
2828            "email": "user@privaterelay.appleid.com",
2829            "email_verified": "true",
2830        });
2831        let claims_b64 =
2832            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims.to_string().as_bytes());
2833        let id_token = format!("{header}.{claims_b64}.signature_ignored");
2834
2835        let cfg = OAuthConfig {
2836            provider: "apple".into(),
2837            client_id: "com.example.app".into(),
2838            client_secret: String::new(),
2839            redirect_uri: "https://app/cb".into(),
2840            apple: Some(provider::AppleConfig {
2841                team_id: "T".into(),
2842                key_id: "K".into(),
2843                private_key_pem: "no".into(),
2844            }),
2845            ..Default::default()
2846        };
2847        let info = cfg
2848            .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2849            .expect("apple id_token decode");
2850        assert_eq!(info.provider_account_id, "001234.abc.def");
2851        assert_eq!(info.email, "user@privaterelay.appleid.com");
2852
2853        // Without an id_token the call must fail loud, not silently
2854        // try to hit a non-existent userinfo endpoint.
2855        let err = cfg.fetch_userinfo_full("token").unwrap_err();
2856        assert!(err.contains("apple login requires"), "got: {err}");
2857    }
2858
2859    /// P1: Twitter/X requires PKCE — `auth_url_with_pkce` must mint a
2860    /// verifier, embed the SHA-256 challenge in the auth URL, and
2861    /// return the verifier for the callback to replay.
2862    #[test]
2863    fn twitter_auth_url_includes_pkce() {
2864        let cfg = OAuthConfig {
2865            provider: "twitter".into(),
2866            client_id: "tw_client".into(),
2867            client_secret: "tw_secret".into(),
2868            redirect_uri: "https://app/cb".into(),
2869            ..Default::default()
2870        };
2871        let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2872        let v = verifier.expect("twitter must produce verifier");
2873        assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2874        assert!(url.contains("code_challenge="), "got: {url}");
2875        assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2876
2877        // Non-PKCE provider must NOT add a code_challenge.
2878        let google = OAuthConfig {
2879            provider: "google".into(),
2880            client_id: "g".into(),
2881            client_secret: "g".into(),
2882            redirect_uri: "https://app/cb".into(),
2883            ..Default::default()
2884        };
2885        let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2886        assert!(gverifier.is_none(), "google should not add PKCE");
2887        assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2888    }
2889
2890    /// P2: TikTok uses `client_key` (not `client_id`) and joins
2891    /// scopes with commas (not spaces).
2892    #[test]
2893    fn tiktok_uses_client_key_and_comma_scopes() {
2894        let cfg = OAuthConfig {
2895            provider: "tiktok".into(),
2896            client_id: "tk_client".into(),
2897            client_secret: "tk_secret".into(),
2898            redirect_uri: "https://app/cb".into(),
2899            scopes_override: Some("user.info.basic video.list".into()),
2900            ..Default::default()
2901        };
2902        let auth = cfg.auth_url();
2903        assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2904        // Comma-separated, url-encoded → "user.info.basic%2Cvideo.list"
2905        assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2906        // Should NOT use the standard space separator.
2907        assert!(
2908            !auth.contains("user.info.basic%20video.list"),
2909            "got: {auth}"
2910        );
2911    }
2912
2913    /// P2: `code` MUST be url-encoded in the token-exchange body.
2914    /// Auth codes can contain reserved characters (`+`, `=`, `/`) that
2915    /// would otherwise corrupt the form body.
2916    #[test]
2917    fn token_exchange_url_encodes_code() {
2918        // We can't hit the network in a unit test, so this asserts
2919        // via the `apple_exchange_requires_apple_config` shape — if
2920        // we DID have a working apple config, encoding would happen
2921        // before the network call. Instead, verify by calling the
2922        // helper used internally:
2923        let raw = "code+with/special=chars";
2924        let encoded = url_encode(raw);
2925        assert!(!encoded.contains('+'));
2926        assert!(!encoded.contains('/'));
2927        assert!(!encoded.contains('='));
2928        assert!(encoded.contains("%2B"));
2929        assert!(encoded.contains("%2F"));
2930        assert!(encoded.contains("%3D"));
2931    }
2932
2933    /// P1: Token-endpoint error bodies must NOT propagate
2934    /// `client_secret`, `code_verifier`, or other sensitive form
2935    /// fields that providers sometimes echo back on auth failure.
2936    #[test]
2937    fn sanitize_token_error_redacts_secrets() {
2938        let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2939        let scrubbed = sanitize_token_error(raw.into());
2940        assert!(!scrubbed.contains("sk_real_secret_value"));
2941        assert!(!scrubbed.contains("verifierxyz"));
2942        assert!(scrubbed.contains("client_secret=***"));
2943        assert!(scrubbed.contains("code_verifier=***"));
2944        // Non-sensitive context preserved.
2945        assert!(scrubbed.contains("invalid_grant"));
2946        assert!(scrubbed.contains("hint=check%20your%20code"));
2947    }
2948
2949    /// P1 (codex round-2): JSON-shaped error bodies (Notion,
2950    /// Atlassian) must also have their secret fields redacted.
2951    #[test]
2952    fn sanitize_token_error_redacts_json_secrets() {
2953        let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2954        let scrubbed = sanitize_token_error(raw.into());
2955        assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2956        assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2957        assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2958        assert!(
2959            scrubbed.contains(r#""client_secret":"***""#),
2960            "got: {scrubbed}"
2961        );
2962        assert!(
2963            scrubbed.contains(r#""refresh_token":"***""#),
2964            "got: {scrubbed}"
2965        );
2966        assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2967        assert!(scrubbed.contains("invalid_grant"));
2968    }
2969
2970    /// P2 (codex round-2): redact_param_form must NOT panic on
2971    /// multibyte chars before the sensitive key. Earlier byte-index
2972    /// implementation hit `panicked at byte index N is not a char
2973    /// boundary` on bodies with emoji or non-ASCII text.
2974    #[test]
2975    fn sanitize_token_error_handles_utf8() {
2976        let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2977        let scrubbed = sanitize_token_error(raw.into());
2978        assert!(
2979            scrubbed.contains("⚠️"),
2980            "non-ASCII chars must survive: {scrubbed}"
2981        );
2982        assert!(!scrubbed.contains("sk_x"));
2983        assert!(scrubbed.contains("client_secret=***"));
2984    }
2985
2986    /// P2: OIDC discovery must respect
2987    /// `token_endpoint_auth_methods_supported`. When the IdP
2988    /// publishes `client_secret_post`, use Standard form bodies.
2989    /// When omitted (the spec default), use BasicAuth.
2990    #[test]
2991    fn oidc_discovery_picks_token_auth_method() {
2992        let json_post = r#"{
2993            "issuer": "https://acme.test/",
2994            "authorization_endpoint": "https://acme.test/auth",
2995            "token_endpoint": "https://acme.test/token",
2996            "token_endpoint_auth_methods_supported": ["client_secret_post"]
2997        }"#;
2998        let spec = provider::OidcDiscoveryDoc::parse(json_post)
2999            .unwrap()
3000            .into_spec();
3001        assert!(matches!(
3002            spec.token_exchange,
3003            provider::TokenExchangeShape::Standard
3004        ));
3005
3006        // Default (omitted) → BasicAuth.
3007        let json_default = r#"{
3008            "issuer": "https://acme.test/",
3009            "authorization_endpoint": "https://acme.test/auth",
3010            "token_endpoint": "https://acme.test/token"
3011        }"#;
3012        let spec = provider::OidcDiscoveryDoc::parse(json_default)
3013            .unwrap()
3014            .into_spec();
3015        assert!(matches!(
3016            spec.token_exchange,
3017            provider::TokenExchangeShape::BasicAuth
3018        ));
3019    }
3020
3021    /// P2: OIDC discovery missing required endpoints must fail loud,
3022    /// not silently produce empty URLs that would 404 every login.
3023    #[test]
3024    fn oidc_discovery_rejects_incomplete_doc() {
3025        // Missing token_endpoint.
3026        let json = r#"{
3027            "issuer": "https://acme.test/",
3028            "authorization_endpoint": "https://acme.test/auth"
3029        }"#;
3030        let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
3031        assert!(err.contains("token_endpoint"), "got: {err}");
3032    }
3033
3034    /// `OAuthRegistry::from_env` must auto-discover every provider
3035    /// whose env vars are set — not just google/github. Smoke-test
3036    /// with Discord since it covers the simple-builtin path.
3037    #[test]
3038    fn from_env_picks_up_discord() {
3039        // Use a unique prefix so this doesn't collide with a real
3040        // dev environment variable. Set+restore in scope.
3041        let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3042        let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3043        // SAFETY: tests run single-threaded for env mutation isn't
3044        // strictly true, but this provider is unique enough that
3045        // contention is unlikely. Cleanup happens at end.
3046        std::env::set_var(key_id, "discord-test-id");
3047        std::env::set_var(key_secret, "discord-test-secret");
3048
3049        let reg = OAuthRegistry::from_env();
3050        let discord = reg.get("discord").expect("discord registered");
3051        assert_eq!(discord.client_id, "discord-test-id");
3052        assert!(discord.auth_url().contains("discord.com"));
3053
3054        std::env::remove_var(key_id);
3055        std::env::remove_var(key_secret);
3056    }
3057
3058    // -- Guest auth --
3059
3060    #[test]
3061    fn guest_session() {
3062        let store = SessionStore::new();
3063        let session = store.create_guest();
3064        assert!(session.user_id.starts_with("guest_"));
3065        assert!(!session.token.is_empty());
3066
3067        let ctx = store.resolve(Some(&session.token));
3068        assert!(ctx.is_authenticated());
3069        assert!(ctx.user_id.unwrap().starts_with("guest_"));
3070    }
3071
3072    #[test]
3073    fn upgrade_guest_to_real_user() {
3074        let store = SessionStore::new();
3075        let session = store.create_guest();
3076        assert!(session.user_id.starts_with("guest_"));
3077
3078        let upgraded = store.upgrade(&session.token, "real-user-123".into());
3079        assert!(upgraded);
3080
3081        let ctx = store.resolve(Some(&session.token));
3082        assert_eq!(ctx.user_id, Some("real-user-123".into()));
3083    }
3084
3085    #[test]
3086    fn upgrade_invalid_token_fails() {
3087        let store = SessionStore::new();
3088        let upgraded = store.upgrade("nonexistent-token", "user".into());
3089        assert!(!upgraded);
3090    }
3091
3092    #[test]
3093    fn guest_context() {
3094        let ctx = AuthContext::guest("guest_123".into());
3095        // Guests carry a stable id but are NOT authenticated — routes
3096        // guarded by AuthMode::User must reject them.
3097        assert!(!ctx.is_authenticated());
3098        assert!(ctx.is_guest);
3099        assert!(!ctx.is_admin);
3100        assert_eq!(ctx.user_id, Some("guest_123".into()));
3101        assert!(!AuthMode::User.check(&ctx));
3102        assert!(AuthMode::Public.check(&ctx));
3103    }
3104
3105    #[test]
3106    fn oauth_token_urls() {
3107        let google = OAuthConfig {
3108            provider: "google".into(),
3109            client_id: "x".into(),
3110            client_secret: "x".into(),
3111            redirect_uri: "x".into(),
3112            ..Default::default()
3113        };
3114        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3115        let github = OAuthConfig {
3116            provider: "github".into(),
3117            client_id: "x".into(),
3118            client_secret: "x".into(),
3119            redirect_uri: "x".into(),
3120            ..Default::default()
3121        };
3122        assert_eq!(
3123            github.token_url(),
3124            "https://github.com/login/oauth/access_token"
3125        );
3126        let unknown = OAuthConfig {
3127            provider: "unknown".into(),
3128            client_id: "x".into(),
3129            client_secret: "x".into(),
3130            redirect_uri: "x".into(),
3131            ..Default::default()
3132        };
3133        assert_eq!(unknown.token_url(), "");
3134        assert!(unknown.auth_url().is_empty());
3135    }
3136
3137    #[test]
3138    fn oauth_auth_url_github() {
3139        let config = OAuthConfig {
3140            provider: "github".into(),
3141            client_id: "gh-id".into(),
3142            client_secret: "gh-secret".into(),
3143            redirect_uri: "http://localhost/cb".into(),
3144            ..Default::default()
3145        };
3146        assert!(config.auth_url().contains("github.com"));
3147        assert!(config.auth_url().contains("gh-id"));
3148    }
3149
3150    #[test]
3151    fn oauth_auth_url_with_state() {
3152        let config = OAuthConfig {
3153            provider: "google".into(),
3154            client_id: "test-id".into(),
3155            client_secret: "test-secret".into(),
3156            redirect_uri: "http://localhost/cb".into(),
3157            ..Default::default()
3158        };
3159        let url = config.auth_url_with_state("random_state_123");
3160        assert!(url.contains("&state=random_state_123"));
3161    }
3162
3163    #[test]
3164    fn oauth_state_store_create_and_validate() {
3165        let store = OAuthStateStore::new();
3166        let token = store.create("google", "https://app/cb", "https://app/login");
3167        let rec = store.validate(&token, "google").expect("valid first time");
3168        assert_eq!(rec.callback_url, "https://app/cb");
3169        assert_eq!(rec.error_callback_url, "https://app/login");
3170        // Second validation should fail — single-use.
3171        assert!(store.validate(&token, "google").is_none());
3172    }
3173
3174    #[test]
3175    fn oauth_state_store_wrong_provider_rejected() {
3176        let store = OAuthStateStore::new();
3177        let token = store.create("google", "https://app/cb", "https://app/cb");
3178        assert!(store.validate(&token, "github").is_none());
3179    }
3180
3181    #[test]
3182    fn oauth_state_store_invalid_state_rejected() {
3183        let store = OAuthStateStore::new();
3184        assert!(store.validate("nonexistent", "google").is_none());
3185    }
3186
3187    #[test]
3188    fn validate_trusted_redirect_basics() {
3189        let trusted = vec!["http://localhost:3000".to_string()];
3190        assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3191        assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3192        assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3193
3194        // Wrong port → wrong origin.
3195        assert!(matches!(
3196            validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3197            Err(TrustedOriginError::NotTrusted { .. })
3198        ));
3199        // Non-http scheme rejected even before trusted check (defense
3200        // against javascript:, file:, data:).
3201        assert!(matches!(
3202            validate_trusted_redirect("javascript:alert(1)", &trusted),
3203            Err(TrustedOriginError::NotHttp)
3204        ));
3205        assert!(matches!(
3206            validate_trusted_redirect("", &trusted),
3207            Err(TrustedOriginError::Empty)
3208        ));
3209    }
3210}