Skip to main content

pylon_auth/
lib.rs

1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9// ---------------------------------------------------------------------------
10// Auth context — the identity available to runtime operations
11// ---------------------------------------------------------------------------
12
13/// The auth context for a request. Represents who is making the request.
14///
15/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
16/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
17/// or add roles and bypass every policy. Identity must come from
18/// server-minted sessions (`Session::to_auth_context`) or explicit
19/// constructors, never from deserialization.
20///
21/// `Serialize` is safe because sending the resolved context BACK to the
22/// client exposes nothing the server didn't already decide.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25    /// The authenticated user ID, or None for public/anonymous access.
26    /// For guest contexts this is `Some(guest_id)` — a stable
27    /// anonymous identifier, NOT a real user.
28    pub user_id: Option<String>,
29    /// Whether this is an admin context (bypasses policies).
30    pub is_admin: bool,
31    /// True for `AuthContext::guest()` — anonymous-with-stable-id, used
32    /// for cart state and similar pre-login persistence. Routes guarded
33    /// by `AuthMode::User` reject guests; only `is_authenticated()` ==
34    /// "real signed-in user" should pass auth-required gates.
35    #[serde(default, skip_serializing_if = "is_false")]
36    pub is_guest: bool,
37    /// Roles granted to this user. Empty for anonymous.
38    pub roles: Vec<String>,
39    /// Active tenant id (for multi-tenant apps). Set when the user has
40    /// selected an organization for the current session.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46    !b
47}
48
49impl AuthContext {
50    /// Create an anonymous/public auth context.
51    pub fn anonymous() -> Self {
52        Self {
53            user_id: None,
54            is_admin: false,
55            is_guest: false,
56            roles: Vec::new(),
57            tenant_id: None,
58        }
59    }
60
61    /// Create an authenticated auth context.
62    pub fn authenticated(user_id: String) -> Self {
63        Self {
64            user_id: Some(user_id),
65            is_admin: false,
66            is_guest: false,
67            roles: Vec::new(),
68            tenant_id: None,
69        }
70    }
71
72    /// Create a guest auth context with a persistent anonymous ID.
73    /// Guests carry an opaque stable id (cart/session continuity) but
74    /// are NOT considered authenticated — `is_authenticated()` returns
75    /// false and `AuthMode::User` rejects them.
76    pub fn guest(guest_id: String) -> Self {
77        Self {
78            user_id: Some(guest_id),
79            is_admin: false,
80            is_guest: true,
81            roles: Vec::new(),
82            tenant_id: None,
83        }
84    }
85
86    /// Create an admin auth context that bypasses all policies.
87    pub fn admin() -> Self {
88        Self {
89            user_id: Some("__admin__".into()),
90            is_admin: true,
91            is_guest: false,
92            roles: vec!["admin".into()],
93            tenant_id: None,
94        }
95    }
96
97    /// Convenience: build a user context from a user id.
98    pub fn user(user_id: String) -> Self {
99        Self::authenticated(user_id)
100    }
101
102    /// Active tenant id (None when the user hasn't selected an org).
103    pub fn tenant_id(&self) -> Option<&str> {
104        self.tenant_id.as_deref()
105    }
106
107    /// Attach a tenant id to the context (chainable).
108    pub fn with_tenant(mut self, tenant_id: String) -> Self {
109        self.tenant_id = Some(tenant_id);
110        self
111    }
112
113    /// Check if this context represents an authenticated user.
114    /// Guests intentionally return `false` — they have a stable anonymous
115    /// id but never gain user-level access.
116    pub fn is_authenticated(&self) -> bool {
117        self.user_id.is_some() && !self.is_guest
118    }
119
120    /// Check if the user has a specific role. Admins have every role implicitly.
121    pub fn has_role(&self, role: &str) -> bool {
122        self.is_admin || self.roles.iter().any(|r| r == role)
123    }
124
125    /// Check if the user has ANY of the given roles.
126    pub fn has_any_role(&self, roles: &[&str]) -> bool {
127        self.is_admin || roles.iter().any(|r| self.has_role(r))
128    }
129
130    /// Attach roles to the context (chainable).
131    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132        self.roles = roles;
133        self
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Constant-time comparison
139// ---------------------------------------------------------------------------
140
141/// Constant-time byte comparison to prevent timing attacks.
142///
143/// The length check leaks whether the two slices are the same length, but the
144/// content comparison always examines every byte regardless of where (or
145/// whether) a mismatch occurs.
146pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147    if a.len() != b.len() {
148        return false;
149    }
150    let mut result: u8 = 0;
151    for (x, y) in a.iter().zip(b.iter()) {
152        result |= x ^ y;
153    }
154    result == 0
155}
156
157// ---------------------------------------------------------------------------
158// Auth mode — matches the route "auth" field values
159// ---------------------------------------------------------------------------
160
161/// The auth mode declared on a route.
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164    /// Anyone can access.
165    Public,
166    /// Only authenticated users can access.
167    User,
168}
169
170impl AuthMode {
171    /// Parse from the manifest auth string.
172    #[allow(clippy::should_implement_trait)]
173    pub fn from_str(s: &str) -> Option<Self> {
174        match s {
175            "public" => Some(AuthMode::Public),
176            "user" => Some(AuthMode::User),
177            _ => None,
178        }
179    }
180
181    /// Check if the given auth context satisfies this mode.
182    pub fn check(&self, ctx: &AuthContext) -> bool {
183        match self {
184            AuthMode::Public => true,
185            AuthMode::User => ctx.is_authenticated(),
186        }
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Session — opaque token session
192// ---------------------------------------------------------------------------
193
194/// A session token and its associated user.
195#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197    pub token: String,
198    pub user_id: String,
199    /// Unix epoch seconds at which this session expires. 0 = never.
200    #[serde(default)]
201    pub expires_at: u64,
202    /// Optional user-agent / device tag recorded at session creation.
203    #[serde(default, skip_serializing_if = "Option::is_none")]
204    pub device: Option<String>,
205    /// Unix epoch seconds when the session was created.
206    #[serde(default)]
207    pub created_at: u64,
208    /// Active tenant id (selected organization). Set via
209    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
210    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub tenant_id: Option<String>,
213}
214
215impl Session {
216    /// Default session lifetime: 30 days.
217    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219    /// Create a new session with a generated token and default 30-day expiry.
220    pub fn new(user_id: String) -> Self {
221        let now = now_secs();
222        Self {
223            token: generate_token(),
224            user_id,
225            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226            device: None,
227            created_at: now,
228            tenant_id: None,
229        }
230    }
231
232    /// Create a session with a specific lifetime.
233    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234        let now = now_secs();
235        Self {
236            token: generate_token(),
237            user_id,
238            expires_at: if lifetime_secs == 0 {
239                0
240            } else {
241                now.saturating_add(lifetime_secs)
242            },
243            device: None,
244            created_at: now,
245            tenant_id: None,
246        }
247    }
248
249    /// Convert this session to an auth context, carrying the selected
250    /// tenant so row-scoped policies see `auth.tenantId`.
251    pub fn to_auth_context(&self) -> AuthContext {
252        let ctx = AuthContext::authenticated(self.user_id.clone());
253        match &self.tenant_id {
254            Some(t) => ctx.with_tenant(t.clone()),
255            None => ctx,
256        }
257    }
258
259    /// Returns true if the session has passed its expires_at time.
260    /// Boundary is inclusive (`>=`) to match the rest of the codebase
261    /// (`magic_codes.expires_at <= now`, `oauth_state.expires_at <= now`).
262    pub fn is_expired(&self) -> bool {
263        self.expires_at != 0 && now_secs() >= self.expires_at
264    }
265}
266
267fn now_secs() -> u64 {
268    use std::time::{SystemTime, UNIX_EPOCH};
269    SystemTime::now()
270        .duration_since(UNIX_EPOCH)
271        .unwrap_or_default()
272        .as_secs()
273}
274
275// ---------------------------------------------------------------------------
276// OAuth provider config
277// ---------------------------------------------------------------------------
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281    pub provider: String,
282    pub client_id: String,
283    pub client_secret: String,
284    pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288    /// Generate the authorization URL for the provider.
289    ///
290    /// Callers MUST append a `&state=<random>` parameter and validate it in the
291    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
292    /// implementation.
293    pub fn auth_url(&self) -> String {
294        match self.provider.as_str() {
295            "google" => format!(
296                "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297                self.client_id, self.redirect_uri
298            ),
299            "github" => format!(
300                "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301                self.client_id, self.redirect_uri
302            ),
303            _ => String::new(),
304        }
305    }
306
307    /// Generate the authorization URL with a CSRF state parameter attached.
308    pub fn auth_url_with_state(&self, state: &str) -> String {
309        let base = self.auth_url();
310        if base.is_empty() {
311            return base;
312        }
313        format!("{}&state={}", base, state)
314    }
315
316    /// Generate the token exchange URL.
317    pub fn token_url(&self) -> &str {
318        match self.provider.as_str() {
319            "google" => "https://oauth2.googleapis.com/token",
320            "github" => "https://github.com/login/oauth/access_token",
321            _ => "",
322        }
323    }
324
325    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
326    pub fn userinfo_url(&self) -> &str {
327        match self.provider.as_str() {
328            "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329            "github" => "https://api.github.com/user",
330            _ => "",
331        }
332    }
333
334    /// Exchange an authorization code for an access token.
335    ///
336    /// Uses the system `curl` binary so the auth crate stays free of HTTP
337    /// client dependencies. Returns the provider-specific access token string
338    /// (extracted from the JSON response).
339    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
340        let body = match self.provider.as_str() {
341            "google" => format!(
342                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
343                url_encode(&self.client_id),
344                url_encode(&self.client_secret),
345                url_encode(&self.redirect_uri)
346            ),
347            "github" => format!(
348                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
349                url_encode(&self.client_id),
350                url_encode(&self.client_secret),
351                url_encode(&self.redirect_uri)
352            ),
353            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
354        };
355
356        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
357        extract_access_token(&out)
358    }
359
360    /// Fetch the authenticated user's email + display name using an access token.
361    /// Returns `(email, display_name)`.
362    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
363        let out = http_get_bearer(self.userinfo_url(), access_token)?;
364        let parsed: serde_json::Value =
365            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
366        match self.provider.as_str() {
367            "google" => {
368                let email = parsed
369                    .get("email")
370                    .and_then(|v| v.as_str())
371                    .ok_or("no email in userinfo")?
372                    .to_string();
373                let name = parsed
374                    .get("name")
375                    .and_then(|v| v.as_str())
376                    .map(String::from);
377                Ok((email, name))
378            }
379            "github" => {
380                let name = parsed
381                    .get("name")
382                    .and_then(|v| v.as_str())
383                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
384                    .map(String::from);
385                let email = parsed
386                    .get("email")
387                    .and_then(|v| v.as_str())
388                    .map(String::from);
389                // GitHub may return a null email if the user hasn't published one;
390                // in that case the caller should hit /user/emails with the same token.
391                let email = email
392                    .or_else(|| fetch_github_primary_email(access_token).ok())
393                    .ok_or("no accessible email on GitHub account")?;
394                Ok((email, name))
395            }
396            _ => Err(format!("unknown provider: {}", self.provider)),
397        }
398    }
399}
400
401fn url_encode(s: &str) -> String {
402    let mut out = String::with_capacity(s.len());
403    for b in s.bytes() {
404        match b {
405            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
406                out.push(b as char)
407            }
408            _ => out.push_str(&format!("%{b:02X}")),
409        }
410    }
411    out
412}
413
414/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
415/// provider doesn't block a login indefinitely; long enough to absorb
416/// typical internet latency.
417const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
418
419fn ureq_agent() -> ureq::Agent {
420    ureq::AgentBuilder::new()
421        .timeout_connect(HTTP_TIMEOUT)
422        .timeout_read(HTTP_TIMEOUT)
423        .timeout_write(HTTP_TIMEOUT)
424        .user_agent("pylon/0.1")
425        .build()
426}
427
428fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
429    let agent = ureq_agent();
430    let mut req = agent
431        .post(url)
432        .set("Content-Type", "application/x-www-form-urlencoded");
433    if accept_json {
434        req = req.set("Accept", "application/json");
435    }
436    match req.send_string(body) {
437        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
438        Err(ureq::Error::Status(code, resp)) => {
439            let body = resp.into_string().unwrap_or_default();
440            Err(format!("HTTP {code}: {body}"))
441        }
442        Err(e) => Err(format!("HTTP error: {e}")),
443    }
444}
445
446fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
447    let agent = ureq_agent();
448    match agent
449        .get(url)
450        .set("Authorization", &format!("Bearer {token}"))
451        .set("Accept", "application/json")
452        .call()
453    {
454        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
455        Err(ureq::Error::Status(code, resp)) => {
456            let body = resp.into_string().unwrap_or_default();
457            Err(format!("HTTP {code}: {body}"))
458        }
459        Err(e) => Err(format!("HTTP error: {e}")),
460    }
461}
462
463fn fetch_github_primary_email(token: &str) -> Result<String, String> {
464    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
465    let emails: serde_json::Value =
466        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
467    emails
468        .as_array()
469        .and_then(|arr| {
470            arr.iter()
471                .find(|e| {
472                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
473                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
474                })
475                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
476        })
477        .ok_or_else(|| "no primary verified email on GitHub".into())
478}
479
480fn extract_access_token(body: &str) -> Result<String, String> {
481    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
482        if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
483            return Ok(t.to_string());
484        }
485    }
486    // GitHub can return url-encoded: access_token=...&scope=...&token_type=bearer
487    for pair in body.split('&') {
488        if let Some(val) = pair.strip_prefix("access_token=") {
489            return Ok(val.to_string());
490        }
491    }
492    Err(format!("no access_token in token response: {body}"))
493}
494
495/// OAuth provider registry.
496pub struct OAuthRegistry {
497    providers: std::collections::HashMap<String, OAuthConfig>,
498}
499
500impl Default for OAuthRegistry {
501    fn default() -> Self {
502        Self::new()
503    }
504}
505
506impl OAuthRegistry {
507    pub fn new() -> Self {
508        Self {
509            providers: std::collections::HashMap::new(),
510        }
511    }
512
513    pub fn register(&mut self, config: OAuthConfig) {
514        self.providers.insert(config.provider.clone(), config);
515    }
516
517    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
518        self.providers.get(provider)
519    }
520
521    /// Build from environment variables.
522    /// Looks for PYLON_OAUTH_GOOGLE_CLIENT_ID, etc.
523    pub fn from_env() -> Self {
524        let mut reg = Self::new();
525
526        // Google
527        if let (Ok(id), Ok(secret)) = (
528            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
529            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
530        ) {
531            reg.register(OAuthConfig {
532                provider: "google".into(),
533                client_id: id,
534                client_secret: secret,
535                redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
536                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
537            });
538        }
539
540        // GitHub
541        if let (Ok(id), Ok(secret)) = (
542            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
543            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
544        ) {
545            reg.register(OAuthConfig {
546                provider: "github".into(),
547                client_id: id,
548                client_secret: secret,
549                redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
550                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
551            });
552        }
553
554        reg
555    }
556}
557
558// ---------------------------------------------------------------------------
559// OAuth state store — CSRF protection for OAuth flows
560// ---------------------------------------------------------------------------
561
562/// Backing store for OAuth state tokens. Default impl keeps them in memory
563/// (fine for tests + dev); the runtime swaps in a SQLite-backed impl so a
564/// restart in the middle of an OAuth handshake doesn't leave the user with
565/// "invalid state" on the callback. Same pattern as `SessionBackend`.
566pub trait OAuthStateBackend: Send + Sync {
567    fn put(&self, token: &str, provider: &str, expires_at: u64);
568    /// Atomic compare-and-consume: returns the stored provider if the token
569    /// exists and hasn't expired, then removes it. Returning `None` means
570    /// either the token never existed or it has already been used.
571    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
572}
573
574/// In-memory backend (default). Lost on restart.
575pub struct InMemoryOAuthBackend {
576    states: Mutex<HashMap<String, OAuthState>>,
577}
578
579impl InMemoryOAuthBackend {
580    pub fn new() -> Self {
581        Self {
582            states: Mutex::new(HashMap::new()),
583        }
584    }
585}
586
587impl Default for InMemoryOAuthBackend {
588    fn default() -> Self {
589        Self::new()
590    }
591}
592
593impl OAuthStateBackend for InMemoryOAuthBackend {
594    fn put(&self, token: &str, provider: &str, expires_at: u64) {
595        self.states.lock().unwrap().insert(
596            token.to_string(),
597            OAuthState {
598                provider: provider.to_string(),
599                expires_at,
600            },
601        );
602    }
603    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
604        let mut s = self.states.lock().unwrap();
605        let entry = s.remove(token)?;
606        if entry.expires_at <= now_unix_secs {
607            return None;
608        }
609        Some(entry.provider)
610    }
611}
612
613/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
614///
615/// State tokens are short-lived (10 minutes) and single-use. Backed by an
616/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
617/// to SQLite so they survive a restart that happens mid-OAuth-handshake.
618pub struct OAuthStateStore {
619    backend: Box<dyn OAuthStateBackend>,
620}
621
622pub struct OAuthState {
623    pub provider: String,
624    pub expires_at: u64,
625}
626
627impl Default for OAuthStateStore {
628    fn default() -> Self {
629        Self::new()
630    }
631}
632
633impl OAuthStateStore {
634    pub fn new() -> Self {
635        Self {
636            backend: Box::new(InMemoryOAuthBackend::new()),
637        }
638    }
639
640    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
641        Self { backend }
642    }
643
644    /// Generate and store a new state parameter. Returns the random state string.
645    pub fn create(&self, provider: &str) -> String {
646        use std::time::{SystemTime, UNIX_EPOCH};
647        let token = generate_token();
648        let now = SystemTime::now()
649            .duration_since(UNIX_EPOCH)
650            .unwrap_or_default()
651            .as_secs();
652        self.backend.put(&token, provider, now + 600);
653        token
654    }
655
656    /// Validate and consume a state parameter. Returns true iff the state
657    /// existed, has not expired, and matches `expected_provider`. The token
658    /// is removed either way to make replay impossible.
659    pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
660        use std::time::{SystemTime, UNIX_EPOCH};
661        let now = SystemTime::now()
662            .duration_since(UNIX_EPOCH)
663            .unwrap_or_default()
664            .as_secs();
665        match self.backend.take(state, now) {
666            Some(provider) => provider == expected_provider,
667            None => false,
668        }
669    }
670}
671
672// ---------------------------------------------------------------------------
673// Magic code auth — email verification codes
674// ---------------------------------------------------------------------------
675
676/// An in-memory magic code store for development.
677pub struct MagicCodeStore {
678    codes: Mutex<HashMap<String, MagicCode>>,
679}
680
681#[derive(Debug, Clone)]
682pub struct MagicCode {
683    pub email: String,
684    pub code: String,
685    pub expires_at: u64,
686    /// Failed verify attempts against this code. Once it reaches
687    /// `MAX_ATTEMPTS` the code is invalidated.
688    pub attempts: u32,
689}
690
691/// Maximum verify attempts per code before it's burned. 5 is a common bound —
692/// lets the user fix typos without enabling realistic brute-force against a
693/// 6-digit code space.
694const MAX_ATTEMPTS: u32 = 5;
695
696/// Minimum seconds between successive `create()` calls for the same email.
697/// Throttles magic-code spam (user can't be flooded with login codes).
698const CREATE_COOLDOWN_SECS: u64 = 60;
699
700#[derive(Debug, Clone, PartialEq, Eq)]
701pub enum MagicCodeError {
702    /// There is no active code for this email, or it expired.
703    NotFound,
704    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
705    TooManyAttempts,
706    /// The code did not match.
707    BadCode,
708    /// The code expired since it was created.
709    Expired,
710    /// Another code was requested too recently. Wait and try again.
711    Throttled { retry_after_secs: u64 },
712}
713
714impl Default for MagicCodeStore {
715    fn default() -> Self {
716        Self::new()
717    }
718}
719
720impl MagicCodeStore {
721    pub fn new() -> Self {
722        Self {
723            codes: Mutex::new(HashMap::new()),
724        }
725    }
726
727    /// Generate a 6-digit code for an email and return it. Subject to a
728    /// per-email cooldown — returns the error-shape via `try_create`.
729    pub fn create(&self, email: &str) -> String {
730        // Back-compat wrapper: same signature as before, but we still burn
731        // the cooldown if one is active. Use `try_create` for a Result shape.
732        self.try_create(email).unwrap_or_else(|_| String::new())
733    }
734
735    /// Create a magic code, enforcing per-email cooldown. Returns the code
736    /// or an error describing why one couldn't be issued.
737    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
738        let now = now_secs();
739
740        let mut codes = self.codes.lock().unwrap();
741
742        // Cooldown check: if a live code exists and was created less than
743        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
744        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
745        if let Some(existing) = codes.get(email) {
746            if existing.expires_at > now {
747                let created_at = existing.expires_at.saturating_sub(600);
748                let age = now.saturating_sub(created_at);
749                if age < CREATE_COOLDOWN_SECS {
750                    return Err(MagicCodeError::Throttled {
751                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
752                    });
753                }
754            }
755        }
756
757        let code = generate_magic_code();
758        let mc = MagicCode {
759            email: email.to_string(),
760            code: code.clone(),
761            expires_at: now + 600, // 10 minutes
762            attempts: 0,
763        };
764        codes.insert(email.to_string(), mc);
765        Ok(code)
766    }
767
768    /// Verify a code for an email. Returns true if valid and not expired.
769    /// Uses constant-time comparison to prevent timing attacks.
770    /// Back-compat wrapper around [`try_verify`].
771    pub fn verify(&self, email: &str, code: &str) -> bool {
772        matches!(self.try_verify(email, code), Ok(()))
773    }
774
775    /// Verify a code. Returns a typed error so callers can surface specific
776    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
777    /// correct subsequent attempts return `TooManyAttempts`.
778    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
779        let now = now_secs();
780        let mut codes = self.codes.lock().unwrap();
781
782        let mc = match codes.get_mut(email) {
783            Some(m) => m,
784            None => return Err(MagicCodeError::NotFound),
785        };
786
787        if mc.attempts >= MAX_ATTEMPTS {
788            return Err(MagicCodeError::TooManyAttempts);
789        }
790        if mc.expires_at <= now {
791            codes.remove(email);
792            return Err(MagicCodeError::Expired);
793        }
794
795        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
796        if !ok {
797            mc.attempts += 1;
798            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
799            if mc.attempts >= MAX_ATTEMPTS {
800                return Err(MagicCodeError::TooManyAttempts);
801            }
802            return Err(MagicCodeError::BadCode);
803        }
804
805        // Correct code — consume it.
806        codes.remove(email);
807        Ok(())
808    }
809}
810
811// ---------------------------------------------------------------------------
812// Cryptographic helpers — CSPRNG-based token and code generation
813// ---------------------------------------------------------------------------
814
815fn hex_encode(bytes: &[u8]) -> String {
816    bytes.iter().map(|b| format!("{:02x}", b)).collect()
817}
818
819/// Generate a 6-digit magic code using a CSPRNG.
820fn generate_magic_code() -> String {
821    use rand::Rng;
822    let mut rng = rand::thread_rng();
823    let code: u32 = rng.gen_range(0..1_000_000);
824    format!("{:06}", code)
825}
826
827/// Generate a session token with 256 bits of entropy from a CSPRNG.
828fn generate_token() -> String {
829    use rand::Rng;
830    let mut rng = rand::thread_rng();
831    let bytes: [u8; 32] = rng.gen();
832    format!("pylon_{}", hex_encode(&bytes))
833}
834
835// ---------------------------------------------------------------------------
836// Session store — in-memory for dev
837// ---------------------------------------------------------------------------
838
839use std::collections::HashMap;
840use std::sync::Mutex;
841
842/// Pluggable storage backend for sessions. The default is in-memory; apps
843/// deploying for real should supply a persistent backend (e.g. SQLite or
844/// Redis) so users don't log out on server restart.
845pub trait SessionBackend: Send + Sync {
846    fn load_all(&self) -> Vec<Session>;
847    fn save(&self, session: &Session);
848    fn remove(&self, token: &str);
849}
850
851/// A session store. In-memory by default; optionally backed by a
852/// persistent [`SessionBackend`].
853///
854/// The in-memory map is always authoritative — reads don't touch the
855/// backend. The backend receives every `save`/`remove`, making it a
856/// write-through cache. On construction via [`SessionStore::with_backend`],
857/// the store hydrates from the backend so sessions survive restart.
858pub struct SessionStore {
859    sessions: Mutex<HashMap<String, Session>>,
860    backend: Option<Box<dyn SessionBackend>>,
861}
862
863impl Default for SessionStore {
864    fn default() -> Self {
865        Self::new()
866    }
867}
868
869impl SessionStore {
870    pub fn new() -> Self {
871        Self {
872            sessions: Mutex::new(HashMap::new()),
873            backend: None,
874        }
875    }
876
877    /// Build a session store backed by a persistent store. Existing sessions
878    /// are loaded from the backend on construction; every future mutation
879    /// writes through.
880    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
881        let mut map = HashMap::new();
882        for s in backend.load_all() {
883            if !s.is_expired() {
884                map.insert(s.token.clone(), s);
885            }
886        }
887        Self {
888            sessions: Mutex::new(map),
889            backend: Some(backend),
890        }
891    }
892
893    /// Create a session for a user and return it.
894    pub fn create(&self, user_id: String) -> Session {
895        let session = Session::new(user_id);
896        let mut sessions = self.sessions.lock().unwrap();
897        sessions.insert(session.token.clone(), session.clone());
898        if let Some(b) = &self.backend {
899            b.save(&session);
900        }
901        session
902    }
903
904    /// Look up a session by token. Returns None if the session is expired.
905    pub fn get(&self, token: &str) -> Option<Session> {
906        let mut sessions = self.sessions.lock().unwrap();
907        match sessions.get(token) {
908            Some(s) if s.is_expired() => {
909                sessions.remove(token);
910                None
911            }
912            Some(s) => Some(s.clone()),
913            None => None,
914        }
915    }
916
917    /// Resolve a token to an auth context.
918    /// Returns anonymous context if the token is invalid, missing, or expired.
919    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
920        match token {
921            Some(t) => match self.get(t) {
922                Some(session) => session.to_auth_context(),
923                None => AuthContext::anonymous(),
924            },
925            None => AuthContext::anonymous(),
926        }
927    }
928
929    /// Refresh a session — issues a new token, copies user/device, extends expiry.
930    /// The old token is revoked. Returns the new session or None if the old
931    /// token is missing/expired.
932    pub fn refresh(&self, old_token: &str) -> Option<Session> {
933        let mut sessions = self.sessions.lock().unwrap();
934        let old = sessions.remove(old_token)?;
935        if let Some(b) = &self.backend {
936            b.remove(old_token);
937        }
938        if old.is_expired() {
939            return None;
940        }
941        let mut new = Session::new(old.user_id.clone());
942        new.device = old.device.clone();
943        sessions.insert(new.token.clone(), new.clone());
944        if let Some(b) = &self.backend {
945            b.save(&new);
946        }
947        Some(new)
948    }
949
950    /// List all active sessions for a user.
951    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
952        let sessions = self.sessions.lock().unwrap();
953        sessions
954            .values()
955            .filter(|s| s.user_id == user_id && !s.is_expired())
956            .cloned()
957            .collect()
958    }
959
960    /// Revoke all sessions for a user. Returns the count removed.
961    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
962        let mut sessions = self.sessions.lock().unwrap();
963        let tokens: Vec<String> = sessions
964            .iter()
965            .filter_map(|(t, s)| {
966                if s.user_id == user_id {
967                    Some(t.clone())
968                } else {
969                    None
970                }
971            })
972            .collect();
973        let n = tokens.len();
974        for t in &tokens {
975            sessions.remove(t);
976            if let Some(b) = &self.backend {
977                b.remove(t);
978            }
979        }
980        n
981    }
982
983    /// Sweep expired sessions. Returns the count removed.
984    pub fn sweep_expired(&self) -> usize {
985        let mut sessions = self.sessions.lock().unwrap();
986        let expired: Vec<String> = sessions
987            .iter()
988            .filter_map(|(t, s)| {
989                if s.is_expired() {
990                    Some(t.clone())
991                } else {
992                    None
993                }
994            })
995            .collect();
996        let n = expired.len();
997        for t in &expired {
998            sessions.remove(t);
999            if let Some(b) = &self.backend {
1000                b.remove(t);
1001            }
1002        }
1003        n
1004    }
1005
1006    /// Attach a device label to a session (typically on login from a browser).
1007    pub fn set_device(&self, token: &str, device: String) -> bool {
1008        let mut sessions = self.sessions.lock().unwrap();
1009        if let Some(s) = sessions.get_mut(token) {
1010            s.device = Some(device);
1011            if let Some(b) = &self.backend {
1012                b.save(s);
1013            }
1014            true
1015        } else {
1016            false
1017        }
1018    }
1019
1020    /// Create a guest session with a generated anonymous ID.
1021    pub fn create_guest(&self) -> Session {
1022        use rand::Rng;
1023        let mut rng = rand::thread_rng();
1024        let bytes: [u8; 16] = rng.gen();
1025        let guest_id = format!("guest_{}", hex_encode(&bytes));
1026        self.create(guest_id)
1027    }
1028
1029    /// Upgrade a guest session to a real user. Replaces the user_id.
1030    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1031        let mut sessions = self.sessions.lock().unwrap();
1032        if let Some(session) = sessions.get_mut(token) {
1033            session.user_id = real_user_id;
1034            if let Some(b) = &self.backend {
1035                b.save(session);
1036            }
1037            true
1038        } else {
1039            false
1040        }
1041    }
1042
1043    /// Switch the session's active tenant (organization). `None` clears it.
1044    /// Callers should verify the user actually has membership in the target
1045    /// tenant BEFORE invoking this — the session store takes the value on
1046    /// trust. Returns true if the session exists, false otherwise.
1047    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1048        let mut sessions = self.sessions.lock().unwrap();
1049        if let Some(session) = sessions.get_mut(token) {
1050            session.tenant_id = tenant_id;
1051            if let Some(b) = &self.backend {
1052                b.save(session);
1053            }
1054            true
1055        } else {
1056            false
1057        }
1058    }
1059
1060    /// Remove a session.
1061    pub fn revoke(&self, token: &str) -> bool {
1062        let mut sessions = self.sessions.lock().unwrap();
1063        let removed = sessions.remove(token).is_some();
1064        if removed {
1065            if let Some(b) = &self.backend {
1066                b.remove(token);
1067            }
1068        }
1069        removed
1070    }
1071}
1072
1073// ---------------------------------------------------------------------------
1074// Tests
1075// ---------------------------------------------------------------------------
1076
1077#[cfg(test)]
1078mod tests {
1079    use super::*;
1080
1081    #[test]
1082    fn anonymous_context() {
1083        let ctx = AuthContext::anonymous();
1084        assert!(!ctx.is_authenticated());
1085        assert!(ctx.user_id.is_none());
1086    }
1087
1088    #[test]
1089    fn authenticated_context() {
1090        let ctx = AuthContext::authenticated("user-1".into());
1091        assert!(ctx.is_authenticated());
1092        assert_eq!(ctx.user_id, Some("user-1".into()));
1093    }
1094
1095    #[test]
1096    fn auth_mode_public_allows_anonymous() {
1097        let mode = AuthMode::Public;
1098        assert!(mode.check(&AuthContext::anonymous()));
1099        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1100    }
1101
1102    #[test]
1103    fn auth_mode_user_requires_authenticated() {
1104        let mode = AuthMode::User;
1105        assert!(!mode.check(&AuthContext::anonymous()));
1106        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1107    }
1108
1109    #[test]
1110    fn auth_mode_from_str() {
1111        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1112        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1113        assert_eq!(AuthMode::from_str("admin"), None);
1114    }
1115
1116    #[test]
1117    fn session_store_create_and_get() {
1118        let store = SessionStore::new();
1119        let session = store.create("user-1".into());
1120        assert!(!session.token.is_empty());
1121        assert!(session.token.starts_with("pylon_"));
1122
1123        let retrieved = store.get(&session.token).unwrap();
1124        assert_eq!(retrieved.user_id, "user-1");
1125    }
1126
1127    #[test]
1128    fn session_store_resolve() {
1129        let store = SessionStore::new();
1130        let session = store.create("user-1".into());
1131
1132        let ctx = store.resolve(Some(&session.token));
1133        assert!(ctx.is_authenticated());
1134        assert_eq!(ctx.user_id, Some("user-1".into()));
1135
1136        let anon = store.resolve(None);
1137        assert!(!anon.is_authenticated());
1138
1139        let bad = store.resolve(Some("invalid-token"));
1140        assert!(!bad.is_authenticated());
1141    }
1142
1143    #[test]
1144    fn session_store_revoke() {
1145        let store = SessionStore::new();
1146        let session = store.create("user-1".into());
1147
1148        assert!(store.revoke(&session.token));
1149        assert!(store.get(&session.token).is_none());
1150        assert!(!store.revoke(&session.token)); // already revoked
1151    }
1152
1153    #[test]
1154    fn session_to_auth_context() {
1155        let session = Session::new("user-42".into());
1156        let ctx = session.to_auth_context();
1157        assert_eq!(ctx.user_id, Some("user-42".into()));
1158    }
1159
1160    // -- Admin context --
1161
1162    #[test]
1163    fn admin_context() {
1164        let ctx = AuthContext::admin();
1165        assert!(ctx.is_admin);
1166        assert!(ctx.is_authenticated());
1167    }
1168
1169    #[test]
1170    fn anonymous_not_admin() {
1171        let ctx = AuthContext::anonymous();
1172        assert!(!ctx.is_admin);
1173    }
1174
1175    #[test]
1176    fn authenticated_not_admin() {
1177        let ctx = AuthContext::authenticated("user-1".into());
1178        assert!(!ctx.is_admin);
1179    }
1180
1181    // -- Magic codes --
1182
1183    #[test]
1184    fn magic_code_create_and_verify() {
1185        let store = MagicCodeStore::new();
1186        let code = store.create("test@example.com");
1187        assert_eq!(code.len(), 6);
1188        assert!(store.verify("test@example.com", &code));
1189    }
1190
1191    #[test]
1192    fn magic_code_wrong_code_rejected() {
1193        let store = MagicCodeStore::new();
1194        store.create("test@example.com");
1195        assert!(!store.verify("test@example.com", "000000"));
1196    }
1197
1198    #[test]
1199    fn magic_code_wrong_email_rejected() {
1200        let store = MagicCodeStore::new();
1201        let code = store.create("test@example.com");
1202        assert!(!store.verify("other@example.com", &code));
1203    }
1204
1205    #[test]
1206    fn magic_code_consumed_after_verify() {
1207        let store = MagicCodeStore::new();
1208        let code = store.create("test@example.com");
1209        assert!(store.verify("test@example.com", &code));
1210        // Second verify should fail — code consumed.
1211        assert!(!store.verify("test@example.com", &code));
1212    }
1213
1214    #[test]
1215    fn magic_code_different_emails_independent() {
1216        let store = MagicCodeStore::new();
1217        let code1 = store.create("alice@example.com");
1218        let code2 = store.create("bob@example.com");
1219        // Each email has its own code.
1220        assert!(store.verify("alice@example.com", &code1));
1221        assert!(store.verify("bob@example.com", &code2));
1222    }
1223
1224    // -- Constant-time comparison --
1225
1226    #[test]
1227    fn constant_time_eq_equal() {
1228        assert!(constant_time_eq(b"hello", b"hello"));
1229        assert!(constant_time_eq(b"", b""));
1230    }
1231
1232    #[test]
1233    fn constant_time_eq_not_equal() {
1234        assert!(!constant_time_eq(b"hello", b"world"));
1235        assert!(!constant_time_eq(b"hello", b"hell"));
1236        assert!(!constant_time_eq(b"a", b"b"));
1237    }
1238
1239    // -- Token generation --
1240
1241    #[test]
1242    fn generated_tokens_are_unique() {
1243        let t1 = generate_token();
1244        let t2 = generate_token();
1245        assert_ne!(t1, t2);
1246        assert!(t1.starts_with("pylon_"));
1247        assert!(t2.starts_with("pylon_"));
1248        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
1249        assert_eq!(t1.len(), 6 + 64);
1250    }
1251
1252    // -- OAuth registry --
1253
1254    #[test]
1255    fn oauth_registry_empty() {
1256        let reg = OAuthRegistry::new();
1257        assert!(reg.get("google").is_none());
1258    }
1259
1260    #[test]
1261    fn oauth_registry_register_and_get() {
1262        let mut reg = OAuthRegistry::new();
1263        reg.register(OAuthConfig {
1264            provider: "google".into(),
1265            client_id: "test-id".into(),
1266            client_secret: "test-secret".into(),
1267            redirect_uri: "http://localhost/callback".into(),
1268        });
1269        let config = reg.get("google").unwrap();
1270        assert_eq!(config.client_id, "test-id");
1271        assert!(config.auth_url().contains("accounts.google.com"));
1272    }
1273
1274    // -- Guest auth --
1275
1276    #[test]
1277    fn guest_session() {
1278        let store = SessionStore::new();
1279        let session = store.create_guest();
1280        assert!(session.user_id.starts_with("guest_"));
1281        assert!(!session.token.is_empty());
1282
1283        let ctx = store.resolve(Some(&session.token));
1284        assert!(ctx.is_authenticated());
1285        assert!(ctx.user_id.unwrap().starts_with("guest_"));
1286    }
1287
1288    #[test]
1289    fn upgrade_guest_to_real_user() {
1290        let store = SessionStore::new();
1291        let session = store.create_guest();
1292        assert!(session.user_id.starts_with("guest_"));
1293
1294        let upgraded = store.upgrade(&session.token, "real-user-123".into());
1295        assert!(upgraded);
1296
1297        let ctx = store.resolve(Some(&session.token));
1298        assert_eq!(ctx.user_id, Some("real-user-123".into()));
1299    }
1300
1301    #[test]
1302    fn upgrade_invalid_token_fails() {
1303        let store = SessionStore::new();
1304        let upgraded = store.upgrade("nonexistent-token", "user".into());
1305        assert!(!upgraded);
1306    }
1307
1308    #[test]
1309    fn guest_context() {
1310        let ctx = AuthContext::guest("guest_123".into());
1311        // Guests carry a stable id but are NOT authenticated — routes
1312        // guarded by AuthMode::User must reject them.
1313        assert!(!ctx.is_authenticated());
1314        assert!(ctx.is_guest);
1315        assert!(!ctx.is_admin);
1316        assert_eq!(ctx.user_id, Some("guest_123".into()));
1317        assert!(!AuthMode::User.check(&ctx));
1318        assert!(AuthMode::Public.check(&ctx));
1319    }
1320
1321    #[test]
1322    fn oauth_token_urls() {
1323        let google = OAuthConfig {
1324            provider: "google".into(),
1325            client_id: "x".into(),
1326            client_secret: "x".into(),
1327            redirect_uri: "x".into(),
1328        };
1329        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1330        let github = OAuthConfig {
1331            provider: "github".into(),
1332            client_id: "x".into(),
1333            client_secret: "x".into(),
1334            redirect_uri: "x".into(),
1335        };
1336        assert_eq!(
1337            github.token_url(),
1338            "https://github.com/login/oauth/access_token"
1339        );
1340        let unknown = OAuthConfig {
1341            provider: "unknown".into(),
1342            client_id: "x".into(),
1343            client_secret: "x".into(),
1344            redirect_uri: "x".into(),
1345        };
1346        assert_eq!(unknown.token_url(), "");
1347        assert!(unknown.auth_url().is_empty());
1348    }
1349
1350    #[test]
1351    fn oauth_auth_url_github() {
1352        let config = OAuthConfig {
1353            provider: "github".into(),
1354            client_id: "gh-id".into(),
1355            client_secret: "gh-secret".into(),
1356            redirect_uri: "http://localhost/cb".into(),
1357        };
1358        assert!(config.auth_url().contains("github.com"));
1359        assert!(config.auth_url().contains("gh-id"));
1360    }
1361
1362    #[test]
1363    fn oauth_auth_url_with_state() {
1364        let config = OAuthConfig {
1365            provider: "google".into(),
1366            client_id: "test-id".into(),
1367            client_secret: "test-secret".into(),
1368            redirect_uri: "http://localhost/cb".into(),
1369        };
1370        let url = config.auth_url_with_state("random_state_123");
1371        assert!(url.contains("&state=random_state_123"));
1372    }
1373
1374    #[test]
1375    fn oauth_state_store_create_and_validate() {
1376        let store = OAuthStateStore::new();
1377        let state = store.create("google");
1378        assert!(store.validate(&state, "google"));
1379        // Second validation should fail — consumed.
1380        assert!(!store.validate(&state, "google"));
1381    }
1382
1383    #[test]
1384    fn oauth_state_store_wrong_provider_rejected() {
1385        let store = OAuthStateStore::new();
1386        let state = store.create("google");
1387        assert!(!store.validate(&state, "github"));
1388    }
1389
1390    #[test]
1391    fn oauth_state_store_invalid_state_rejected() {
1392        let store = OAuthStateStore::new();
1393        assert!(!store.validate("nonexistent", "google"));
1394    }
1395}