Skip to main content

pylon_auth/
lib.rs

1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9// ---------------------------------------------------------------------------
10// Auth context — the identity available to runtime operations
11// ---------------------------------------------------------------------------
12
13/// The auth context for a request. Represents who is making the request.
14///
15/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
16/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
17/// or add roles and bypass every policy. Identity must come from
18/// server-minted sessions (`Session::to_auth_context`) or explicit
19/// constructors, never from deserialization.
20///
21/// `Serialize` is safe because sending the resolved context BACK to the
22/// client exposes nothing the server didn't already decide.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25    /// The authenticated user ID, or None for public/anonymous access.
26    /// For guest contexts this is `Some(guest_id)` — a stable
27    /// anonymous identifier, NOT a real user.
28    pub user_id: Option<String>,
29    /// Whether this is an admin context (bypasses policies).
30    pub is_admin: bool,
31    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
32    /// for cart state and similar pre-login persistence. Routes guarded
33    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
34    /// "real signed-in user" should pass auth-required gates.
35    #[serde(default, skip_serializing_if = "is_false")]
36    pub is_guest: bool,
37    /// Roles granted to this user. Empty for anonymous.
38    pub roles: Vec<String>,
39    /// Active tenant id (for multi-tenant apps). Set when the user has
40    /// selected an organization for the current session.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46    !b
47}
48
49impl AuthContext {
50    /// Create an anonymous/public auth context.
51    pub fn anonymous() -> Self {
52        Self {
53            user_id: None,
54            is_admin: false,
55            is_guest: false,
56            roles: Vec::new(),
57            tenant_id: None,
58        }
59    }
60
61    /// Create an authenticated auth context.
62    pub fn authenticated(user_id: String) -> Self {
63        Self {
64            user_id: Some(user_id),
65            is_admin: false,
66            is_guest: false,
67            roles: Vec::new(),
68            tenant_id: None,
69        }
70    }
71
72    /// Create a guest auth context with a persistent anonymous ID.
73    /// Guests carry an opaque stable id (cart/session continuity) but
74    /// are NOT considered authenticated — `is_authenticated()` returns
75    /// false and `AuthMode::User` rejects them.
76    pub fn guest(guest_id: String) -> Self {
77        Self {
78            user_id: Some(guest_id),
79            is_admin: false,
80            is_guest: true,
81            roles: Vec::new(),
82            tenant_id: None,
83        }
84    }
85
86    /// Create an admin auth context that bypasses all policies.
87    pub fn admin() -> Self {
88        Self {
89            user_id: Some("__admin__".into()),
90            is_admin: true,
91            is_guest: false,
92            roles: vec!["admin".into()],
93            tenant_id: None,
94        }
95    }
96
97    /// Convenience: build a user context from a user id.
98    pub fn user(user_id: String) -> Self {
99        Self::authenticated(user_id)
100    }
101
102    /// Active tenant id (None when the user hasn't selected an org).
103    pub fn tenant_id(&self) -> Option<&str> {
104        self.tenant_id.as_deref()
105    }
106
107    /// Attach a tenant id to the context (chainable).
108    pub fn with_tenant(mut self, tenant_id: String) -> Self {
109        self.tenant_id = Some(tenant_id);
110        self
111    }
112
113    /// Check if this context represents an authenticated user.
114    /// Guests intentionally return `false` — they have a stable anonymous
115    /// id but never gain user-level access.
116    pub fn is_authenticated(&self) -> bool {
117        self.user_id.is_some() && !self.is_guest
118    }
119
120    /// Check if the user has a specific role. Admins have every role implicitly.
121    pub fn has_role(&self, role: &str) -> bool {
122        self.is_admin || self.roles.iter().any(|r| r == role)
123    }
124
125    /// Check if the user has ANY of the given roles.
126    pub fn has_any_role(&self, roles: &[&str]) -> bool {
127        self.is_admin || roles.iter().any(|r| self.has_role(r))
128    }
129
130    /// Attach roles to the context (chainable).
131    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132        self.roles = roles;
133        self
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Constant-time comparison
139// ---------------------------------------------------------------------------
140
141/// Constant-time byte comparison to prevent timing attacks.
142///
143/// The length check leaks whether the two slices are the same length, but the
144/// content comparison always examines every byte regardless of where (or
145/// whether) a mismatch occurs.
146pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147    if a.len() != b.len() {
148        return false;
149    }
150    let mut result: u8 = 0;
151    for (x, y) in a.iter().zip(b.iter()) {
152        result |= x ^ y;
153    }
154    result == 0
155}
156
157// ---------------------------------------------------------------------------
158// Auth mode — matches the route "auth" field values
159// ---------------------------------------------------------------------------
160
161/// The auth mode declared on a route.
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164    /// Anyone can access.
165    Public,
166    /// Only authenticated users can access.
167    User,
168}
169
170impl AuthMode {
171    /// Parse from the manifest auth string.
172    #[allow(clippy::should_implement_trait)]
173    pub fn from_str(s: &str) -> Option<Self> {
174        match s {
175            "public" => Some(AuthMode::Public),
176            "user" => Some(AuthMode::User),
177            _ => None,
178        }
179    }
180
181    /// Check if the given auth context satisfies this mode.
182    pub fn check(&self, ctx: &AuthContext) -> bool {
183        match self {
184            AuthMode::Public => true,
185            AuthMode::User => ctx.is_authenticated(),
186        }
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Session — opaque token session
192// ---------------------------------------------------------------------------
193
194/// A session token and its associated user.
195#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197    pub token: String,
198    pub user_id: String,
199    /// Unix epoch seconds at which this session expires. 0 = never.
200    #[serde(default)]
201    pub expires_at: u64,
202    /// Optional user-agent / device tag recorded at session creation.
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub device: Option<String>,
205    /// Unix epoch seconds when the session was created.
206    #[serde(default)]
207    pub created_at: u64,
208    /// Active tenant id (selected organization). Set via
209    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
210    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub tenant_id: Option<String>,
213}
214
215impl Session {
216    /// Default session lifetime: 30 days.
217    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219    /// Create a new session with a generated token and default 30-day expiry.
220    pub fn new(user_id: String) -> Self {
221        let now = now_secs();
222        Self {
223            token: generate_token(),
224            user_id,
225            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226            device: None,
227            created_at: now,
228            tenant_id: None,
229        }
230    }
231
232    /// Create a session with a specific lifetime.
233    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234        let now = now_secs();
235        Self {
236            token: generate_token(),
237            user_id,
238            expires_at: if lifetime_secs == 0 {
239                0
240            } else {
241                now.saturating_add(lifetime_secs)
242            },
243            device: None,
244            created_at: now,
245            tenant_id: None,
246        }
247    }
248
249    /// Convert this session to an auth context, carrying the selected
250    /// tenant so row-scoped policies see `auth.tenantId`.
251    pub fn to_auth_context(&self) -> AuthContext {
252        let ctx = AuthContext::authenticated(self.user_id.clone());
253        match &self.tenant_id {
254            Some(t) => ctx.with_tenant(t.clone()),
255            None => ctx,
256        }
257    }
258
259    /// Returns true if the session has passed its expires_at time.
260    /// Boundary is inclusive (`>=`) to match the rest of the codebase
261    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
262    pub fn is_expired(&self) -> bool {
263        self.expires_at != 0 && now_secs() >= self.expires_at
264    }
265}
266
267fn now_secs() -> u64 {
268    use std::time::{SystemTime, UNIX_EPOCH};
269    SystemTime::now()
270        .duration_since(UNIX_EPOCH)
271        .unwrap_or_default()
272        .as_secs()
273}
274
275// ---------------------------------------------------------------------------
276// OAuth provider config
277// ---------------------------------------------------------------------------
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281    pub provider: String,
282    pub client_id: String,
283    pub client_secret: String,
284    pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288    /// Generate the authorization URL for the provider.
289    ///
290    /// Callers MUST append a `&state=<random>` parameter and validate it in the
291    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
292    /// implementation.
293    pub fn auth_url(&self) -> String {
294        match self.provider.as_str() {
295            "google" => format!(
296                "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297                self.client_id, self.redirect_uri
298            ),
299            "github" => format!(
300                "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301                self.client_id, self.redirect_uri
302            ),
303            _ => String::new(),
304        }
305    }
306
307    /// Generate the authorization URL with a CSRF state parameter attached.
308    pub fn auth_url_with_state(&self, state: &str) -> String {
309        let base = self.auth_url();
310        if base.is_empty() {
311            return base;
312        }
313        format!("{}&state={}", base, state)
314    }
315
316    /// Generate the token exchange URL.
317    pub fn token_url(&self) -> &str {
318        match self.provider.as_str() {
319            "google" => "https://oauth2.googleapis.com/token",
320            "github" => "https://github.com/login/oauth/access_token",
321            _ => "",
322        }
323    }
324
325    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
326    pub fn userinfo_url(&self) -> &str {
327        match self.provider.as_str() {
328            "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329            "github" => "https://api.github.com/user",
330            _ => "",
331        }
332    }
333
334    /// Exchange an authorization code for the full token set
335    /// (`access_token`, optional `refresh_token`, optional `id_token`,
336    /// `expires_in`, `scope`). The longer struct is what the
337    /// account-store needs to persist; the legacy
338    /// [`OAuthConfig::exchange_code`] returns just the access token for
339    /// callers that don't care.
340    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341        let body = match self.provider.as_str() {
342            "google" => format!(
343                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344                url_encode(&self.client_id),
345                url_encode(&self.client_secret),
346                url_encode(&self.redirect_uri)
347            ),
348            "github" => format!(
349                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350                url_encode(&self.client_id),
351                url_encode(&self.client_secret),
352                url_encode(&self.redirect_uri)
353            ),
354            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355        };
356
357        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358        parse_token_response(&out)
359    }
360
361    /// Exchange an authorization code for an access token. Thin wrapper
362    /// around [`OAuthConfig::exchange_code_full`] for callers that only
363    /// need the access token (existing pre-account-store call sites).
364    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365        let body = match self.provider.as_str() {
366            "google" => format!(
367                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368                url_encode(&self.client_id),
369                url_encode(&self.client_secret),
370                url_encode(&self.redirect_uri)
371            ),
372            "github" => format!(
373                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374                url_encode(&self.client_id),
375                url_encode(&self.client_secret),
376                url_encode(&self.redirect_uri)
377            ),
378            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379        };
380
381        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382        extract_access_token(&out)
383    }
384
385    /// Fetch the authenticated user's email + display name using an access token.
386    /// Returns `(email, display_name)`. Use [`OAuthConfig::fetch_userinfo_full`]
387    /// when you also need the provider-stable account ID for account
388    /// linking — the (`provider`, `provider_account_id`) pair is what
389    /// keeps a renamed-email user matched to the same row.
390    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391        let info = self.fetch_userinfo_full(access_token)?;
392        Ok((info.email, info.name))
393    }
394
395    /// Fetch the authenticated user's full identity info — email + name +
396    /// the provider-stable account ID (Google's `sub`, GitHub's `id`).
397    /// `provider_account_id` is what the account-store keys on, NOT the
398    /// email; otherwise a user changing their Google address would orphan
399    /// their existing pylon account.
400    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401        let out = http_get_bearer(self.userinfo_url(), access_token)?;
402        let parsed: serde_json::Value =
403            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404        match self.provider.as_str() {
405            "google" => {
406                let email = parsed
407                    .get("email")
408                    .and_then(|v| v.as_str())
409                    .ok_or("no email in userinfo")?
410                    .to_string();
411                let name = parsed
412                    .get("name")
413                    .and_then(|v| v.as_str())
414                    .map(String::from);
415                let provider_account_id = parsed
416                    .get("sub")
417                    .and_then(|v| v.as_str())
418                    .ok_or("no sub in userinfo")?
419                    .to_string();
420                Ok(UserInfo {
421                    provider: self.provider.clone(),
422                    provider_account_id,
423                    email,
424                    name,
425                })
426            }
427            "github" => {
428                let name = parsed
429                    .get("name")
430                    .and_then(|v| v.as_str())
431                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432                    .map(String::from);
433                let email = parsed
434                    .get("email")
435                    .and_then(|v| v.as_str())
436                    .map(String::from);
437                // GitHub may return a null email if the user hasn't published one;
438                // in that case the caller should hit /user/emails with the same token.
439                let email = email
440                    .or_else(|| fetch_github_primary_email(access_token).ok())
441                    .ok_or("no accessible email on GitHub account")?;
442                // GitHub's `id` field is a numeric user ID — the stable
443                // account identifier even if the user renames themselves.
444                let provider_account_id = parsed
445                    .get("id")
446                    .map(|v| {
447                        v.as_i64()
448                            .map(|n| n.to_string())
449                            .or_else(|| v.as_str().map(String::from))
450                            .unwrap_or_default()
451                    })
452                    .filter(|s| !s.is_empty())
453                    .ok_or("no id in userinfo")?;
454                Ok(UserInfo {
455                    provider: self.provider.clone(),
456                    provider_account_id,
457                    email,
458                    name,
459                })
460            }
461            _ => Err(format!("unknown provider: {}", self.provider)),
462        }
463    }
464}
465
466/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
467/// `provider_account_id` is the provider-stable subject id (Google `sub`,
468/// GitHub numeric `id`) — what the account store keys on so a renamed
469/// email doesn't orphan the pylon account.
470#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472    pub provider: String,
473    pub provider_account_id: String,
474    pub email: String,
475    pub name: Option<String>,
476}
477
478/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
479/// on the matching `Account` row so `refresh_token` is available for
480/// silent re-auth and `expires_at` is checked before each provider call.
481#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483    pub access_token: String,
484    pub refresh_token: Option<String>,
485    pub id_token: Option<String>,
486    /// Unix epoch seconds at which the access token expires. `None` when
487    /// the provider didn't return `expires_in` (GitHub's classic OAuth
488    /// app tokens are non-expiring).
489    pub expires_at: Option<u64>,
490    pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494    // Most providers return JSON; GitHub Classic apps return form-urlencoded
495    // unless you ask with Accept: application/json (which we do).
496    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
498        let mut map = serde_json::Map::new();
499        for pair in body.split('&') {
500            if let Some((k, v)) = pair.split_once('=') {
501                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502            }
503        }
504        serde_json::Value::Object(map)
505    });
506
507    let access_token = json
508        .get("access_token")
509        .and_then(|v| v.as_str())
510        .ok_or_else(|| format!("no access_token in token response: {body}"))?
511        .to_string();
512    let refresh_token = json
513        .get("refresh_token")
514        .and_then(|v| v.as_str())
515        .map(String::from);
516    let id_token = json
517        .get("id_token")
518        .and_then(|v| v.as_str())
519        .map(String::from);
520    let expires_at = json
521        .get("expires_in")
522        .and_then(|v| {
523            v.as_u64()
524                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525        })
526        .map(|secs| now_secs().saturating_add(secs));
527    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528    Ok(TokenSet {
529        access_token,
530        refresh_token,
531        id_token,
532        expires_at,
533        scope,
534    })
535}
536
537fn url_encode(s: &str) -> String {
538    let mut out = String::with_capacity(s.len());
539    for b in s.bytes() {
540        match b {
541            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542                out.push(b as char)
543            }
544            _ => out.push_str(&format!("%{b:02X}")),
545        }
546    }
547    out
548}
549
550/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
551/// provider doesn't block a login indefinitely; long enough to absorb
552/// typical internet latency.
553const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556    ureq::AgentBuilder::new()
557        .timeout_connect(HTTP_TIMEOUT)
558        .timeout_read(HTTP_TIMEOUT)
559        .timeout_write(HTTP_TIMEOUT)
560        .user_agent("pylon/0.1")
561        .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565    let agent = ureq_agent();
566    let mut req = agent
567        .post(url)
568        .set("Content-Type", "application/x-www-form-urlencoded");
569    if accept_json {
570        req = req.set("Accept", "application/json");
571    }
572    match req.send_string(body) {
573        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574        Err(ureq::Error::Status(code, resp)) => {
575            let body = resp.into_string().unwrap_or_default();
576            Err(format!("HTTP {code}: {body}"))
577        }
578        Err(e) => Err(format!("HTTP error: {e}")),
579    }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583    let agent = ureq_agent();
584    match agent
585        .get(url)
586        .set("Authorization", &format!("Bearer {token}"))
587        .set("Accept", "application/json")
588        .call()
589    {
590        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591        Err(ureq::Error::Status(code, resp)) => {
592            let body = resp.into_string().unwrap_or_default();
593            Err(format!("HTTP {code}: {body}"))
594        }
595        Err(e) => Err(format!("HTTP error: {e}")),
596    }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601    let emails: serde_json::Value =
602        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603    emails
604        .as_array()
605        .and_then(|arr| {
606            arr.iter()
607                .find(|e| {
608                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610                })
611                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612        })
613        .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618        if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619            return Ok(t.to_string());
620        }
621    }
622    // GitHub can return url-encoded: access_token=...&scope=...&token_type=bearer
623    for pair in body.split('&') {
624        if let Some(val) = pair.strip_prefix("access_token=") {
625            return Ok(val.to_string());
626        }
627    }
628    Err(format!("no access_token in token response: {body}"))
629}
630
631/// OAuth provider registry.
632pub struct OAuthRegistry {
633    providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642impl OAuthRegistry {
643    pub fn new() -> Self {
644        Self {
645            providers: std::collections::HashMap::new(),
646        }
647    }
648
649    pub fn register(&mut self, config: OAuthConfig) {
650        self.providers.insert(config.provider.clone(), config);
651    }
652
653    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654        self.providers.get(provider)
655    }
656
657    /// Build from environment variables.
658    /// Looks for PYLON_OAUTH_GOOGLE_CLIENT_ID, etc.
659    pub fn from_env() -> Self {
660        let mut reg = Self::new();
661
662        // Google
663        if let (Ok(id), Ok(secret)) = (
664            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666        ) {
667            reg.register(OAuthConfig {
668                provider: "google".into(),
669                client_id: id,
670                client_secret: secret,
671                redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673            });
674        }
675
676        // GitHub
677        if let (Ok(id), Ok(secret)) = (
678            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680        ) {
681            reg.register(OAuthConfig {
682                provider: "github".into(),
683                client_id: id,
684                client_secret: secret,
685                redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687            });
688        }
689
690        reg
691    }
692}
693
694// ---------------------------------------------------------------------------
695// OAuth state store — CSRF protection for OAuth flows
696// ---------------------------------------------------------------------------
697
698/// One stored OAuth state record. Carries the post-callback redirect
699/// URLs alongside the provider so the callback handler doesn't need to
700/// consult an env var to know where to send the user. Both URLs are
701/// validated against `PYLON_TRUSTED_ORIGINS` at create time, so the
702/// callback can trust them without re-checking.
703#[derive(Debug, Clone, PartialEq, Eq)]
704pub struct OAuthState {
705    pub provider: String,
706    /// URL the callback redirects to on success. The frontend supplies
707    /// this via `?callback=` on the start request.
708    pub callback_url: String,
709    /// URL the callback redirects to on failure. Defaults to
710    /// `callback_url` when the frontend doesn't pass an explicit
711    /// `?error_callback=`. The error code + message ride along as
712    /// query params (`?oauth_error=X&oauth_error_message=Y`).
713    pub error_callback_url: String,
714    pub expires_at: u64,
715}
716
717/// Backing store for OAuth state records. Default impl keeps them in
718/// memory (fine for tests + dev); the runtime swaps in a SQLite or
719/// Postgres backend so a restart in the middle of an OAuth handshake
720/// doesn't leave the user with "invalid state" on the callback.
721pub trait OAuthStateBackend: Send + Sync {
722    /// Persist a state record under `token`.
723    fn put(&self, token: &str, state: &OAuthState);
724    /// Atomic compare-and-consume: returns the stored record if the
725    /// token exists and hasn't expired, then removes it. Returning
726    /// `None` means either the token never existed or it has already
727    /// been used / expired.
728    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
729}
730
731/// In-memory backend (default). Lost on restart.
732pub struct InMemoryOAuthBackend {
733    states: Mutex<HashMap<String, OAuthState>>,
734}
735
736impl InMemoryOAuthBackend {
737    pub fn new() -> Self {
738        Self {
739            states: Mutex::new(HashMap::new()),
740        }
741    }
742}
743
744impl Default for InMemoryOAuthBackend {
745    fn default() -> Self {
746        Self::new()
747    }
748}
749
750impl OAuthStateBackend for InMemoryOAuthBackend {
751    fn put(&self, token: &str, state: &OAuthState) {
752        self.states
753            .lock()
754            .unwrap()
755            .insert(token.to_string(), state.clone());
756    }
757    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
758        let mut s = self.states.lock().unwrap();
759        let entry = s.remove(token)?;
760        if entry.expires_at <= now_unix_secs {
761            return None;
762        }
763        Some(entry)
764    }
765}
766
767/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
768///
769/// State tokens are short-lived (10 minutes) and single-use. Backed by an
770/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
771/// to SQLite (or Postgres when `DATABASE_URL` is set) so they survive a
772/// restart that happens mid-OAuth-handshake.
773pub struct OAuthStateStore {
774    backend: Box<dyn OAuthStateBackend>,
775}
776
777impl Default for OAuthStateStore {
778    fn default() -> Self {
779        Self::new()
780    }
781}
782
783impl OAuthStateStore {
784    pub fn new() -> Self {
785        Self {
786            backend: Box::new(InMemoryOAuthBackend::new()),
787        }
788    }
789
790    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
791        Self { backend }
792    }
793
794    /// Generate and store a new state record. Returns the random
795    /// state token (the value the OAuth provider echoes back as
796    /// `?state=…` on the callback).
797    ///
798    /// Caller is responsible for validating `callback_url` and
799    /// `error_callback_url` against the trusted-origins allowlist
800    /// BEFORE calling this — the store trusts what it's given.
801    pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
802        use std::time::{SystemTime, UNIX_EPOCH};
803        let token = generate_token();
804        let now = SystemTime::now()
805            .duration_since(UNIX_EPOCH)
806            .unwrap_or_default()
807            .as_secs();
808        let state = OAuthState {
809            provider: provider.to_string(),
810            callback_url: callback_url.to_string(),
811            error_callback_url: error_callback_url.to_string(),
812            expires_at: now + 600,
813        };
814        self.backend.put(&token, &state);
815        token
816    }
817
818    /// Validate and consume a state token. Returns the stored record
819    /// iff the token existed, has not expired, AND matches
820    /// `expected_provider`. The token is removed either way to make
821    /// replay impossible.
822    pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
823        use std::time::{SystemTime, UNIX_EPOCH};
824        let now = SystemTime::now()
825            .duration_since(UNIX_EPOCH)
826            .unwrap_or_default()
827            .as_secs();
828        let entry = self.backend.take(state, now)?;
829        if entry.provider != expected_provider {
830            return None;
831        }
832        Some(entry)
833    }
834}
835
836/// Validate that `url` has an origin (scheme://host[:port]) listed in
837/// `trusted_origins`. Returns `Ok(url)` when trusted (echoes input for
838/// chaining), `Err` with a code/message when not. Used by the OAuth
839/// start endpoint to gate `?callback=` + `?error_callback=` values
840/// before storing them in the state record.
841///
842/// `trusted_origins` entries are origin strings like
843/// `"https://app.example.com"` or `"http://localhost:3000"` — no
844/// trailing slash, no path. A `url` like
845/// `"http://localhost:3000/dashboard?x=1"` matches the
846/// `"http://localhost:3000"` entry.
847///
848/// Borrowed wholesale from better-auth's `trustedOrigins` model:
849/// explicit allowlist, no implicit "same-origin trust," no env-var
850/// magic. An open-redirect via OAuth is one of the easier auth bugs
851/// to ship by accident.
852pub fn validate_trusted_redirect(
853    url: &str,
854    trusted_origins: &[String],
855) -> Result<(), TrustedOriginError> {
856    if url.is_empty() {
857        return Err(TrustedOriginError::Empty);
858    }
859    // Must be absolute http(s) URL — no relative paths, no schemes
860    // like javascript:, file:, data:.
861    if !url.starts_with("http://") && !url.starts_with("https://") {
862        return Err(TrustedOriginError::NotHttp);
863    }
864    let url_origin = origin_of(url);
865    if trusted_origins.iter().any(|t| t == &url_origin) {
866        Ok(())
867    } else {
868        Err(TrustedOriginError::NotTrusted { origin: url_origin })
869    }
870}
871
872/// Reasons a redirect URL might be rejected by [`validate_trusted_redirect`].
873#[derive(Debug, Clone, PartialEq, Eq)]
874pub enum TrustedOriginError {
875    Empty,
876    NotHttp,
877    NotTrusted { origin: String },
878}
879
880impl std::fmt::Display for TrustedOriginError {
881    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882        match self {
883            TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
884            TrustedOriginError::NotHttp => {
885                write!(f, "redirect URL must use http:// or https:// scheme")
886            }
887            TrustedOriginError::NotTrusted { origin } => write!(
888                f,
889                "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
890            ),
891        }
892    }
893}
894
895/// Extract the origin (`scheme://host[:port]`) from a URL string,
896/// stripping any path/query/fragment. Best-effort string slicing —
897/// no full URL parser dep. Public so router crates can reuse the same
898/// logic when comparing redirect URLs against the trusted-origins list.
899pub fn origin_of(url: &str) -> String {
900    let after_scheme = match url.find("://") {
901        Some(i) => i + 3,
902        None => return url.trim_end_matches('/').to_string(),
903    };
904    let rest = &url[after_scheme..];
905    let cut = rest
906        .find(|c: char| c == '/' || c == '?' || c == '#')
907        .unwrap_or(rest.len());
908    url[..after_scheme + cut].to_string()
909}
910
911// ---------------------------------------------------------------------------
912// Magic code auth — email verification codes
913// ---------------------------------------------------------------------------
914
915/// Pluggable storage for magic-code records. In-memory is the default
916/// (fine for dev); persistent backends (SQLite, Postgres) live in
917/// `pylon-runtime` so a server restart between "send code" and "verify
918/// code" doesn't invalidate the user's pending login.
919///
920/// All methods are infallible from the caller's perspective — durability
921/// is best-effort. A backend that fails to write should log; the
922/// in-memory cache remains authoritative for the current process.
923pub trait MagicCodeBackend: Send + Sync {
924    /// Replace any existing code for `email` with `code`.
925    fn put(&self, email: &str, code: &MagicCode);
926    /// Look up the current code for `email`. Returns `None` if absent.
927    fn get(&self, email: &str) -> Option<MagicCode>;
928    /// Remove the code for `email` (called on successful verify or
929    /// expiry). Idempotent — missing key is not an error.
930    fn remove(&self, email: &str);
931    /// Persist an attempts++ on the existing record without touching
932    /// other fields. Used by the verify-failed path to enforce
933    /// `MAX_ATTEMPTS` across restarts.
934    fn bump_attempts(&self, email: &str);
935    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
936    /// hydrate the in-memory cache from durable storage on startup.
937    fn load_all(&self) -> Vec<MagicCode>;
938}
939
940/// In-memory backend for magic codes. The default — also used as the
941/// authoritative cache by `MagicCodeStore`.
942pub struct InMemoryMagicCodeBackend {
943    codes: Mutex<HashMap<String, MagicCode>>,
944}
945
946impl InMemoryMagicCodeBackend {
947    pub fn new() -> Self {
948        Self {
949            codes: Mutex::new(HashMap::new()),
950        }
951    }
952}
953
954impl Default for InMemoryMagicCodeBackend {
955    fn default() -> Self {
956        Self::new()
957    }
958}
959
960impl MagicCodeBackend for InMemoryMagicCodeBackend {
961    fn put(&self, email: &str, code: &MagicCode) {
962        self.codes
963            .lock()
964            .unwrap()
965            .insert(email.to_string(), code.clone());
966    }
967    fn get(&self, email: &str) -> Option<MagicCode> {
968        self.codes.lock().unwrap().get(email).cloned()
969    }
970    fn remove(&self, email: &str) {
971        self.codes.lock().unwrap().remove(email);
972    }
973    fn bump_attempts(&self, email: &str) {
974        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
975            c.attempts = c.attempts.saturating_add(1);
976        }
977    }
978    fn load_all(&self) -> Vec<MagicCode> {
979        self.codes.lock().unwrap().values().cloned().collect()
980    }
981}
982
983/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
984/// and applies the verify/cooldown semantics. Hydrates the in-memory
985/// cache from the backend on construction so durable backends survive
986/// restart without losing in-flight codes.
987pub struct MagicCodeStore {
988    cache: Mutex<HashMap<String, MagicCode>>,
989    backend: Box<dyn MagicCodeBackend>,
990}
991
992#[derive(Debug, Clone)]
993pub struct MagicCode {
994    pub email: String,
995    pub code: String,
996    pub expires_at: u64,
997    /// Failed verify attempts against this code. Once it reaches
998    /// `MAX_ATTEMPTS` the code is invalidated.
999    pub attempts: u32,
1000}
1001
1002/// Maximum verify attempts per code before it's burned. 5 is a common bound —
1003/// lets the user fix typos without enabling realistic brute-force against a
1004/// 6-digit code space.
1005const MAX_ATTEMPTS: u32 = 5;
1006
1007/// Minimum seconds between successive `create()` calls for the same email.
1008/// Throttles magic-code spam (user can't be flooded with login codes).
1009const CREATE_COOLDOWN_SECS: u64 = 60;
1010
1011#[derive(Debug, Clone, PartialEq, Eq)]
1012pub enum MagicCodeError {
1013    /// There is no active code for this email, or it expired.
1014    NotFound,
1015    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
1016    TooManyAttempts,
1017    /// The code did not match.
1018    BadCode,
1019    /// The code expired since it was created.
1020    Expired,
1021    /// Another code was requested too recently. Wait and try again.
1022    Throttled { retry_after_secs: u64 },
1023}
1024
1025impl Default for MagicCodeStore {
1026    fn default() -> Self {
1027        Self::new()
1028    }
1029}
1030
1031impl MagicCodeStore {
1032    pub fn new() -> Self {
1033        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1034    }
1035
1036    /// Build a magic-code store backed by a persistent backend. Existing
1037    /// live codes are hydrated into the in-memory cache on construction
1038    /// so a server restart between "send" and "verify" doesn't kill the
1039    /// user's pending login.
1040    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1041        let now = now_secs();
1042        let mut cache = HashMap::new();
1043        for c in backend.load_all() {
1044            if c.expires_at > now {
1045                cache.insert(c.email.clone(), c);
1046            }
1047        }
1048        Self {
1049            cache: Mutex::new(cache),
1050            backend,
1051        }
1052    }
1053
1054    /// Generate a 6-digit code for an email and return it. Subject to a
1055    /// per-email cooldown — returns the error-shape via `try_create`.
1056    pub fn create(&self, email: &str) -> String {
1057        // Back-compat wrapper: same signature as before, but we still burn
1058        // the cooldown if one is active. Use `try_create` for a Result shape.
1059        self.try_create(email).unwrap_or_else(|_| String::new())
1060    }
1061
1062    /// Create a magic code, enforcing per-email cooldown. Returns the code
1063    /// or an error describing why one couldn't be issued.
1064    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1065        let now = now_secs();
1066
1067        let mut codes = self.cache.lock().unwrap();
1068
1069        // Cooldown check: if a live code exists and was created less than
1070        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
1071        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
1072        if let Some(existing) = codes.get(email) {
1073            if existing.expires_at > now {
1074                let created_at = existing.expires_at.saturating_sub(600);
1075                let age = now.saturating_sub(created_at);
1076                if age < CREATE_COOLDOWN_SECS {
1077                    return Err(MagicCodeError::Throttled {
1078                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
1079                    });
1080                }
1081            }
1082        }
1083
1084        let code = generate_magic_code();
1085        let mc = MagicCode {
1086            email: email.to_string(),
1087            code: code.clone(),
1088            expires_at: now + 600, // 10 minutes
1089            attempts: 0,
1090        };
1091        codes.insert(email.to_string(), mc.clone());
1092        // Persist after the cache mutation lands. Backend write is
1093        // best-effort — if it fails the code still works for this
1094        // process; only a restart in the next 10 minutes would lose it.
1095        self.backend.put(email, &mc);
1096        Ok(code)
1097    }
1098
1099    /// Verify a code for an email. Returns true if valid and not expired.
1100    /// Uses constant-time comparison to prevent timing attacks.
1101    /// Back-compat wrapper around [`try_verify`].
1102    pub fn verify(&self, email: &str, code: &str) -> bool {
1103        matches!(self.try_verify(email, code), Ok(()))
1104    }
1105
1106    /// Verify a code. Returns a typed error so callers can surface specific
1107    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1108    /// correct subsequent attempts return `TooManyAttempts`.
1109    /// Every magic code currently in the cache. Powers the Studio
1110    /// "Auth tables" view; not for app use. Includes expired codes —
1111    /// the cache only drops them on next verify attempt for that email.
1112    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1113        self.cache
1114            .lock()
1115            .map(|m| m.values().cloned().collect())
1116            .unwrap_or_default()
1117    }
1118
1119    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1120        let now = now_secs();
1121        let mut codes = self.cache.lock().unwrap();
1122
1123        let mc = match codes.get_mut(email) {
1124            Some(m) => m,
1125            None => return Err(MagicCodeError::NotFound),
1126        };
1127
1128        if mc.attempts >= MAX_ATTEMPTS {
1129            return Err(MagicCodeError::TooManyAttempts);
1130        }
1131        if mc.expires_at <= now {
1132            codes.remove(email);
1133            self.backend.remove(email);
1134            return Err(MagicCodeError::Expired);
1135        }
1136
1137        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1138        if !ok {
1139            mc.attempts += 1;
1140            self.backend.bump_attempts(email);
1141            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1142            if mc.attempts >= MAX_ATTEMPTS {
1143                return Err(MagicCodeError::TooManyAttempts);
1144            }
1145            return Err(MagicCodeError::BadCode);
1146        }
1147
1148        // Correct code — consume it.
1149        codes.remove(email);
1150        self.backend.remove(email);
1151        Ok(())
1152    }
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Cryptographic helpers — CSPRNG-based token and code generation
1157// ---------------------------------------------------------------------------
1158
1159fn hex_encode(bytes: &[u8]) -> String {
1160    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1161}
1162
1163/// Generate a 6-digit magic code using a CSPRNG.
1164fn generate_magic_code() -> String {
1165    use rand::Rng;
1166    let mut rng = rand::thread_rng();
1167    let code: u32 = rng.gen_range(0..1_000_000);
1168    format!("{:06}", code)
1169}
1170
1171/// Generate a session token with 256 bits of entropy from a CSPRNG.
1172fn generate_token() -> String {
1173    use rand::Rng;
1174    let mut rng = rand::thread_rng();
1175    let bytes: [u8; 32] = rng.gen();
1176    format!("pylon_{}", hex_encode(&bytes))
1177}
1178
1179// ---------------------------------------------------------------------------
1180// Session store — in-memory for dev
1181// ---------------------------------------------------------------------------
1182
1183use std::collections::HashMap;
1184use std::sync::Mutex;
1185
1186/// Pluggable storage backend for sessions. The default is in-memory; apps
1187/// deploying for real should supply a persistent backend (e.g. SQLite or
1188/// Redis) so users don't log out on server restart.
1189pub trait SessionBackend: Send + Sync {
1190    fn load_all(&self) -> Vec<Session>;
1191    fn save(&self, session: &Session);
1192    fn remove(&self, token: &str);
1193}
1194
1195/// A session store. In-memory by default; optionally backed by a
1196/// persistent [`SessionBackend`].
1197///
1198/// The in-memory map is always authoritative — reads don't touch the
1199/// backend. The backend receives every `save`/`remove`, making it a
1200/// write-through cache. On construction via [`SessionStore::with_backend`],
1201/// the store hydrates from the backend so sessions survive restart.
1202pub struct SessionStore {
1203    sessions: Mutex<HashMap<String, Session>>,
1204    backend: Option<Box<dyn SessionBackend>>,
1205    /// Default lifetime for new sessions (seconds). Sourced from the
1206    /// manifest's `auth.session.expires_in` config at server boot;
1207    /// falls back to `Session::DEFAULT_LIFETIME_SECS` (30 days).
1208    default_lifetime_secs: u64,
1209}
1210
1211impl Default for SessionStore {
1212    fn default() -> Self {
1213        Self::new()
1214    }
1215}
1216
1217impl SessionStore {
1218    pub fn new() -> Self {
1219        Self {
1220            sessions: Mutex::new(HashMap::new()),
1221            backend: None,
1222            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1223        }
1224    }
1225
1226    /// Override the default session lifetime. Used by `pylon-runtime`'s
1227    /// server bootstrap to apply the manifest's `auth.session.expires_in`.
1228    pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1229        self.default_lifetime_secs = lifetime_secs;
1230        self
1231    }
1232
1233    /// Build a session store backed by a persistent store. Existing sessions
1234    /// are loaded from the backend on construction; every future mutation
1235    /// writes through.
1236    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1237        let mut map = HashMap::new();
1238        for s in backend.load_all() {
1239            if !s.is_expired() {
1240                map.insert(s.token.clone(), s);
1241            }
1242        }
1243        Self {
1244            sessions: Mutex::new(map),
1245            backend: Some(backend),
1246            default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1247        }
1248    }
1249
1250    /// Create a session for a user and return it. Uses the store's
1251    /// configured `default_lifetime_secs` (from the manifest's
1252    /// `auth.session.expires_in`, default 30 days).
1253    pub fn create(&self, user_id: String) -> Session {
1254        let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1255        let mut sessions = self.sessions.lock().unwrap();
1256        sessions.insert(session.token.clone(), session.clone());
1257        if let Some(b) = &self.backend {
1258            b.save(&session);
1259        }
1260        session
1261    }
1262
1263    /// Look up a session by token. Returns None if the session is expired.
1264    pub fn get(&self, token: &str) -> Option<Session> {
1265        let mut sessions = self.sessions.lock().unwrap();
1266        match sessions.get(token) {
1267            Some(s) if s.is_expired() => {
1268                sessions.remove(token);
1269                None
1270            }
1271            Some(s) => Some(s.clone()),
1272            None => None,
1273        }
1274    }
1275
1276    /// Resolve a token to an auth context.
1277    /// Returns anonymous context if the token is invalid, missing, or expired.
1278    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1279        match token {
1280            Some(t) => match self.get(t) {
1281                Some(session) => session.to_auth_context(),
1282                None => AuthContext::anonymous(),
1283            },
1284            None => AuthContext::anonymous(),
1285        }
1286    }
1287
1288    /// Refresh a session — issues a new token, copies user/device, extends expiry.
1289    /// The old token is revoked. Returns the new session or None if the old
1290    /// token is missing/expired.
1291    pub fn refresh(&self, old_token: &str) -> Option<Session> {
1292        let mut sessions = self.sessions.lock().unwrap();
1293        let old = sessions.remove(old_token)?;
1294        if let Some(b) = &self.backend {
1295            b.remove(old_token);
1296        }
1297        if old.is_expired() {
1298            return None;
1299        }
1300        let mut new = Session::new(old.user_id.clone());
1301        new.device = old.device.clone();
1302        sessions.insert(new.token.clone(), new.clone());
1303        if let Some(b) = &self.backend {
1304            b.save(&new);
1305        }
1306        Some(new)
1307    }
1308
1309    /// Every session in the store, including expired ones, with no
1310    /// filtering. Powers the Studio "Auth tables" view so operators
1311    /// can see orphaned sessions / debug stuck logins. Don't use for
1312    /// app code — `list_for_user` is the right surface there.
1313    pub fn list_all_unfiltered(&self) -> Vec<Session> {
1314        self.sessions
1315            .lock()
1316            .map(|m| m.values().cloned().collect())
1317            .unwrap_or_default()
1318    }
1319
1320    /// List all active sessions for a user.
1321    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1322        let sessions = self.sessions.lock().unwrap();
1323        sessions
1324            .values()
1325            .filter(|s| s.user_id == user_id && !s.is_expired())
1326            .cloned()
1327            .collect()
1328    }
1329
1330    /// Revoke all sessions for a user. Returns the count removed.
1331    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1332        let mut sessions = self.sessions.lock().unwrap();
1333        let tokens: Vec<String> = sessions
1334            .iter()
1335            .filter_map(|(t, s)| {
1336                if s.user_id == user_id {
1337                    Some(t.clone())
1338                } else {
1339                    None
1340                }
1341            })
1342            .collect();
1343        let n = tokens.len();
1344        for t in &tokens {
1345            sessions.remove(t);
1346            if let Some(b) = &self.backend {
1347                b.remove(t);
1348            }
1349        }
1350        n
1351    }
1352
1353    /// Sweep expired sessions. Returns the count removed.
1354    pub fn sweep_expired(&self) -> usize {
1355        let mut sessions = self.sessions.lock().unwrap();
1356        let expired: Vec<String> = sessions
1357            .iter()
1358            .filter_map(|(t, s)| {
1359                if s.is_expired() {
1360                    Some(t.clone())
1361                } else {
1362                    None
1363                }
1364            })
1365            .collect();
1366        let n = expired.len();
1367        for t in &expired {
1368            sessions.remove(t);
1369            if let Some(b) = &self.backend {
1370                b.remove(t);
1371            }
1372        }
1373        n
1374    }
1375
1376    /// Attach a device label to a session (typically on login from a browser).
1377    pub fn set_device(&self, token: &str, device: String) -> bool {
1378        let mut sessions = self.sessions.lock().unwrap();
1379        if let Some(s) = sessions.get_mut(token) {
1380            s.device = Some(device);
1381            if let Some(b) = &self.backend {
1382                b.save(s);
1383            }
1384            true
1385        } else {
1386            false
1387        }
1388    }
1389
1390    /// Create a guest session with a generated anonymous ID.
1391    pub fn create_guest(&self) -> Session {
1392        use rand::Rng;
1393        let mut rng = rand::thread_rng();
1394        let bytes: [u8; 16] = rng.gen();
1395        let guest_id = format!("guest_{}", hex_encode(&bytes));
1396        self.create(guest_id)
1397    }
1398
1399    /// Upgrade a guest session to a real user. Replaces the user_id.
1400    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1401        let mut sessions = self.sessions.lock().unwrap();
1402        if let Some(session) = sessions.get_mut(token) {
1403            session.user_id = real_user_id;
1404            if let Some(b) = &self.backend {
1405                b.save(session);
1406            }
1407            true
1408        } else {
1409            false
1410        }
1411    }
1412
1413    /// Switch the session's active tenant (organization). `None` clears it.
1414    /// Callers should verify the user actually has membership in the target
1415    /// tenant BEFORE invoking this — the session store takes the value on
1416    /// trust. Returns true if the session exists, false otherwise.
1417    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1418        let mut sessions = self.sessions.lock().unwrap();
1419        if let Some(session) = sessions.get_mut(token) {
1420            session.tenant_id = tenant_id;
1421            if let Some(b) = &self.backend {
1422                b.save(session);
1423            }
1424            true
1425        } else {
1426            false
1427        }
1428    }
1429
1430    /// Remove a session.
1431    pub fn revoke(&self, token: &str) -> bool {
1432        let mut sessions = self.sessions.lock().unwrap();
1433        let removed = sessions.remove(token).is_some();
1434        if removed {
1435            if let Some(b) = &self.backend {
1436                b.remove(token);
1437            }
1438        }
1439        removed
1440    }
1441}
1442
1443// ---------------------------------------------------------------------------
1444// OAuth account links — better-auth's `account` table equivalent
1445// ---------------------------------------------------------------------------
1446
1447/// A persisted account link. Schema-aligned with better-auth's `account`
1448/// table (verified against https://www.better-auth.com/docs/concepts/database
1449/// at the time of writing) so users migrating from better-auth see the
1450/// same field names + meanings:
1451///
1452/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
1453///   `"credential"` once email/password auth lands). Matches
1454///   better-auth's `providerId`.
1455/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
1456///   GitHub numeric `id`, or for email/password the user's own id).
1457///   Matches better-auth's `accountId`. NOT the row PK.
1458/// - `id` — the row PK, generated. Lets the row be referenced
1459///   independently of the (provider_id, account_id) natural key.
1460/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
1461///   rows; `None` for OAuth links. Reserves the column so adding
1462///   email/password auth doesn't need a schema migration.
1463///
1464/// Account vs. user: a single User row can have many Account rows
1465/// (Google + GitHub + a password — all linked to one pylon user).
1466/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
1467/// a user changing their Google address keeps the same pylon account.
1468#[derive(Debug, Clone, PartialEq, Eq)]
1469pub struct Account {
1470    pub id: String,
1471    pub user_id: String,
1472    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
1473    /// (better-auth: `providerId`)
1474    pub provider_id: String,
1475    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
1476    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
1477    pub account_id: String,
1478    pub access_token: Option<String>,
1479    pub refresh_token: Option<String>,
1480    pub id_token: Option<String>,
1481    /// Unix epoch seconds at which `access_token` expires. `None` for
1482    /// non-expiring tokens (GitHub Classic apps) or for password rows.
1483    pub access_token_expires_at: Option<u64>,
1484    /// Unix epoch seconds at which `refresh_token` expires. `None` when
1485    /// the provider doesn't expire refresh tokens (most don't, but
1486    /// Microsoft Identity Platform does after 90 days of inactivity).
1487    pub refresh_token_expires_at: Option<u64>,
1488    pub scope: Option<String>,
1489    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
1490    /// Always `None` today — present so adding password auth later
1491    /// doesn't require a schema migration.
1492    pub password: Option<String>,
1493    /// Unix epoch seconds when this account was first linked.
1494    pub created_at: u64,
1495    /// Unix epoch seconds when the token bundle was last refreshed.
1496    pub updated_at: u64,
1497}
1498
1499impl Account {
1500    /// Build a new account link from a freshly-completed OAuth handshake.
1501    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
1502    /// what later lookups key on.
1503    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1504        let now = now_secs();
1505        Self {
1506            id: generate_token(),
1507            user_id,
1508            provider_id: info.provider.clone(),
1509            account_id: info.provider_account_id.clone(),
1510            access_token: Some(tokens.access_token.clone()),
1511            refresh_token: tokens.refresh_token.clone(),
1512            id_token: tokens.id_token.clone(),
1513            access_token_expires_at: tokens.expires_at,
1514            refresh_token_expires_at: None,
1515            scope: tokens.scope.clone(),
1516            password: None,
1517            created_at: now,
1518            updated_at: now,
1519        }
1520    }
1521
1522    /// True if `access_token_expires_at` is set and has passed.
1523    /// Non-expiring tokens (GitHub Classic) report `false` — caller
1524    /// should treat them as "valid until proven otherwise" and refresh
1525    /// on 401.
1526    pub fn access_token_expired(&self) -> bool {
1527        match self.access_token_expires_at {
1528            Some(ts) => now_secs() >= ts,
1529            None => false,
1530        }
1531    }
1532}
1533
1534/// Pluggable storage for account links. In-memory default ships with
1535/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
1536pub trait AccountBackend: Send + Sync {
1537    /// Insert or refresh an account link. The `(provider_id, account_id)`
1538    /// pair is the natural key — repeated calls for the same pair
1539    /// update the token bundle and `updated_at` on the existing row.
1540    fn upsert(&self, account: &Account);
1541    /// Find an account by provider identity. Returns `None` if the user
1542    /// hasn't linked this provider yet.
1543    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1544    /// Every account linked to a user. The `/api/auth/me` endpoint uses
1545    /// this to render "you're connected via Google + GitHub" in the UI
1546    /// and to gate "unlink" affordances behind "user has another way to
1547    /// sign in" checks.
1548    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1549    /// Remove a single provider link. Returns `true` if a row was removed.
1550    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1551    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
1552    /// to power the Studio admin inspector. Backends that can stream
1553    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
1554    /// returns its full map.
1555    fn list_all(&self) -> Vec<Account>;
1556}
1557
1558/// In-memory account backend (default). Lost on restart — production
1559/// deployments should swap in a persistent backend so refresh tokens
1560/// survive a redeploy.
1561pub struct InMemoryAccountBackend {
1562    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
1563    /// user_id would speed up `find_for_user` but at framework scale
1564    /// the linear scan of (typically ≤ 5) accounts per user is fine.
1565    accounts: Mutex<HashMap<(String, String), Account>>,
1566}
1567
1568impl InMemoryAccountBackend {
1569    pub fn new() -> Self {
1570        Self {
1571            accounts: Mutex::new(HashMap::new()),
1572        }
1573    }
1574}
1575
1576impl Default for InMemoryAccountBackend {
1577    fn default() -> Self {
1578        Self::new()
1579    }
1580}
1581
1582impl AccountBackend for InMemoryAccountBackend {
1583    fn upsert(&self, account: &Account) {
1584        let key = (account.provider_id.clone(), account.account_id.clone());
1585        self.accounts.lock().unwrap().insert(key, account.clone());
1586    }
1587    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1588        self.accounts
1589            .lock()
1590            .unwrap()
1591            .get(&(provider_id.to_string(), account_id.to_string()))
1592            .cloned()
1593    }
1594    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1595        self.accounts
1596            .lock()
1597            .unwrap()
1598            .values()
1599            .filter(|a| a.user_id == user_id)
1600            .cloned()
1601            .collect()
1602    }
1603    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1604        self.accounts
1605            .lock()
1606            .unwrap()
1607            .remove(&(provider_id.to_string(), account_id.to_string()))
1608            .is_some()
1609    }
1610    fn list_all(&self) -> Vec<Account> {
1611        self.accounts.lock().unwrap().values().cloned().collect()
1612    }
1613}
1614
1615/// Account store. Wraps an `AccountBackend` and provides the methods the
1616/// OAuth callback / API endpoints actually call.
1617pub struct AccountStore {
1618    backend: Box<dyn AccountBackend>,
1619}
1620
1621impl Default for AccountStore {
1622    fn default() -> Self {
1623        Self::new()
1624    }
1625}
1626
1627impl AccountStore {
1628    pub fn new() -> Self {
1629        Self {
1630            backend: Box::new(InMemoryAccountBackend::new()),
1631        }
1632    }
1633    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1634        Self { backend }
1635    }
1636    pub fn upsert(&self, account: &Account) {
1637        self.backend.upsert(account);
1638    }
1639    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1640        self.backend.find_by_provider(provider_id, account_id)
1641    }
1642    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1643        self.backend.find_for_user(user_id)
1644    }
1645    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1646        self.backend.unlink(provider_id, account_id)
1647    }
1648
1649    /// Every account in the store. Powers the Studio "Auth tables"
1650    /// view; not for app use. Implemented by walking the backend's
1651    /// per-user index — doable because account counts per user are
1652    /// small (typically ≤ 5) and total account count tracks user
1653    /// count.
1654    ///
1655    /// We don't add a `list_all` method to the `AccountBackend` trait
1656    /// because the in-memory + sqlite + postgres impls would each
1657    /// need a separate implementation, and the operational use case
1658    /// (Studio inspector) is narrow enough to live behind a wrapper
1659    /// that walks the underlying store directly. For PG/SQLite that
1660    /// means a `SELECT * FROM _pylon_accounts` — which the backends
1661    /// can grow if we ever need this at scale.
1662    pub fn list_all_unfiltered(&self) -> Vec<Account> {
1663        self.backend.list_all()
1664    }
1665}
1666
1667// ---------------------------------------------------------------------------
1668// Tests
1669// ---------------------------------------------------------------------------
1670
1671#[cfg(test)]
1672mod tests {
1673    use super::*;
1674
1675    #[test]
1676    fn anonymous_context() {
1677        let ctx = AuthContext::anonymous();
1678        assert!(!ctx.is_authenticated());
1679        assert!(ctx.user_id.is_none());
1680    }
1681
1682    #[test]
1683    fn authenticated_context() {
1684        let ctx = AuthContext::authenticated("user-1".into());
1685        assert!(ctx.is_authenticated());
1686        assert_eq!(ctx.user_id, Some("user-1".into()));
1687    }
1688
1689    #[test]
1690    fn auth_mode_public_allows_anonymous() {
1691        let mode = AuthMode::Public;
1692        assert!(mode.check(&AuthContext::anonymous()));
1693        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1694    }
1695
1696    #[test]
1697    fn auth_mode_user_requires_authenticated() {
1698        let mode = AuthMode::User;
1699        assert!(!mode.check(&AuthContext::anonymous()));
1700        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1701    }
1702
1703    #[test]
1704    fn auth_mode_from_str() {
1705        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1706        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1707        assert_eq!(AuthMode::from_str("admin"), None);
1708    }
1709
1710    #[test]
1711    fn session_store_create_and_get() {
1712        let store = SessionStore::new();
1713        let session = store.create("user-1".into());
1714        assert!(!session.token.is_empty());
1715        assert!(session.token.starts_with("pylon_"));
1716
1717        let retrieved = store.get(&session.token).unwrap();
1718        assert_eq!(retrieved.user_id, "user-1");
1719    }
1720
1721    #[test]
1722    fn session_store_resolve() {
1723        let store = SessionStore::new();
1724        let session = store.create("user-1".into());
1725
1726        let ctx = store.resolve(Some(&session.token));
1727        assert!(ctx.is_authenticated());
1728        assert_eq!(ctx.user_id, Some("user-1".into()));
1729
1730        let anon = store.resolve(None);
1731        assert!(!anon.is_authenticated());
1732
1733        let bad = store.resolve(Some("invalid-token"));
1734        assert!(!bad.is_authenticated());
1735    }
1736
1737    #[test]
1738    fn session_store_revoke() {
1739        let store = SessionStore::new();
1740        let session = store.create("user-1".into());
1741
1742        assert!(store.revoke(&session.token));
1743        assert!(store.get(&session.token).is_none());
1744        assert!(!store.revoke(&session.token)); // already revoked
1745    }
1746
1747    #[test]
1748    fn session_to_auth_context() {
1749        let session = Session::new("user-42".into());
1750        let ctx = session.to_auth_context();
1751        assert_eq!(ctx.user_id, Some("user-42".into()));
1752    }
1753
1754    // -- Admin context --
1755
1756    #[test]
1757    fn admin_context() {
1758        let ctx = AuthContext::admin();
1759        assert!(ctx.is_admin);
1760        assert!(ctx.is_authenticated());
1761    }
1762
1763    #[test]
1764    fn anonymous_not_admin() {
1765        let ctx = AuthContext::anonymous();
1766        assert!(!ctx.is_admin);
1767    }
1768
1769    #[test]
1770    fn authenticated_not_admin() {
1771        let ctx = AuthContext::authenticated("user-1".into());
1772        assert!(!ctx.is_admin);
1773    }
1774
1775    // -- Magic codes --
1776
1777    #[test]
1778    fn magic_code_create_and_verify() {
1779        let store = MagicCodeStore::new();
1780        let code = store.create("test@example.com");
1781        assert_eq!(code.len(), 6);
1782        assert!(store.verify("test@example.com", &code));
1783    }
1784
1785    #[test]
1786    fn magic_code_wrong_code_rejected() {
1787        let store = MagicCodeStore::new();
1788        store.create("test@example.com");
1789        assert!(!store.verify("test@example.com", "000000"));
1790    }
1791
1792    #[test]
1793    fn magic_code_wrong_email_rejected() {
1794        let store = MagicCodeStore::new();
1795        let code = store.create("test@example.com");
1796        assert!(!store.verify("other@example.com", &code));
1797    }
1798
1799    #[test]
1800    fn magic_code_consumed_after_verify() {
1801        let store = MagicCodeStore::new();
1802        let code = store.create("test@example.com");
1803        assert!(store.verify("test@example.com", &code));
1804        // Second verify should fail — code consumed.
1805        assert!(!store.verify("test@example.com", &code));
1806    }
1807
1808    #[test]
1809    fn magic_code_different_emails_independent() {
1810        let store = MagicCodeStore::new();
1811        let code1 = store.create("alice@example.com");
1812        let code2 = store.create("bob@example.com");
1813        // Each email has its own code.
1814        assert!(store.verify("alice@example.com", &code1));
1815        assert!(store.verify("bob@example.com", &code2));
1816    }
1817
1818    // -- Constant-time comparison --
1819
1820    #[test]
1821    fn constant_time_eq_equal() {
1822        assert!(constant_time_eq(b"hello", b"hello"));
1823        assert!(constant_time_eq(b"", b""));
1824    }
1825
1826    #[test]
1827    fn constant_time_eq_not_equal() {
1828        assert!(!constant_time_eq(b"hello", b"world"));
1829        assert!(!constant_time_eq(b"hello", b"hell"));
1830        assert!(!constant_time_eq(b"a", b"b"));
1831    }
1832
1833    // -- Token generation --
1834
1835    #[test]
1836    fn generated_tokens_are_unique() {
1837        let t1 = generate_token();
1838        let t2 = generate_token();
1839        assert_ne!(t1, t2);
1840        assert!(t1.starts_with("pylon_"));
1841        assert!(t2.starts_with("pylon_"));
1842        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
1843        assert_eq!(t1.len(), 6 + 64);
1844    }
1845
1846    // -- OAuth registry --
1847
1848    #[test]
1849    fn oauth_registry_empty() {
1850        let reg = OAuthRegistry::new();
1851        assert!(reg.get("google").is_none());
1852    }
1853
1854    #[test]
1855    fn oauth_registry_register_and_get() {
1856        let mut reg = OAuthRegistry::new();
1857        reg.register(OAuthConfig {
1858            provider: "google".into(),
1859            client_id: "test-id".into(),
1860            client_secret: "test-secret".into(),
1861            redirect_uri: "http://localhost/callback".into(),
1862        });
1863        let config = reg.get("google").unwrap();
1864        assert_eq!(config.client_id, "test-id");
1865        assert!(config.auth_url().contains("accounts.google.com"));
1866    }
1867
1868    // -- Guest auth --
1869
1870    #[test]
1871    fn guest_session() {
1872        let store = SessionStore::new();
1873        let session = store.create_guest();
1874        assert!(session.user_id.starts_with("guest_"));
1875        assert!(!session.token.is_empty());
1876
1877        let ctx = store.resolve(Some(&session.token));
1878        assert!(ctx.is_authenticated());
1879        assert!(ctx.user_id.unwrap().starts_with("guest_"));
1880    }
1881
1882    #[test]
1883    fn upgrade_guest_to_real_user() {
1884        let store = SessionStore::new();
1885        let session = store.create_guest();
1886        assert!(session.user_id.starts_with("guest_"));
1887
1888        let upgraded = store.upgrade(&session.token, "real-user-123".into());
1889        assert!(upgraded);
1890
1891        let ctx = store.resolve(Some(&session.token));
1892        assert_eq!(ctx.user_id, Some("real-user-123".into()));
1893    }
1894
1895    #[test]
1896    fn upgrade_invalid_token_fails() {
1897        let store = SessionStore::new();
1898        let upgraded = store.upgrade("nonexistent-token", "user".into());
1899        assert!(!upgraded);
1900    }
1901
1902    #[test]
1903    fn guest_context() {
1904        let ctx = AuthContext::guest("guest_123".into());
1905        // Guests carry a stable id but are NOT authenticated — routes
1906        // guarded by AuthMode::User must reject them.
1907        assert!(!ctx.is_authenticated());
1908        assert!(ctx.is_guest);
1909        assert!(!ctx.is_admin);
1910        assert_eq!(ctx.user_id, Some("guest_123".into()));
1911        assert!(!AuthMode::User.check(&ctx));
1912        assert!(AuthMode::Public.check(&ctx));
1913    }
1914
1915    #[test]
1916    fn oauth_token_urls() {
1917        let google = OAuthConfig {
1918            provider: "google".into(),
1919            client_id: "x".into(),
1920            client_secret: "x".into(),
1921            redirect_uri: "x".into(),
1922        };
1923        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1924        let github = OAuthConfig {
1925            provider: "github".into(),
1926            client_id: "x".into(),
1927            client_secret: "x".into(),
1928            redirect_uri: "x".into(),
1929        };
1930        assert_eq!(
1931            github.token_url(),
1932            "https://github.com/login/oauth/access_token"
1933        );
1934        let unknown = OAuthConfig {
1935            provider: "unknown".into(),
1936            client_id: "x".into(),
1937            client_secret: "x".into(),
1938            redirect_uri: "x".into(),
1939        };
1940        assert_eq!(unknown.token_url(), "");
1941        assert!(unknown.auth_url().is_empty());
1942    }
1943
1944    #[test]
1945    fn oauth_auth_url_github() {
1946        let config = OAuthConfig {
1947            provider: "github".into(),
1948            client_id: "gh-id".into(),
1949            client_secret: "gh-secret".into(),
1950            redirect_uri: "http://localhost/cb".into(),
1951        };
1952        assert!(config.auth_url().contains("github.com"));
1953        assert!(config.auth_url().contains("gh-id"));
1954    }
1955
1956    #[test]
1957    fn oauth_auth_url_with_state() {
1958        let config = OAuthConfig {
1959            provider: "google".into(),
1960            client_id: "test-id".into(),
1961            client_secret: "test-secret".into(),
1962            redirect_uri: "http://localhost/cb".into(),
1963        };
1964        let url = config.auth_url_with_state("random_state_123");
1965        assert!(url.contains("&state=random_state_123"));
1966    }
1967
1968    #[test]
1969    fn oauth_state_store_create_and_validate() {
1970        let store = OAuthStateStore::new();
1971        let token = store.create("google", "https://app/cb", "https://app/login");
1972        let rec = store.validate(&token, "google").expect("valid first time");
1973        assert_eq!(rec.callback_url, "https://app/cb");
1974        assert_eq!(rec.error_callback_url, "https://app/login");
1975        // Second validation should fail — single-use.
1976        assert!(store.validate(&token, "google").is_none());
1977    }
1978
1979    #[test]
1980    fn oauth_state_store_wrong_provider_rejected() {
1981        let store = OAuthStateStore::new();
1982        let token = store.create("google", "https://app/cb", "https://app/cb");
1983        assert!(store.validate(&token, "github").is_none());
1984    }
1985
1986    #[test]
1987    fn oauth_state_store_invalid_state_rejected() {
1988        let store = OAuthStateStore::new();
1989        assert!(store.validate("nonexistent", "google").is_none());
1990    }
1991
1992    #[test]
1993    fn validate_trusted_redirect_basics() {
1994        let trusted = vec!["http://localhost:3000".to_string()];
1995        assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
1996        assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
1997        assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
1998
1999        // Wrong port → wrong origin.
2000        assert!(matches!(
2001            validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
2002            Err(TrustedOriginError::NotTrusted { .. })
2003        ));
2004        // Non-http scheme rejected even before trusted check (defense
2005        // against javascript:, file:, data:).
2006        assert!(matches!(
2007            validate_trusted_redirect("javascript:alert(1)", &trusted),
2008            Err(TrustedOriginError::NotHttp)
2009        ));
2010        assert!(matches!(
2011            validate_trusted_redirect("", &trusted),
2012            Err(TrustedOriginError::Empty)
2013        ));
2014    }
2015}