Skip to main content

pylon_auth/
lib.rs

1pub mod email;
2pub mod password;
3
4use serde::{Deserialize, Serialize};
5
6// ---------------------------------------------------------------------------
7// Auth context — the identity available to runtime operations
8// ---------------------------------------------------------------------------
9
10/// The auth context for a request. Represents who is making the request.
11///
12/// **Do NOT derive `Deserialize` on this type.** If the server ever parses an
13/// `AuthContext` from client-supplied JSON, a client can set `is_admin=true`
14/// or add roles and bypass every policy. Identity must come from
15/// server-minted sessions (`Session::to_auth_context`) or explicit
16/// constructors, never from deserialization.
17///
18/// `Serialize` is safe because sending the resolved context BACK to the
19/// client exposes nothing the server didn't already decide.
20#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
21pub struct AuthContext {
22    /// The authenticated user ID, or None for public/anonymous access.
23    pub user_id: Option<String>,
24    /// Whether this is an admin context (bypasses policies).
25    pub is_admin: bool,
26    /// Roles granted to this user. Empty for anonymous.
27    pub roles: Vec<String>,
28    /// Active tenant id (for multi-tenant apps). Set when the user has
29    /// selected an organization for the current session.
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub tenant_id: Option<String>,
32}
33
34impl AuthContext {
35    /// Create an anonymous/public auth context.
36    pub fn anonymous() -> Self {
37        Self {
38            user_id: None,
39            is_admin: false,
40            roles: Vec::new(),
41            tenant_id: None,
42        }
43    }
44
45    /// Create an authenticated auth context.
46    pub fn authenticated(user_id: String) -> Self {
47        Self {
48            user_id: Some(user_id),
49            is_admin: false,
50            roles: Vec::new(),
51            tenant_id: None,
52        }
53    }
54
55    /// Create a guest auth context with a persistent anonymous ID.
56    pub fn guest(guest_id: String) -> Self {
57        Self {
58            user_id: Some(guest_id),
59            is_admin: false,
60            roles: Vec::new(),
61            tenant_id: None,
62        }
63    }
64
65    /// Create an admin auth context that bypasses all policies.
66    pub fn admin() -> Self {
67        Self {
68            user_id: Some("__admin__".into()),
69            is_admin: true,
70            roles: vec!["admin".into()],
71            tenant_id: None,
72        }
73    }
74
75    /// Convenience: build a user context from a user id.
76    pub fn user(user_id: String) -> Self {
77        Self::authenticated(user_id)
78    }
79
80    /// Active tenant id (None when the user hasn't selected an org).
81    pub fn tenant_id(&self) -> Option<&str> {
82        self.tenant_id.as_deref()
83    }
84
85    /// Attach a tenant id to the context (chainable).
86    pub fn with_tenant(mut self, tenant_id: String) -> Self {
87        self.tenant_id = Some(tenant_id);
88        self
89    }
90
91    /// Check if this context represents an authenticated user.
92    pub fn is_authenticated(&self) -> bool {
93        self.user_id.is_some()
94    }
95
96    /// Check if the user has a specific role. Admins have every role implicitly.
97    pub fn has_role(&self, role: &str) -> bool {
98        self.is_admin || self.roles.iter().any(|r| r == role)
99    }
100
101    /// Check if the user has ANY of the given roles.
102    pub fn has_any_role(&self, roles: &[&str]) -> bool {
103        self.is_admin || roles.iter().any(|r| self.has_role(r))
104    }
105
106    /// Attach roles to the context (chainable).
107    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
108        self.roles = roles;
109        self
110    }
111}
112
113// ---------------------------------------------------------------------------
114// Constant-time comparison
115// ---------------------------------------------------------------------------
116
117/// Constant-time byte comparison to prevent timing attacks.
118///
119/// The length check leaks whether the two slices are the same length, but the
120/// content comparison always examines every byte regardless of where (or
121/// whether) a mismatch occurs.
122pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
123    if a.len() != b.len() {
124        return false;
125    }
126    let mut result: u8 = 0;
127    for (x, y) in a.iter().zip(b.iter()) {
128        result |= x ^ y;
129    }
130    result == 0
131}
132
133// ---------------------------------------------------------------------------
134// Auth mode — matches the route "auth" field values
135// ---------------------------------------------------------------------------
136
137/// The auth mode declared on a route.
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub enum AuthMode {
140    /// Anyone can access.
141    Public,
142    /// Only authenticated users can access.
143    User,
144}
145
146impl AuthMode {
147    /// Parse from the manifest auth string.
148    #[allow(clippy::should_implement_trait)]
149    pub fn from_str(s: &str) -> Option<Self> {
150        match s {
151            "public" => Some(AuthMode::Public),
152            "user" => Some(AuthMode::User),
153            _ => None,
154        }
155    }
156
157    /// Check if the given auth context satisfies this mode.
158    pub fn check(&self, ctx: &AuthContext) -> bool {
159        match self {
160            AuthMode::Public => true,
161            AuthMode::User => ctx.is_authenticated(),
162        }
163    }
164}
165
166// ---------------------------------------------------------------------------
167// Session — opaque token session
168// ---------------------------------------------------------------------------
169
170/// A session token and its associated user.
171#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
172pub struct Session {
173    pub token: String,
174    pub user_id: String,
175    /// Unix epoch seconds at which this session expires. 0 = never.
176    #[serde(default)]
177    pub expires_at: u64,
178    /// Optional user-agent / device tag recorded at session creation.
179    #[serde(default, skip_serializing_if = "Option::is_none")]
180    pub device: Option<String>,
181    /// Unix epoch seconds when the session was created.
182    #[serde(default)]
183    pub created_at: u64,
184    /// Active tenant id (selected organization). Set via
185    /// `/api/auth/select-org`. Flows into `AuthContext.tenant_id` which
186    /// powers row-scoped policies like `data.orgId == auth.tenantId`.
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub tenant_id: Option<String>,
189}
190
191impl Session {
192    /// Default session lifetime: 30 days.
193    pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
194
195    /// Create a new session with a generated token and default 30-day expiry.
196    pub fn new(user_id: String) -> Self {
197        let now = now_secs();
198        Self {
199            token: generate_token(),
200            user_id,
201            expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
202            device: None,
203            created_at: now,
204            tenant_id: None,
205        }
206    }
207
208    /// Create a session with a specific lifetime.
209    pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
210        let now = now_secs();
211        Self {
212            token: generate_token(),
213            user_id,
214            expires_at: if lifetime_secs == 0 {
215                0
216            } else {
217                now.saturating_add(lifetime_secs)
218            },
219            device: None,
220            created_at: now,
221            tenant_id: None,
222        }
223    }
224
225    /// Convert this session to an auth context, carrying the selected
226    /// tenant so row-scoped policies see `auth.tenantId`.
227    pub fn to_auth_context(&self) -> AuthContext {
228        let ctx = AuthContext::authenticated(self.user_id.clone());
229        match &self.tenant_id {
230            Some(t) => ctx.with_tenant(t.clone()),
231            None => ctx,
232        }
233    }
234
235    /// Returns true if the session has passed its expires_at time.
236    pub fn is_expired(&self) -> bool {
237        self.expires_at != 0 && now_secs() > self.expires_at
238    }
239}
240
241fn now_secs() -> u64 {
242    use std::time::{SystemTime, UNIX_EPOCH};
243    SystemTime::now()
244        .duration_since(UNIX_EPOCH)
245        .unwrap_or_default()
246        .as_secs()
247}
248
249// ---------------------------------------------------------------------------
250// OAuth provider config
251// ---------------------------------------------------------------------------
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct OAuthConfig {
255    pub provider: String,
256    pub client_id: String,
257    pub client_secret: String,
258    pub redirect_uri: String,
259}
260
261impl OAuthConfig {
262    /// Generate the authorization URL for the provider.
263    ///
264    /// Callers MUST append a `&state=<random>` parameter and validate it in the
265    /// callback to prevent CSRF attacks. See `OAuthStateStore` for a minimal
266    /// implementation.
267    pub fn auth_url(&self) -> String {
268        match self.provider.as_str() {
269            "google" => format!(
270                "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
271                self.client_id, self.redirect_uri
272            ),
273            "github" => format!(
274                "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
275                self.client_id, self.redirect_uri
276            ),
277            _ => String::new(),
278        }
279    }
280
281    /// Generate the authorization URL with a CSRF state parameter attached.
282    pub fn auth_url_with_state(&self, state: &str) -> String {
283        let base = self.auth_url();
284        if base.is_empty() {
285            return base;
286        }
287        format!("{}&state={}", base, state)
288    }
289
290    /// Generate the token exchange URL.
291    pub fn token_url(&self) -> &str {
292        match self.provider.as_str() {
293            "google" => "https://oauth2.googleapis.com/token",
294            "github" => "https://github.com/login/oauth/access_token",
295            _ => "",
296        }
297    }
298
299    /// URL for the userinfo endpoint, which returns the authenticated user's profile.
300    pub fn userinfo_url(&self) -> &str {
301        match self.provider.as_str() {
302            "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
303            "github" => "https://api.github.com/user",
304            _ => "",
305        }
306    }
307
308    /// Exchange an authorization code for an access token.
309    ///
310    /// Uses the system `curl` binary so the auth crate stays free of HTTP
311    /// client dependencies. Returns the provider-specific access token string
312    /// (extracted from the JSON response).
313    pub fn exchange_code(&self, code: &str) -> Result<String, String> {
314        let body = match self.provider.as_str() {
315            "google" => format!(
316                "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
317                url_encode(&self.client_id),
318                url_encode(&self.client_secret),
319                url_encode(&self.redirect_uri)
320            ),
321            "github" => format!(
322                "code={code}&client_id={}&client_secret={}&redirect_uri={}",
323                url_encode(&self.client_id),
324                url_encode(&self.client_secret),
325                url_encode(&self.redirect_uri)
326            ),
327            _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
328        };
329
330        let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
331        extract_access_token(&out)
332    }
333
334    /// Fetch the authenticated user's email + display name using an access token.
335    /// Returns `(email, display_name)`.
336    pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
337        let out = http_get_bearer(self.userinfo_url(), access_token)?;
338        let parsed: serde_json::Value =
339            serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
340        match self.provider.as_str() {
341            "google" => {
342                let email = parsed
343                    .get("email")
344                    .and_then(|v| v.as_str())
345                    .ok_or("no email in userinfo")?
346                    .to_string();
347                let name = parsed
348                    .get("name")
349                    .and_then(|v| v.as_str())
350                    .map(String::from);
351                Ok((email, name))
352            }
353            "github" => {
354                let name = parsed
355                    .get("name")
356                    .and_then(|v| v.as_str())
357                    .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
358                    .map(String::from);
359                let email = parsed
360                    .get("email")
361                    .and_then(|v| v.as_str())
362                    .map(String::from);
363                // GitHub may return a null email if the user hasn't published one;
364                // in that case the caller should hit /user/emails with the same token.
365                let email = email
366                    .or_else(|| fetch_github_primary_email(access_token).ok())
367                    .ok_or("no accessible email on GitHub account")?;
368                Ok((email, name))
369            }
370            _ => Err(format!("unknown provider: {}", self.provider)),
371        }
372    }
373}
374
375fn url_encode(s: &str) -> String {
376    let mut out = String::with_capacity(s.len());
377    for b in s.bytes() {
378        match b {
379            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
380                out.push(b as char)
381            }
382            _ => out.push_str(&format!("%{b:02X}")),
383        }
384    }
385    out
386}
387
388/// Timeout for OAuth / userinfo HTTP calls. Short enough that a hung
389/// provider doesn't block a login indefinitely; long enough to absorb
390/// typical internet latency.
391const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
392
393fn ureq_agent() -> ureq::Agent {
394    ureq::AgentBuilder::new()
395        .timeout_connect(HTTP_TIMEOUT)
396        .timeout_read(HTTP_TIMEOUT)
397        .timeout_write(HTTP_TIMEOUT)
398        .user_agent("pylon/0.1")
399        .build()
400}
401
402fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
403    let agent = ureq_agent();
404    let mut req = agent
405        .post(url)
406        .set("Content-Type", "application/x-www-form-urlencoded");
407    if accept_json {
408        req = req.set("Accept", "application/json");
409    }
410    match req.send_string(body) {
411        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
412        Err(ureq::Error::Status(code, resp)) => {
413            let body = resp.into_string().unwrap_or_default();
414            Err(format!("HTTP {code}: {body}"))
415        }
416        Err(e) => Err(format!("HTTP error: {e}")),
417    }
418}
419
420fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
421    let agent = ureq_agent();
422    match agent
423        .get(url)
424        .set("Authorization", &format!("Bearer {token}"))
425        .set("Accept", "application/json")
426        .call()
427    {
428        Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
429        Err(ureq::Error::Status(code, resp)) => {
430            let body = resp.into_string().unwrap_or_default();
431            Err(format!("HTTP {code}: {body}"))
432        }
433        Err(e) => Err(format!("HTTP error: {e}")),
434    }
435}
436
437fn fetch_github_primary_email(token: &str) -> Result<String, String> {
438    let out = http_get_bearer("https://api.github.com/user/emails", token)?;
439    let emails: serde_json::Value =
440        serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
441    emails
442        .as_array()
443        .and_then(|arr| {
444            arr.iter()
445                .find(|e| {
446                    e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
447                        && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
448                })
449                .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
450        })
451        .ok_or_else(|| "no primary verified email on GitHub".into())
452}
453
454fn extract_access_token(body: &str) -> Result<String, String> {
455    if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
456        if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
457            return Ok(t.to_string());
458        }
459    }
460    // GitHub can return url-encoded: access_token=...&scope=...&token_type=bearer
461    for pair in body.split('&') {
462        if let Some(val) = pair.strip_prefix("access_token=") {
463            return Ok(val.to_string());
464        }
465    }
466    Err(format!("no access_token in token response: {body}"))
467}
468
469/// OAuth provider registry.
470pub struct OAuthRegistry {
471    providers: std::collections::HashMap<String, OAuthConfig>,
472}
473
474impl Default for OAuthRegistry {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480impl OAuthRegistry {
481    pub fn new() -> Self {
482        Self {
483            providers: std::collections::HashMap::new(),
484        }
485    }
486
487    pub fn register(&mut self, config: OAuthConfig) {
488        self.providers.insert(config.provider.clone(), config);
489    }
490
491    pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
492        self.providers.get(provider)
493    }
494
495    /// Build from environment variables.
496    /// Looks for PYLON_OAUTH_GOOGLE_CLIENT_ID, etc.
497    pub fn from_env() -> Self {
498        let mut reg = Self::new();
499
500        // Google
501        if let (Ok(id), Ok(secret)) = (
502            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
503            std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
504        ) {
505            reg.register(OAuthConfig {
506                provider: "google".into(),
507                client_id: id,
508                client_secret: secret,
509                redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
510                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
511            });
512        }
513
514        // GitHub
515        if let (Ok(id), Ok(secret)) = (
516            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
517            std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
518        ) {
519            reg.register(OAuthConfig {
520                provider: "github".into(),
521                client_id: id,
522                client_secret: secret,
523                redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
524                    .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
525            });
526        }
527
528        reg
529    }
530}
531
532// ---------------------------------------------------------------------------
533// OAuth state store — CSRF protection for OAuth flows
534// ---------------------------------------------------------------------------
535
536/// Backing store for OAuth state tokens. Default impl keeps them in memory
537/// (fine for tests + dev); the runtime swaps in a SQLite-backed impl so a
538/// restart in the middle of an OAuth handshake doesn't leave the user with
539/// "invalid state" on the callback. Same pattern as `SessionBackend`.
540pub trait OAuthStateBackend: Send + Sync {
541    fn put(&self, token: &str, provider: &str, expires_at: u64);
542    /// Atomic compare-and-consume: returns the stored provider if the token
543    /// exists and hasn't expired, then removes it. Returning `None` means
544    /// either the token never existed or it has already been used.
545    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
546}
547
548/// In-memory backend (default). Lost on restart.
549pub struct InMemoryOAuthBackend {
550    states: Mutex<HashMap<String, OAuthState>>,
551}
552
553impl InMemoryOAuthBackend {
554    pub fn new() -> Self {
555        Self {
556            states: Mutex::new(HashMap::new()),
557        }
558    }
559}
560
561impl Default for InMemoryOAuthBackend {
562    fn default() -> Self {
563        Self::new()
564    }
565}
566
567impl OAuthStateBackend for InMemoryOAuthBackend {
568    fn put(&self, token: &str, provider: &str, expires_at: u64) {
569        self.states.lock().unwrap().insert(
570            token.to_string(),
571            OAuthState {
572                provider: provider.to_string(),
573                expires_at,
574            },
575        );
576    }
577    fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
578        let mut s = self.states.lock().unwrap();
579        let entry = s.remove(token)?;
580        if entry.expires_at <= now_unix_secs {
581            return None;
582        }
583        Some(entry.provider)
584    }
585}
586
587/// Stores OAuth state parameters to prevent CSRF attacks on the callback.
588///
589/// State tokens are short-lived (10 minutes) and single-use. Backed by an
590/// `OAuthStateBackend`; defaults to in-memory but the runtime persists them
591/// to SQLite so they survive a restart that happens mid-OAuth-handshake.
592pub struct OAuthStateStore {
593    backend: Box<dyn OAuthStateBackend>,
594}
595
596pub struct OAuthState {
597    pub provider: String,
598    pub expires_at: u64,
599}
600
601impl Default for OAuthStateStore {
602    fn default() -> Self {
603        Self::new()
604    }
605}
606
607impl OAuthStateStore {
608    pub fn new() -> Self {
609        Self {
610            backend: Box::new(InMemoryOAuthBackend::new()),
611        }
612    }
613
614    pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
615        Self { backend }
616    }
617
618    /// Generate and store a new state parameter. Returns the random state string.
619    pub fn create(&self, provider: &str) -> String {
620        use std::time::{SystemTime, UNIX_EPOCH};
621        let token = generate_token();
622        let now = SystemTime::now()
623            .duration_since(UNIX_EPOCH)
624            .unwrap_or_default()
625            .as_secs();
626        self.backend.put(&token, provider, now + 600);
627        token
628    }
629
630    /// Validate and consume a state parameter. Returns true iff the state
631    /// existed, has not expired, and matches `expected_provider`. The token
632    /// is removed either way to make replay impossible.
633    pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
634        use std::time::{SystemTime, UNIX_EPOCH};
635        let now = SystemTime::now()
636            .duration_since(UNIX_EPOCH)
637            .unwrap_or_default()
638            .as_secs();
639        match self.backend.take(state, now) {
640            Some(provider) => provider == expected_provider,
641            None => false,
642        }
643    }
644}
645
646// ---------------------------------------------------------------------------
647// Magic code auth — email verification codes
648// ---------------------------------------------------------------------------
649
650/// An in-memory magic code store for development.
651pub struct MagicCodeStore {
652    codes: Mutex<HashMap<String, MagicCode>>,
653}
654
655#[derive(Debug, Clone)]
656pub struct MagicCode {
657    pub email: String,
658    pub code: String,
659    pub expires_at: u64,
660    /// Failed verify attempts against this code. Once it reaches
661    /// `MAX_ATTEMPTS` the code is invalidated.
662    pub attempts: u32,
663}
664
665/// Maximum verify attempts per code before it's burned. 5 is a common bound —
666/// lets the user fix typos without enabling realistic brute-force against a
667/// 6-digit code space.
668const MAX_ATTEMPTS: u32 = 5;
669
670/// Minimum seconds between successive `create()` calls for the same email.
671/// Throttles magic-code spam (user can't be flooded with login codes).
672const CREATE_COOLDOWN_SECS: u64 = 60;
673
674#[derive(Debug, Clone, PartialEq, Eq)]
675pub enum MagicCodeError {
676    /// There is no active code for this email, or it expired.
677    NotFound,
678    /// The code is present but `MAX_ATTEMPTS` failed verifies have occurred.
679    TooManyAttempts,
680    /// The code did not match.
681    BadCode,
682    /// The code expired since it was created.
683    Expired,
684    /// Another code was requested too recently. Wait and try again.
685    Throttled { retry_after_secs: u64 },
686}
687
688impl Default for MagicCodeStore {
689    fn default() -> Self {
690        Self::new()
691    }
692}
693
694impl MagicCodeStore {
695    pub fn new() -> Self {
696        Self {
697            codes: Mutex::new(HashMap::new()),
698        }
699    }
700
701    /// Generate a 6-digit code for an email and return it. Subject to a
702    /// per-email cooldown — returns the error-shape via `try_create`.
703    pub fn create(&self, email: &str) -> String {
704        // Back-compat wrapper: same signature as before, but we still burn
705        // the cooldown if one is active. Use `try_create` for a Result shape.
706        self.try_create(email).unwrap_or_else(|_| String::new())
707    }
708
709    /// Create a magic code, enforcing per-email cooldown. Returns the code
710    /// or an error describing why one couldn't be issued.
711    pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
712        let now = now_secs();
713
714        let mut codes = self.codes.lock().unwrap();
715
716        // Cooldown check: if a live code exists and was created less than
717        // CREATE_COOLDOWN_SECS ago, throttle. The age-of-code is
718        // `expires_at - 600 + cooldown` since expires_at is create_time + 600.
719        if let Some(existing) = codes.get(email) {
720            if existing.expires_at > now {
721                let created_at = existing.expires_at.saturating_sub(600);
722                let age = now.saturating_sub(created_at);
723                if age < CREATE_COOLDOWN_SECS {
724                    return Err(MagicCodeError::Throttled {
725                        retry_after_secs: CREATE_COOLDOWN_SECS - age,
726                    });
727                }
728            }
729        }
730
731        let code = generate_magic_code();
732        let mc = MagicCode {
733            email: email.to_string(),
734            code: code.clone(),
735            expires_at: now + 600, // 10 minutes
736            attempts: 0,
737        };
738        codes.insert(email.to_string(), mc);
739        Ok(code)
740    }
741
742    /// Verify a code for an email. Returns true if valid and not expired.
743    /// Uses constant-time comparison to prevent timing attacks.
744    /// Back-compat wrapper around [`try_verify`].
745    pub fn verify(&self, email: &str, code: &str) -> bool {
746        matches!(self.try_verify(email, code), Ok(()))
747    }
748
749    /// Verify a code. Returns a typed error so callers can surface specific
750    /// messages. On the MAX_ATTEMPTS-th failure, the code is burned — even
751    /// correct subsequent attempts return `TooManyAttempts`.
752    pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
753        let now = now_secs();
754        let mut codes = self.codes.lock().unwrap();
755
756        let mc = match codes.get_mut(email) {
757            Some(m) => m,
758            None => return Err(MagicCodeError::NotFound),
759        };
760
761        if mc.attempts >= MAX_ATTEMPTS {
762            return Err(MagicCodeError::TooManyAttempts);
763        }
764        if mc.expires_at <= now {
765            codes.remove(email);
766            return Err(MagicCodeError::Expired);
767        }
768
769        let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
770        if !ok {
771            mc.attempts += 1;
772            // Burn the code at MAX_ATTEMPTS so retries can't hit max.
773            if mc.attempts >= MAX_ATTEMPTS {
774                return Err(MagicCodeError::TooManyAttempts);
775            }
776            return Err(MagicCodeError::BadCode);
777        }
778
779        // Correct code — consume it.
780        codes.remove(email);
781        Ok(())
782    }
783}
784
785// ---------------------------------------------------------------------------
786// Cryptographic helpers — CSPRNG-based token and code generation
787// ---------------------------------------------------------------------------
788
789fn hex_encode(bytes: &[u8]) -> String {
790    bytes.iter().map(|b| format!("{:02x}", b)).collect()
791}
792
793/// Generate a 6-digit magic code using a CSPRNG.
794fn generate_magic_code() -> String {
795    use rand::Rng;
796    let mut rng = rand::thread_rng();
797    let code: u32 = rng.gen_range(0..1_000_000);
798    format!("{:06}", code)
799}
800
801/// Generate a session token with 256 bits of entropy from a CSPRNG.
802fn generate_token() -> String {
803    use rand::Rng;
804    let mut rng = rand::thread_rng();
805    let bytes: [u8; 32] = rng.gen();
806    format!("pylon_{}", hex_encode(&bytes))
807}
808
809// ---------------------------------------------------------------------------
810// Session store — in-memory for dev
811// ---------------------------------------------------------------------------
812
813use std::collections::HashMap;
814use std::sync::Mutex;
815
816/// Pluggable storage backend for sessions. The default is in-memory; apps
817/// deploying for real should supply a persistent backend (e.g. SQLite or
818/// Redis) so users don't log out on server restart.
819pub trait SessionBackend: Send + Sync {
820    fn load_all(&self) -> Vec<Session>;
821    fn save(&self, session: &Session);
822    fn remove(&self, token: &str);
823}
824
825/// A session store. In-memory by default; optionally backed by a
826/// persistent [`SessionBackend`].
827///
828/// The in-memory map is always authoritative — reads don't touch the
829/// backend. The backend receives every `save`/`remove`, making it a
830/// write-through cache. On construction via [`SessionStore::with_backend`],
831/// the store hydrates from the backend so sessions survive restart.
832pub struct SessionStore {
833    sessions: Mutex<HashMap<String, Session>>,
834    backend: Option<Box<dyn SessionBackend>>,
835}
836
837impl Default for SessionStore {
838    fn default() -> Self {
839        Self::new()
840    }
841}
842
843impl SessionStore {
844    pub fn new() -> Self {
845        Self {
846            sessions: Mutex::new(HashMap::new()),
847            backend: None,
848        }
849    }
850
851    /// Build a session store backed by a persistent store. Existing sessions
852    /// are loaded from the backend on construction; every future mutation
853    /// writes through.
854    pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
855        let mut map = HashMap::new();
856        for s in backend.load_all() {
857            if !s.is_expired() {
858                map.insert(s.token.clone(), s);
859            }
860        }
861        Self {
862            sessions: Mutex::new(map),
863            backend: Some(backend),
864        }
865    }
866
867    /// Create a session for a user and return it.
868    pub fn create(&self, user_id: String) -> Session {
869        let session = Session::new(user_id);
870        let mut sessions = self.sessions.lock().unwrap();
871        sessions.insert(session.token.clone(), session.clone());
872        if let Some(b) = &self.backend {
873            b.save(&session);
874        }
875        session
876    }
877
878    /// Look up a session by token. Returns None if the session is expired.
879    pub fn get(&self, token: &str) -> Option<Session> {
880        let mut sessions = self.sessions.lock().unwrap();
881        match sessions.get(token) {
882            Some(s) if s.is_expired() => {
883                sessions.remove(token);
884                None
885            }
886            Some(s) => Some(s.clone()),
887            None => None,
888        }
889    }
890
891    /// Resolve a token to an auth context.
892    /// Returns anonymous context if the token is invalid, missing, or expired.
893    pub fn resolve(&self, token: Option<&str>) -> AuthContext {
894        match token {
895            Some(t) => match self.get(t) {
896                Some(session) => session.to_auth_context(),
897                None => AuthContext::anonymous(),
898            },
899            None => AuthContext::anonymous(),
900        }
901    }
902
903    /// Refresh a session — issues a new token, copies user/device, extends expiry.
904    /// The old token is revoked. Returns the new session or None if the old
905    /// token is missing/expired.
906    pub fn refresh(&self, old_token: &str) -> Option<Session> {
907        let mut sessions = self.sessions.lock().unwrap();
908        let old = sessions.remove(old_token)?;
909        if let Some(b) = &self.backend {
910            b.remove(old_token);
911        }
912        if old.is_expired() {
913            return None;
914        }
915        let mut new = Session::new(old.user_id.clone());
916        new.device = old.device.clone();
917        sessions.insert(new.token.clone(), new.clone());
918        if let Some(b) = &self.backend {
919            b.save(&new);
920        }
921        Some(new)
922    }
923
924    /// List all active sessions for a user.
925    pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
926        let sessions = self.sessions.lock().unwrap();
927        sessions
928            .values()
929            .filter(|s| s.user_id == user_id && !s.is_expired())
930            .cloned()
931            .collect()
932    }
933
934    /// Revoke all sessions for a user. Returns the count removed.
935    pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
936        let mut sessions = self.sessions.lock().unwrap();
937        let tokens: Vec<String> = sessions
938            .iter()
939            .filter_map(|(t, s)| {
940                if s.user_id == user_id {
941                    Some(t.clone())
942                } else {
943                    None
944                }
945            })
946            .collect();
947        let n = tokens.len();
948        for t in &tokens {
949            sessions.remove(t);
950            if let Some(b) = &self.backend {
951                b.remove(t);
952            }
953        }
954        n
955    }
956
957    /// Sweep expired sessions. Returns the count removed.
958    pub fn sweep_expired(&self) -> usize {
959        let mut sessions = self.sessions.lock().unwrap();
960        let expired: Vec<String> = sessions
961            .iter()
962            .filter_map(|(t, s)| {
963                if s.is_expired() {
964                    Some(t.clone())
965                } else {
966                    None
967                }
968            })
969            .collect();
970        let n = expired.len();
971        for t in &expired {
972            sessions.remove(t);
973            if let Some(b) = &self.backend {
974                b.remove(t);
975            }
976        }
977        n
978    }
979
980    /// Attach a device label to a session (typically on login from a browser).
981    pub fn set_device(&self, token: &str, device: String) -> bool {
982        let mut sessions = self.sessions.lock().unwrap();
983        if let Some(s) = sessions.get_mut(token) {
984            s.device = Some(device);
985            if let Some(b) = &self.backend {
986                b.save(s);
987            }
988            true
989        } else {
990            false
991        }
992    }
993
994    /// Create a guest session with a generated anonymous ID.
995    pub fn create_guest(&self) -> Session {
996        use rand::Rng;
997        let mut rng = rand::thread_rng();
998        let bytes: [u8; 16] = rng.gen();
999        let guest_id = format!("guest_{}", hex_encode(&bytes));
1000        self.create(guest_id)
1001    }
1002
1003    /// Upgrade a guest session to a real user. Replaces the user_id.
1004    pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1005        let mut sessions = self.sessions.lock().unwrap();
1006        if let Some(session) = sessions.get_mut(token) {
1007            session.user_id = real_user_id;
1008            if let Some(b) = &self.backend {
1009                b.save(session);
1010            }
1011            true
1012        } else {
1013            false
1014        }
1015    }
1016
1017    /// Switch the session's active tenant (organization). `None` clears it.
1018    /// Callers should verify the user actually has membership in the target
1019    /// tenant BEFORE invoking this — the session store takes the value on
1020    /// trust. Returns true if the session exists, false otherwise.
1021    pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1022        let mut sessions = self.sessions.lock().unwrap();
1023        if let Some(session) = sessions.get_mut(token) {
1024            session.tenant_id = tenant_id;
1025            if let Some(b) = &self.backend {
1026                b.save(session);
1027            }
1028            true
1029        } else {
1030            false
1031        }
1032    }
1033
1034    /// Remove a session.
1035    pub fn revoke(&self, token: &str) -> bool {
1036        let mut sessions = self.sessions.lock().unwrap();
1037        let removed = sessions.remove(token).is_some();
1038        if removed {
1039            if let Some(b) = &self.backend {
1040                b.remove(token);
1041            }
1042        }
1043        removed
1044    }
1045}
1046
1047// ---------------------------------------------------------------------------
1048// Tests
1049// ---------------------------------------------------------------------------
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054
1055    #[test]
1056    fn anonymous_context() {
1057        let ctx = AuthContext::anonymous();
1058        assert!(!ctx.is_authenticated());
1059        assert!(ctx.user_id.is_none());
1060    }
1061
1062    #[test]
1063    fn authenticated_context() {
1064        let ctx = AuthContext::authenticated("user-1".into());
1065        assert!(ctx.is_authenticated());
1066        assert_eq!(ctx.user_id, Some("user-1".into()));
1067    }
1068
1069    #[test]
1070    fn auth_mode_public_allows_anonymous() {
1071        let mode = AuthMode::Public;
1072        assert!(mode.check(&AuthContext::anonymous()));
1073        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1074    }
1075
1076    #[test]
1077    fn auth_mode_user_requires_authenticated() {
1078        let mode = AuthMode::User;
1079        assert!(!mode.check(&AuthContext::anonymous()));
1080        assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1081    }
1082
1083    #[test]
1084    fn auth_mode_from_str() {
1085        assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1086        assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1087        assert_eq!(AuthMode::from_str("admin"), None);
1088    }
1089
1090    #[test]
1091    fn session_store_create_and_get() {
1092        let store = SessionStore::new();
1093        let session = store.create("user-1".into());
1094        assert!(!session.token.is_empty());
1095        assert!(session.token.starts_with("pylon_"));
1096
1097        let retrieved = store.get(&session.token).unwrap();
1098        assert_eq!(retrieved.user_id, "user-1");
1099    }
1100
1101    #[test]
1102    fn session_store_resolve() {
1103        let store = SessionStore::new();
1104        let session = store.create("user-1".into());
1105
1106        let ctx = store.resolve(Some(&session.token));
1107        assert!(ctx.is_authenticated());
1108        assert_eq!(ctx.user_id, Some("user-1".into()));
1109
1110        let anon = store.resolve(None);
1111        assert!(!anon.is_authenticated());
1112
1113        let bad = store.resolve(Some("invalid-token"));
1114        assert!(!bad.is_authenticated());
1115    }
1116
1117    #[test]
1118    fn session_store_revoke() {
1119        let store = SessionStore::new();
1120        let session = store.create("user-1".into());
1121
1122        assert!(store.revoke(&session.token));
1123        assert!(store.get(&session.token).is_none());
1124        assert!(!store.revoke(&session.token)); // already revoked
1125    }
1126
1127    #[test]
1128    fn session_to_auth_context() {
1129        let session = Session::new("user-42".into());
1130        let ctx = session.to_auth_context();
1131        assert_eq!(ctx.user_id, Some("user-42".into()));
1132    }
1133
1134    // -- Admin context --
1135
1136    #[test]
1137    fn admin_context() {
1138        let ctx = AuthContext::admin();
1139        assert!(ctx.is_admin);
1140        assert!(ctx.is_authenticated());
1141    }
1142
1143    #[test]
1144    fn anonymous_not_admin() {
1145        let ctx = AuthContext::anonymous();
1146        assert!(!ctx.is_admin);
1147    }
1148
1149    #[test]
1150    fn authenticated_not_admin() {
1151        let ctx = AuthContext::authenticated("user-1".into());
1152        assert!(!ctx.is_admin);
1153    }
1154
1155    // -- Magic codes --
1156
1157    #[test]
1158    fn magic_code_create_and_verify() {
1159        let store = MagicCodeStore::new();
1160        let code = store.create("test@example.com");
1161        assert_eq!(code.len(), 6);
1162        assert!(store.verify("test@example.com", &code));
1163    }
1164
1165    #[test]
1166    fn magic_code_wrong_code_rejected() {
1167        let store = MagicCodeStore::new();
1168        store.create("test@example.com");
1169        assert!(!store.verify("test@example.com", "000000"));
1170    }
1171
1172    #[test]
1173    fn magic_code_wrong_email_rejected() {
1174        let store = MagicCodeStore::new();
1175        let code = store.create("test@example.com");
1176        assert!(!store.verify("other@example.com", &code));
1177    }
1178
1179    #[test]
1180    fn magic_code_consumed_after_verify() {
1181        let store = MagicCodeStore::new();
1182        let code = store.create("test@example.com");
1183        assert!(store.verify("test@example.com", &code));
1184        // Second verify should fail — code consumed.
1185        assert!(!store.verify("test@example.com", &code));
1186    }
1187
1188    #[test]
1189    fn magic_code_different_emails_independent() {
1190        let store = MagicCodeStore::new();
1191        let code1 = store.create("alice@example.com");
1192        let code2 = store.create("bob@example.com");
1193        // Each email has its own code.
1194        assert!(store.verify("alice@example.com", &code1));
1195        assert!(store.verify("bob@example.com", &code2));
1196    }
1197
1198    // -- Constant-time comparison --
1199
1200    #[test]
1201    fn constant_time_eq_equal() {
1202        assert!(constant_time_eq(b"hello", b"hello"));
1203        assert!(constant_time_eq(b"", b""));
1204    }
1205
1206    #[test]
1207    fn constant_time_eq_not_equal() {
1208        assert!(!constant_time_eq(b"hello", b"world"));
1209        assert!(!constant_time_eq(b"hello", b"hell"));
1210        assert!(!constant_time_eq(b"a", b"b"));
1211    }
1212
1213    // -- Token generation --
1214
1215    #[test]
1216    fn generated_tokens_are_unique() {
1217        let t1 = generate_token();
1218        let t2 = generate_token();
1219        assert_ne!(t1, t2);
1220        assert!(t1.starts_with("pylon_"));
1221        assert!(t2.starts_with("pylon_"));
1222        // 256 bits = 64 hex chars + "pylon_" prefix (6 chars)
1223        assert_eq!(t1.len(), 6 + 64);
1224    }
1225
1226    // -- OAuth registry --
1227
1228    #[test]
1229    fn oauth_registry_empty() {
1230        let reg = OAuthRegistry::new();
1231        assert!(reg.get("google").is_none());
1232    }
1233
1234    #[test]
1235    fn oauth_registry_register_and_get() {
1236        let mut reg = OAuthRegistry::new();
1237        reg.register(OAuthConfig {
1238            provider: "google".into(),
1239            client_id: "test-id".into(),
1240            client_secret: "test-secret".into(),
1241            redirect_uri: "http://localhost/callback".into(),
1242        });
1243        let config = reg.get("google").unwrap();
1244        assert_eq!(config.client_id, "test-id");
1245        assert!(config.auth_url().contains("accounts.google.com"));
1246    }
1247
1248    // -- Guest auth --
1249
1250    #[test]
1251    fn guest_session() {
1252        let store = SessionStore::new();
1253        let session = store.create_guest();
1254        assert!(session.user_id.starts_with("guest_"));
1255        assert!(!session.token.is_empty());
1256
1257        let ctx = store.resolve(Some(&session.token));
1258        assert!(ctx.is_authenticated());
1259        assert!(ctx.user_id.unwrap().starts_with("guest_"));
1260    }
1261
1262    #[test]
1263    fn upgrade_guest_to_real_user() {
1264        let store = SessionStore::new();
1265        let session = store.create_guest();
1266        assert!(session.user_id.starts_with("guest_"));
1267
1268        let upgraded = store.upgrade(&session.token, "real-user-123".into());
1269        assert!(upgraded);
1270
1271        let ctx = store.resolve(Some(&session.token));
1272        assert_eq!(ctx.user_id, Some("real-user-123".into()));
1273    }
1274
1275    #[test]
1276    fn upgrade_invalid_token_fails() {
1277        let store = SessionStore::new();
1278        let upgraded = store.upgrade("nonexistent-token", "user".into());
1279        assert!(!upgraded);
1280    }
1281
1282    #[test]
1283    fn guest_context() {
1284        let ctx = AuthContext::guest("guest_123".into());
1285        assert!(ctx.is_authenticated());
1286        assert!(!ctx.is_admin);
1287        assert_eq!(ctx.user_id, Some("guest_123".into()));
1288    }
1289
1290    #[test]
1291    fn oauth_token_urls() {
1292        let google = OAuthConfig {
1293            provider: "google".into(),
1294            client_id: "x".into(),
1295            client_secret: "x".into(),
1296            redirect_uri: "x".into(),
1297        };
1298        assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1299        let github = OAuthConfig {
1300            provider: "github".into(),
1301            client_id: "x".into(),
1302            client_secret: "x".into(),
1303            redirect_uri: "x".into(),
1304        };
1305        assert_eq!(
1306            github.token_url(),
1307            "https://github.com/login/oauth/access_token"
1308        );
1309        let unknown = OAuthConfig {
1310            provider: "unknown".into(),
1311            client_id: "x".into(),
1312            client_secret: "x".into(),
1313            redirect_uri: "x".into(),
1314        };
1315        assert_eq!(unknown.token_url(), "");
1316        assert!(unknown.auth_url().is_empty());
1317    }
1318
1319    #[test]
1320    fn oauth_auth_url_github() {
1321        let config = OAuthConfig {
1322            provider: "github".into(),
1323            client_id: "gh-id".into(),
1324            client_secret: "gh-secret".into(),
1325            redirect_uri: "http://localhost/cb".into(),
1326        };
1327        assert!(config.auth_url().contains("github.com"));
1328        assert!(config.auth_url().contains("gh-id"));
1329    }
1330
1331    #[test]
1332    fn oauth_auth_url_with_state() {
1333        let config = OAuthConfig {
1334            provider: "google".into(),
1335            client_id: "test-id".into(),
1336            client_secret: "test-secret".into(),
1337            redirect_uri: "http://localhost/cb".into(),
1338        };
1339        let url = config.auth_url_with_state("random_state_123");
1340        assert!(url.contains("&state=random_state_123"));
1341    }
1342
1343    #[test]
1344    fn oauth_state_store_create_and_validate() {
1345        let store = OAuthStateStore::new();
1346        let state = store.create("google");
1347        assert!(store.validate(&state, "google"));
1348        // Second validation should fail — consumed.
1349        assert!(!store.validate(&state, "google"));
1350    }
1351
1352    #[test]
1353    fn oauth_state_store_wrong_provider_rejected() {
1354        let store = OAuthStateStore::new();
1355        let state = store.create("google");
1356        assert!(!store.validate(&state, "github"));
1357    }
1358
1359    #[test]
1360    fn oauth_state_store_invalid_state_rejected() {
1361        let store = OAuthStateStore::new();
1362        assert!(!store.validate("nonexistent", "google"));
1363    }
1364}