Skip to main content

pylon_auth/
lib.rs

1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9// ---------------------------------------------------------------------------
10// Auth context — the identity available to runtime operations
11// ---------------------------------------------------------------------------
12
13/// The auth context for a request. Represents who is making the request.
14///
15/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
16/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
17/// or add roles and bypass every policy. Identity must come from
18/// server-minted sessions (`Session::to_auth_context`) or explicit
19/// constructors, never from deserialization.
20///
21/// `Serialize` is safe because sending the resolved context BACK to the
22/// client exposes nothing the server didn't already decide.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25    /// The authenticated user ID, or None for public/anonymous access.
26    /// For guest contexts this is `Some(guest_id)` — a stable
27    /// anonymous identifier, NOT a real user.
28    pub user_id: Option<String>,
29    /// Whether this is an admin context (bypasses policies).
30    pub is_admin: bool,
31    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
32    /// for cart state and similar pre-login persistence. Routes guarded
33    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
34    /// "real signed-in user" should pass auth-required gates.
35    #[serde(default, skip_serializing_if = "is_false")]
36    pub is_guest: bool,
37    /// Roles granted to this user. Empty for anonymous.
38    pub roles: Vec<String>,
39    /// Active tenant id (for multi-tenant apps). Set when the user has
40    /// selected an organization for the current session.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46    !b
47}
48
49impl AuthContext {
50    /// Create an anonymous/public auth context.
51    pub fn anonymous() -> Self {
52        Self {
53            user_id: None,
54            is_admin: false,
55            is_guest: false,
56            roles: Vec::new(),
57            tenant_id: None,
58        }
59    }
60
61    /// Create an authenticated auth context.
62    pub fn authenticated(user_id: String) -> Self {
63        Self {
64            user_id: Some(user_id),
65            is_admin: false,
66            is_guest: false,
67            roles: Vec::new(),
68            tenant_id: None,
69        }
70    }
71
72    /// Create a guest auth context with a persistent anonymous ID.
73    /// Guests carry an opaque stable id (cart/session continuity) but
74    /// are NOT considered authenticated — `is_authenticated()` returns
75    /// false and `AuthMode::User` rejects them.
76    pub fn guest(guest_id: String) -> Self {
77        Self {
78            user_id: Some(guest_id),
79            is_admin: false,
80            is_guest: true,
81            roles: Vec::new(),
82            tenant_id: None,
83        }
84    }
85
86    /// Create an admin auth context that bypasses all policies.
87    pub fn admin() -> Self {
88        Self {
89            user_id: Some("__admin__".into()),
90            is_admin: true,
91            is_guest: false,
92            roles: vec!["admin".into()],
93            tenant_id: None,
94        }
95    }
96
97    /// Convenience: build a user context from a user id.
98    pub fn user(user_id: String) -> Self {
99        Self::authenticated(user_id)
100    }
101
102    /// Active tenant id (None when the user hasn't selected an org).
103    pub fn tenant_id(&self) -> Option<&str> {
104        self.tenant_id.as_deref()
105    }
106
107    /// Attach a tenant id to the context (chainable).
108    pub fn with_tenant(mut self, tenant_id: String) -> Self {
109        self.tenant_id = Some(tenant_id);
110        self
111    }
112
113    /// Check if this context represents an authenticated user.
114    /// Guests intentionally return `false` — they have a stable anonymous
115    /// id but never gain user-level access.
116    pub fn is_authenticated(&self) -> bool {
117        self.user_id.is_some() && !self.is_guest
118    }
119
120    /// Check if the user has a specific role. Admins have every role implicitly.
121    pub fn has_role(&self, role: &str) -> bool {
122        self.is_admin || self.roles.iter().any(|r| r == role)
123    }
124
125    /// Check if the user has ANY of the given roles.
126    pub fn has_any_role(&self, roles: &[&str]) -> bool {
127        self.is_admin || roles.iter().any(|r| self.has_role(r))
128    }
129
130    /// Attach roles to the context (chainable).
131    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132        self.roles = roles;
133        self
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Constant-time comparison
139// ---------------------------------------------------------------------------
140
141/// Constant-time byte comparison to prevent timing attacks.
142///
143/// The length check leaks whether the two slices are the same length, but the
144/// content comparison always examines every byte regardless of where (or
145/// whether) a mismatch occurs.
146pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147    if a.len() != b.len() {
148        return false;
149    }
150    let mut result: u8 = 0;
151    for (x, y) in a.iter().zip(b.iter()) {
152        result |= x ^ y;
153    }
154    result == 0
155}
156
157// ---------------------------------------------------------------------------
158// Auth mode — matches the route "auth" field values
159// ---------------------------------------------------------------------------
160
161/// The auth mode declared on a route.
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164    /// Anyone can access.
165    Public,
166    /// Only authenticated users can access.
167    User,
168}
169
170impl AuthMode {
171    /// Parse from the manifest auth string.
172    #[allow(clippy::should_implement_trait)]
173    pub fn from_str(s: &str) -> Option<Self> {
174        match s {
175            "public" => Some(AuthMode::Public),
176            "user" => Some(AuthMode::User),
177            _ => None,
178        }
179    }
180
181    /// Check if the given auth context satisfies this mode.
182    pub fn check(&self, ctx: &AuthContext) -> bool {
183        match self {
184            AuthMode::Public => true,
185            AuthMode::User => ctx.is_authenticated(),
186        }
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Session — opaque token session
192// ---------------------------------------------------------------------------
193
194/// A session token and its associated user.
195#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197    pub token: String,
198    pub user_id: String,
199    /// Unix epoch seconds at which this session expires. 0 = never.
200    #[serde(default)]
201    pub expires_at: u64,
202    /// Optional user-agent / device tag recorded at session creation.
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub device: Option<String>,
205    /// Unix epoch seconds when the session was created.
206    #[serde(default)]
207    pub created_at: u64,
208    /// Active tenant id (selected organization). Set via
209    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
210    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub tenant_id: Option<String>,
213}
214
215impl Session {
216    /// Default session lifetime: 30 days.
217    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219    /// Create a new session with a generated token and default 30-day expiry.
220    pub fn new(user_id: String) -> Self {
221        let now = now_secs();
222        Self {
223            token: generate_token(),
224            user_id,
225            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226            device: None,
227            created_at: now,
228            tenant_id: None,
229        }
230    }
231
232    /// Create a session with a specific lifetime.
233    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234        let now = now_secs();
235        Self {
236            token: generate_token(),
237            user_id,
238            expires_at: if lifetime_secs == 0 {
239                0
240            } else {
241                now.saturating_add(lifetime_secs)
242            },
243            device: None,
244            created_at: now,
245            tenant_id: None,
246        }
247    }
248
249    /// Convert this session to an auth context, carrying the selected
250    /// tenant so row-scoped policies see `auth.tenantId`.
251    pub fn to_auth_context(&self) -> AuthContext {
252        let ctx = AuthContext::authenticated(self.user_id.clone());
253        match &self.tenant_id {
254            Some(t) => ctx.with_tenant(t.clone()),
255            None => ctx,
256        }
257    }
258
259    /// Returns true if the session has passed its expires_at time.
260    /// Boundary is inclusive (`>=`) to match the rest of the codebase
261    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
262    pub fn is_expired(&self) -> bool {
263        self.expires_at != 0 && now_secs() >= self.expires_at
264    }
265}
266
267fn now_secs() -> u64 {
268    use std::time::{SystemTime, UNIX_EPOCH};
269    SystemTime::now()
270        .duration_since(UNIX_EPOCH)
271        .unwrap_or_default()
272        .as_secs()
273}
274
275// ---------------------------------------------------------------------------
276// OAuth provider config
277// ---------------------------------------------------------------------------
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281    pub provider: String,
282    pub client_id: String,
283    pub client_secret: String,
284    pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288    /// Generate the authorization URL for the provider.
289    ///
290    /// Callers MUST append a `&state=<random>` parameter and validate it in the
291    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
292    /// implementation.
293    pub fn auth_url(&self) -> String {
294        match self.provider.as_str() {
295            "google" => format!(
296                "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297                self.client_id, self.redirect_uri
298            ),
299            "github" => format!(
300                "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301                self.client_id, self.redirect_uri
302            ),
303            _ => String::new(),
304        }
305    }
306
307    /// Generate the authorization URL with a CSRF state parameter attached.
308    pub fn auth_url_with_state(&self, state: &str) -> String {
309        let base = self.auth_url();
310        if base.is_empty() {
311            return base;
312        }
313        format!("{}&state={}", base, state)
314    }
315
316    /// Generate the token exchange URL.
317    pub fn token_url(&self) -> &str {
318        match self.provider.as_str() {
319            "google" => "https://oauth2.googleapis.com/token",
320            "github" => "https://github.com/login/oauth/access_token",
321            _ => "",
322        }
323    }
324
325    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
326    pub fn userinfo_url(&self) -> &str {
327        match self.provider.as_str() {
328            "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329            "github" => "https://api.github.com/user",
330            _ => "",
331        }
332    }
333
334    /// Exchange an authorization code for the full token set
335    /// (`access_token`, optional `refresh_token`, optional `id_token`,
336    /// `expires_in`, `scope`). The longer struct is what the
337    /// account-store needs to persist; the legacy
338    /// [`OAuthConfig::exchange_code`] returns just the access token for
339    /// callers that don't care.
340    pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341        let body = match self.provider.as_str() {
342            "google" => format!(
343                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344                url_encode(&self.client_id),
345                url_encode(&self.client_secret),
346                url_encode(&self.redirect_uri)
347            ),
348            "github" => format!(
349                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350                url_encode(&self.client_id),
351                url_encode(&self.client_secret),
352                url_encode(&self.redirect_uri)
353            ),
354            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355        };
356
357        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358        parse_token_response(&out)
359    }
360
361    /// Exchange an authorization code for an access token. Thin wrapper
362    /// around [`OAuthConfig::exchange_code_full`] for callers that only
363    /// need the access token (existing pre-account-store call sites).
364    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365        let body = match self.provider.as_str() {
366            "google" => format!(
367                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368                url_encode(&self.client_id),
369                url_encode(&self.client_secret),
370                url_encode(&self.redirect_uri)
371            ),
372            "github" => format!(
373                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374                url_encode(&self.client_id),
375                url_encode(&self.client_secret),
376                url_encode(&self.redirect_uri)
377            ),
378            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379        };
380
381        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382        extract_access_token(&out)
383    }
384
385    /// Fetch the authenticated user's email + display name using an access token.
386    /// Returns `(email, display_name)`. Use [`OAuthConfig::fetch_userinfo_full`]
387    /// when you also need the provider-stable account ID for account
388    /// linking — the (`provider`, `provider_account_id`) pair is what
389    /// keeps a renamed-email user matched to the same row.
390    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391        let info = self.fetch_userinfo_full(access_token)?;
392        Ok((info.email, info.name))
393    }
394
395    /// Fetch the authenticated user's full identity info — email + name +
396    /// the provider-stable account ID (Google's `sub`, GitHub's `id`).
397    /// `provider_account_id` is what the account-store keys on, NOT the
398    /// email; otherwise a user changing their Google address would orphan
399    /// their existing pylon account.
400    pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401        let out = http_get_bearer(self.userinfo_url(), access_token)?;
402        let parsed: serde_json::Value =
403            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404        match self.provider.as_str() {
405            "google" => {
406                let email = parsed
407                    .get("email")
408                    .and_then(|v| v.as_str())
409                    .ok_or("no email in userinfo")?
410                    .to_string();
411                let name = parsed
412                    .get("name")
413                    .and_then(|v| v.as_str())
414                    .map(String::from);
415                let provider_account_id = parsed
416                    .get("sub")
417                    .and_then(|v| v.as_str())
418                    .ok_or("no sub in userinfo")?
419                    .to_string();
420                Ok(UserInfo {
421                    provider: self.provider.clone(),
422                    provider_account_id,
423                    email,
424                    name,
425                })
426            }
427            "github" => {
428                let name = parsed
429                    .get("name")
430                    .and_then(|v| v.as_str())
431                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432                    .map(String::from);
433                let email = parsed
434                    .get("email")
435                    .and_then(|v| v.as_str())
436                    .map(String::from);
437                // GitHub may return a null email if the user hasn't published one;
438                // in that case the caller should hit /user/emails with the same token.
439                let email = email
440                    .or_else(|| fetch_github_primary_email(access_token).ok())
441                    .ok_or("no accessible email on GitHub account")?;
442                // GitHub's `id` field is a numeric user ID — the stable
443                // account identifier even if the user renames themselves.
444                let provider_account_id = parsed
445                    .get("id")
446                    .map(|v| {
447                        v.as_i64()
448                            .map(|n| n.to_string())
449                            .or_else(|| v.as_str().map(String::from))
450                            .unwrap_or_default()
451                    })
452                    .filter(|s| !s.is_empty())
453                    .ok_or("no id in userinfo")?;
454                Ok(UserInfo {
455                    provider: self.provider.clone(),
456                    provider_account_id,
457                    email,
458                    name,
459                })
460            }
461            _ => Err(format!("unknown provider: {}", self.provider)),
462        }
463    }
464}
465
466/// Resolved identity returned by [`OAuthConfig::fetch_userinfo_full`].
467/// `provider_account_id` is the provider-stable subject id (Google `sub`,
468/// GitHub numeric `id`) — what the account store keys on so a renamed
469/// email doesn't orphan the pylon account.
470#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472    pub provider: String,
473    pub provider_account_id: String,
474    pub email: String,
475    pub name: Option<String>,
476}
477
478/// Token bundle returned by [`OAuthConfig::exchange_code_full`]. Stored
479/// on the matching `Account` row so `refresh_token` is available for
480/// silent re-auth and `expires_at` is checked before each provider call.
481#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483    pub access_token: String,
484    pub refresh_token: Option<String>,
485    pub id_token: Option<String>,
486    /// Unix epoch seconds at which the access token expires. `None` when
487    /// the provider didn't return `expires_in` (GitHub's classic OAuth
488    /// app tokens are non-expiring).
489    pub expires_at: Option<u64>,
490    pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494    // Most providers return JSON; GitHub Classic apps return form-urlencoded
495    // unless you ask with Accept: application/json (which we do).
496    let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497        // Fall back to form-urlencoded: access_token=...&scope=...&token_type=...
498        let mut map = serde_json::Map::new();
499        for pair in body.split('&') {
500            if let Some((k, v)) = pair.split_once('=') {
501                map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502            }
503        }
504        serde_json::Value::Object(map)
505    });
506
507    let access_token = json
508        .get("access_token")
509        .and_then(|v| v.as_str())
510        .ok_or_else(|| format!("no access_token in token response: {body}"))?
511        .to_string();
512    let refresh_token = json
513        .get("refresh_token")
514        .and_then(|v| v.as_str())
515        .map(String::from);
516    let id_token = json
517        .get("id_token")
518        .and_then(|v| v.as_str())
519        .map(String::from);
520    let expires_at = json
521        .get("expires_in")
522        .and_then(|v| {
523            v.as_u64()
524                .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525        })
526        .map(|secs| now_secs().saturating_add(secs));
527    let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528    Ok(TokenSet {
529        access_token,
530        refresh_token,
531        id_token,
532        expires_at,
533        scope,
534    })
535}
536
537fn url_encode(s: &str) -> String {
538    let mut out = String::with_capacity(s.len());
539    for b in s.bytes() {
540        match b {
541            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542                out.push(b as char)
543            }
544            _ => out.push_str(&format!("%{b:02X}")),
545        }
546    }
547    out
548}
549
550/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
551/// provider doesn't block a login indefinitely; long enough to absorb
552/// typical internet latency.
553const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556    ureq::AgentBuilder::new()
557        .timeout_connect(HTTP_TIMEOUT)
558        .timeout_read(HTTP_TIMEOUT)
559        .timeout_write(HTTP_TIMEOUT)
560        .user_agent("pylon/0.1")
561        .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565    let agent = ureq_agent();
566    let mut req = agent
567        .post(url)
568        .set("Content-Type", "application/x-www-form-urlencoded");
569    if accept_json {
570        req = req.set("Accept", "application/json");
571    }
572    match req.send_string(body) {
573        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574        Err(ureq::Error::Status(code, resp)) => {
575            let body = resp.into_string().unwrap_or_default();
576            Err(format!("HTTP {code}: {body}"))
577        }
578        Err(e) => Err(format!("HTTP error: {e}")),
579    }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583    let agent = ureq_agent();
584    match agent
585        .get(url)
586        .set("Authorization", &format!("Bearer {token}"))
587        .set("Accept", "application/json")
588        .call()
589    {
590        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591        Err(ureq::Error::Status(code, resp)) => {
592            let body = resp.into_string().unwrap_or_default();
593            Err(format!("HTTP {code}: {body}"))
594        }
595        Err(e) => Err(format!("HTTP error: {e}")),
596    }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601    let emails: serde_json::Value =
602        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603    emails
604        .as_array()
605        .and_then(|arr| {
606            arr.iter()
607                .find(|e| {
608                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610                })
611                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612        })
613        .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618        if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619            return Ok(t.to_string());
620        }
621    }
622    // GitHub can return url-encoded: access_token=...&scope=...&token_type=bearer
623    for pair in body.split('&') {
624        if let Some(val) = pair.strip_prefix("access_token=") {
625            return Ok(val.to_string());
626        }
627    }
628    Err(format!("no access_token in token response: {body}"))
629}
630
631/// OAuth provider registry.
632pub struct OAuthRegistry {
633    providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642impl OAuthRegistry {
643    pub fn new() -> Self {
644        Self {
645            providers: std::collections::HashMap::new(),
646        }
647    }
648
649    pub fn register(&mut self, config: OAuthConfig) {
650        self.providers.insert(config.provider.clone(), config);
651    }
652
653    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654        self.providers.get(provider)
655    }
656
657    /// Build from environment variables.
658    /// Looks for PYLON_OAUTH_GOOGLE_CLIENT_ID, etc.
659    pub fn from_env() -> Self {
660        let mut reg = Self::new();
661
662        // Google
663        if let (Ok(id), Ok(secret)) = (
664            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666        ) {
667            reg.register(OAuthConfig {
668                provider: "google".into(),
669                client_id: id,
670                client_secret: secret,
671                redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673            });
674        }
675
676        // GitHub
677        if let (Ok(id), Ok(secret)) = (
678            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680        ) {
681            reg.register(OAuthConfig {
682                provider: "github".into(),
683                client_id: id,
684                client_secret: secret,
685                redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687            });
688        }
689
690        reg
691    }
692}
693
694// ---------------------------------------------------------------------------
695// OAuth state store — CSRF protection for OAuth flows
696// ---------------------------------------------------------------------------
697
698/// Backing store for OAuth state tokens. Default impl keeps them in memory
699/// (fine for tests + dev); the runtime swaps in a SQLite-backed impl so a
700/// restart in the middle of an OAuth handshake doesn't leave the user with
701/// "invalid state" on the callback. Same pattern as `SessionBackend`.
702pub trait OAuthStateBackend: Send + Sync {
703    fn put(&self, token: &str, provider: &str, expires_at: u64);
704    /// Atomic compare-and-consume: returns the stored provider if the token
705    /// exists and hasn't expired, then removes it. Returning `None` means
706    /// either the token never existed or it has already been used.
707    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
708}
709
710/// In-memory backend (default). Lost on restart.
711pub struct InMemoryOAuthBackend {
712    states: Mutex<HashMap<String, OAuthState>>,
713}
714
715impl InMemoryOAuthBackend {
716    pub fn new() -> Self {
717        Self {
718            states: Mutex::new(HashMap::new()),
719        }
720    }
721}
722
723impl Default for InMemoryOAuthBackend {
724    fn default() -> Self {
725        Self::new()
726    }
727}
728
729impl OAuthStateBackend for InMemoryOAuthBackend {
730    fn put(&self, token: &str, provider: &str, expires_at: u64) {
731        self.states.lock().unwrap().insert(
732            token.to_string(),
733            OAuthState {
734                provider: provider.to_string(),
735                expires_at,
736            },
737        );
738    }
739    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
740        let mut s = self.states.lock().unwrap();
741        let entry = s.remove(token)?;
742        if entry.expires_at <= now_unix_secs {
743            return None;
744        }
745        Some(entry.provider)
746    }
747}
748
749/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
750///
751/// State tokens are short-lived (10 minutes) and single-use. Backed by an
752/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
753/// to SQLite so they survive a restart that happens mid-OAuth-handshake.
754pub struct OAuthStateStore {
755    backend: Box<dyn OAuthStateBackend>,
756}
757
758pub struct OAuthState {
759    pub provider: String,
760    pub expires_at: u64,
761}
762
763impl Default for OAuthStateStore {
764    fn default() -> Self {
765        Self::new()
766    }
767}
768
769impl OAuthStateStore {
770    pub fn new() -> Self {
771        Self {
772            backend: Box::new(InMemoryOAuthBackend::new()),
773        }
774    }
775
776    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
777        Self { backend }
778    }
779
780    /// Generate and store a new state parameter. Returns the random state string.
781    pub fn create(&self, provider: &str) -> String {
782        use std::time::{SystemTime, UNIX_EPOCH};
783        let token = generate_token();
784        let now = SystemTime::now()
785            .duration_since(UNIX_EPOCH)
786            .unwrap_or_default()
787            .as_secs();
788        self.backend.put(&token, provider, now + 600);
789        token
790    }
791
792    /// Validate and consume a state parameter. Returns true iff the state
793    /// existed, has not expired, and matches `expected_provider`. The token
794    /// is removed either way to make replay impossible.
795    pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
796        use std::time::{SystemTime, UNIX_EPOCH};
797        let now = SystemTime::now()
798            .duration_since(UNIX_EPOCH)
799            .unwrap_or_default()
800            .as_secs();
801        match self.backend.take(state, now) {
802            Some(provider) => provider == expected_provider,
803            None => false,
804        }
805    }
806}
807
808// ---------------------------------------------------------------------------
809// Magic code auth — email verification codes
810// ---------------------------------------------------------------------------
811
812/// Pluggable storage for magic-code records. In-memory is the default
813/// (fine for dev); persistent backends (SQLite, Postgres) live in
814/// `pylon-runtime` so a server restart between "send code" and "verify
815/// code" doesn't invalidate the user's pending login.
816///
817/// All methods are infallible from the caller's perspective — durability
818/// is best-effort. A backend that fails to write should log; the
819/// in-memory cache remains authoritative for the current process.
820pub trait MagicCodeBackend: Send + Sync {
821    /// Replace any existing code for `email` with `code`.
822    fn put(&self, email: &str, code: &MagicCode);
823    /// Look up the current code for `email`. Returns `None` if absent.
824    fn get(&self, email: &str) -> Option<MagicCode>;
825    /// Remove the code for `email` (called on successful verify or
826    /// expiry). Idempotent — missing key is not an error.
827    fn remove(&self, email: &str);
828    /// Persist an attempts++ on the existing record without touching
829    /// other fields. Used by the verify-failed path to enforce
830    /// `MAX_ATTEMPTS` across restarts.
831    fn bump_attempts(&self, email: &str);
832    /// Load all live records on construction. Lets `MagicCodeStore::with_backend`
833    /// hydrate the in-memory cache from durable storage on startup.
834    fn load_all(&self) -> Vec<MagicCode>;
835}
836
837/// In-memory backend for magic codes. The default — also used as the
838/// authoritative cache by `MagicCodeStore`.
839pub struct InMemoryMagicCodeBackend {
840    codes: Mutex<HashMap<String, MagicCode>>,
841}
842
843impl InMemoryMagicCodeBackend {
844    pub fn new() -> Self {
845        Self {
846            codes: Mutex::new(HashMap::new()),
847        }
848    }
849}
850
851impl Default for InMemoryMagicCodeBackend {
852    fn default() -> Self {
853        Self::new()
854    }
855}
856
857impl MagicCodeBackend for InMemoryMagicCodeBackend {
858    fn put(&self, email: &str, code: &MagicCode) {
859        self.codes
860            .lock()
861            .unwrap()
862            .insert(email.to_string(), code.clone());
863    }
864    fn get(&self, email: &str) -> Option<MagicCode> {
865        self.codes.lock().unwrap().get(email).cloned()
866    }
867    fn remove(&self, email: &str) {
868        self.codes.lock().unwrap().remove(email);
869    }
870    fn bump_attempts(&self, email: &str) {
871        if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
872            c.attempts = c.attempts.saturating_add(1);
873        }
874    }
875    fn load_all(&self) -> Vec<MagicCode> {
876        self.codes.lock().unwrap().values().cloned().collect()
877    }
878}
879
880/// A magic-code store. Wraps a `MagicCodeBackend` (in-memory by default)
881/// and applies the verify/cooldown semantics. Hydrates the in-memory
882/// cache from the backend on construction so durable backends survive
883/// restart without losing in-flight codes.
884pub struct MagicCodeStore {
885    cache: Mutex<HashMap<String, MagicCode>>,
886    backend: Box<dyn MagicCodeBackend>,
887}
888
889#[derive(Debug, Clone)]
890pub struct MagicCode {
891    pub email: String,
892    pub code: String,
893    pub expires_at: u64,
894    /// Failed verify attempts against this code. Once it reaches
895    /// `MAX_ATTEMPTS` the code is invalidated.
896    pub attempts: u32,
897}
898
899/// Maximum verify attempts per code before it's burned. 5 is a common bound —
900/// lets the user fix typos without enabling realistic brute-force against a
901/// 6-digit code space.
902const MAX_ATTEMPTS: u32 = 5;
903
904/// Minimum seconds between successive `create()` calls for the same email.
905/// Throttles magic-code spam (user can't be flooded with login codes).
906const CREATE_COOLDOWN_SECS: u64 = 60;
907
908#[derive(Debug, Clone, PartialEq, Eq)]
909pub enum MagicCodeError {
910    /// There is no active code for this email, or it expired.
911    NotFound,
912    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
913    TooManyAttempts,
914    /// The code did not match.
915    BadCode,
916    /// The code expired since it was created.
917    Expired,
918    /// Another code was requested too recently. Wait and try again.
919    Throttled { retry_after_secs: u64 },
920}
921
922impl Default for MagicCodeStore {
923    fn default() -> Self {
924        Self::new()
925    }
926}
927
928impl MagicCodeStore {
929    pub fn new() -> Self {
930        Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
931    }
932
933    /// Build a magic-code store backed by a persistent backend. Existing
934    /// live codes are hydrated into the in-memory cache on construction
935    /// so a server restart between "send" and "verify" doesn't kill the
936    /// user's pending login.
937    pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
938        let now = now_secs();
939        let mut cache = HashMap::new();
940        for c in backend.load_all() {
941            if c.expires_at > now {
942                cache.insert(c.email.clone(), c);
943            }
944        }
945        Self {
946            cache: Mutex::new(cache),
947            backend,
948        }
949    }
950
951    /// Generate a 6-digit code for an email and return it. Subject to a
952    /// per-email cooldown — returns the error-shape via `try_create`.
953    pub fn create(&self, email: &str) -> String {
954        // Back-compat wrapper: same signature as before, but we still burn
955        // the cooldown if one is active. Use `try_create` for a Result shape.
956        self.try_create(email).unwrap_or_else(|_| String::new())
957    }
958
959    /// Create a magic code, enforcing per-email cooldown. Returns the code
960    /// or an error describing why one couldn't be issued.
961    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
962        let now = now_secs();
963
964        let mut codes = self.cache.lock().unwrap();
965
966        // Cooldown check: if a live code exists and was created less than
967        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
968        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
969        if let Some(existing) = codes.get(email) {
970            if existing.expires_at > now {
971                let created_at = existing.expires_at.saturating_sub(600);
972                let age = now.saturating_sub(created_at);
973                if age < CREATE_COOLDOWN_SECS {
974                    return Err(MagicCodeError::Throttled {
975                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
976                    });
977                }
978            }
979        }
980
981        let code = generate_magic_code();
982        let mc = MagicCode {
983            email: email.to_string(),
984            code: code.clone(),
985            expires_at: now + 600, // 10 minutes
986            attempts: 0,
987        };
988        codes.insert(email.to_string(), mc.clone());
989        // Persist after the cache mutation lands. Backend write is
990        // best-effort — if it fails the code still works for this
991        // process; only a restart in the next 10 minutes would lose it.
992        self.backend.put(email, &mc);
993        Ok(code)
994    }
995
996    /// Verify a code for an email. Returns true if valid and not expired.
997    /// Uses constant-time comparison to prevent timing attacks.
998    /// Back-compat wrapper around [`try_verify`].
999    pub fn verify(&self, email: &str, code: &str) -> bool {
1000        matches!(self.try_verify(email, code), Ok(()))
1001    }
1002
1003    /// Verify a code. Returns a typed error so callers can surface specific
1004    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
1005    /// correct subsequent attempts return `TooManyAttempts`.
1006    /// Every magic code currently in the cache. Powers the Studio
1007    /// "Auth tables" view; not for app use. Includes expired codes —
1008    /// the cache only drops them on next verify attempt for that email.
1009    pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1010        self.cache
1011            .lock()
1012            .map(|m| m.values().cloned().collect())
1013            .unwrap_or_default()
1014    }
1015
1016    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1017        let now = now_secs();
1018        let mut codes = self.cache.lock().unwrap();
1019
1020        let mc = match codes.get_mut(email) {
1021            Some(m) => m,
1022            None => return Err(MagicCodeError::NotFound),
1023        };
1024
1025        if mc.attempts >= MAX_ATTEMPTS {
1026            return Err(MagicCodeError::TooManyAttempts);
1027        }
1028        if mc.expires_at <= now {
1029            codes.remove(email);
1030            self.backend.remove(email);
1031            return Err(MagicCodeError::Expired);
1032        }
1033
1034        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1035        if !ok {
1036            mc.attempts += 1;
1037            self.backend.bump_attempts(email);
1038            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
1039            if mc.attempts >= MAX_ATTEMPTS {
1040                return Err(MagicCodeError::TooManyAttempts);
1041            }
1042            return Err(MagicCodeError::BadCode);
1043        }
1044
1045        // Correct code — consume it.
1046        codes.remove(email);
1047        self.backend.remove(email);
1048        Ok(())
1049    }
1050}
1051
1052// ---------------------------------------------------------------------------
1053// Cryptographic helpers — CSPRNG-based token and code generation
1054// ---------------------------------------------------------------------------
1055
1056fn hex_encode(bytes: &[u8]) -> String {
1057    bytes.iter().map(|b| format!("{:02x}", b)).collect()
1058}
1059
1060/// Generate a 6-digit magic code using a CSPRNG.
1061fn generate_magic_code() -> String {
1062    use rand::Rng;
1063    let mut rng = rand::thread_rng();
1064    let code: u32 = rng.gen_range(0..1_000_000);
1065    format!("{:06}", code)
1066}
1067
1068/// Generate a session token with 256 bits of entropy from a CSPRNG.
1069fn generate_token() -> String {
1070    use rand::Rng;
1071    let mut rng = rand::thread_rng();
1072    let bytes: [u8; 32] = rng.gen();
1073    format!("pylon_{}", hex_encode(&bytes))
1074}
1075
1076// ---------------------------------------------------------------------------
1077// Session store — in-memory for dev
1078// ---------------------------------------------------------------------------
1079
1080use std::collections::HashMap;
1081use std::sync::Mutex;
1082
1083/// Pluggable storage backend for sessions. The default is in-memory; apps
1084/// deploying for real should supply a persistent backend (e.g. SQLite or
1085/// Redis) so users don't log out on server restart.
1086pub trait SessionBackend: Send + Sync {
1087    fn load_all(&self) -> Vec<Session>;
1088    fn save(&self, session: &Session);
1089    fn remove(&self, token: &str);
1090}
1091
1092/// A session store. In-memory by default; optionally backed by a
1093/// persistent [`SessionBackend`].
1094///
1095/// The in-memory map is always authoritative — reads don't touch the
1096/// backend. The backend receives every `save`/`remove`, making it a
1097/// write-through cache. On construction via [`SessionStore::with_backend`],
1098/// the store hydrates from the backend so sessions survive restart.
1099pub struct SessionStore {
1100    sessions: Mutex<HashMap<String, Session>>,
1101    backend: Option<Box<dyn SessionBackend>>,
1102}
1103
1104impl Default for SessionStore {
1105    fn default() -> Self {
1106        Self::new()
1107    }
1108}
1109
1110impl SessionStore {
1111    pub fn new() -> Self {
1112        Self {
1113            sessions: Mutex::new(HashMap::new()),
1114            backend: None,
1115        }
1116    }
1117
1118    /// Build a session store backed by a persistent store. Existing sessions
1119    /// are loaded from the backend on construction; every future mutation
1120    /// writes through.
1121    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1122        let mut map = HashMap::new();
1123        for s in backend.load_all() {
1124            if !s.is_expired() {
1125                map.insert(s.token.clone(), s);
1126            }
1127        }
1128        Self {
1129            sessions: Mutex::new(map),
1130            backend: Some(backend),
1131        }
1132    }
1133
1134    /// Create a session for a user and return it.
1135    pub fn create(&self, user_id: String) -> Session {
1136        let session = Session::new(user_id);
1137        let mut sessions = self.sessions.lock().unwrap();
1138        sessions.insert(session.token.clone(), session.clone());
1139        if let Some(b) = &self.backend {
1140            b.save(&session);
1141        }
1142        session
1143    }
1144
1145    /// Look up a session by token. Returns None if the session is expired.
1146    pub fn get(&self, token: &str) -> Option<Session> {
1147        let mut sessions = self.sessions.lock().unwrap();
1148        match sessions.get(token) {
1149            Some(s) if s.is_expired() => {
1150                sessions.remove(token);
1151                None
1152            }
1153            Some(s) => Some(s.clone()),
1154            None => None,
1155        }
1156    }
1157
1158    /// Resolve a token to an auth context.
1159    /// Returns anonymous context if the token is invalid, missing, or expired.
1160    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1161        match token {
1162            Some(t) => match self.get(t) {
1163                Some(session) => session.to_auth_context(),
1164                None => AuthContext::anonymous(),
1165            },
1166            None => AuthContext::anonymous(),
1167        }
1168    }
1169
1170    /// Refresh a session — issues a new token, copies user/device, extends expiry.
1171    /// The old token is revoked. Returns the new session or None if the old
1172    /// token is missing/expired.
1173    pub fn refresh(&self, old_token: &str) -> Option<Session> {
1174        let mut sessions = self.sessions.lock().unwrap();
1175        let old = sessions.remove(old_token)?;
1176        if let Some(b) = &self.backend {
1177            b.remove(old_token);
1178        }
1179        if old.is_expired() {
1180            return None;
1181        }
1182        let mut new = Session::new(old.user_id.clone());
1183        new.device = old.device.clone();
1184        sessions.insert(new.token.clone(), new.clone());
1185        if let Some(b) = &self.backend {
1186            b.save(&new);
1187        }
1188        Some(new)
1189    }
1190
1191    /// Every session in the store, including expired ones, with no
1192    /// filtering. Powers the Studio "Auth tables" view so operators
1193    /// can see orphaned sessions / debug stuck logins. Don't use for
1194    /// app code — `list_for_user` is the right surface there.
1195    pub fn list_all_unfiltered(&self) -> Vec<Session> {
1196        self.sessions
1197            .lock()
1198            .map(|m| m.values().cloned().collect())
1199            .unwrap_or_default()
1200    }
1201
1202    /// List all active sessions for a user.
1203    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1204        let sessions = self.sessions.lock().unwrap();
1205        sessions
1206            .values()
1207            .filter(|s| s.user_id == user_id && !s.is_expired())
1208            .cloned()
1209            .collect()
1210    }
1211
1212    /// Revoke all sessions for a user. Returns the count removed.
1213    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1214        let mut sessions = self.sessions.lock().unwrap();
1215        let tokens: Vec<String> = sessions
1216            .iter()
1217            .filter_map(|(t, s)| {
1218                if s.user_id == user_id {
1219                    Some(t.clone())
1220                } else {
1221                    None
1222                }
1223            })
1224            .collect();
1225        let n = tokens.len();
1226        for t in &tokens {
1227            sessions.remove(t);
1228            if let Some(b) = &self.backend {
1229                b.remove(t);
1230            }
1231        }
1232        n
1233    }
1234
1235    /// Sweep expired sessions. Returns the count removed.
1236    pub fn sweep_expired(&self) -> usize {
1237        let mut sessions = self.sessions.lock().unwrap();
1238        let expired: Vec<String> = sessions
1239            .iter()
1240            .filter_map(|(t, s)| {
1241                if s.is_expired() {
1242                    Some(t.clone())
1243                } else {
1244                    None
1245                }
1246            })
1247            .collect();
1248        let n = expired.len();
1249        for t in &expired {
1250            sessions.remove(t);
1251            if let Some(b) = &self.backend {
1252                b.remove(t);
1253            }
1254        }
1255        n
1256    }
1257
1258    /// Attach a device label to a session (typically on login from a browser).
1259    pub fn set_device(&self, token: &str, device: String) -> bool {
1260        let mut sessions = self.sessions.lock().unwrap();
1261        if let Some(s) = sessions.get_mut(token) {
1262            s.device = Some(device);
1263            if let Some(b) = &self.backend {
1264                b.save(s);
1265            }
1266            true
1267        } else {
1268            false
1269        }
1270    }
1271
1272    /// Create a guest session with a generated anonymous ID.
1273    pub fn create_guest(&self) -> Session {
1274        use rand::Rng;
1275        let mut rng = rand::thread_rng();
1276        let bytes: [u8; 16] = rng.gen();
1277        let guest_id = format!("guest_{}", hex_encode(&bytes));
1278        self.create(guest_id)
1279    }
1280
1281    /// Upgrade a guest session to a real user. Replaces the user_id.
1282    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1283        let mut sessions = self.sessions.lock().unwrap();
1284        if let Some(session) = sessions.get_mut(token) {
1285            session.user_id = real_user_id;
1286            if let Some(b) = &self.backend {
1287                b.save(session);
1288            }
1289            true
1290        } else {
1291            false
1292        }
1293    }
1294
1295    /// Switch the session's active tenant (organization). `None` clears it.
1296    /// Callers should verify the user actually has membership in the target
1297    /// tenant BEFORE invoking this — the session store takes the value on
1298    /// trust. Returns true if the session exists, false otherwise.
1299    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1300        let mut sessions = self.sessions.lock().unwrap();
1301        if let Some(session) = sessions.get_mut(token) {
1302            session.tenant_id = tenant_id;
1303            if let Some(b) = &self.backend {
1304                b.save(session);
1305            }
1306            true
1307        } else {
1308            false
1309        }
1310    }
1311
1312    /// Remove a session.
1313    pub fn revoke(&self, token: &str) -> bool {
1314        let mut sessions = self.sessions.lock().unwrap();
1315        let removed = sessions.remove(token).is_some();
1316        if removed {
1317            if let Some(b) = &self.backend {
1318                b.remove(token);
1319            }
1320        }
1321        removed
1322    }
1323}
1324
1325// ---------------------------------------------------------------------------
1326// OAuth account links — better-auth's `account` table equivalent
1327// ---------------------------------------------------------------------------
1328
1329/// A persisted account link. Schema-aligned with better-auth's `account`
1330/// table (verified against https://www.better-auth.com/docs/concepts/database
1331/// at the time of writing) so users migrating from better-auth see the
1332/// same field names + meanings:
1333///
1334/// - `provider_id` — the provider's name (`"google"`, `"github"`, plus
1335///   `"credential"` once email/password auth lands). Matches
1336///   better-auth's `providerId`.
1337/// - `account_id` — the PROVIDER'S ID for the user (Google `sub`,
1338///   GitHub numeric `id`, or for email/password the user's own id).
1339///   Matches better-auth's `accountId`. NOT the row PK.
1340/// - `id` — the row PK, generated. Lets the row be referenced
1341///   independently of the (provider_id, account_id) natural key.
1342/// - `password` — bcrypt/argon2 hash for `provider_id="credential"`
1343///   rows; `None` for OAuth links. Reserves the column so adding
1344///   email/password auth doesn't need a schema migration.
1345///
1346/// Account vs. user: a single User row can have many Account rows
1347/// (Google + GitHub + a password — all linked to one pylon user).
1348/// Provider lookup is by `(provider_id, account_id)` — NOT email — so
1349/// a user changing their Google address keeps the same pylon account.
1350#[derive(Debug, Clone, PartialEq, Eq)]
1351pub struct Account {
1352    pub id: String,
1353    pub user_id: String,
1354    /// Provider name — `"google"`, `"github"`, `"credential"`, etc.
1355    /// (better-auth: `providerId`)
1356    pub provider_id: String,
1357    /// Provider's id for the user — Google `sub`, GitHub numeric `id`,
1358    /// or for `provider_id="credential"` the user's own id. (better-auth: `accountId`)
1359    pub account_id: String,
1360    pub access_token: Option<String>,
1361    pub refresh_token: Option<String>,
1362    pub id_token: Option<String>,
1363    /// Unix epoch seconds at which `access_token` expires. `None` for
1364    /// non-expiring tokens (GitHub Classic apps) or for password rows.
1365    pub access_token_expires_at: Option<u64>,
1366    /// Unix epoch seconds at which `refresh_token` expires. `None` when
1367    /// the provider doesn't expire refresh tokens (most don't, but
1368    /// Microsoft Identity Platform does after 90 days of inactivity).
1369    pub refresh_token_expires_at: Option<u64>,
1370    pub scope: Option<String>,
1371    /// Bcrypt/argon2 hash for email/password rows. `None` for OAuth.
1372    /// Always `None` today — present so adding password auth later
1373    /// doesn't require a schema migration.
1374    pub password: Option<String>,
1375    /// Unix epoch seconds when this account was first linked.
1376    pub created_at: u64,
1377    /// Unix epoch seconds when the token bundle was last refreshed.
1378    pub updated_at: u64,
1379}
1380
1381impl Account {
1382    /// Build a new account link from a freshly-completed OAuth handshake.
1383    /// Generates a fresh row id; the `(provider_id, account_id)` pair is
1384    /// what later lookups key on.
1385    pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1386        let now = now_secs();
1387        Self {
1388            id: generate_token(),
1389            user_id,
1390            provider_id: info.provider.clone(),
1391            account_id: info.provider_account_id.clone(),
1392            access_token: Some(tokens.access_token.clone()),
1393            refresh_token: tokens.refresh_token.clone(),
1394            id_token: tokens.id_token.clone(),
1395            access_token_expires_at: tokens.expires_at,
1396            refresh_token_expires_at: None,
1397            scope: tokens.scope.clone(),
1398            password: None,
1399            created_at: now,
1400            updated_at: now,
1401        }
1402    }
1403
1404    /// True if `access_token_expires_at` is set and has passed.
1405    /// Non-expiring tokens (GitHub Classic) report `false` — caller
1406    /// should treat them as "valid until proven otherwise" and refresh
1407    /// on 401.
1408    pub fn access_token_expired(&self) -> bool {
1409        match self.access_token_expires_at {
1410            Some(ts) => now_secs() >= ts,
1411            None => false,
1412        }
1413    }
1414}
1415
1416/// Pluggable storage for account links. In-memory default ships with
1417/// the crate; SQLite + Postgres impls live in `pylon-runtime`.
1418pub trait AccountBackend: Send + Sync {
1419    /// Insert or refresh an account link. The `(provider_id, account_id)`
1420    /// pair is the natural key — repeated calls for the same pair
1421    /// update the token bundle and `updated_at` on the existing row.
1422    fn upsert(&self, account: &Account);
1423    /// Find an account by provider identity. Returns `None` if the user
1424    /// hasn't linked this provider yet.
1425    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1426    /// Every account linked to a user. The `/api/auth/me` endpoint uses
1427    /// this to render "you're connected via Google + GitHub" in the UI
1428    /// and to gate "unlink" affordances behind "user has another way to
1429    /// sign in" checks.
1430    fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1431    /// Remove a single provider link. Returns `true` if a row was removed.
1432    fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1433    /// Every account in the store. Used by `AccountStore::list_all_unfiltered`
1434    /// to power the Studio admin inspector. Backends that can stream
1435    /// (SQLite, Postgres) just `SELECT *`; the in-memory backend
1436    /// returns its full map.
1437    fn list_all(&self) -> Vec<Account>;
1438}
1439
1440/// In-memory account backend (default). Lost on restart — production
1441/// deployments should swap in a persistent backend so refresh tokens
1442/// survive a redeploy.
1443pub struct InMemoryAccountBackend {
1444    /// Keyed by `(provider_id, account_id)`. A separate map keyed on
1445    /// user_id would speed up `find_for_user` but at framework scale
1446    /// the linear scan of (typically ≤ 5) accounts per user is fine.
1447    accounts: Mutex<HashMap<(String, String), Account>>,
1448}
1449
1450impl InMemoryAccountBackend {
1451    pub fn new() -> Self {
1452        Self {
1453            accounts: Mutex::new(HashMap::new()),
1454        }
1455    }
1456}
1457
1458impl Default for InMemoryAccountBackend {
1459    fn default() -> Self {
1460        Self::new()
1461    }
1462}
1463
1464impl AccountBackend for InMemoryAccountBackend {
1465    fn upsert(&self, account: &Account) {
1466        let key = (account.provider_id.clone(), account.account_id.clone());
1467        self.accounts.lock().unwrap().insert(key, account.clone());
1468    }
1469    fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1470        self.accounts
1471            .lock()
1472            .unwrap()
1473            .get(&(provider_id.to_string(), account_id.to_string()))
1474            .cloned()
1475    }
1476    fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1477        self.accounts
1478            .lock()
1479            .unwrap()
1480            .values()
1481            .filter(|a| a.user_id == user_id)
1482            .cloned()
1483            .collect()
1484    }
1485    fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1486        self.accounts
1487            .lock()
1488            .unwrap()
1489            .remove(&(provider_id.to_string(), account_id.to_string()))
1490            .is_some()
1491    }
1492    fn list_all(&self) -> Vec<Account> {
1493        self.accounts.lock().unwrap().values().cloned().collect()
1494    }
1495}
1496
1497/// Account store. Wraps an `AccountBackend` and provides the methods the
1498/// OAuth callback / API endpoints actually call.
1499pub struct AccountStore {
1500    backend: Box<dyn AccountBackend>,
1501}
1502
1503impl Default for AccountStore {
1504    fn default() -> Self {
1505        Self::new()
1506    }
1507}
1508
1509impl AccountStore {
1510    pub fn new() -> Self {
1511        Self {
1512            backend: Box::new(InMemoryAccountBackend::new()),
1513        }
1514    }
1515    pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1516        Self { backend }
1517    }
1518    pub fn upsert(&self, account: &Account) {
1519        self.backend.upsert(account);
1520    }
1521    pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1522        self.backend.find_by_provider(provider_id, account_id)
1523    }
1524    pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1525        self.backend.find_for_user(user_id)
1526    }
1527    pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1528        self.backend.unlink(provider_id, account_id)
1529    }
1530
1531    /// Every account in the store. Powers the Studio "Auth tables"
1532    /// view; not for app use. Implemented by walking the backend's
1533    /// per-user index — doable because account counts per user are
1534    /// small (typically ≤ 5) and total account count tracks user
1535    /// count.
1536    ///
1537    /// We don't add a `list_all` method to the `AccountBackend` trait
1538    /// because the in-memory + sqlite + postgres impls would each
1539    /// need a separate implementation, and the operational use case
1540    /// (Studio inspector) is narrow enough to live behind a wrapper
1541    /// that walks the underlying store directly. For PG/SQLite that
1542    /// means a `SELECT * FROM _pylon_accounts` — which the backends
1543    /// can grow if we ever need this at scale.
1544    pub fn list_all_unfiltered(&self) -> Vec<Account> {
1545        self.backend.list_all()
1546    }
1547}
1548
1549// ---------------------------------------------------------------------------
1550// Tests
1551// ---------------------------------------------------------------------------
1552
1553#[cfg(test)]
1554mod tests {
1555    use super::*;
1556
1557    #[test]
1558    fn anonymous_context() {
1559        let ctx = AuthContext::anonymous();
1560        assert!(!ctx.is_authenticated());
1561        assert!(ctx.user_id.is_none());
1562    }
1563
1564    #[test]
1565    fn authenticated_context() {
1566        let ctx = AuthContext::authenticated("user-1".into());
1567        assert!(ctx.is_authenticated());
1568        assert_eq!(ctx.user_id, Some("user-1".into()));
1569    }
1570
1571    #[test]
1572    fn auth_mode_public_allows_anonymous() {
1573        let mode = AuthMode::Public;
1574        assert!(mode.check(&AuthContext::anonymous()));
1575        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1576    }
1577
1578    #[test]
1579    fn auth_mode_user_requires_authenticated() {
1580        let mode = AuthMode::User;
1581        assert!(!mode.check(&AuthContext::anonymous()));
1582        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1583    }
1584
1585    #[test]
1586    fn auth_mode_from_str() {
1587        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1588        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1589        assert_eq!(AuthMode::from_str("admin"), None);
1590    }
1591
1592    #[test]
1593    fn session_store_create_and_get() {
1594        let store = SessionStore::new();
1595        let session = store.create("user-1".into());
1596        assert!(!session.token.is_empty());
1597        assert!(session.token.starts_with("pylon_"));
1598
1599        let retrieved = store.get(&session.token).unwrap();
1600        assert_eq!(retrieved.user_id, "user-1");
1601    }
1602
1603    #[test]
1604    fn session_store_resolve() {
1605        let store = SessionStore::new();
1606        let session = store.create("user-1".into());
1607
1608        let ctx = store.resolve(Some(&session.token));
1609        assert!(ctx.is_authenticated());
1610        assert_eq!(ctx.user_id, Some("user-1".into()));
1611
1612        let anon = store.resolve(None);
1613        assert!(!anon.is_authenticated());
1614
1615        let bad = store.resolve(Some("invalid-token"));
1616        assert!(!bad.is_authenticated());
1617    }
1618
1619    #[test]
1620    fn session_store_revoke() {
1621        let store = SessionStore::new();
1622        let session = store.create("user-1".into());
1623
1624        assert!(store.revoke(&session.token));
1625        assert!(store.get(&session.token).is_none());
1626        assert!(!store.revoke(&session.token)); // already revoked
1627    }
1628
1629    #[test]
1630    fn session_to_auth_context() {
1631        let session = Session::new("user-42".into());
1632        let ctx = session.to_auth_context();
1633        assert_eq!(ctx.user_id, Some("user-42".into()));
1634    }
1635
1636    // -- Admin context --
1637
1638    #[test]
1639    fn admin_context() {
1640        let ctx = AuthContext::admin();
1641        assert!(ctx.is_admin);
1642        assert!(ctx.is_authenticated());
1643    }
1644
1645    #[test]
1646    fn anonymous_not_admin() {
1647        let ctx = AuthContext::anonymous();
1648        assert!(!ctx.is_admin);
1649    }
1650
1651    #[test]
1652    fn authenticated_not_admin() {
1653        let ctx = AuthContext::authenticated("user-1".into());
1654        assert!(!ctx.is_admin);
1655    }
1656
1657    // -- Magic codes --
1658
1659    #[test]
1660    fn magic_code_create_and_verify() {
1661        let store = MagicCodeStore::new();
1662        let code = store.create("test@example.com");
1663        assert_eq!(code.len(), 6);
1664        assert!(store.verify("test@example.com", &code));
1665    }
1666
1667    #[test]
1668    fn magic_code_wrong_code_rejected() {
1669        let store = MagicCodeStore::new();
1670        store.create("test@example.com");
1671        assert!(!store.verify("test@example.com", "000000"));
1672    }
1673
1674    #[test]
1675    fn magic_code_wrong_email_rejected() {
1676        let store = MagicCodeStore::new();
1677        let code = store.create("test@example.com");
1678        assert!(!store.verify("other@example.com", &code));
1679    }
1680
1681    #[test]
1682    fn magic_code_consumed_after_verify() {
1683        let store = MagicCodeStore::new();
1684        let code = store.create("test@example.com");
1685        assert!(store.verify("test@example.com", &code));
1686        // Second verify should fail — code consumed.
1687        assert!(!store.verify("test@example.com", &code));
1688    }
1689
1690    #[test]
1691    fn magic_code_different_emails_independent() {
1692        let store = MagicCodeStore::new();
1693        let code1 = store.create("alice@example.com");
1694        let code2 = store.create("bob@example.com");
1695        // Each email has its own code.
1696        assert!(store.verify("alice@example.com", &code1));
1697        assert!(store.verify("bob@example.com", &code2));
1698    }
1699
1700    // -- Constant-time comparison --
1701
1702    #[test]
1703    fn constant_time_eq_equal() {
1704        assert!(constant_time_eq(b"hello", b"hello"));
1705        assert!(constant_time_eq(b"", b""));
1706    }
1707
1708    #[test]
1709    fn constant_time_eq_not_equal() {
1710        assert!(!constant_time_eq(b"hello", b"world"));
1711        assert!(!constant_time_eq(b"hello", b"hell"));
1712        assert!(!constant_time_eq(b"a", b"b"));
1713    }
1714
1715    // -- Token generation --
1716
1717    #[test]
1718    fn generated_tokens_are_unique() {
1719        let t1 = generate_token();
1720        let t2 = generate_token();
1721        assert_ne!(t1, t2);
1722        assert!(t1.starts_with("pylon_"));
1723        assert!(t2.starts_with("pylon_"));
1724        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
1725        assert_eq!(t1.len(), 6 + 64);
1726    }
1727
1728    // -- OAuth registry --
1729
1730    #[test]
1731    fn oauth_registry_empty() {
1732        let reg = OAuthRegistry::new();
1733        assert!(reg.get("google").is_none());
1734    }
1735
1736    #[test]
1737    fn oauth_registry_register_and_get() {
1738        let mut reg = OAuthRegistry::new();
1739        reg.register(OAuthConfig {
1740            provider: "google".into(),
1741            client_id: "test-id".into(),
1742            client_secret: "test-secret".into(),
1743            redirect_uri: "http://localhost/callback".into(),
1744        });
1745        let config = reg.get("google").unwrap();
1746        assert_eq!(config.client_id, "test-id");
1747        assert!(config.auth_url().contains("accounts.google.com"));
1748    }
1749
1750    // -- Guest auth --
1751
1752    #[test]
1753    fn guest_session() {
1754        let store = SessionStore::new();
1755        let session = store.create_guest();
1756        assert!(session.user_id.starts_with("guest_"));
1757        assert!(!session.token.is_empty());
1758
1759        let ctx = store.resolve(Some(&session.token));
1760        assert!(ctx.is_authenticated());
1761        assert!(ctx.user_id.unwrap().starts_with("guest_"));
1762    }
1763
1764    #[test]
1765    fn upgrade_guest_to_real_user() {
1766        let store = SessionStore::new();
1767        let session = store.create_guest();
1768        assert!(session.user_id.starts_with("guest_"));
1769
1770        let upgraded = store.upgrade(&session.token, "real-user-123".into());
1771        assert!(upgraded);
1772
1773        let ctx = store.resolve(Some(&session.token));
1774        assert_eq!(ctx.user_id, Some("real-user-123".into()));
1775    }
1776
1777    #[test]
1778    fn upgrade_invalid_token_fails() {
1779        let store = SessionStore::new();
1780        let upgraded = store.upgrade("nonexistent-token", "user".into());
1781        assert!(!upgraded);
1782    }
1783
1784    #[test]
1785    fn guest_context() {
1786        let ctx = AuthContext::guest("guest_123".into());
1787        // Guests carry a stable id but are NOT authenticated — routes
1788        // guarded by AuthMode::User must reject them.
1789        assert!(!ctx.is_authenticated());
1790        assert!(ctx.is_guest);
1791        assert!(!ctx.is_admin);
1792        assert_eq!(ctx.user_id, Some("guest_123".into()));
1793        assert!(!AuthMode::User.check(&ctx));
1794        assert!(AuthMode::Public.check(&ctx));
1795    }
1796
1797    #[test]
1798    fn oauth_token_urls() {
1799        let google = OAuthConfig {
1800            provider: "google".into(),
1801            client_id: "x".into(),
1802            client_secret: "x".into(),
1803            redirect_uri: "x".into(),
1804        };
1805        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1806        let github = OAuthConfig {
1807            provider: "github".into(),
1808            client_id: "x".into(),
1809            client_secret: "x".into(),
1810            redirect_uri: "x".into(),
1811        };
1812        assert_eq!(
1813            github.token_url(),
1814            "https://github.com/login/oauth/access_token"
1815        );
1816        let unknown = OAuthConfig {
1817            provider: "unknown".into(),
1818            client_id: "x".into(),
1819            client_secret: "x".into(),
1820            redirect_uri: "x".into(),
1821        };
1822        assert_eq!(unknown.token_url(), "");
1823        assert!(unknown.auth_url().is_empty());
1824    }
1825
1826    #[test]
1827    fn oauth_auth_url_github() {
1828        let config = OAuthConfig {
1829            provider: "github".into(),
1830            client_id: "gh-id".into(),
1831            client_secret: "gh-secret".into(),
1832            redirect_uri: "http://localhost/cb".into(),
1833        };
1834        assert!(config.auth_url().contains("github.com"));
1835        assert!(config.auth_url().contains("gh-id"));
1836    }
1837
1838    #[test]
1839    fn oauth_auth_url_with_state() {
1840        let config = OAuthConfig {
1841            provider: "google".into(),
1842            client_id: "test-id".into(),
1843            client_secret: "test-secret".into(),
1844            redirect_uri: "http://localhost/cb".into(),
1845        };
1846        let url = config.auth_url_with_state("random_state_123");
1847        assert!(url.contains("&state=random_state_123"));
1848    }
1849
1850    #[test]
1851    fn oauth_state_store_create_and_validate() {
1852        let store = OAuthStateStore::new();
1853        let state = store.create("google");
1854        assert!(store.validate(&state, "google"));
1855        // Second validation should fail — consumed.
1856        assert!(!store.validate(&state, "google"));
1857    }
1858
1859    #[test]
1860    fn oauth_state_store_wrong_provider_rejected() {
1861        let store = OAuthStateStore::new();
1862        let state = store.create("google");
1863        assert!(!store.validate(&state, "github"));
1864    }
1865
1866    #[test]
1867    fn oauth_state_store_invalid_state_rejected() {
1868        let store = OAuthStateStore::new();
1869        assert!(!store.validate("nonexistent", "google"));
1870    }
1871}