Skip to main content

pylon_auth/
lib.rs

1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9// ---------------------------------------------------------------------------
10// Auth context — the identity available to runtime operations
11// ---------------------------------------------------------------------------
12
13/// The auth context for a request. Represents who is making the request.
14///
15/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
16/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
17/// or add roles and bypass every policy. Identity must come from
18/// server-minted sessions (`Session::to_auth_context`) or explicit
19/// constructors, never from deserialization.
20///
21/// `Serialize` is safe because sending the resolved context BACK to the
22/// client exposes nothing the server didn't already decide.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25    /// The authenticated user ID, or None for public/anonymous access.
26    /// For guest contexts this is `Some(guest_id)` — a stable
27    /// anonymous identifier, NOT a real user.
28    pub user_id: Option<String>,
29    /// Whether this is an admin context (bypasses policies).
30    pub is_admin: bool,
31    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
32    /// for cart state and similar pre-login persistence. Routes guarded
33    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
34    /// "real signed-in user" should pass auth-required gates.
35    #[serde(default, skip_serializing_if = "is_false")]
36    pub is_guest: bool,
37    /// Roles granted to this user. Empty for anonymous.
38    pub roles: Vec<String>,
39    /// Active tenant id (for multi-tenant apps). Set when the user has
40    /// selected an organization for the current session.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46    !b
47}
48
49impl AuthContext {
50    /// Create an anonymous/public auth context.
51    pub fn anonymous() -> Self {
52        Self {
53            user_id: None,
54            is_admin: false,
55            is_guest: false,
56            roles: Vec::new(),
57            tenant_id: None,
58        }
59    }
60
61    /// Create an authenticated auth context.
62    pub fn authenticated(user_id: String) -> Self {
63        Self {
64            user_id: Some(user_id),
65            is_admin: false,
66            is_guest: false,
67            roles: Vec::new(),
68            tenant_id: None,
69        }
70    }
71
72    /// Create a guest auth context with a persistent anonymous ID.
73    /// Guests carry an opaque stable id (cart/session continuity) but
74    /// are NOT considered authenticated — `is_authenticated()` returns
75    /// false and `AuthMode::User` rejects them.
76    pub fn guest(guest_id: String) -> Self {
77        Self {
78            user_id: Some(guest_id),
79            is_admin: false,
80            is_guest: true,
81            roles: Vec::new(),
82            tenant_id: None,
83        }
84    }
85
86    /// Create an admin auth context that bypasses all policies.
87    pub fn admin() -> Self {
88        Self {
89            user_id: Some("__admin__".into()),
90            is_admin: true,
91            is_guest: false,
92            roles: vec!["admin".into()],
93            tenant_id: None,
94        }
95    }
96
97    /// Convenience: build a user context from a user id.
98    pub fn user(user_id: String) -> Self {
99        Self::authenticated(user_id)
100    }
101
102    /// Active tenant id (None when the user hasn't selected an org).
103    pub fn tenant_id(&self) -> Option<&str> {
104        self.tenant_id.as_deref()
105    }
106
107    /// Attach a tenant id to the context (chainable).
108    pub fn with_tenant(mut self, tenant_id: String) -> Self {
109        self.tenant_id = Some(tenant_id);
110        self
111    }
112
113    /// Check if this context represents an authenticated user.
114    /// Guests intentionally return `false` — they have a stable anonymous
115    /// id but never gain user-level access.
116    pub fn is_authenticated(&self) -> bool {
117        self.user_id.is_some() && !self.is_guest
118    }
119
120    /// Check if the user has a specific role. Admins have every role implicitly.
121    pub fn has_role(&self, role: &str) -> bool {
122        self.is_admin || self.roles.iter().any(|r| r == role)
123    }
124
125    /// Check if the user has ANY of the given roles.
126    pub fn has_any_role(&self, roles: &[&str]) -> bool {
127        self.is_admin || roles.iter().any(|r| self.has_role(r))
128    }
129
130    /// Attach roles to the context (chainable).
131    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132        self.roles = roles;
133        self
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Constant-time comparison
139// ---------------------------------------------------------------------------
140
141/// Constant-time byte comparison to prevent timing attacks.
142///
143/// The length check leaks whether the two slices are the same length, but the
144/// content comparison always examines every byte regardless of where (or
145/// whether) a mismatch occurs.
146pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147    if a.len() != b.len() {
148        return false;
149    }
150    let mut result: u8 = 0;
151    for (x, y) in a.iter().zip(b.iter()) {
152        result |= x ^ y;
153    }
154    result == 0
155}
156
157// ---------------------------------------------------------------------------
158// Auth mode — matches the route "auth" field values
159// ---------------------------------------------------------------------------
160
161/// The auth mode declared on a route.
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164    /// Anyone can access.
165    Public,
166    /// Only authenticated users can access.
167    User,
168}
169
170impl AuthMode {
171    /// Parse from the manifest auth string.
172    #[allow(clippy::should_implement_trait)]
173    pub fn from_str(s: &str) -> Option<Self> {
174        match s {
175            "public" => Some(AuthMode::Public),
176            "user" => Some(AuthMode::User),
177            _ => None,
178        }
179    }
180
181    /// Check if the given auth context satisfies this mode.
182    pub fn check(&self, ctx: &AuthContext) -> bool {
183        match self {
184            AuthMode::Public => true,
185            AuthMode::User => ctx.is_authenticated(),
186        }
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Session — opaque token session
192// ---------------------------------------------------------------------------
193
194/// A session token and its associated user.
195#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197    pub token: String,
198    pub user_id: String,
199    /// Unix epoch seconds at which this session expires. 0 = never.
200    #[serde(default)]
201    pub expires_at: u64,
202    /// Optional user-agent / device tag recorded at session creation.
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub device: Option<String>,
205    /// Unix epoch seconds when the session was created.
206    #[serde(default)]
207    pub created_at: u64,
208    /// Active tenant id (selected organization). Set via
209    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
210    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub tenant_id: Option<String>,
213}
214
215impl Session {
216    /// Default session lifetime: 30 days.
217    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219    /// Create a new session with a generated token and default 30-day expiry.
220    pub fn new(user_id: String) -> Self {
221        let now = now_secs();
222        Self {
223            token: generate_token(),
224            user_id,
225            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226            device: None,
227            created_at: now,
228            tenant_id: None,
229        }
230    }
231
232    /// Create a session with a specific lifetime.
233    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234        let now = now_secs();
235        Self {
236            token: generate_token(),
237            user_id,
238            expires_at: if lifetime_secs == 0 {
239                0
240            } else {
241                now.saturating_add(lifetime_secs)
242            },
243            device: None,
244            created_at: now,
245            tenant_id: None,
246        }
247    }
248
249    /// Convert this session to an auth context, carrying the selected
250    /// tenant so row-scoped policies see `auth.tenantId`.
251    pub fn to_auth_context(&self) -> AuthContext {
252        let ctx = AuthContext::authenticated(self.user_id.clone());
253        match &self.tenant_id {
254            Some(t) => ctx.with_tenant(t.clone()),
255            None => ctx,
256        }
257    }
258
259    /// Returns true if the session has passed its expires_at time.
260    /// Boundary is inclusive (`>=`) to match the rest of the codebase
261    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
262    pub fn is_expired(&self) -> bool {
263        self.expires_at != 0 && now_secs() >= self.expires_at
264    }
265}
266
267fn now_secs() -> u64 {
268    use std::time::{SystemTime, UNIX_EPOCH};
269    SystemTime::now()
270        .duration_since(UNIX_EPOCH)
271        .unwrap_or_default()
272        .as_secs()
273}
274
275// ---------------------------------------------------------------------------
276// OAuth provider config
277// ---------------------------------------------------------------------------
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281    pub provider: String,
282    pub client_id: String,
283    pub client_secret: String,
284    pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288    /// Generate the authorization URL for the provider.
289    ///
290    /// Callers MUST append a `&state=<random>` parameter and validate it in the
291    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
292    /// implementation.
293    pub fn auth_url(&self) -> String {
294        match self.provider.as_str() {
295            "google" => format!(
296                "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297                self.client_id, self.redirect_uri
298            ),
299            "github" => format!(
300                "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301                self.client_id, self.redirect_uri
302            ),
303            _ => String::new(),
304        }
305    }
306
307    /// Generate the authorization URL with a CSRF state parameter attached.
308    pub fn auth_url_with_state(&self, state: &str) -> String {
309        let base = self.auth_url();
310        if base.is_empty() {
311            return base;
312        }
313        format!("{}&state={}", base, state)
314    }
315
316    /// Generate the token exchange URL.
317    pub fn token_url(&self) -> &str {
318        match self.provider.as_str() {
319            "google" => "https://oauth2.googleapis.com/token",
320            "github" => "https://github.com/login/oauth/access_token",
321            _ => "",
322        }
323    }
324
325    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
326    pub fn userinfo_url(&self) -> &str {
327        match self.provider.as_str() {
328            "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329            "github" => "https://api.github.com/user",
330            _ => "",
331        }
332    }
333
334    /// Exchange an authorization code for the full token set
335    /// (`access_token`, optional `refresh_token`, optional `id_token`,
336    /// `expires_in`, `scope`). The longer struct is what the
337    /// account-store needs to persist; the legacy
338    /// [`OAuthConfig::exchange_code`] returns just the access token for
339    /// callers that don't care.
340    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341        let body = match self.provider.as_str() {
342            "google" => format!(
343                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344                url_encode(&self.client_id),
345                url_encode(&self.client_secret),
346                url_encode(&self.redirect_uri)
347            ),
348            "github" => format!(
349                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350                url_encode(&self.client_id),
351                url_encode(&self.client_secret),
352                url_encode(&self.redirect_uri)
353            ),
354            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355        };
356
357        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358        parse_token_response(&out)
359    }
360
361    /// Exchange an authorization code for an access token. Thin wrapper
362    /// around [`OAuthConfig::exchange_code_full`] for callers that only
363    /// need the access token (existing pre-account-store call sites).
364    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365        let body = match self.provider.as_str() {
366            "google" => format!(
367                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368                url_encode(&self.client_id),
369                url_encode(&self.client_secret),
370                url_encode(&self.redirect_uri)
371            ),
372            "github" => format!(
373                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374                url_encode(&self.client_id),
375                url_encode(&self.client_secret),
376                url_encode(&self.redirect_uri)
377            ),
378            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379        };
380
381        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382        extract_access_token(&out)
383    }
384
385    /// Fetch the authenticated user's email + display name using an access token.
386    /// Returns `(email, display_name)`. Use [`OAuthConfig::fetch_userinfo_full`]
387    /// when you also need the provider-stable account ID for account
388    /// linking — the (`provider`, `provider_account_id`) pair is what
389    /// keeps a renamed-email user matched to the same row.
390    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391        let info = self.fetch_userinfo_full(access_token)?;
392        Ok((info.email, info.name))
393    }
394
395    /// Fetch the authenticated user's full identity info — email + name +
396    /// the provider-stable account ID (Google's `sub`, GitHub's `id`).
397    /// `provider_account_id` is what the account-store keys on, NOT the
398    /// email; otherwise a user changing their Google address would orphan
399    /// their existing pylon account.
400    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401        let out = http_get_bearer(self.userinfo_url(), access_token)?;
402        let parsed: serde_json::Value =
403            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404        match self.provider.as_str() {
405            "google" => {
406                let email = parsed
407                    .get("email")
408                    .and_then(|v| v.as_str())
409                    .ok_or("no email in userinfo")?
410                    .to_string();
411                let name = parsed
412                    .get("name")
413                    .and_then(|v| v.as_str())
414                    .map(String::from);
415                let provider_account_id = parsed
416                    .get("sub")
417                    .and_then(|v| v.as_str())
418                    .ok_or("no sub in userinfo")?
419                    .to_string();
420                Ok(UserInfo {
421                    provider: self.provider.clone(),
422                    provider_account_id,
423                    email,
424                    name,
425                })
426            }
427            "github" => {
428                let name = parsed
429                    .get("name")
430                    .and_then(|v| v.as_str())
431                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432                    .map(String::from);
433                let email = parsed
434                    .get("email")
435                    .and_then(|v| v.as_str())
436                    .map(String::from);
437                // GitHub may return a null email if the user hasn't published one;
438                // in that case the caller should hit /user/emails with the same token.
439                let email = email
440                    .or_else(|| fetch_github_primary_email(access_token).ok())
441                    .ok_or("no accessible email on GitHub account")?;
442                // GitHub's `id` field is a numeric user ID — the stable
443                // account identifier even if the user renames themselves.
444                let provider_account_id = parsed
445                    .get("id")
446                    .map(|v| {
447                        v.as_i64()
448                            .map(|n| n.to_string())
449                            .or_else(|| v.as_str().map(String::from))
450                            .unwrap_or_default()
451                    })
452                    .filter(|s| !s.is_empty())
453                    .ok_or("no id in userinfo")?;
454                Ok(UserInfo {
455                    provider: self.provider.clone(),
456                    provider_account_id,
457                    email,
458                    name,
459                })
460            }
461            _ => Err(format!("unknown provider: {}", self.provider)),
462        }
463    }
464}
465
466/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
467/// `provider_account_id` is the provider-stable subject id (Google `sub`,
468/// GitHub numeric `id`) — what the account store keys on so a renamed
469/// email doesn't orphan the pylon account.
470#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472    pub provider: String,
473    pub provider_account_id: String,
474    pub email: String,
475    pub name: Option<String>,
476}
477
478/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
479/// on the matching `Account` row so `refresh_token` is available for
480/// silent re-auth and `expires_at` is checked before each provider call.
481#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483    pub access_token: String,
484    pub refresh_token: Option<String>,
485    pub id_token: Option<String>,
486    /// Unix epoch seconds at which the access token expires. `None` when
487    /// the provider didn't return `expires_in` (GitHub's classic OAuth
488    /// app tokens are non-expiring).
489    pub expires_at: Option<u64>,
490    pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494    // Most providers return JSON; GitHub Classic apps return form-urlencoded
495    // unless you ask with Accept: application/json (which we do).
496    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
498        let mut map = serde_json::Map::new();
499        for pair in body.split('&') {
500            if let Some((k, v)) = pair.split_once('=') {
501                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502            }
503        }
504        serde_json::Value::Object(map)
505    });
506
507    let access_token = json
508        .get("access_token")
509        .and_then(|v| v.as_str())
510        .ok_or_else(|| format!("no access_token in token response: {body}"))?
511        .to_string();
512    let refresh_token = json
513        .get("refresh_token")
514        .and_then(|v| v.as_str())
515        .map(String::from);
516    let id_token = json
517        .get("id_token")
518        .and_then(|v| v.as_str())
519        .map(String::from);
520    let expires_at = json
521        .get("expires_in")
522        .and_then(|v| {
523            v.as_u64()
524                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525        })
526        .map(|secs| now_secs().saturating_add(secs));
527    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528    Ok(TokenSet {
529        access_token,
530        refresh_token,
531        id_token,
532        expires_at,
533        scope,
534    })
535}
536
537fn url_encode(s: &str) -> String {
538    let mut out = String::with_capacity(s.len());
539    for b in s.bytes() {
540        match b {
541            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542                out.push(b as char)
543            }
544            _ => out.push_str(&format!("%{b:02X}")),
545        }
546    }
547    out
548}
549
550/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
551/// provider doesn't block a login indefinitely; long enough to absorb
552/// typical internet latency.
553const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556    ureq::AgentBuilder::new()
557        .timeout_connect(HTTP_TIMEOUT)
558        .timeout_read(HTTP_TIMEOUT)
559        .timeout_write(HTTP_TIMEOUT)
560        .user_agent("pylon/0.1")
561        .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565    let agent = ureq_agent();
566    let mut req = agent
567        .post(url)
568        .set("Content-Type", "application/x-www-form-urlencoded");
569    if accept_json {
570        req = req.set("Accept", "application/json");
571    }
572    match req.send_string(body) {
573        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574        Err(ureq::Error::Status(code, resp)) => {
575            let body = resp.into_string().unwrap_or_default();
576            Err(format!("HTTP {code}: {body}"))
577        }
578        Err(e) => Err(format!("HTTP error: {e}")),
579    }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583    let agent = ureq_agent();
584    match agent
585        .get(url)
586        .set("Authorization", &format!("Bearer {token}"))
587        .set("Accept", "application/json")
588        .call()
589    {
590        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591        Err(ureq::Error::Status(code, resp)) => {
592            let body = resp.into_string().unwrap_or_default();
593            Err(format!("HTTP {code}: {body}"))
594        }
595        Err(e) => Err(format!("HTTP error: {e}")),
596    }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601    let emails: serde_json::Value =
602        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603    emails
604        .as_array()
605        .and_then(|arr| {
606            arr.iter()
607                .find(|e| {
608                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610                })
611                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612        })
613        .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618        if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619            return Ok(t.to_string());
620        }
621    }
622    // GitHub can return url-encoded: access_token=...&scope=...&token_type=bearer
623    for pair in body.split('&') {
624        if let Some(val) = pair.strip_prefix("access_token=") {
625            return Ok(val.to_string());
626        }
627    }
628    Err(format!("no access_token in token response: {body}"))
629}
630
631/// OAuth provider registry.
632pub struct OAuthRegistry {
633    providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642impl OAuthRegistry {
643    pub fn new() -> Self {
644        Self {
645            providers: std::collections::HashMap::new(),
646        }
647    }
648
649    pub fn register(&mut self, config: OAuthConfig) {
650        self.providers.insert(config.provider.clone(), config);
651    }
652
653    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654        self.providers.get(provider)
655    }
656
657    /// Build from environment variables.
658    /// Looks for PYLON_OAUTH_GOOGLE_CLIENT_ID, etc.
659    pub fn from_env() -> Self {
660        let mut reg = Self::new();
661
662        // Google
663        if let (Ok(id), Ok(secret)) = (
664            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666        ) {
667            reg.register(OAuthConfig {
668                provider: "google".into(),
669                client_id: id,
670                client_secret: secret,
671                redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673            });
674        }
675
676        // GitHub
677        if let (Ok(id), Ok(secret)) = (
678            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680        ) {
681            reg.register(OAuthConfig {
682                provider: "github".into(),
683                client_id: id,
684                client_secret: secret,
685                redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687            });
688        }
689
690        reg
691    }
692}
693
694// ---------------------------------------------------------------------------
695// OAuth state store — CSRF protection for OAuth flows
696// ---------------------------------------------------------------------------
697
698/// One stored OAuth state record. Carries the post-callback redirect
699/// URLs alongside the provider so the callback handler doesn't need to
700/// consult an env var to know where to send the user. Both URLs are
701/// validated against `PYLON_TRUSTED_ORIGINS` at create time, so the
702/// callback can trust them without re-checking.
703#[derive(Debug, Clone, PartialEq, Eq)]
704pub struct OAuthState {
705    pub provider: String,
706    /// URL the callback redirects to on success. The frontend supplies
707    /// this via `?callback=` on the start request.
708    pub callback_url: String,
709    /// URL the callback redirects to on failure. Defaults to
710    /// `callback_url` when the frontend doesn't pass an explicit
711    /// `?error_callback=`. The error code + message ride along as
712    /// query params (`?oauth_error=X&oauth_error_message=Y`).
713    pub error_callback_url: String,
714    pub expires_at: u64,
715}
716
717/// Backing store for OAuth state records. Default impl keeps them in
718/// memory (fine for tests + dev); the runtime swaps in a SQLite or
719/// Postgres backend so a restart in the middle of an OAuth handshake
720/// doesn't leave the user with "invalid state" on the callback.
721pub trait OAuthStateBackend: Send + Sync {
722    /// Persist a state record under `token`.
723    fn put(&self, token: &str, state: &OAuthState);
724    /// Atomic compare-and-consume: returns the stored record if the
725    /// token exists and hasn't expired, then removes it. Returning
726    /// `None` means either the token never existed or it has already
727    /// been used / expired.
728    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
729}
730
731/// In-memory backend (default). Lost on restart.
732pub struct InMemoryOAuthBackend {
733    states: Mutex<HashMap<String, OAuthState>>,
734}
735
736impl InMemoryOAuthBackend {
737    pub fn new() -> Self {
738        Self {
739            states: Mutex::new(HashMap::new()),
740        }
741    }
742}
743
744impl Default for InMemoryOAuthBackend {
745    fn default() -> Self {
746        Self::new()
747    }
748}
749
750impl OAuthStateBackend for InMemoryOAuthBackend {
751    fn put(&self, token: &str, state: &OAuthState) {
752        self.states
753            .lock()
754            .unwrap()
755            .insert(token.to_string(), state.clone());
756    }
757    fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
758        let mut s = self.states.lock().unwrap();
759        let entry = s.remove(token)?;
760        if entry.expires_at <= now_unix_secs {
761            return None;
762        }
763        Some(entry)
764    }
765}
766
767/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
768///
769/// State tokens are short-lived (10 minutes) and single-use. Backed by an
770/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
771/// to SQLite (or Postgres when `DATABASE_URL` is set) so they survive a
772/// restart that happens mid-OAuth-handshake.
773pub struct OAuthStateStore {
774    backend: Box<dyn OAuthStateBackend>,
775}
776
777impl Default for OAuthStateStore {
778    fn default() -> Self {
779        Self::new()
780    }
781}
782
783impl OAuthStateStore {
784    pub fn new() -> Self {
785        Self {
786            backend: Box::new(InMemoryOAuthBackend::new()),
787        }
788    }
789
790    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
791        Self { backend }
792    }
793
794    /// Generate and store a new state record. Returns the random
795    /// state token (the value the OAuth provider echoes back as
796    /// `?state=…` on the callback).
797    ///
798    /// Caller is responsible for validating `callback_url` and
799    /// `error_callback_url` against the trusted-origins allowlist
800    /// BEFORE calling this — the store trusts what it's given.
801    pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
802        use std::time::{SystemTime, UNIX_EPOCH};
803        let token = generate_token();
804        let now = SystemTime::now()
805            .duration_since(UNIX_EPOCH)
806            .unwrap_or_default()
807            .as_secs();
808        let state = OAuthState {
809            provider: provider.to_string(),
810            callback_url: callback_url.to_string(),
811            error_callback_url: error_callback_url.to_string(),
812            expires_at: now + 600,
813        };
814        self.backend.put(&token, &state);
815        token
816    }
817
818    /// Validate and consume a state token. Returns the stored record
819    /// iff the token existed, has not expired, AND matches
820    /// `expected_provider`. The token is removed either way to make
821    /// replay impossible.
822    pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
823        use std::time::{SystemTime, UNIX_EPOCH};
824        let now = SystemTime::now()
825            .duration_since(UNIX_EPOCH)
826            .unwrap_or_default()
827            .as_secs();
828        let entry = self.backend.take(state, now)?;
829        if entry.provider != expected_provider {
830            return None;
831        }
832        Some(entry)
833    }
834}
835
836/// Validate that `url` has an origin (scheme://host[:port]) listed in
837/// `trusted_origins`. Returns `Ok(url)` when trusted (echoes input for
838/// chaining), `Err` with a code/message when not. Used by the OAuth
839/// start endpoint to gate `?callback=` + `?error_callback=` values
840/// before storing them in the state record.
841///
842/// `trusted_origins` entries are origin strings like
843/// `"https://app.example.com"` or `"http://localhost:3000"` — no
844/// trailing slash, no path. A `url` like
845/// `"http://localhost:3000/dashboard?x=1"` matches the
846/// `"http://localhost:3000"` entry.
847///
848/// Borrowed wholesale from better-auth's `trustedOrigins` model:
849/// explicit allowlist, no implicit "same-origin trust," no env-var
850/// magic. An open-redirect via OAuth is one of the easier auth bugs
851/// to ship by accident.
852pub fn validate_trusted_redirect(
853    url: &str,
854    trusted_origins: &[String],
855) -> Result<(), TrustedOriginError> {
856    if url.is_empty() {
857        return Err(TrustedOriginError::Empty);
858    }
859    // Must be absolute http(s) URL — no relative paths, no schemes
860    // like javascript:, file:, data:.
861    if !url.starts_with("http://") && !url.starts_with("https://") {
862        return Err(TrustedOriginError::NotHttp);
863    }
864    let url_origin = origin_of(url);
865    if trusted_origins.iter().any(|t| t == &url_origin) {
866        Ok(())
867    } else {
868        Err(TrustedOriginError::NotTrusted { origin: url_origin })
869    }
870}
871
872/// Reasons a redirect URL might be rejected by [`validate_trusted_redirect`].
873#[derive(Debug, Clone, PartialEq, Eq)]
874pub enum TrustedOriginError {
875    Empty,
876    NotHttp,
877    NotTrusted { origin: String },
878}
879
880impl std::fmt::Display for TrustedOriginError {
881    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882        match self {
883            TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
884            TrustedOriginError::NotHttp => {
885                write!(f, "redirect URL must use http:// or https:// scheme")
886            }
887            TrustedOriginError::NotTrusted { origin } => write!(
888                f,
889                "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
890            ),
891        }
892    }
893}
894
895/// Extract the origin (`scheme://host[:port]`) from a URL string,
896/// stripping any path/query/fragment. Best-effort string slicing —
897/// no full URL parser dep. Public so router crates can reuse the same
898/// logic when comparing redirect URLs against the trusted-origins list.
899pub fn origin_of(url: &str) -> String {
900    let after_scheme = match url.find("://") {
901        Some(i) => i + 3,
902        None => return url.trim_end_matches('/').to_string(),
903    };
904    let rest = &url[after_scheme..];
905    let cut = rest
906        .find(|c: char| c == '/' || c == '?' || c == '#')
907        .unwrap_or(rest.len());
908    url[..after_scheme + cut].to_string()
909}
910
911// ---------------------------------------------------------------------------
912// Magic code auth — email verification codes
913// ---------------------------------------------------------------------------
914
915/// Pluggable storage for magic-code records. In-memory is the default
916/// (fine for dev); persistent backends (SQLite, Postgres) live in
917/// `pylon-runtime` so a server restart between "send code" and "verify
918/// code" doesn't invalidate the user's pending login.
919///
920/// All methods are infallible from the caller's perspective — durability
921/// is best-effort. A backend that fails to write should log; the
922/// in-memory cache remains authoritative for the current process.
923pub trait MagicCodeBackend: Send + Sync {
924    /// Replace any existing code for `email` with `code`.
925    fn put(&self, email: &str, code: &MagicCode);
926    /// Look up the current code for `email`. Returns `None` if absent.
927    fn get(&self, email: &str) -> Option<MagicCode>;
928    /// Remove the code for `email` (called on successful verify or
929    /// expiry). Idempotent — missing key is not an error.
930    fn remove(&self, email: &str);
931    /// Persist an attempts++ on the existing record without touching
932    /// other fields. Used by the verify-failed path to enforce
933    /// `MAX_ATTEMPTS` across restarts.
934    fn bump_attempts(&self, email: &str);
935    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
936    /// hydrate the in-memory cache from durable storage on startup.
937    fn load_all(&self) -> Vec<MagicCode>;
938}
939
940/// In-memory backend for magic codes. The default — also used as the
941/// authoritative cache by `MagicCodeStore`.
942pub struct InMemoryMagicCodeBackend {
943    codes: Mutex<HashMap<String, MagicCode>>,
944}
945
946impl InMemoryMagicCodeBackend {
947    pub fn new() -> Self {
948        Self {
949            codes: Mutex::new(HashMap::new()),
950        }
951    }
952}
953
954impl Default for InMemoryMagicCodeBackend {
955    fn default() -> Self {
956        Self::new()
957    }
958}
959
960impl MagicCodeBackend for InMemoryMagicCodeBackend {
961    fn put(&self, email: &str, code: &MagicCode) {
962        self.codes
963            .lock()
964            .unwrap()
965            .insert(email.to_string(), code.clone());
966    }
967    fn get(&self, email: &str) -> Option<MagicCode> {
968        self.codes.lock().unwrap().get(email).cloned()
969    }
970    fn remove(&self, email: &str) {
971        self.codes.lock().unwrap().remove(email);
972    }
973    fn bump_attempts(&self, email: &str) {
974        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
975            c.attempts = c.attempts.saturating_add(1);
976        }
977    }
978    fn load_all(&self) -> Vec<MagicCode> {
979        self.codes.lock().unwrap().values().cloned().collect()
980    }
981}
982
983/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
984/// and applies the verify/cooldown semantics. Hydrates the in-memory
985/// cache from the backend on construction so durable backends survive
986/// restart without losing in-flight codes.
987pub struct MagicCodeStore {
988    cache: Mutex<HashMap<String, MagicCode>>,
989    backend: Box<dyn MagicCodeBackend>,
990}
991
992#[derive(Debug, Clone)]
993pub struct MagicCode {
994    pub email: String,
995    pub code: String,
996    pub expires_at: u64,
997    /// Failed verify attempts against this code. Once it reaches
998    /// `MAX_ATTEMPTS` the code is invalidated.
999    pub attempts: u32,
1000}
1001
1002/// Maximum verify attempts per code before it's burned. 5 is a common bound —
1003/// lets the user fix typos without enabling realistic brute-force against a
1004/// 6-digit code space.
1005const MAX_ATTEMPTS: u32 = 5;
1006
1007/// Minimum seconds between successive `create()` calls for the same email.
1008/// Throttles magic-code spam (user can't be flooded with login codes).
1009const CREATE_COOLDOWN_SECS: u64 = 60;
1010
1011#[derive(Debug, Clone, PartialEq, Eq)]
1012pub enum MagicCodeError {
1013    /// There is no active code for this email, or it expired.
1014    NotFound,
1015    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
1016    TooManyAttempts,
1017    /// The code did not match.
1018    BadCode,
1019    /// The code expired since it was created.
1020    Expired,
1021    /// Another code was requested too recently. Wait and try again.
1022    Throttled { retry_after_secs: u64 },
1023}
1024
1025impl Default for MagicCodeStore {
1026    fn default() -> Self {
1027        Self::new()
1028    }
1029}
1030
1031impl MagicCodeStore {
1032    pub fn new() -> Self {
1033        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1034    }
1035
1036    /// Build a magic-code store backed by a persistent backend. Existing
1037    /// live codes are hydrated into the in-memory cache on construction
1038    /// so a server restart between "send" and "verify" doesn't kill the
1039    /// user's pending login.
1040    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1041        let now = now_secs();
1042        let mut cache = HashMap::new();
1043        for c in backend.load_all() {
1044            if c.expires_at > now {
1045                cache.insert(c.email.clone(), c);
1046            }
1047        }
1048        Self {
1049            cache: Mutex::new(cache),
1050            backend,
1051        }
1052    }
1053
1054    /// Generate a 6-digit code for an email and return it. Subject to a
1055    /// per-email cooldown — returns the error-shape via `try_create`.
1056    pub fn create(&self, email: &str) -> String {
1057        // Back-compat wrapper: same signature as before, but we still burn
1058        // the cooldown if one is active. Use `try_create` for a Result shape.
1059        self.try_create(email).unwrap_or_else(|_| String::new())
1060    }
1061
1062    /// Create a magic code, enforcing per-email cooldown. Returns the code
1063    /// or an error describing why one couldn't be issued.
1064    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1065        let now = now_secs();
1066
1067        let mut codes = self.cache.lock().unwrap();
1068
1069        // Cooldown check: if a live code exists and was created less than
1070        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
1071        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
1072        if let Some(existing) = codes.get(email) {
1073            if existing.expires_at > now {
1074                let created_at = existing.expires_at.saturating_sub(600);
1075                let age = now.saturating_sub(created_at);
1076                if age < CREATE_COOLDOWN_SECS {
1077                    return Err(MagicCodeError::Throttled {
1078                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
1079                    });
1080                }
1081            }
1082        }
1083
1084        let code = generate_magic_code();
1085        let mc = MagicCode {
1086            email: email.to_string(),
1087            code: code.clone(),
1088            expires_at: now + 600, // 10 minutes
1089            attempts: 0,
1090        };
1091        codes.insert(email.to_string(), mc.clone());
1092        // Persist after the cache mutation lands. Backend write is
1093        // best-effort — if it fails the code still works for this
1094        // process; only a restart in the next 10 minutes would lose it.
1095        self.backend.put(email, &mc);
1096        Ok(code)
1097    }
1098
1099    /// Verify a code for an email. Returns true if valid and not expired.
1100    /// Uses constant-time comparison to prevent timing attacks.
1101    /// Back-compat wrapper around [`try_verify`].
1102    pub fn verify(&self, email: &str, code: &str) -> bool {
1103        matches!(self.try_verify(email, code), Ok(()))
1104    }
1105
1106    /// Verify a code. Returns a typed error so callers can surface specific
1107    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1108    /// correct subsequent attempts return `TooManyAttempts`.
1109    /// Every magic code currently in the cache. Powers the Studio
1110    /// "Auth tables" view; not for app use. Includes expired codes —
1111    /// the cache only drops them on next verify attempt for that email.
1112    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1113        self.cache
1114            .lock()
1115            .map(|m| m.values().cloned().collect())
1116            .unwrap_or_default()
1117    }
1118
1119    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1120        let now = now_secs();
1121        let mut codes = self.cache.lock().unwrap();
1122
1123        let mc = match codes.get_mut(email) {
1124            Some(m) => m,
1125            None => return Err(MagicCodeError::NotFound),
1126        };
1127
1128        if mc.attempts >= MAX_ATTEMPTS {
1129            return Err(MagicCodeError::TooManyAttempts);
1130        }
1131        if mc.expires_at <= now {
1132            codes.remove(email);
1133            self.backend.remove(email);
1134            return Err(MagicCodeError::Expired);
1135        }
1136
1137        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1138        if !ok {
1139            mc.attempts += 1;
1140            self.backend.bump_attempts(email);
1141            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1142            if mc.attempts >= MAX_ATTEMPTS {
1143                return Err(MagicCodeError::TooManyAttempts);
1144            }
1145            return Err(MagicCodeError::BadCode);
1146        }
1147
1148        // Correct code — consume it.
1149        codes.remove(email);
1150        self.backend.remove(email);
1151        Ok(())
1152    }
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Cryptographic helpers — CSPRNG-based token and code generation
1157// ---------------------------------------------------------------------------
1158
1159fn hex_encode(bytes: &[u8]) -> String {
1160    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1161}
1162
1163/// Generate a 6-digit magic code using a CSPRNG.
1164fn generate_magic_code() -> String {
1165    use rand::Rng;
1166    let mut rng = rand::thread_rng();
1167    let code: u32 = rng.gen_range(0..1_000_000);
1168    format!("{:06}", code)
1169}
1170
1171/// Generate a session token with 256 bits of entropy from a CSPRNG.
1172fn generate_token() -> String {
1173    use rand::Rng;
1174    let mut rng = rand::thread_rng();
1175    let bytes: [u8; 32] = rng.gen();
1176    format!("pylon_{}", hex_encode(&bytes))
1177}
1178
1179// ---------------------------------------------------------------------------
1180// Session store — in-memory for dev
1181// ---------------------------------------------------------------------------
1182
1183use std::collections::HashMap;
1184use std::sync::Mutex;
1185
1186/// Pluggable storage backend for sessions. The default is in-memory; apps
1187/// deploying for real should supply a persistent backend (e.g. SQLite or
1188/// Redis) so users don't log out on server restart.
1189pub trait SessionBackend: Send + Sync {
1190    fn load_all(&self) -> Vec<Session>;
1191    fn save(&self, session: &Session);
1192    fn remove(&self, token: &str);
1193}
1194
1195/// A session store. In-memory by default; optionally backed by a
1196/// persistent [`SessionBackend`].
1197///
1198/// The in-memory map is always authoritative — reads don't touch the
1199/// backend. The backend receives every `save`/`remove`, making it a
1200/// write-through cache. On construction via [`SessionStore::with_backend`],
1201/// the store hydrates from the backend so sessions survive restart.
1202pub struct SessionStore {
1203    sessions: Mutex<HashMap<String, Session>>,
1204    backend: Option<Box<dyn SessionBackend>>,
1205}
1206
1207impl Default for SessionStore {
1208    fn default() -> Self {
1209        Self::new()
1210    }
1211}
1212
1213impl SessionStore {
1214    pub fn new() -> Self {
1215        Self {
1216            sessions: Mutex::new(HashMap::new()),
1217            backend: None,
1218        }
1219    }
1220
1221    /// Build a session store backed by a persistent store. Existing sessions
1222    /// are loaded from the backend on construction; every future mutation
1223    /// writes through.
1224    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1225        let mut map = HashMap::new();
1226        for s in backend.load_all() {
1227            if !s.is_expired() {
1228                map.insert(s.token.clone(), s);
1229            }
1230        }
1231        Self {
1232            sessions: Mutex::new(map),
1233            backend: Some(backend),
1234        }
1235    }
1236
1237    /// Create a session for a user and return it.
1238    pub fn create(&self, user_id: String) -> Session {
1239        let session = Session::new(user_id);
1240        let mut sessions = self.sessions.lock().unwrap();
1241        sessions.insert(session.token.clone(), session.clone());
1242        if let Some(b) = &self.backend {
1243            b.save(&session);
1244        }
1245        session
1246    }
1247
1248    /// Look up a session by token. Returns None if the session is expired.
1249    pub fn get(&self, token: &str) -> Option<Session> {
1250        let mut sessions = self.sessions.lock().unwrap();
1251        match sessions.get(token) {
1252            Some(s) if s.is_expired() => {
1253                sessions.remove(token);
1254                None
1255            }
1256            Some(s) => Some(s.clone()),
1257            None => None,
1258        }
1259    }
1260
1261    /// Resolve a token to an auth context.
1262    /// Returns anonymous context if the token is invalid, missing, or expired.
1263    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1264        match token {
1265            Some(t) => match self.get(t) {
1266                Some(session) => session.to_auth_context(),
1267                None => AuthContext::anonymous(),
1268            },
1269            None => AuthContext::anonymous(),
1270        }
1271    }
1272
1273    /// Refresh a session — issues a new token, copies user/device, extends expiry.
1274    /// The old token is revoked. Returns the new session or None if the old
1275    /// token is missing/expired.
1276    pub fn refresh(&self, old_token: &str) -> Option<Session> {
1277        let mut sessions = self.sessions.lock().unwrap();
1278        let old = sessions.remove(old_token)?;
1279        if let Some(b) = &self.backend {
1280            b.remove(old_token);
1281        }
1282        if old.is_expired() {
1283            return None;
1284        }
1285        let mut new = Session::new(old.user_id.clone());
1286        new.device = old.device.clone();
1287        sessions.insert(new.token.clone(), new.clone());
1288        if let Some(b) = &self.backend {
1289            b.save(&new);
1290        }
1291        Some(new)
1292    }
1293
1294    /// Every session in the store, including expired ones, with no
1295    /// filtering. Powers the Studio "Auth tables" view so operators
1296    /// can see orphaned sessions / debug stuck logins. Don't use for
1297    /// app code — `list_for_user` is the right surface there.
1298    pub fn list_all_unfiltered(&self) -> Vec<Session> {
1299        self.sessions
1300            .lock()
1301            .map(|m| m.values().cloned().collect())
1302            .unwrap_or_default()
1303    }
1304
1305    /// List all active sessions for a user.
1306    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1307        let sessions = self.sessions.lock().unwrap();
1308        sessions
1309            .values()
1310            .filter(|s| s.user_id == user_id && !s.is_expired())
1311            .cloned()
1312            .collect()
1313    }
1314
1315    /// Revoke all sessions for a user. Returns the count removed.
1316    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1317        let mut sessions = self.sessions.lock().unwrap();
1318        let tokens: Vec<String> = sessions
1319            .iter()
1320            .filter_map(|(t, s)| {
1321                if s.user_id == user_id {
1322                    Some(t.clone())
1323                } else {
1324                    None
1325                }
1326            })
1327            .collect();
1328        let n = tokens.len();
1329        for t in &tokens {
1330            sessions.remove(t);
1331            if let Some(b) = &self.backend {
1332                b.remove(t);
1333            }
1334        }
1335        n
1336    }
1337
1338    /// Sweep expired sessions. Returns the count removed.
1339    pub fn sweep_expired(&self) -> usize {
1340        let mut sessions = self.sessions.lock().unwrap();
1341        let expired: Vec<String> = sessions
1342            .iter()
1343            .filter_map(|(t, s)| {
1344                if s.is_expired() {
1345                    Some(t.clone())
1346                } else {
1347                    None
1348                }
1349            })
1350            .collect();
1351        let n = expired.len();
1352        for t in &expired {
1353            sessions.remove(t);
1354            if let Some(b) = &self.backend {
1355                b.remove(t);
1356            }
1357        }
1358        n
1359    }
1360
1361    /// Attach a device label to a session (typically on login from a browser).
1362    pub fn set_device(&self, token: &str, device: String) -> bool {
1363        let mut sessions = self.sessions.lock().unwrap();
1364        if let Some(s) = sessions.get_mut(token) {
1365            s.device = Some(device);
1366            if let Some(b) = &self.backend {
1367                b.save(s);
1368            }
1369            true
1370        } else {
1371            false
1372        }
1373    }
1374
1375    /// Create a guest session with a generated anonymous ID.
1376    pub fn create_guest(&self) -> Session {
1377        use rand::Rng;
1378        let mut rng = rand::thread_rng();
1379        let bytes: [u8; 16] = rng.gen();
1380        let guest_id = format!("guest_{}", hex_encode(&bytes));
1381        self.create(guest_id)
1382    }
1383
1384    /// Upgrade a guest session to a real user. Replaces the user_id.
1385    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1386        let mut sessions = self.sessions.lock().unwrap();
1387        if let Some(session) = sessions.get_mut(token) {
1388            session.user_id = real_user_id;
1389            if let Some(b) = &self.backend {
1390                b.save(session);
1391            }
1392            true
1393        } else {
1394            false
1395        }
1396    }
1397
1398    /// Switch the session's active tenant (organization). `None` clears it.
1399    /// Callers should verify the user actually has membership in the target
1400    /// tenant BEFORE invoking this — the session store takes the value on
1401    /// trust. Returns true if the session exists, false otherwise.
1402    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1403        let mut sessions = self.sessions.lock().unwrap();
1404        if let Some(session) = sessions.get_mut(token) {
1405            session.tenant_id = tenant_id;
1406            if let Some(b) = &self.backend {
1407                b.save(session);
1408            }
1409            true
1410        } else {
1411            false
1412        }
1413    }
1414
1415    /// Remove a session.
1416    pub fn revoke(&self, token: &str) -> bool {
1417        let mut sessions = self.sessions.lock().unwrap();
1418        let removed = sessions.remove(token).is_some();
1419        if removed {
1420            if let Some(b) = &self.backend {
1421                b.remove(token);
1422            }
1423        }
1424        removed
1425    }
1426}
1427
1428// ---------------------------------------------------------------------------
1429// OAuth account links — better-auth's `account` table equivalent
1430// ---------------------------------------------------------------------------
1431
1432/// A persisted account link. Schema-aligned with better-auth's `account`
1433/// table (verified against https://www.better-auth.com/docs/concepts/database
1434/// at the time of writing) so users migrating from better-auth see the
1435/// same field names + meanings:
1436///
1437/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
1438///   `"credential"` once email/password auth lands). Matches
1439///   better-auth's `providerId`.
1440/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
1441///   GitHub numeric `id`, or for email/password the user's own id).
1442///   Matches better-auth's `accountId`. NOT the row PK.
1443/// - `id` — the row PK, generated. Lets the row be referenced
1444///   independently of the (provider_id, account_id) natural key.
1445/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
1446///   rows; `None` for OAuth links. Reserves the column so adding
1447///   email/password auth doesn't need a schema migration.
1448///
1449/// Account vs. user: a single User row can have many Account rows
1450/// (Google + GitHub + a password — all linked to one pylon user).
1451/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
1452/// a user changing their Google address keeps the same pylon account.
1453#[derive(Debug, Clone, PartialEq, Eq)]
1454pub struct Account {
1455    pub id: String,
1456    pub user_id: String,
1457    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
1458    /// (better-auth: `providerId`)
1459    pub provider_id: String,
1460    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
1461    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
1462    pub account_id: String,
1463    pub access_token: Option<String>,
1464    pub refresh_token: Option<String>,
1465    pub id_token: Option<String>,
1466    /// Unix epoch seconds at which `access_token` expires. `None` for
1467    /// non-expiring tokens (GitHub Classic apps) or for password rows.
1468    pub access_token_expires_at: Option<u64>,
1469    /// Unix epoch seconds at which `refresh_token` expires. `None` when
1470    /// the provider doesn't expire refresh tokens (most don't, but
1471    /// Microsoft Identity Platform does after 90 days of inactivity).
1472    pub refresh_token_expires_at: Option<u64>,
1473    pub scope: Option<String>,
1474    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
1475    /// Always `None` today — present so adding password auth later
1476    /// doesn't require a schema migration.
1477    pub password: Option<String>,
1478    /// Unix epoch seconds when this account was first linked.
1479    pub created_at: u64,
1480    /// Unix epoch seconds when the token bundle was last refreshed.
1481    pub updated_at: u64,
1482}
1483
1484impl Account {
1485    /// Build a new account link from a freshly-completed OAuth handshake.
1486    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
1487    /// what later lookups key on.
1488    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1489        let now = now_secs();
1490        Self {
1491            id: generate_token(),
1492            user_id,
1493            provider_id: info.provider.clone(),
1494            account_id: info.provider_account_id.clone(),
1495            access_token: Some(tokens.access_token.clone()),
1496            refresh_token: tokens.refresh_token.clone(),
1497            id_token: tokens.id_token.clone(),
1498            access_token_expires_at: tokens.expires_at,
1499            refresh_token_expires_at: None,
1500            scope: tokens.scope.clone(),
1501            password: None,
1502            created_at: now,
1503            updated_at: now,
1504        }
1505    }
1506
1507    /// True if `access_token_expires_at` is set and has passed.
1508    /// Non-expiring tokens (GitHub Classic) report `false` — caller
1509    /// should treat them as "valid until proven otherwise" and refresh
1510    /// on 401.
1511    pub fn access_token_expired(&self) -> bool {
1512        match self.access_token_expires_at {
1513            Some(ts) => now_secs() >= ts,
1514            None => false,
1515        }
1516    }
1517}
1518
1519/// Pluggable storage for account links. In-memory default ships with
1520/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
1521pub trait AccountBackend: Send + Sync {
1522    /// Insert or refresh an account link. The `(provider_id, account_id)`
1523    /// pair is the natural key — repeated calls for the same pair
1524    /// update the token bundle and `updated_at` on the existing row.
1525    fn upsert(&self, account: &Account);
1526    /// Find an account by provider identity. Returns `None` if the user
1527    /// hasn't linked this provider yet.
1528    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1529    /// Every account linked to a user. The `/api/auth/me` endpoint uses
1530    /// this to render "you're connected via Google + GitHub" in the UI
1531    /// and to gate "unlink" affordances behind "user has another way to
1532    /// sign in" checks.
1533    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1534    /// Remove a single provider link. Returns `true` if a row was removed.
1535    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1536    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
1537    /// to power the Studio admin inspector. Backends that can stream
1538    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
1539    /// returns its full map.
1540    fn list_all(&self) -> Vec<Account>;
1541}
1542
1543/// In-memory account backend (default). Lost on restart — production
1544/// deployments should swap in a persistent backend so refresh tokens
1545/// survive a redeploy.
1546pub struct InMemoryAccountBackend {
1547    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
1548    /// user_id would speed up `find_for_user` but at framework scale
1549    /// the linear scan of (typically ≤ 5) accounts per user is fine.
1550    accounts: Mutex<HashMap<(String, String), Account>>,
1551}
1552
1553impl InMemoryAccountBackend {
1554    pub fn new() -> Self {
1555        Self {
1556            accounts: Mutex::new(HashMap::new()),
1557        }
1558    }
1559}
1560
1561impl Default for InMemoryAccountBackend {
1562    fn default() -> Self {
1563        Self::new()
1564    }
1565}
1566
1567impl AccountBackend for InMemoryAccountBackend {
1568    fn upsert(&self, account: &Account) {
1569        let key = (account.provider_id.clone(), account.account_id.clone());
1570        self.accounts.lock().unwrap().insert(key, account.clone());
1571    }
1572    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1573        self.accounts
1574            .lock()
1575            .unwrap()
1576            .get(&(provider_id.to_string(), account_id.to_string()))
1577            .cloned()
1578    }
1579    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1580        self.accounts
1581            .lock()
1582            .unwrap()
1583            .values()
1584            .filter(|a| a.user_id == user_id)
1585            .cloned()
1586            .collect()
1587    }
1588    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1589        self.accounts
1590            .lock()
1591            .unwrap()
1592            .remove(&(provider_id.to_string(), account_id.to_string()))
1593            .is_some()
1594    }
1595    fn list_all(&self) -> Vec<Account> {
1596        self.accounts.lock().unwrap().values().cloned().collect()
1597    }
1598}
1599
1600/// Account store. Wraps an `AccountBackend` and provides the methods the
1601/// OAuth callback / API endpoints actually call.
1602pub struct AccountStore {
1603    backend: Box<dyn AccountBackend>,
1604}
1605
1606impl Default for AccountStore {
1607    fn default() -> Self {
1608        Self::new()
1609    }
1610}
1611
1612impl AccountStore {
1613    pub fn new() -> Self {
1614        Self {
1615            backend: Box::new(InMemoryAccountBackend::new()),
1616        }
1617    }
1618    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1619        Self { backend }
1620    }
1621    pub fn upsert(&self, account: &Account) {
1622        self.backend.upsert(account);
1623    }
1624    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1625        self.backend.find_by_provider(provider_id, account_id)
1626    }
1627    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1628        self.backend.find_for_user(user_id)
1629    }
1630    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1631        self.backend.unlink(provider_id, account_id)
1632    }
1633
1634    /// Every account in the store. Powers the Studio "Auth tables"
1635    /// view; not for app use. Implemented by walking the backend's
1636    /// per-user index — doable because account counts per user are
1637    /// small (typically ≤ 5) and total account count tracks user
1638    /// count.
1639    ///
1640    /// We don't add a `list_all` method to the `AccountBackend` trait
1641    /// because the in-memory + sqlite + postgres impls would each
1642    /// need a separate implementation, and the operational use case
1643    /// (Studio inspector) is narrow enough to live behind a wrapper
1644    /// that walks the underlying store directly. For PG/SQLite that
1645    /// means a `SELECT * FROM _pylon_accounts` — which the backends
1646    /// can grow if we ever need this at scale.
1647    pub fn list_all_unfiltered(&self) -> Vec<Account> {
1648        self.backend.list_all()
1649    }
1650}
1651
1652// ---------------------------------------------------------------------------
1653// Tests
1654// ---------------------------------------------------------------------------
1655
1656#[cfg(test)]
1657mod tests {
1658    use super::*;
1659
1660    #[test]
1661    fn anonymous_context() {
1662        let ctx = AuthContext::anonymous();
1663        assert!(!ctx.is_authenticated());
1664        assert!(ctx.user_id.is_none());
1665    }
1666
1667    #[test]
1668    fn authenticated_context() {
1669        let ctx = AuthContext::authenticated("user-1".into());
1670        assert!(ctx.is_authenticated());
1671        assert_eq!(ctx.user_id, Some("user-1".into()));
1672    }
1673
1674    #[test]
1675    fn auth_mode_public_allows_anonymous() {
1676        let mode = AuthMode::Public;
1677        assert!(mode.check(&AuthContext::anonymous()));
1678        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1679    }
1680
1681    #[test]
1682    fn auth_mode_user_requires_authenticated() {
1683        let mode = AuthMode::User;
1684        assert!(!mode.check(&AuthContext::anonymous()));
1685        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1686    }
1687
1688    #[test]
1689    fn auth_mode_from_str() {
1690        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1691        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1692        assert_eq!(AuthMode::from_str("admin"), None);
1693    }
1694
1695    #[test]
1696    fn session_store_create_and_get() {
1697        let store = SessionStore::new();
1698        let session = store.create("user-1".into());
1699        assert!(!session.token.is_empty());
1700        assert!(session.token.starts_with("pylon_"));
1701
1702        let retrieved = store.get(&session.token).unwrap();
1703        assert_eq!(retrieved.user_id, "user-1");
1704    }
1705
1706    #[test]
1707    fn session_store_resolve() {
1708        let store = SessionStore::new();
1709        let session = store.create("user-1".into());
1710
1711        let ctx = store.resolve(Some(&session.token));
1712        assert!(ctx.is_authenticated());
1713        assert_eq!(ctx.user_id, Some("user-1".into()));
1714
1715        let anon = store.resolve(None);
1716        assert!(!anon.is_authenticated());
1717
1718        let bad = store.resolve(Some("invalid-token"));
1719        assert!(!bad.is_authenticated());
1720    }
1721
1722    #[test]
1723    fn session_store_revoke() {
1724        let store = SessionStore::new();
1725        let session = store.create("user-1".into());
1726
1727        assert!(store.revoke(&session.token));
1728        assert!(store.get(&session.token).is_none());
1729        assert!(!store.revoke(&session.token)); // already revoked
1730    }
1731
1732    #[test]
1733    fn session_to_auth_context() {
1734        let session = Session::new("user-42".into());
1735        let ctx = session.to_auth_context();
1736        assert_eq!(ctx.user_id, Some("user-42".into()));
1737    }
1738
1739    // -- Admin context --
1740
1741    #[test]
1742    fn admin_context() {
1743        let ctx = AuthContext::admin();
1744        assert!(ctx.is_admin);
1745        assert!(ctx.is_authenticated());
1746    }
1747
1748    #[test]
1749    fn anonymous_not_admin() {
1750        let ctx = AuthContext::anonymous();
1751        assert!(!ctx.is_admin);
1752    }
1753
1754    #[test]
1755    fn authenticated_not_admin() {
1756        let ctx = AuthContext::authenticated("user-1".into());
1757        assert!(!ctx.is_admin);
1758    }
1759
1760    // -- Magic codes --
1761
1762    #[test]
1763    fn magic_code_create_and_verify() {
1764        let store = MagicCodeStore::new();
1765        let code = store.create("test@example.com");
1766        assert_eq!(code.len(), 6);
1767        assert!(store.verify("test@example.com", &code));
1768    }
1769
1770    #[test]
1771    fn magic_code_wrong_code_rejected() {
1772        let store = MagicCodeStore::new();
1773        store.create("test@example.com");
1774        assert!(!store.verify("test@example.com", "000000"));
1775    }
1776
1777    #[test]
1778    fn magic_code_wrong_email_rejected() {
1779        let store = MagicCodeStore::new();
1780        let code = store.create("test@example.com");
1781        assert!(!store.verify("other@example.com", &code));
1782    }
1783
1784    #[test]
1785    fn magic_code_consumed_after_verify() {
1786        let store = MagicCodeStore::new();
1787        let code = store.create("test@example.com");
1788        assert!(store.verify("test@example.com", &code));
1789        // Second verify should fail — code consumed.
1790        assert!(!store.verify("test@example.com", &code));
1791    }
1792
1793    #[test]
1794    fn magic_code_different_emails_independent() {
1795        let store = MagicCodeStore::new();
1796        let code1 = store.create("alice@example.com");
1797        let code2 = store.create("bob@example.com");
1798        // Each email has its own code.
1799        assert!(store.verify("alice@example.com", &code1));
1800        assert!(store.verify("bob@example.com", &code2));
1801    }
1802
1803    // -- Constant-time comparison --
1804
1805    #[test]
1806    fn constant_time_eq_equal() {
1807        assert!(constant_time_eq(b"hello", b"hello"));
1808        assert!(constant_time_eq(b"", b""));
1809    }
1810
1811    #[test]
1812    fn constant_time_eq_not_equal() {
1813        assert!(!constant_time_eq(b"hello", b"world"));
1814        assert!(!constant_time_eq(b"hello", b"hell"));
1815        assert!(!constant_time_eq(b"a", b"b"));
1816    }
1817
1818    // -- Token generation --
1819
1820    #[test]
1821    fn generated_tokens_are_unique() {
1822        let t1 = generate_token();
1823        let t2 = generate_token();
1824        assert_ne!(t1, t2);
1825        assert!(t1.starts_with("pylon_"));
1826        assert!(t2.starts_with("pylon_"));
1827        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
1828        assert_eq!(t1.len(), 6 + 64);
1829    }
1830
1831    // -- OAuth registry --
1832
1833    #[test]
1834    fn oauth_registry_empty() {
1835        let reg = OAuthRegistry::new();
1836        assert!(reg.get("google").is_none());
1837    }
1838
1839    #[test]
1840    fn oauth_registry_register_and_get() {
1841        let mut reg = OAuthRegistry::new();
1842        reg.register(OAuthConfig {
1843            provider: "google".into(),
1844            client_id: "test-id".into(),
1845            client_secret: "test-secret".into(),
1846            redirect_uri: "http://localhost/callback".into(),
1847        });
1848        let config = reg.get("google").unwrap();
1849        assert_eq!(config.client_id, "test-id");
1850        assert!(config.auth_url().contains("accounts.google.com"));
1851    }
1852
1853    // -- Guest auth --
1854
1855    #[test]
1856    fn guest_session() {
1857        let store = SessionStore::new();
1858        let session = store.create_guest();
1859        assert!(session.user_id.starts_with("guest_"));
1860        assert!(!session.token.is_empty());
1861
1862        let ctx = store.resolve(Some(&session.token));
1863        assert!(ctx.is_authenticated());
1864        assert!(ctx.user_id.unwrap().starts_with("guest_"));
1865    }
1866
1867    #[test]
1868    fn upgrade_guest_to_real_user() {
1869        let store = SessionStore::new();
1870        let session = store.create_guest();
1871        assert!(session.user_id.starts_with("guest_"));
1872
1873        let upgraded = store.upgrade(&session.token, "real-user-123".into());
1874        assert!(upgraded);
1875
1876        let ctx = store.resolve(Some(&session.token));
1877        assert_eq!(ctx.user_id, Some("real-user-123".into()));
1878    }
1879
1880    #[test]
1881    fn upgrade_invalid_token_fails() {
1882        let store = SessionStore::new();
1883        let upgraded = store.upgrade("nonexistent-token", "user".into());
1884        assert!(!upgraded);
1885    }
1886
1887    #[test]
1888    fn guest_context() {
1889        let ctx = AuthContext::guest("guest_123".into());
1890        // Guests carry a stable id but are NOT authenticated — routes
1891        // guarded by AuthMode::User must reject them.
1892        assert!(!ctx.is_authenticated());
1893        assert!(ctx.is_guest);
1894        assert!(!ctx.is_admin);
1895        assert_eq!(ctx.user_id, Some("guest_123".into()));
1896        assert!(!AuthMode::User.check(&ctx));
1897        assert!(AuthMode::Public.check(&ctx));
1898    }
1899
1900    #[test]
1901    fn oauth_token_urls() {
1902        let google = OAuthConfig {
1903            provider: "google".into(),
1904            client_id: "x".into(),
1905            client_secret: "x".into(),
1906            redirect_uri: "x".into(),
1907        };
1908        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1909        let github = OAuthConfig {
1910            provider: "github".into(),
1911            client_id: "x".into(),
1912            client_secret: "x".into(),
1913            redirect_uri: "x".into(),
1914        };
1915        assert_eq!(
1916            github.token_url(),
1917            "https://github.com/login/oauth/access_token"
1918        );
1919        let unknown = OAuthConfig {
1920            provider: "unknown".into(),
1921            client_id: "x".into(),
1922            client_secret: "x".into(),
1923            redirect_uri: "x".into(),
1924        };
1925        assert_eq!(unknown.token_url(), "");
1926        assert!(unknown.auth_url().is_empty());
1927    }
1928
1929    #[test]
1930    fn oauth_auth_url_github() {
1931        let config = OAuthConfig {
1932            provider: "github".into(),
1933            client_id: "gh-id".into(),
1934            client_secret: "gh-secret".into(),
1935            redirect_uri: "http://localhost/cb".into(),
1936        };
1937        assert!(config.auth_url().contains("github.com"));
1938        assert!(config.auth_url().contains("gh-id"));
1939    }
1940
1941    #[test]
1942    fn oauth_auth_url_with_state() {
1943        let config = OAuthConfig {
1944            provider: "google".into(),
1945            client_id: "test-id".into(),
1946            client_secret: "test-secret".into(),
1947            redirect_uri: "http://localhost/cb".into(),
1948        };
1949        let url = config.auth_url_with_state("random_state_123");
1950        assert!(url.contains("&state=random_state_123"));
1951    }
1952
1953    #[test]
1954    fn oauth_state_store_create_and_validate() {
1955        let store = OAuthStateStore::new();
1956        let token = store.create("google", "https://app/cb", "https://app/login");
1957        let rec = store.validate(&token, "google").expect("valid first time");
1958        assert_eq!(rec.callback_url, "https://app/cb");
1959        assert_eq!(rec.error_callback_url, "https://app/login");
1960        // Second validation should fail — single-use.
1961        assert!(store.validate(&token, "google").is_none());
1962    }
1963
1964    #[test]
1965    fn oauth_state_store_wrong_provider_rejected() {
1966        let store = OAuthStateStore::new();
1967        let token = store.create("google", "https://app/cb", "https://app/cb");
1968        assert!(store.validate(&token, "github").is_none());
1969    }
1970
1971    #[test]
1972    fn oauth_state_store_invalid_state_rejected() {
1973        let store = OAuthStateStore::new();
1974        assert!(store.validate("nonexistent", "google").is_none());
1975    }
1976
1977    #[test]
1978    fn validate_trusted_redirect_basics() {
1979        let trusted = vec!["http://localhost:3000".to_string()];
1980        assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
1981        assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
1982        assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
1983
1984        // Wrong port → wrong origin.
1985        assert!(matches!(
1986            validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
1987            Err(TrustedOriginError::NotTrusted { .. })
1988        ));
1989        // Non-http scheme rejected even before trusted check (defense
1990        // against javascript:, file:, data:).
1991        assert!(matches!(
1992            validate_trusted_redirect("javascript:alert(1)", &trusted),
1993            Err(TrustedOriginError::NotHttp)
1994        ));
1995        assert!(matches!(
1996            validate_trusted_redirect("", &trusted),
1997            Err(TrustedOriginError::Empty)
1998        ));
1999    }
2000}