1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25 pub user_id: Option<String>,
29 pub is_admin: bool,
31 #[serde(default, skip_serializing_if = "is_false")]
36 pub is_guest: bool,
37 pub roles: Vec<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46 !b
47}
48
49impl AuthContext {
50 pub fn anonymous() -> Self {
52 Self {
53 user_id: None,
54 is_admin: false,
55 is_guest: false,
56 roles: Vec::new(),
57 tenant_id: None,
58 }
59 }
60
61 pub fn authenticated(user_id: String) -> Self {
63 Self {
64 user_id: Some(user_id),
65 is_admin: false,
66 is_guest: false,
67 roles: Vec::new(),
68 tenant_id: None,
69 }
70 }
71
72 pub fn guest(guest_id: String) -> Self {
77 Self {
78 user_id: Some(guest_id),
79 is_admin: false,
80 is_guest: true,
81 roles: Vec::new(),
82 tenant_id: None,
83 }
84 }
85
86 pub fn admin() -> Self {
88 Self {
89 user_id: Some("__admin__".into()),
90 is_admin: true,
91 is_guest: false,
92 roles: vec!["admin".into()],
93 tenant_id: None,
94 }
95 }
96
97 pub fn user(user_id: String) -> Self {
99 Self::authenticated(user_id)
100 }
101
102 pub fn tenant_id(&self) -> Option<&str> {
104 self.tenant_id.as_deref()
105 }
106
107 pub fn with_tenant(mut self, tenant_id: String) -> Self {
109 self.tenant_id = Some(tenant_id);
110 self
111 }
112
113 pub fn is_authenticated(&self) -> bool {
117 self.user_id.is_some() && !self.is_guest
118 }
119
120 pub fn has_role(&self, role: &str) -> bool {
122 self.is_admin || self.roles.iter().any(|r| r == role)
123 }
124
125 pub fn has_any_role(&self, roles: &[&str]) -> bool {
127 self.is_admin || roles.iter().any(|r| self.has_role(r))
128 }
129
130 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132 self.roles = roles;
133 self
134 }
135}
136
137pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147 if a.len() != b.len() {
148 return false;
149 }
150 let mut result: u8 = 0;
151 for (x, y) in a.iter().zip(b.iter()) {
152 result |= x ^ y;
153 }
154 result == 0
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164 Public,
166 User,
168}
169
170impl AuthMode {
171 #[allow(clippy::should_implement_trait)]
173 pub fn from_str(s: &str) -> Option<Self> {
174 match s {
175 "public" => Some(AuthMode::Public),
176 "user" => Some(AuthMode::User),
177 _ => None,
178 }
179 }
180
181 pub fn check(&self, ctx: &AuthContext) -> bool {
183 match self {
184 AuthMode::Public => true,
185 AuthMode::User => ctx.is_authenticated(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197 pub token: String,
198 pub user_id: String,
199 #[serde(default)]
201 pub expires_at: u64,
202 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub device: Option<String>,
205 #[serde(default)]
207 pub created_at: u64,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub tenant_id: Option<String>,
213}
214
215impl Session {
216 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219 pub fn new(user_id: String) -> Self {
221 let now = now_secs();
222 Self {
223 token: generate_token(),
224 user_id,
225 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226 device: None,
227 created_at: now,
228 tenant_id: None,
229 }
230 }
231
232 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234 let now = now_secs();
235 Self {
236 token: generate_token(),
237 user_id,
238 expires_at: if lifetime_secs == 0 {
239 0
240 } else {
241 now.saturating_add(lifetime_secs)
242 },
243 device: None,
244 created_at: now,
245 tenant_id: None,
246 }
247 }
248
249 pub fn to_auth_context(&self) -> AuthContext {
252 let ctx = AuthContext::authenticated(self.user_id.clone());
253 match &self.tenant_id {
254 Some(t) => ctx.with_tenant(t.clone()),
255 None => ctx,
256 }
257 }
258
259 pub fn is_expired(&self) -> bool {
263 self.expires_at != 0 && now_secs() >= self.expires_at
264 }
265}
266
267fn now_secs() -> u64 {
268 use std::time::{SystemTime, UNIX_EPOCH};
269 SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281 pub provider: String,
282 pub client_id: String,
283 pub client_secret: String,
284 pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288 pub fn auth_url(&self) -> String {
294 match self.provider.as_str() {
295 "google" => format!(
296 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297 self.client_id, self.redirect_uri
298 ),
299 "github" => format!(
300 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301 self.client_id, self.redirect_uri
302 ),
303 _ => String::new(),
304 }
305 }
306
307 pub fn auth_url_with_state(&self, state: &str) -> String {
309 let base = self.auth_url();
310 if base.is_empty() {
311 return base;
312 }
313 format!("{}&state={}", base, state)
314 }
315
316 pub fn token_url(&self) -> &str {
318 match self.provider.as_str() {
319 "google" => "https://oauth2.googleapis.com/token",
320 "github" => "https://github.com/login/oauth/access_token",
321 _ => "",
322 }
323 }
324
325 pub fn userinfo_url(&self) -> &str {
327 match self.provider.as_str() {
328 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329 "github" => "https://api.github.com/user",
330 _ => "",
331 }
332 }
333
334 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341 let body = match self.provider.as_str() {
342 "google" => format!(
343 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344 url_encode(&self.client_id),
345 url_encode(&self.client_secret),
346 url_encode(&self.redirect_uri)
347 ),
348 "github" => format!(
349 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350 url_encode(&self.client_id),
351 url_encode(&self.client_secret),
352 url_encode(&self.redirect_uri)
353 ),
354 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355 };
356
357 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358 parse_token_response(&out)
359 }
360
361 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365 let body = match self.provider.as_str() {
366 "google" => format!(
367 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368 url_encode(&self.client_id),
369 url_encode(&self.client_secret),
370 url_encode(&self.redirect_uri)
371 ),
372 "github" => format!(
373 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374 url_encode(&self.client_id),
375 url_encode(&self.client_secret),
376 url_encode(&self.redirect_uri)
377 ),
378 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379 };
380
381 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382 extract_access_token(&out)
383 }
384
385 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391 let info = self.fetch_userinfo_full(access_token)?;
392 Ok((info.email, info.name))
393 }
394
395 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401 let out = http_get_bearer(self.userinfo_url(), access_token)?;
402 let parsed: serde_json::Value =
403 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404 match self.provider.as_str() {
405 "google" => {
406 let email = parsed
407 .get("email")
408 .and_then(|v| v.as_str())
409 .ok_or("no email in userinfo")?
410 .to_string();
411 let name = parsed
412 .get("name")
413 .and_then(|v| v.as_str())
414 .map(String::from);
415 let provider_account_id = parsed
416 .get("sub")
417 .and_then(|v| v.as_str())
418 .ok_or("no sub in userinfo")?
419 .to_string();
420 Ok(UserInfo {
421 provider: self.provider.clone(),
422 provider_account_id,
423 email,
424 name,
425 })
426 }
427 "github" => {
428 let name = parsed
429 .get("name")
430 .and_then(|v| v.as_str())
431 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432 .map(String::from);
433 let email = parsed
434 .get("email")
435 .and_then(|v| v.as_str())
436 .map(String::from);
437 let email = email
440 .or_else(|| fetch_github_primary_email(access_token).ok())
441 .ok_or("no accessible email on GitHub account")?;
442 let provider_account_id = parsed
445 .get("id")
446 .map(|v| {
447 v.as_i64()
448 .map(|n| n.to_string())
449 .or_else(|| v.as_str().map(String::from))
450 .unwrap_or_default()
451 })
452 .filter(|s| !s.is_empty())
453 .ok_or("no id in userinfo")?;
454 Ok(UserInfo {
455 provider: self.provider.clone(),
456 provider_account_id,
457 email,
458 name,
459 })
460 }
461 _ => Err(format!("unknown provider: {}", self.provider)),
462 }
463 }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472 pub provider: String,
473 pub provider_account_id: String,
474 pub email: String,
475 pub name: Option<String>,
476}
477
478#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483 pub access_token: String,
484 pub refresh_token: Option<String>,
485 pub id_token: Option<String>,
486 pub expires_at: Option<u64>,
490 pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497 let mut map = serde_json::Map::new();
499 for pair in body.split('&') {
500 if let Some((k, v)) = pair.split_once('=') {
501 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502 }
503 }
504 serde_json::Value::Object(map)
505 });
506
507 let access_token = json
508 .get("access_token")
509 .and_then(|v| v.as_str())
510 .ok_or_else(|| format!("no access_token in token response: {body}"))?
511 .to_string();
512 let refresh_token = json
513 .get("refresh_token")
514 .and_then(|v| v.as_str())
515 .map(String::from);
516 let id_token = json
517 .get("id_token")
518 .and_then(|v| v.as_str())
519 .map(String::from);
520 let expires_at = json
521 .get("expires_in")
522 .and_then(|v| {
523 v.as_u64()
524 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525 })
526 .map(|secs| now_secs().saturating_add(secs));
527 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528 Ok(TokenSet {
529 access_token,
530 refresh_token,
531 id_token,
532 expires_at,
533 scope,
534 })
535}
536
537fn url_encode(s: &str) -> String {
538 let mut out = String::with_capacity(s.len());
539 for b in s.bytes() {
540 match b {
541 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542 out.push(b as char)
543 }
544 _ => out.push_str(&format!("%{b:02X}")),
545 }
546 }
547 out
548}
549
550const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556 ureq::AgentBuilder::new()
557 .timeout_connect(HTTP_TIMEOUT)
558 .timeout_read(HTTP_TIMEOUT)
559 .timeout_write(HTTP_TIMEOUT)
560 .user_agent("pylon/0.1")
561 .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565 let agent = ureq_agent();
566 let mut req = agent
567 .post(url)
568 .set("Content-Type", "application/x-www-form-urlencoded");
569 if accept_json {
570 req = req.set("Accept", "application/json");
571 }
572 match req.send_string(body) {
573 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574 Err(ureq::Error::Status(code, resp)) => {
575 let body = resp.into_string().unwrap_or_default();
576 Err(format!("HTTP {code}: {body}"))
577 }
578 Err(e) => Err(format!("HTTP error: {e}")),
579 }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583 let agent = ureq_agent();
584 match agent
585 .get(url)
586 .set("Authorization", &format!("Bearer {token}"))
587 .set("Accept", "application/json")
588 .call()
589 {
590 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591 Err(ureq::Error::Status(code, resp)) => {
592 let body = resp.into_string().unwrap_or_default();
593 Err(format!("HTTP {code}: {body}"))
594 }
595 Err(e) => Err(format!("HTTP error: {e}")),
596 }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601 let emails: serde_json::Value =
602 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603 emails
604 .as_array()
605 .and_then(|arr| {
606 arr.iter()
607 .find(|e| {
608 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610 })
611 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612 })
613 .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619 return Ok(t.to_string());
620 }
621 }
622 for pair in body.split('&') {
624 if let Some(val) = pair.strip_prefix("access_token=") {
625 return Ok(val.to_string());
626 }
627 }
628 Err(format!("no access_token in token response: {body}"))
629}
630
631pub struct OAuthRegistry {
633 providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642impl OAuthRegistry {
643 pub fn new() -> Self {
644 Self {
645 providers: std::collections::HashMap::new(),
646 }
647 }
648
649 pub fn register(&mut self, config: OAuthConfig) {
650 self.providers.insert(config.provider.clone(), config);
651 }
652
653 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654 self.providers.get(provider)
655 }
656
657 pub fn from_env() -> Self {
660 let mut reg = Self::new();
661
662 if let (Ok(id), Ok(secret)) = (
664 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666 ) {
667 reg.register(OAuthConfig {
668 provider: "google".into(),
669 client_id: id,
670 client_secret: secret,
671 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673 });
674 }
675
676 if let (Ok(id), Ok(secret)) = (
678 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680 ) {
681 reg.register(OAuthConfig {
682 provider: "github".into(),
683 client_id: id,
684 client_secret: secret,
685 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687 });
688 }
689
690 reg
691 }
692}
693
694pub trait OAuthStateBackend: Send + Sync {
703 fn put(&self, token: &str, provider: &str, expires_at: u64);
704 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
708}
709
710pub struct InMemoryOAuthBackend {
712 states: Mutex<HashMap<String, OAuthState>>,
713}
714
715impl InMemoryOAuthBackend {
716 pub fn new() -> Self {
717 Self {
718 states: Mutex::new(HashMap::new()),
719 }
720 }
721}
722
723impl Default for InMemoryOAuthBackend {
724 fn default() -> Self {
725 Self::new()
726 }
727}
728
729impl OAuthStateBackend for InMemoryOAuthBackend {
730 fn put(&self, token: &str, provider: &str, expires_at: u64) {
731 self.states.lock().unwrap().insert(
732 token.to_string(),
733 OAuthState {
734 provider: provider.to_string(),
735 expires_at,
736 },
737 );
738 }
739 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
740 let mut s = self.states.lock().unwrap();
741 let entry = s.remove(token)?;
742 if entry.expires_at <= now_unix_secs {
743 return None;
744 }
745 Some(entry.provider)
746 }
747}
748
749pub struct OAuthStateStore {
755 backend: Box<dyn OAuthStateBackend>,
756}
757
758pub struct OAuthState {
759 pub provider: String,
760 pub expires_at: u64,
761}
762
763impl Default for OAuthStateStore {
764 fn default() -> Self {
765 Self::new()
766 }
767}
768
769impl OAuthStateStore {
770 pub fn new() -> Self {
771 Self {
772 backend: Box::new(InMemoryOAuthBackend::new()),
773 }
774 }
775
776 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
777 Self { backend }
778 }
779
780 pub fn create(&self, provider: &str) -> String {
782 use std::time::{SystemTime, UNIX_EPOCH};
783 let token = generate_token();
784 let now = SystemTime::now()
785 .duration_since(UNIX_EPOCH)
786 .unwrap_or_default()
787 .as_secs();
788 self.backend.put(&token, provider, now + 600);
789 token
790 }
791
792 pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
796 use std::time::{SystemTime, UNIX_EPOCH};
797 let now = SystemTime::now()
798 .duration_since(UNIX_EPOCH)
799 .unwrap_or_default()
800 .as_secs();
801 match self.backend.take(state, now) {
802 Some(provider) => provider == expected_provider,
803 None => false,
804 }
805 }
806}
807
808pub trait MagicCodeBackend: Send + Sync {
821 fn put(&self, email: &str, code: &MagicCode);
823 fn get(&self, email: &str) -> Option<MagicCode>;
825 fn remove(&self, email: &str);
828 fn bump_attempts(&self, email: &str);
832 fn load_all(&self) -> Vec<MagicCode>;
835}
836
837pub struct InMemoryMagicCodeBackend {
840 codes: Mutex<HashMap<String, MagicCode>>,
841}
842
843impl InMemoryMagicCodeBackend {
844 pub fn new() -> Self {
845 Self {
846 codes: Mutex::new(HashMap::new()),
847 }
848 }
849}
850
851impl Default for InMemoryMagicCodeBackend {
852 fn default() -> Self {
853 Self::new()
854 }
855}
856
857impl MagicCodeBackend for InMemoryMagicCodeBackend {
858 fn put(&self, email: &str, code: &MagicCode) {
859 self.codes
860 .lock()
861 .unwrap()
862 .insert(email.to_string(), code.clone());
863 }
864 fn get(&self, email: &str) -> Option<MagicCode> {
865 self.codes.lock().unwrap().get(email).cloned()
866 }
867 fn remove(&self, email: &str) {
868 self.codes.lock().unwrap().remove(email);
869 }
870 fn bump_attempts(&self, email: &str) {
871 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
872 c.attempts = c.attempts.saturating_add(1);
873 }
874 }
875 fn load_all(&self) -> Vec<MagicCode> {
876 self.codes.lock().unwrap().values().cloned().collect()
877 }
878}
879
880pub struct MagicCodeStore {
885 cache: Mutex<HashMap<String, MagicCode>>,
886 backend: Box<dyn MagicCodeBackend>,
887}
888
889#[derive(Debug, Clone)]
890pub struct MagicCode {
891 pub email: String,
892 pub code: String,
893 pub expires_at: u64,
894 pub attempts: u32,
897}
898
899const MAX_ATTEMPTS: u32 = 5;
903
904const CREATE_COOLDOWN_SECS: u64 = 60;
907
908#[derive(Debug, Clone, PartialEq, Eq)]
909pub enum MagicCodeError {
910 NotFound,
912 TooManyAttempts,
914 BadCode,
916 Expired,
918 Throttled { retry_after_secs: u64 },
920}
921
922impl Default for MagicCodeStore {
923 fn default() -> Self {
924 Self::new()
925 }
926}
927
928impl MagicCodeStore {
929 pub fn new() -> Self {
930 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
931 }
932
933 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
938 let now = now_secs();
939 let mut cache = HashMap::new();
940 for c in backend.load_all() {
941 if c.expires_at > now {
942 cache.insert(c.email.clone(), c);
943 }
944 }
945 Self {
946 cache: Mutex::new(cache),
947 backend,
948 }
949 }
950
951 pub fn create(&self, email: &str) -> String {
954 self.try_create(email).unwrap_or_else(|_| String::new())
957 }
958
959 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
962 let now = now_secs();
963
964 let mut codes = self.cache.lock().unwrap();
965
966 if let Some(existing) = codes.get(email) {
970 if existing.expires_at > now {
971 let created_at = existing.expires_at.saturating_sub(600);
972 let age = now.saturating_sub(created_at);
973 if age < CREATE_COOLDOWN_SECS {
974 return Err(MagicCodeError::Throttled {
975 retry_after_secs: CREATE_COOLDOWN_SECS - age,
976 });
977 }
978 }
979 }
980
981 let code = generate_magic_code();
982 let mc = MagicCode {
983 email: email.to_string(),
984 code: code.clone(),
985 expires_at: now + 600, attempts: 0,
987 };
988 codes.insert(email.to_string(), mc.clone());
989 self.backend.put(email, &mc);
993 Ok(code)
994 }
995
996 pub fn verify(&self, email: &str, code: &str) -> bool {
1000 matches!(self.try_verify(email, code), Ok(()))
1001 }
1002
1003 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1007 let now = now_secs();
1008 let mut codes = self.cache.lock().unwrap();
1009
1010 let mc = match codes.get_mut(email) {
1011 Some(m) => m,
1012 None => return Err(MagicCodeError::NotFound),
1013 };
1014
1015 if mc.attempts >= MAX_ATTEMPTS {
1016 return Err(MagicCodeError::TooManyAttempts);
1017 }
1018 if mc.expires_at <= now {
1019 codes.remove(email);
1020 self.backend.remove(email);
1021 return Err(MagicCodeError::Expired);
1022 }
1023
1024 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1025 if !ok {
1026 mc.attempts += 1;
1027 self.backend.bump_attempts(email);
1028 if mc.attempts >= MAX_ATTEMPTS {
1030 return Err(MagicCodeError::TooManyAttempts);
1031 }
1032 return Err(MagicCodeError::BadCode);
1033 }
1034
1035 codes.remove(email);
1037 self.backend.remove(email);
1038 Ok(())
1039 }
1040}
1041
1042fn hex_encode(bytes: &[u8]) -> String {
1047 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1048}
1049
1050fn generate_magic_code() -> String {
1052 use rand::Rng;
1053 let mut rng = rand::thread_rng();
1054 let code: u32 = rng.gen_range(0..1_000_000);
1055 format!("{:06}", code)
1056}
1057
1058fn generate_token() -> String {
1060 use rand::Rng;
1061 let mut rng = rand::thread_rng();
1062 let bytes: [u8; 32] = rng.gen();
1063 format!("pylon_{}", hex_encode(&bytes))
1064}
1065
1066use std::collections::HashMap;
1071use std::sync::Mutex;
1072
1073pub trait SessionBackend: Send + Sync {
1077 fn load_all(&self) -> Vec<Session>;
1078 fn save(&self, session: &Session);
1079 fn remove(&self, token: &str);
1080}
1081
1082pub struct SessionStore {
1090 sessions: Mutex<HashMap<String, Session>>,
1091 backend: Option<Box<dyn SessionBackend>>,
1092}
1093
1094impl Default for SessionStore {
1095 fn default() -> Self {
1096 Self::new()
1097 }
1098}
1099
1100impl SessionStore {
1101 pub fn new() -> Self {
1102 Self {
1103 sessions: Mutex::new(HashMap::new()),
1104 backend: None,
1105 }
1106 }
1107
1108 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1112 let mut map = HashMap::new();
1113 for s in backend.load_all() {
1114 if !s.is_expired() {
1115 map.insert(s.token.clone(), s);
1116 }
1117 }
1118 Self {
1119 sessions: Mutex::new(map),
1120 backend: Some(backend),
1121 }
1122 }
1123
1124 pub fn create(&self, user_id: String) -> Session {
1126 let session = Session::new(user_id);
1127 let mut sessions = self.sessions.lock().unwrap();
1128 sessions.insert(session.token.clone(), session.clone());
1129 if let Some(b) = &self.backend {
1130 b.save(&session);
1131 }
1132 session
1133 }
1134
1135 pub fn get(&self, token: &str) -> Option<Session> {
1137 let mut sessions = self.sessions.lock().unwrap();
1138 match sessions.get(token) {
1139 Some(s) if s.is_expired() => {
1140 sessions.remove(token);
1141 None
1142 }
1143 Some(s) => Some(s.clone()),
1144 None => None,
1145 }
1146 }
1147
1148 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1151 match token {
1152 Some(t) => match self.get(t) {
1153 Some(session) => session.to_auth_context(),
1154 None => AuthContext::anonymous(),
1155 },
1156 None => AuthContext::anonymous(),
1157 }
1158 }
1159
1160 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1164 let mut sessions = self.sessions.lock().unwrap();
1165 let old = sessions.remove(old_token)?;
1166 if let Some(b) = &self.backend {
1167 b.remove(old_token);
1168 }
1169 if old.is_expired() {
1170 return None;
1171 }
1172 let mut new = Session::new(old.user_id.clone());
1173 new.device = old.device.clone();
1174 sessions.insert(new.token.clone(), new.clone());
1175 if let Some(b) = &self.backend {
1176 b.save(&new);
1177 }
1178 Some(new)
1179 }
1180
1181 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1183 let sessions = self.sessions.lock().unwrap();
1184 sessions
1185 .values()
1186 .filter(|s| s.user_id == user_id && !s.is_expired())
1187 .cloned()
1188 .collect()
1189 }
1190
1191 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1193 let mut sessions = self.sessions.lock().unwrap();
1194 let tokens: Vec<String> = sessions
1195 .iter()
1196 .filter_map(|(t, s)| {
1197 if s.user_id == user_id {
1198 Some(t.clone())
1199 } else {
1200 None
1201 }
1202 })
1203 .collect();
1204 let n = tokens.len();
1205 for t in &tokens {
1206 sessions.remove(t);
1207 if let Some(b) = &self.backend {
1208 b.remove(t);
1209 }
1210 }
1211 n
1212 }
1213
1214 pub fn sweep_expired(&self) -> usize {
1216 let mut sessions = self.sessions.lock().unwrap();
1217 let expired: Vec<String> = sessions
1218 .iter()
1219 .filter_map(|(t, s)| {
1220 if s.is_expired() {
1221 Some(t.clone())
1222 } else {
1223 None
1224 }
1225 })
1226 .collect();
1227 let n = expired.len();
1228 for t in &expired {
1229 sessions.remove(t);
1230 if let Some(b) = &self.backend {
1231 b.remove(t);
1232 }
1233 }
1234 n
1235 }
1236
1237 pub fn set_device(&self, token: &str, device: String) -> bool {
1239 let mut sessions = self.sessions.lock().unwrap();
1240 if let Some(s) = sessions.get_mut(token) {
1241 s.device = Some(device);
1242 if let Some(b) = &self.backend {
1243 b.save(s);
1244 }
1245 true
1246 } else {
1247 false
1248 }
1249 }
1250
1251 pub fn create_guest(&self) -> Session {
1253 use rand::Rng;
1254 let mut rng = rand::thread_rng();
1255 let bytes: [u8; 16] = rng.gen();
1256 let guest_id = format!("guest_{}", hex_encode(&bytes));
1257 self.create(guest_id)
1258 }
1259
1260 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1262 let mut sessions = self.sessions.lock().unwrap();
1263 if let Some(session) = sessions.get_mut(token) {
1264 session.user_id = real_user_id;
1265 if let Some(b) = &self.backend {
1266 b.save(session);
1267 }
1268 true
1269 } else {
1270 false
1271 }
1272 }
1273
1274 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1279 let mut sessions = self.sessions.lock().unwrap();
1280 if let Some(session) = sessions.get_mut(token) {
1281 session.tenant_id = tenant_id;
1282 if let Some(b) = &self.backend {
1283 b.save(session);
1284 }
1285 true
1286 } else {
1287 false
1288 }
1289 }
1290
1291 pub fn revoke(&self, token: &str) -> bool {
1293 let mut sessions = self.sessions.lock().unwrap();
1294 let removed = sessions.remove(token).is_some();
1295 if removed {
1296 if let Some(b) = &self.backend {
1297 b.remove(token);
1298 }
1299 }
1300 removed
1301 }
1302}
1303
1304#[derive(Debug, Clone, PartialEq, Eq)]
1330pub struct Account {
1331 pub id: String,
1332 pub user_id: String,
1333 pub provider_id: String,
1336 pub account_id: String,
1339 pub access_token: Option<String>,
1340 pub refresh_token: Option<String>,
1341 pub id_token: Option<String>,
1342 pub access_token_expires_at: Option<u64>,
1345 pub refresh_token_expires_at: Option<u64>,
1349 pub scope: Option<String>,
1350 pub password: Option<String>,
1354 pub created_at: u64,
1356 pub updated_at: u64,
1358}
1359
1360impl Account {
1361 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1365 let now = now_secs();
1366 Self {
1367 id: generate_token(),
1368 user_id,
1369 provider_id: info.provider.clone(),
1370 account_id: info.provider_account_id.clone(),
1371 access_token: Some(tokens.access_token.clone()),
1372 refresh_token: tokens.refresh_token.clone(),
1373 id_token: tokens.id_token.clone(),
1374 access_token_expires_at: tokens.expires_at,
1375 refresh_token_expires_at: None,
1376 scope: tokens.scope.clone(),
1377 password: None,
1378 created_at: now,
1379 updated_at: now,
1380 }
1381 }
1382
1383 pub fn access_token_expired(&self) -> bool {
1388 match self.access_token_expires_at {
1389 Some(ts) => now_secs() >= ts,
1390 None => false,
1391 }
1392 }
1393}
1394
1395pub trait AccountBackend: Send + Sync {
1398 fn upsert(&self, account: &Account);
1402 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1405 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1410 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1412}
1413
1414pub struct InMemoryAccountBackend {
1418 accounts: Mutex<HashMap<(String, String), Account>>,
1422}
1423
1424impl InMemoryAccountBackend {
1425 pub fn new() -> Self {
1426 Self {
1427 accounts: Mutex::new(HashMap::new()),
1428 }
1429 }
1430}
1431
1432impl Default for InMemoryAccountBackend {
1433 fn default() -> Self {
1434 Self::new()
1435 }
1436}
1437
1438impl AccountBackend for InMemoryAccountBackend {
1439 fn upsert(&self, account: &Account) {
1440 let key = (account.provider_id.clone(), account.account_id.clone());
1441 self.accounts.lock().unwrap().insert(key, account.clone());
1442 }
1443 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1444 self.accounts
1445 .lock()
1446 .unwrap()
1447 .get(&(provider_id.to_string(), account_id.to_string()))
1448 .cloned()
1449 }
1450 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1451 self.accounts
1452 .lock()
1453 .unwrap()
1454 .values()
1455 .filter(|a| a.user_id == user_id)
1456 .cloned()
1457 .collect()
1458 }
1459 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1460 self.accounts
1461 .lock()
1462 .unwrap()
1463 .remove(&(provider_id.to_string(), account_id.to_string()))
1464 .is_some()
1465 }
1466}
1467
1468pub struct AccountStore {
1471 backend: Box<dyn AccountBackend>,
1472}
1473
1474impl Default for AccountStore {
1475 fn default() -> Self {
1476 Self::new()
1477 }
1478}
1479
1480impl AccountStore {
1481 pub fn new() -> Self {
1482 Self {
1483 backend: Box::new(InMemoryAccountBackend::new()),
1484 }
1485 }
1486 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1487 Self { backend }
1488 }
1489 pub fn upsert(&self, account: &Account) {
1490 self.backend.upsert(account);
1491 }
1492 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1493 self.backend.find_by_provider(provider_id, account_id)
1494 }
1495 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1496 self.backend.find_for_user(user_id)
1497 }
1498 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1499 self.backend.unlink(provider_id, account_id)
1500 }
1501}
1502
1503#[cfg(test)]
1508mod tests {
1509 use super::*;
1510
1511 #[test]
1512 fn anonymous_context() {
1513 let ctx = AuthContext::anonymous();
1514 assert!(!ctx.is_authenticated());
1515 assert!(ctx.user_id.is_none());
1516 }
1517
1518 #[test]
1519 fn authenticated_context() {
1520 let ctx = AuthContext::authenticated("user-1".into());
1521 assert!(ctx.is_authenticated());
1522 assert_eq!(ctx.user_id, Some("user-1".into()));
1523 }
1524
1525 #[test]
1526 fn auth_mode_public_allows_anonymous() {
1527 let mode = AuthMode::Public;
1528 assert!(mode.check(&AuthContext::anonymous()));
1529 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1530 }
1531
1532 #[test]
1533 fn auth_mode_user_requires_authenticated() {
1534 let mode = AuthMode::User;
1535 assert!(!mode.check(&AuthContext::anonymous()));
1536 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1537 }
1538
1539 #[test]
1540 fn auth_mode_from_str() {
1541 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1542 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1543 assert_eq!(AuthMode::from_str("admin"), None);
1544 }
1545
1546 #[test]
1547 fn session_store_create_and_get() {
1548 let store = SessionStore::new();
1549 let session = store.create("user-1".into());
1550 assert!(!session.token.is_empty());
1551 assert!(session.token.starts_with("pylon_"));
1552
1553 let retrieved = store.get(&session.token).unwrap();
1554 assert_eq!(retrieved.user_id, "user-1");
1555 }
1556
1557 #[test]
1558 fn session_store_resolve() {
1559 let store = SessionStore::new();
1560 let session = store.create("user-1".into());
1561
1562 let ctx = store.resolve(Some(&session.token));
1563 assert!(ctx.is_authenticated());
1564 assert_eq!(ctx.user_id, Some("user-1".into()));
1565
1566 let anon = store.resolve(None);
1567 assert!(!anon.is_authenticated());
1568
1569 let bad = store.resolve(Some("invalid-token"));
1570 assert!(!bad.is_authenticated());
1571 }
1572
1573 #[test]
1574 fn session_store_revoke() {
1575 let store = SessionStore::new();
1576 let session = store.create("user-1".into());
1577
1578 assert!(store.revoke(&session.token));
1579 assert!(store.get(&session.token).is_none());
1580 assert!(!store.revoke(&session.token)); }
1582
1583 #[test]
1584 fn session_to_auth_context() {
1585 let session = Session::new("user-42".into());
1586 let ctx = session.to_auth_context();
1587 assert_eq!(ctx.user_id, Some("user-42".into()));
1588 }
1589
1590 #[test]
1593 fn admin_context() {
1594 let ctx = AuthContext::admin();
1595 assert!(ctx.is_admin);
1596 assert!(ctx.is_authenticated());
1597 }
1598
1599 #[test]
1600 fn anonymous_not_admin() {
1601 let ctx = AuthContext::anonymous();
1602 assert!(!ctx.is_admin);
1603 }
1604
1605 #[test]
1606 fn authenticated_not_admin() {
1607 let ctx = AuthContext::authenticated("user-1".into());
1608 assert!(!ctx.is_admin);
1609 }
1610
1611 #[test]
1614 fn magic_code_create_and_verify() {
1615 let store = MagicCodeStore::new();
1616 let code = store.create("test@example.com");
1617 assert_eq!(code.len(), 6);
1618 assert!(store.verify("test@example.com", &code));
1619 }
1620
1621 #[test]
1622 fn magic_code_wrong_code_rejected() {
1623 let store = MagicCodeStore::new();
1624 store.create("test@example.com");
1625 assert!(!store.verify("test@example.com", "000000"));
1626 }
1627
1628 #[test]
1629 fn magic_code_wrong_email_rejected() {
1630 let store = MagicCodeStore::new();
1631 let code = store.create("test@example.com");
1632 assert!(!store.verify("other@example.com", &code));
1633 }
1634
1635 #[test]
1636 fn magic_code_consumed_after_verify() {
1637 let store = MagicCodeStore::new();
1638 let code = store.create("test@example.com");
1639 assert!(store.verify("test@example.com", &code));
1640 assert!(!store.verify("test@example.com", &code));
1642 }
1643
1644 #[test]
1645 fn magic_code_different_emails_independent() {
1646 let store = MagicCodeStore::new();
1647 let code1 = store.create("alice@example.com");
1648 let code2 = store.create("bob@example.com");
1649 assert!(store.verify("alice@example.com", &code1));
1651 assert!(store.verify("bob@example.com", &code2));
1652 }
1653
1654 #[test]
1657 fn constant_time_eq_equal() {
1658 assert!(constant_time_eq(b"hello", b"hello"));
1659 assert!(constant_time_eq(b"", b""));
1660 }
1661
1662 #[test]
1663 fn constant_time_eq_not_equal() {
1664 assert!(!constant_time_eq(b"hello", b"world"));
1665 assert!(!constant_time_eq(b"hello", b"hell"));
1666 assert!(!constant_time_eq(b"a", b"b"));
1667 }
1668
1669 #[test]
1672 fn generated_tokens_are_unique() {
1673 let t1 = generate_token();
1674 let t2 = generate_token();
1675 assert_ne!(t1, t2);
1676 assert!(t1.starts_with("pylon_"));
1677 assert!(t2.starts_with("pylon_"));
1678 assert_eq!(t1.len(), 6 + 64);
1680 }
1681
1682 #[test]
1685 fn oauth_registry_empty() {
1686 let reg = OAuthRegistry::new();
1687 assert!(reg.get("google").is_none());
1688 }
1689
1690 #[test]
1691 fn oauth_registry_register_and_get() {
1692 let mut reg = OAuthRegistry::new();
1693 reg.register(OAuthConfig {
1694 provider: "google".into(),
1695 client_id: "test-id".into(),
1696 client_secret: "test-secret".into(),
1697 redirect_uri: "http://localhost/callback".into(),
1698 });
1699 let config = reg.get("google").unwrap();
1700 assert_eq!(config.client_id, "test-id");
1701 assert!(config.auth_url().contains("accounts.google.com"));
1702 }
1703
1704 #[test]
1707 fn guest_session() {
1708 let store = SessionStore::new();
1709 let session = store.create_guest();
1710 assert!(session.user_id.starts_with("guest_"));
1711 assert!(!session.token.is_empty());
1712
1713 let ctx = store.resolve(Some(&session.token));
1714 assert!(ctx.is_authenticated());
1715 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1716 }
1717
1718 #[test]
1719 fn upgrade_guest_to_real_user() {
1720 let store = SessionStore::new();
1721 let session = store.create_guest();
1722 assert!(session.user_id.starts_with("guest_"));
1723
1724 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1725 assert!(upgraded);
1726
1727 let ctx = store.resolve(Some(&session.token));
1728 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1729 }
1730
1731 #[test]
1732 fn upgrade_invalid_token_fails() {
1733 let store = SessionStore::new();
1734 let upgraded = store.upgrade("nonexistent-token", "user".into());
1735 assert!(!upgraded);
1736 }
1737
1738 #[test]
1739 fn guest_context() {
1740 let ctx = AuthContext::guest("guest_123".into());
1741 assert!(!ctx.is_authenticated());
1744 assert!(ctx.is_guest);
1745 assert!(!ctx.is_admin);
1746 assert_eq!(ctx.user_id, Some("guest_123".into()));
1747 assert!(!AuthMode::User.check(&ctx));
1748 assert!(AuthMode::Public.check(&ctx));
1749 }
1750
1751 #[test]
1752 fn oauth_token_urls() {
1753 let google = OAuthConfig {
1754 provider: "google".into(),
1755 client_id: "x".into(),
1756 client_secret: "x".into(),
1757 redirect_uri: "x".into(),
1758 };
1759 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1760 let github = OAuthConfig {
1761 provider: "github".into(),
1762 client_id: "x".into(),
1763 client_secret: "x".into(),
1764 redirect_uri: "x".into(),
1765 };
1766 assert_eq!(
1767 github.token_url(),
1768 "https://github.com/login/oauth/access_token"
1769 );
1770 let unknown = OAuthConfig {
1771 provider: "unknown".into(),
1772 client_id: "x".into(),
1773 client_secret: "x".into(),
1774 redirect_uri: "x".into(),
1775 };
1776 assert_eq!(unknown.token_url(), "");
1777 assert!(unknown.auth_url().is_empty());
1778 }
1779
1780 #[test]
1781 fn oauth_auth_url_github() {
1782 let config = OAuthConfig {
1783 provider: "github".into(),
1784 client_id: "gh-id".into(),
1785 client_secret: "gh-secret".into(),
1786 redirect_uri: "http://localhost/cb".into(),
1787 };
1788 assert!(config.auth_url().contains("github.com"));
1789 assert!(config.auth_url().contains("gh-id"));
1790 }
1791
1792 #[test]
1793 fn oauth_auth_url_with_state() {
1794 let config = OAuthConfig {
1795 provider: "google".into(),
1796 client_id: "test-id".into(),
1797 client_secret: "test-secret".into(),
1798 redirect_uri: "http://localhost/cb".into(),
1799 };
1800 let url = config.auth_url_with_state("random_state_123");
1801 assert!(url.contains("&state=random_state_123"));
1802 }
1803
1804 #[test]
1805 fn oauth_state_store_create_and_validate() {
1806 let store = OAuthStateStore::new();
1807 let state = store.create("google");
1808 assert!(store.validate(&state, "google"));
1809 assert!(!store.validate(&state, "google"));
1811 }
1812
1813 #[test]
1814 fn oauth_state_store_wrong_provider_rejected() {
1815 let store = OAuthStateStore::new();
1816 let state = store.create("google");
1817 assert!(!store.validate(&state, "github"));
1818 }
1819
1820 #[test]
1821 fn oauth_state_store_invalid_state_rejected() {
1822 let store = OAuthStateStore::new();
1823 assert!(!store.validate("nonexistent", "google"));
1824 }
1825}