1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25 pub user_id: Option<String>,
29 pub is_admin: bool,
31 #[serde(default, skip_serializing_if = "is_false")]
36 pub is_guest: bool,
37 pub roles: Vec<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46 !b
47}
48
49impl AuthContext {
50 pub fn anonymous() -> Self {
52 Self {
53 user_id: None,
54 is_admin: false,
55 is_guest: false,
56 roles: Vec::new(),
57 tenant_id: None,
58 }
59 }
60
61 pub fn authenticated(user_id: String) -> Self {
63 Self {
64 user_id: Some(user_id),
65 is_admin: false,
66 is_guest: false,
67 roles: Vec::new(),
68 tenant_id: None,
69 }
70 }
71
72 pub fn guest(guest_id: String) -> Self {
77 Self {
78 user_id: Some(guest_id),
79 is_admin: false,
80 is_guest: true,
81 roles: Vec::new(),
82 tenant_id: None,
83 }
84 }
85
86 pub fn admin() -> Self {
88 Self {
89 user_id: Some("__admin__".into()),
90 is_admin: true,
91 is_guest: false,
92 roles: vec!["admin".into()],
93 tenant_id: None,
94 }
95 }
96
97 pub fn user(user_id: String) -> Self {
99 Self::authenticated(user_id)
100 }
101
102 pub fn tenant_id(&self) -> Option<&str> {
104 self.tenant_id.as_deref()
105 }
106
107 pub fn with_tenant(mut self, tenant_id: String) -> Self {
109 self.tenant_id = Some(tenant_id);
110 self
111 }
112
113 pub fn is_authenticated(&self) -> bool {
117 self.user_id.is_some() && !self.is_guest
118 }
119
120 pub fn has_role(&self, role: &str) -> bool {
122 self.is_admin || self.roles.iter().any(|r| r == role)
123 }
124
125 pub fn has_any_role(&self, roles: &[&str]) -> bool {
127 self.is_admin || roles.iter().any(|r| self.has_role(r))
128 }
129
130 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132 self.roles = roles;
133 self
134 }
135}
136
137pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147 if a.len() != b.len() {
148 return false;
149 }
150 let mut result: u8 = 0;
151 for (x, y) in a.iter().zip(b.iter()) {
152 result |= x ^ y;
153 }
154 result == 0
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164 Public,
166 User,
168}
169
170impl AuthMode {
171 #[allow(clippy::should_implement_trait)]
173 pub fn from_str(s: &str) -> Option<Self> {
174 match s {
175 "public" => Some(AuthMode::Public),
176 "user" => Some(AuthMode::User),
177 _ => None,
178 }
179 }
180
181 pub fn check(&self, ctx: &AuthContext) -> bool {
183 match self {
184 AuthMode::Public => true,
185 AuthMode::User => ctx.is_authenticated(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197 pub token: String,
198 pub user_id: String,
199 #[serde(default)]
201 pub expires_at: u64,
202 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub device: Option<String>,
205 #[serde(default)]
207 pub created_at: u64,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub tenant_id: Option<String>,
213}
214
215impl Session {
216 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219 pub fn new(user_id: String) -> Self {
221 let now = now_secs();
222 Self {
223 token: generate_token(),
224 user_id,
225 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226 device: None,
227 created_at: now,
228 tenant_id: None,
229 }
230 }
231
232 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234 let now = now_secs();
235 Self {
236 token: generate_token(),
237 user_id,
238 expires_at: if lifetime_secs == 0 {
239 0
240 } else {
241 now.saturating_add(lifetime_secs)
242 },
243 device: None,
244 created_at: now,
245 tenant_id: None,
246 }
247 }
248
249 pub fn to_auth_context(&self) -> AuthContext {
252 let ctx = AuthContext::authenticated(self.user_id.clone());
253 match &self.tenant_id {
254 Some(t) => ctx.with_tenant(t.clone()),
255 None => ctx,
256 }
257 }
258
259 pub fn is_expired(&self) -> bool {
263 self.expires_at != 0 && now_secs() >= self.expires_at
264 }
265}
266
267fn now_secs() -> u64 {
268 use std::time::{SystemTime, UNIX_EPOCH};
269 SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281 pub provider: String,
282 pub client_id: String,
283 pub client_secret: String,
284 pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288 pub fn auth_url(&self) -> String {
294 match self.provider.as_str() {
295 "google" => format!(
296 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297 self.client_id, self.redirect_uri
298 ),
299 "github" => format!(
300 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301 self.client_id, self.redirect_uri
302 ),
303 _ => String::new(),
304 }
305 }
306
307 pub fn auth_url_with_state(&self, state: &str) -> String {
309 let base = self.auth_url();
310 if base.is_empty() {
311 return base;
312 }
313 format!("{}&state={}", base, state)
314 }
315
316 pub fn token_url(&self) -> &str {
318 match self.provider.as_str() {
319 "google" => "https://oauth2.googleapis.com/token",
320 "github" => "https://github.com/login/oauth/access_token",
321 _ => "",
322 }
323 }
324
325 pub fn userinfo_url(&self) -> &str {
327 match self.provider.as_str() {
328 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329 "github" => "https://api.github.com/user",
330 _ => "",
331 }
332 }
333
334 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341 let body = match self.provider.as_str() {
342 "google" => format!(
343 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344 url_encode(&self.client_id),
345 url_encode(&self.client_secret),
346 url_encode(&self.redirect_uri)
347 ),
348 "github" => format!(
349 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350 url_encode(&self.client_id),
351 url_encode(&self.client_secret),
352 url_encode(&self.redirect_uri)
353 ),
354 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355 };
356
357 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358 parse_token_response(&out)
359 }
360
361 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365 let body = match self.provider.as_str() {
366 "google" => format!(
367 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368 url_encode(&self.client_id),
369 url_encode(&self.client_secret),
370 url_encode(&self.redirect_uri)
371 ),
372 "github" => format!(
373 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374 url_encode(&self.client_id),
375 url_encode(&self.client_secret),
376 url_encode(&self.redirect_uri)
377 ),
378 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379 };
380
381 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382 extract_access_token(&out)
383 }
384
385 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391 let info = self.fetch_userinfo_full(access_token)?;
392 Ok((info.email, info.name))
393 }
394
395 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401 let out = http_get_bearer(self.userinfo_url(), access_token)?;
402 let parsed: serde_json::Value =
403 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404 match self.provider.as_str() {
405 "google" => {
406 let email = parsed
407 .get("email")
408 .and_then(|v| v.as_str())
409 .ok_or("no email in userinfo")?
410 .to_string();
411 let name = parsed
412 .get("name")
413 .and_then(|v| v.as_str())
414 .map(String::from);
415 let provider_account_id = parsed
416 .get("sub")
417 .and_then(|v| v.as_str())
418 .ok_or("no sub in userinfo")?
419 .to_string();
420 Ok(UserInfo {
421 provider: self.provider.clone(),
422 provider_account_id,
423 email,
424 name,
425 })
426 }
427 "github" => {
428 let name = parsed
429 .get("name")
430 .and_then(|v| v.as_str())
431 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432 .map(String::from);
433 let email = parsed
434 .get("email")
435 .and_then(|v| v.as_str())
436 .map(String::from);
437 let email = email
440 .or_else(|| fetch_github_primary_email(access_token).ok())
441 .ok_or("no accessible email on GitHub account")?;
442 let provider_account_id = parsed
445 .get("id")
446 .map(|v| {
447 v.as_i64()
448 .map(|n| n.to_string())
449 .or_else(|| v.as_str().map(String::from))
450 .unwrap_or_default()
451 })
452 .filter(|s| !s.is_empty())
453 .ok_or("no id in userinfo")?;
454 Ok(UserInfo {
455 provider: self.provider.clone(),
456 provider_account_id,
457 email,
458 name,
459 })
460 }
461 _ => Err(format!("unknown provider: {}", self.provider)),
462 }
463 }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472 pub provider: String,
473 pub provider_account_id: String,
474 pub email: String,
475 pub name: Option<String>,
476}
477
478#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483 pub access_token: String,
484 pub refresh_token: Option<String>,
485 pub id_token: Option<String>,
486 pub expires_at: Option<u64>,
490 pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497 let mut map = serde_json::Map::new();
499 for pair in body.split('&') {
500 if let Some((k, v)) = pair.split_once('=') {
501 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502 }
503 }
504 serde_json::Value::Object(map)
505 });
506
507 let access_token = json
508 .get("access_token")
509 .and_then(|v| v.as_str())
510 .ok_or_else(|| format!("no access_token in token response: {body}"))?
511 .to_string();
512 let refresh_token = json
513 .get("refresh_token")
514 .and_then(|v| v.as_str())
515 .map(String::from);
516 let id_token = json
517 .get("id_token")
518 .and_then(|v| v.as_str())
519 .map(String::from);
520 let expires_at = json
521 .get("expires_in")
522 .and_then(|v| {
523 v.as_u64()
524 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525 })
526 .map(|secs| now_secs().saturating_add(secs));
527 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528 Ok(TokenSet {
529 access_token,
530 refresh_token,
531 id_token,
532 expires_at,
533 scope,
534 })
535}
536
537fn url_encode(s: &str) -> String {
538 let mut out = String::with_capacity(s.len());
539 for b in s.bytes() {
540 match b {
541 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542 out.push(b as char)
543 }
544 _ => out.push_str(&format!("%{b:02X}")),
545 }
546 }
547 out
548}
549
550const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556 ureq::AgentBuilder::new()
557 .timeout_connect(HTTP_TIMEOUT)
558 .timeout_read(HTTP_TIMEOUT)
559 .timeout_write(HTTP_TIMEOUT)
560 .user_agent("pylon/0.1")
561 .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565 let agent = ureq_agent();
566 let mut req = agent
567 .post(url)
568 .set("Content-Type", "application/x-www-form-urlencoded");
569 if accept_json {
570 req = req.set("Accept", "application/json");
571 }
572 match req.send_string(body) {
573 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574 Err(ureq::Error::Status(code, resp)) => {
575 let body = resp.into_string().unwrap_or_default();
576 Err(format!("HTTP {code}: {body}"))
577 }
578 Err(e) => Err(format!("HTTP error: {e}")),
579 }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583 let agent = ureq_agent();
584 match agent
585 .get(url)
586 .set("Authorization", &format!("Bearer {token}"))
587 .set("Accept", "application/json")
588 .call()
589 {
590 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591 Err(ureq::Error::Status(code, resp)) => {
592 let body = resp.into_string().unwrap_or_default();
593 Err(format!("HTTP {code}: {body}"))
594 }
595 Err(e) => Err(format!("HTTP error: {e}")),
596 }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601 let emails: serde_json::Value =
602 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603 emails
604 .as_array()
605 .and_then(|arr| {
606 arr.iter()
607 .find(|e| {
608 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610 })
611 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612 })
613 .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619 return Ok(t.to_string());
620 }
621 }
622 for pair in body.split('&') {
624 if let Some(val) = pair.strip_prefix("access_token=") {
625 return Ok(val.to_string());
626 }
627 }
628 Err(format!("no access_token in token response: {body}"))
629}
630
631pub struct OAuthRegistry {
633 providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642impl OAuthRegistry {
643 pub fn new() -> Self {
644 Self {
645 providers: std::collections::HashMap::new(),
646 }
647 }
648
649 pub fn register(&mut self, config: OAuthConfig) {
650 self.providers.insert(config.provider.clone(), config);
651 }
652
653 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654 self.providers.get(provider)
655 }
656
657 pub fn from_env() -> Self {
660 let mut reg = Self::new();
661
662 if let (Ok(id), Ok(secret)) = (
664 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666 ) {
667 reg.register(OAuthConfig {
668 provider: "google".into(),
669 client_id: id,
670 client_secret: secret,
671 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673 });
674 }
675
676 if let (Ok(id), Ok(secret)) = (
678 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680 ) {
681 reg.register(OAuthConfig {
682 provider: "github".into(),
683 client_id: id,
684 client_secret: secret,
685 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687 });
688 }
689
690 reg
691 }
692}
693
694pub trait OAuthStateBackend: Send + Sync {
703 fn put(&self, token: &str, provider: &str, expires_at: u64);
704 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
708}
709
710pub struct InMemoryOAuthBackend {
712 states: Mutex<HashMap<String, OAuthState>>,
713}
714
715impl InMemoryOAuthBackend {
716 pub fn new() -> Self {
717 Self {
718 states: Mutex::new(HashMap::new()),
719 }
720 }
721}
722
723impl Default for InMemoryOAuthBackend {
724 fn default() -> Self {
725 Self::new()
726 }
727}
728
729impl OAuthStateBackend for InMemoryOAuthBackend {
730 fn put(&self, token: &str, provider: &str, expires_at: u64) {
731 self.states.lock().unwrap().insert(
732 token.to_string(),
733 OAuthState {
734 provider: provider.to_string(),
735 expires_at,
736 },
737 );
738 }
739 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
740 let mut s = self.states.lock().unwrap();
741 let entry = s.remove(token)?;
742 if entry.expires_at <= now_unix_secs {
743 return None;
744 }
745 Some(entry.provider)
746 }
747}
748
749pub struct OAuthStateStore {
755 backend: Box<dyn OAuthStateBackend>,
756}
757
758pub struct OAuthState {
759 pub provider: String,
760 pub expires_at: u64,
761}
762
763impl Default for OAuthStateStore {
764 fn default() -> Self {
765 Self::new()
766 }
767}
768
769impl OAuthStateStore {
770 pub fn new() -> Self {
771 Self {
772 backend: Box::new(InMemoryOAuthBackend::new()),
773 }
774 }
775
776 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
777 Self { backend }
778 }
779
780 pub fn create(&self, provider: &str) -> String {
782 use std::time::{SystemTime, UNIX_EPOCH};
783 let token = generate_token();
784 let now = SystemTime::now()
785 .duration_since(UNIX_EPOCH)
786 .unwrap_or_default()
787 .as_secs();
788 self.backend.put(&token, provider, now + 600);
789 token
790 }
791
792 pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
796 use std::time::{SystemTime, UNIX_EPOCH};
797 let now = SystemTime::now()
798 .duration_since(UNIX_EPOCH)
799 .unwrap_or_default()
800 .as_secs();
801 match self.backend.take(state, now) {
802 Some(provider) => provider == expected_provider,
803 None => false,
804 }
805 }
806}
807
808pub trait MagicCodeBackend: Send + Sync {
821 fn put(&self, email: &str, code: &MagicCode);
823 fn get(&self, email: &str) -> Option<MagicCode>;
825 fn remove(&self, email: &str);
828 fn bump_attempts(&self, email: &str);
832 fn load_all(&self) -> Vec<MagicCode>;
835}
836
837pub struct InMemoryMagicCodeBackend {
840 codes: Mutex<HashMap<String, MagicCode>>,
841}
842
843impl InMemoryMagicCodeBackend {
844 pub fn new() -> Self {
845 Self {
846 codes: Mutex::new(HashMap::new()),
847 }
848 }
849}
850
851impl Default for InMemoryMagicCodeBackend {
852 fn default() -> Self {
853 Self::new()
854 }
855}
856
857impl MagicCodeBackend for InMemoryMagicCodeBackend {
858 fn put(&self, email: &str, code: &MagicCode) {
859 self.codes
860 .lock()
861 .unwrap()
862 .insert(email.to_string(), code.clone());
863 }
864 fn get(&self, email: &str) -> Option<MagicCode> {
865 self.codes.lock().unwrap().get(email).cloned()
866 }
867 fn remove(&self, email: &str) {
868 self.codes.lock().unwrap().remove(email);
869 }
870 fn bump_attempts(&self, email: &str) {
871 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
872 c.attempts = c.attempts.saturating_add(1);
873 }
874 }
875 fn load_all(&self) -> Vec<MagicCode> {
876 self.codes.lock().unwrap().values().cloned().collect()
877 }
878}
879
880pub struct MagicCodeStore {
885 cache: Mutex<HashMap<String, MagicCode>>,
886 backend: Box<dyn MagicCodeBackend>,
887}
888
889#[derive(Debug, Clone)]
890pub struct MagicCode {
891 pub email: String,
892 pub code: String,
893 pub expires_at: u64,
894 pub attempts: u32,
897}
898
899const MAX_ATTEMPTS: u32 = 5;
903
904const CREATE_COOLDOWN_SECS: u64 = 60;
907
908#[derive(Debug, Clone, PartialEq, Eq)]
909pub enum MagicCodeError {
910 NotFound,
912 TooManyAttempts,
914 BadCode,
916 Expired,
918 Throttled { retry_after_secs: u64 },
920}
921
922impl Default for MagicCodeStore {
923 fn default() -> Self {
924 Self::new()
925 }
926}
927
928impl MagicCodeStore {
929 pub fn new() -> Self {
930 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
931 }
932
933 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
938 let now = now_secs();
939 let mut cache = HashMap::new();
940 for c in backend.load_all() {
941 if c.expires_at > now {
942 cache.insert(c.email.clone(), c);
943 }
944 }
945 Self {
946 cache: Mutex::new(cache),
947 backend,
948 }
949 }
950
951 pub fn create(&self, email: &str) -> String {
954 self.try_create(email).unwrap_or_else(|_| String::new())
957 }
958
959 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
962 let now = now_secs();
963
964 let mut codes = self.cache.lock().unwrap();
965
966 if let Some(existing) = codes.get(email) {
970 if existing.expires_at > now {
971 let created_at = existing.expires_at.saturating_sub(600);
972 let age = now.saturating_sub(created_at);
973 if age < CREATE_COOLDOWN_SECS {
974 return Err(MagicCodeError::Throttled {
975 retry_after_secs: CREATE_COOLDOWN_SECS - age,
976 });
977 }
978 }
979 }
980
981 let code = generate_magic_code();
982 let mc = MagicCode {
983 email: email.to_string(),
984 code: code.clone(),
985 expires_at: now + 600, attempts: 0,
987 };
988 codes.insert(email.to_string(), mc.clone());
989 self.backend.put(email, &mc);
993 Ok(code)
994 }
995
996 pub fn verify(&self, email: &str, code: &str) -> bool {
1000 matches!(self.try_verify(email, code), Ok(()))
1001 }
1002
1003 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1010 self.cache
1011 .lock()
1012 .map(|m| m.values().cloned().collect())
1013 .unwrap_or_default()
1014 }
1015
1016 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1017 let now = now_secs();
1018 let mut codes = self.cache.lock().unwrap();
1019
1020 let mc = match codes.get_mut(email) {
1021 Some(m) => m,
1022 None => return Err(MagicCodeError::NotFound),
1023 };
1024
1025 if mc.attempts >= MAX_ATTEMPTS {
1026 return Err(MagicCodeError::TooManyAttempts);
1027 }
1028 if mc.expires_at <= now {
1029 codes.remove(email);
1030 self.backend.remove(email);
1031 return Err(MagicCodeError::Expired);
1032 }
1033
1034 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1035 if !ok {
1036 mc.attempts += 1;
1037 self.backend.bump_attempts(email);
1038 if mc.attempts >= MAX_ATTEMPTS {
1040 return Err(MagicCodeError::TooManyAttempts);
1041 }
1042 return Err(MagicCodeError::BadCode);
1043 }
1044
1045 codes.remove(email);
1047 self.backend.remove(email);
1048 Ok(())
1049 }
1050}
1051
1052fn hex_encode(bytes: &[u8]) -> String {
1057 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1058}
1059
1060fn generate_magic_code() -> String {
1062 use rand::Rng;
1063 let mut rng = rand::thread_rng();
1064 let code: u32 = rng.gen_range(0..1_000_000);
1065 format!("{:06}", code)
1066}
1067
1068fn generate_token() -> String {
1070 use rand::Rng;
1071 let mut rng = rand::thread_rng();
1072 let bytes: [u8; 32] = rng.gen();
1073 format!("pylon_{}", hex_encode(&bytes))
1074}
1075
1076use std::collections::HashMap;
1081use std::sync::Mutex;
1082
1083pub trait SessionBackend: Send + Sync {
1087 fn load_all(&self) -> Vec<Session>;
1088 fn save(&self, session: &Session);
1089 fn remove(&self, token: &str);
1090}
1091
1092pub struct SessionStore {
1100 sessions: Mutex<HashMap<String, Session>>,
1101 backend: Option<Box<dyn SessionBackend>>,
1102}
1103
1104impl Default for SessionStore {
1105 fn default() -> Self {
1106 Self::new()
1107 }
1108}
1109
1110impl SessionStore {
1111 pub fn new() -> Self {
1112 Self {
1113 sessions: Mutex::new(HashMap::new()),
1114 backend: None,
1115 }
1116 }
1117
1118 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1122 let mut map = HashMap::new();
1123 for s in backend.load_all() {
1124 if !s.is_expired() {
1125 map.insert(s.token.clone(), s);
1126 }
1127 }
1128 Self {
1129 sessions: Mutex::new(map),
1130 backend: Some(backend),
1131 }
1132 }
1133
1134 pub fn create(&self, user_id: String) -> Session {
1136 let session = Session::new(user_id);
1137 let mut sessions = self.sessions.lock().unwrap();
1138 sessions.insert(session.token.clone(), session.clone());
1139 if let Some(b) = &self.backend {
1140 b.save(&session);
1141 }
1142 session
1143 }
1144
1145 pub fn get(&self, token: &str) -> Option<Session> {
1147 let mut sessions = self.sessions.lock().unwrap();
1148 match sessions.get(token) {
1149 Some(s) if s.is_expired() => {
1150 sessions.remove(token);
1151 None
1152 }
1153 Some(s) => Some(s.clone()),
1154 None => None,
1155 }
1156 }
1157
1158 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1161 match token {
1162 Some(t) => match self.get(t) {
1163 Some(session) => session.to_auth_context(),
1164 None => AuthContext::anonymous(),
1165 },
1166 None => AuthContext::anonymous(),
1167 }
1168 }
1169
1170 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1174 let mut sessions = self.sessions.lock().unwrap();
1175 let old = sessions.remove(old_token)?;
1176 if let Some(b) = &self.backend {
1177 b.remove(old_token);
1178 }
1179 if old.is_expired() {
1180 return None;
1181 }
1182 let mut new = Session::new(old.user_id.clone());
1183 new.device = old.device.clone();
1184 sessions.insert(new.token.clone(), new.clone());
1185 if let Some(b) = &self.backend {
1186 b.save(&new);
1187 }
1188 Some(new)
1189 }
1190
1191 pub fn list_all_unfiltered(&self) -> Vec<Session> {
1196 self.sessions
1197 .lock()
1198 .map(|m| m.values().cloned().collect())
1199 .unwrap_or_default()
1200 }
1201
1202 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1204 let sessions = self.sessions.lock().unwrap();
1205 sessions
1206 .values()
1207 .filter(|s| s.user_id == user_id && !s.is_expired())
1208 .cloned()
1209 .collect()
1210 }
1211
1212 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1214 let mut sessions = self.sessions.lock().unwrap();
1215 let tokens: Vec<String> = sessions
1216 .iter()
1217 .filter_map(|(t, s)| {
1218 if s.user_id == user_id {
1219 Some(t.clone())
1220 } else {
1221 None
1222 }
1223 })
1224 .collect();
1225 let n = tokens.len();
1226 for t in &tokens {
1227 sessions.remove(t);
1228 if let Some(b) = &self.backend {
1229 b.remove(t);
1230 }
1231 }
1232 n
1233 }
1234
1235 pub fn sweep_expired(&self) -> usize {
1237 let mut sessions = self.sessions.lock().unwrap();
1238 let expired: Vec<String> = sessions
1239 .iter()
1240 .filter_map(|(t, s)| {
1241 if s.is_expired() {
1242 Some(t.clone())
1243 } else {
1244 None
1245 }
1246 })
1247 .collect();
1248 let n = expired.len();
1249 for t in &expired {
1250 sessions.remove(t);
1251 if let Some(b) = &self.backend {
1252 b.remove(t);
1253 }
1254 }
1255 n
1256 }
1257
1258 pub fn set_device(&self, token: &str, device: String) -> bool {
1260 let mut sessions = self.sessions.lock().unwrap();
1261 if let Some(s) = sessions.get_mut(token) {
1262 s.device = Some(device);
1263 if let Some(b) = &self.backend {
1264 b.save(s);
1265 }
1266 true
1267 } else {
1268 false
1269 }
1270 }
1271
1272 pub fn create_guest(&self) -> Session {
1274 use rand::Rng;
1275 let mut rng = rand::thread_rng();
1276 let bytes: [u8; 16] = rng.gen();
1277 let guest_id = format!("guest_{}", hex_encode(&bytes));
1278 self.create(guest_id)
1279 }
1280
1281 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1283 let mut sessions = self.sessions.lock().unwrap();
1284 if let Some(session) = sessions.get_mut(token) {
1285 session.user_id = real_user_id;
1286 if let Some(b) = &self.backend {
1287 b.save(session);
1288 }
1289 true
1290 } else {
1291 false
1292 }
1293 }
1294
1295 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1300 let mut sessions = self.sessions.lock().unwrap();
1301 if let Some(session) = sessions.get_mut(token) {
1302 session.tenant_id = tenant_id;
1303 if let Some(b) = &self.backend {
1304 b.save(session);
1305 }
1306 true
1307 } else {
1308 false
1309 }
1310 }
1311
1312 pub fn revoke(&self, token: &str) -> bool {
1314 let mut sessions = self.sessions.lock().unwrap();
1315 let removed = sessions.remove(token).is_some();
1316 if removed {
1317 if let Some(b) = &self.backend {
1318 b.remove(token);
1319 }
1320 }
1321 removed
1322 }
1323}
1324
1325#[derive(Debug, Clone, PartialEq, Eq)]
1351pub struct Account {
1352 pub id: String,
1353 pub user_id: String,
1354 pub provider_id: String,
1357 pub account_id: String,
1360 pub access_token: Option<String>,
1361 pub refresh_token: Option<String>,
1362 pub id_token: Option<String>,
1363 pub access_token_expires_at: Option<u64>,
1366 pub refresh_token_expires_at: Option<u64>,
1370 pub scope: Option<String>,
1371 pub password: Option<String>,
1375 pub created_at: u64,
1377 pub updated_at: u64,
1379}
1380
1381impl Account {
1382 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1386 let now = now_secs();
1387 Self {
1388 id: generate_token(),
1389 user_id,
1390 provider_id: info.provider.clone(),
1391 account_id: info.provider_account_id.clone(),
1392 access_token: Some(tokens.access_token.clone()),
1393 refresh_token: tokens.refresh_token.clone(),
1394 id_token: tokens.id_token.clone(),
1395 access_token_expires_at: tokens.expires_at,
1396 refresh_token_expires_at: None,
1397 scope: tokens.scope.clone(),
1398 password: None,
1399 created_at: now,
1400 updated_at: now,
1401 }
1402 }
1403
1404 pub fn access_token_expired(&self) -> bool {
1409 match self.access_token_expires_at {
1410 Some(ts) => now_secs() >= ts,
1411 None => false,
1412 }
1413 }
1414}
1415
1416pub trait AccountBackend: Send + Sync {
1419 fn upsert(&self, account: &Account);
1423 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1426 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1431 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1433 fn list_all(&self) -> Vec<Account>;
1438}
1439
1440pub struct InMemoryAccountBackend {
1444 accounts: Mutex<HashMap<(String, String), Account>>,
1448}
1449
1450impl InMemoryAccountBackend {
1451 pub fn new() -> Self {
1452 Self {
1453 accounts: Mutex::new(HashMap::new()),
1454 }
1455 }
1456}
1457
1458impl Default for InMemoryAccountBackend {
1459 fn default() -> Self {
1460 Self::new()
1461 }
1462}
1463
1464impl AccountBackend for InMemoryAccountBackend {
1465 fn upsert(&self, account: &Account) {
1466 let key = (account.provider_id.clone(), account.account_id.clone());
1467 self.accounts.lock().unwrap().insert(key, account.clone());
1468 }
1469 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1470 self.accounts
1471 .lock()
1472 .unwrap()
1473 .get(&(provider_id.to_string(), account_id.to_string()))
1474 .cloned()
1475 }
1476 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1477 self.accounts
1478 .lock()
1479 .unwrap()
1480 .values()
1481 .filter(|a| a.user_id == user_id)
1482 .cloned()
1483 .collect()
1484 }
1485 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1486 self.accounts
1487 .lock()
1488 .unwrap()
1489 .remove(&(provider_id.to_string(), account_id.to_string()))
1490 .is_some()
1491 }
1492 fn list_all(&self) -> Vec<Account> {
1493 self.accounts.lock().unwrap().values().cloned().collect()
1494 }
1495}
1496
1497pub struct AccountStore {
1500 backend: Box<dyn AccountBackend>,
1501}
1502
1503impl Default for AccountStore {
1504 fn default() -> Self {
1505 Self::new()
1506 }
1507}
1508
1509impl AccountStore {
1510 pub fn new() -> Self {
1511 Self {
1512 backend: Box::new(InMemoryAccountBackend::new()),
1513 }
1514 }
1515 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1516 Self { backend }
1517 }
1518 pub fn upsert(&self, account: &Account) {
1519 self.backend.upsert(account);
1520 }
1521 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1522 self.backend.find_by_provider(provider_id, account_id)
1523 }
1524 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1525 self.backend.find_for_user(user_id)
1526 }
1527 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1528 self.backend.unlink(provider_id, account_id)
1529 }
1530
1531 pub fn list_all_unfiltered(&self) -> Vec<Account> {
1545 self.backend.list_all()
1546 }
1547}
1548
1549#[cfg(test)]
1554mod tests {
1555 use super::*;
1556
1557 #[test]
1558 fn anonymous_context() {
1559 let ctx = AuthContext::anonymous();
1560 assert!(!ctx.is_authenticated());
1561 assert!(ctx.user_id.is_none());
1562 }
1563
1564 #[test]
1565 fn authenticated_context() {
1566 let ctx = AuthContext::authenticated("user-1".into());
1567 assert!(ctx.is_authenticated());
1568 assert_eq!(ctx.user_id, Some("user-1".into()));
1569 }
1570
1571 #[test]
1572 fn auth_mode_public_allows_anonymous() {
1573 let mode = AuthMode::Public;
1574 assert!(mode.check(&AuthContext::anonymous()));
1575 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1576 }
1577
1578 #[test]
1579 fn auth_mode_user_requires_authenticated() {
1580 let mode = AuthMode::User;
1581 assert!(!mode.check(&AuthContext::anonymous()));
1582 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1583 }
1584
1585 #[test]
1586 fn auth_mode_from_str() {
1587 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1588 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1589 assert_eq!(AuthMode::from_str("admin"), None);
1590 }
1591
1592 #[test]
1593 fn session_store_create_and_get() {
1594 let store = SessionStore::new();
1595 let session = store.create("user-1".into());
1596 assert!(!session.token.is_empty());
1597 assert!(session.token.starts_with("pylon_"));
1598
1599 let retrieved = store.get(&session.token).unwrap();
1600 assert_eq!(retrieved.user_id, "user-1");
1601 }
1602
1603 #[test]
1604 fn session_store_resolve() {
1605 let store = SessionStore::new();
1606 let session = store.create("user-1".into());
1607
1608 let ctx = store.resolve(Some(&session.token));
1609 assert!(ctx.is_authenticated());
1610 assert_eq!(ctx.user_id, Some("user-1".into()));
1611
1612 let anon = store.resolve(None);
1613 assert!(!anon.is_authenticated());
1614
1615 let bad = store.resolve(Some("invalid-token"));
1616 assert!(!bad.is_authenticated());
1617 }
1618
1619 #[test]
1620 fn session_store_revoke() {
1621 let store = SessionStore::new();
1622 let session = store.create("user-1".into());
1623
1624 assert!(store.revoke(&session.token));
1625 assert!(store.get(&session.token).is_none());
1626 assert!(!store.revoke(&session.token)); }
1628
1629 #[test]
1630 fn session_to_auth_context() {
1631 let session = Session::new("user-42".into());
1632 let ctx = session.to_auth_context();
1633 assert_eq!(ctx.user_id, Some("user-42".into()));
1634 }
1635
1636 #[test]
1639 fn admin_context() {
1640 let ctx = AuthContext::admin();
1641 assert!(ctx.is_admin);
1642 assert!(ctx.is_authenticated());
1643 }
1644
1645 #[test]
1646 fn anonymous_not_admin() {
1647 let ctx = AuthContext::anonymous();
1648 assert!(!ctx.is_admin);
1649 }
1650
1651 #[test]
1652 fn authenticated_not_admin() {
1653 let ctx = AuthContext::authenticated("user-1".into());
1654 assert!(!ctx.is_admin);
1655 }
1656
1657 #[test]
1660 fn magic_code_create_and_verify() {
1661 let store = MagicCodeStore::new();
1662 let code = store.create("test@example.com");
1663 assert_eq!(code.len(), 6);
1664 assert!(store.verify("test@example.com", &code));
1665 }
1666
1667 #[test]
1668 fn magic_code_wrong_code_rejected() {
1669 let store = MagicCodeStore::new();
1670 store.create("test@example.com");
1671 assert!(!store.verify("test@example.com", "000000"));
1672 }
1673
1674 #[test]
1675 fn magic_code_wrong_email_rejected() {
1676 let store = MagicCodeStore::new();
1677 let code = store.create("test@example.com");
1678 assert!(!store.verify("other@example.com", &code));
1679 }
1680
1681 #[test]
1682 fn magic_code_consumed_after_verify() {
1683 let store = MagicCodeStore::new();
1684 let code = store.create("test@example.com");
1685 assert!(store.verify("test@example.com", &code));
1686 assert!(!store.verify("test@example.com", &code));
1688 }
1689
1690 #[test]
1691 fn magic_code_different_emails_independent() {
1692 let store = MagicCodeStore::new();
1693 let code1 = store.create("alice@example.com");
1694 let code2 = store.create("bob@example.com");
1695 assert!(store.verify("alice@example.com", &code1));
1697 assert!(store.verify("bob@example.com", &code2));
1698 }
1699
1700 #[test]
1703 fn constant_time_eq_equal() {
1704 assert!(constant_time_eq(b"hello", b"hello"));
1705 assert!(constant_time_eq(b"", b""));
1706 }
1707
1708 #[test]
1709 fn constant_time_eq_not_equal() {
1710 assert!(!constant_time_eq(b"hello", b"world"));
1711 assert!(!constant_time_eq(b"hello", b"hell"));
1712 assert!(!constant_time_eq(b"a", b"b"));
1713 }
1714
1715 #[test]
1718 fn generated_tokens_are_unique() {
1719 let t1 = generate_token();
1720 let t2 = generate_token();
1721 assert_ne!(t1, t2);
1722 assert!(t1.starts_with("pylon_"));
1723 assert!(t2.starts_with("pylon_"));
1724 assert_eq!(t1.len(), 6 + 64);
1726 }
1727
1728 #[test]
1731 fn oauth_registry_empty() {
1732 let reg = OAuthRegistry::new();
1733 assert!(reg.get("google").is_none());
1734 }
1735
1736 #[test]
1737 fn oauth_registry_register_and_get() {
1738 let mut reg = OAuthRegistry::new();
1739 reg.register(OAuthConfig {
1740 provider: "google".into(),
1741 client_id: "test-id".into(),
1742 client_secret: "test-secret".into(),
1743 redirect_uri: "http://localhost/callback".into(),
1744 });
1745 let config = reg.get("google").unwrap();
1746 assert_eq!(config.client_id, "test-id");
1747 assert!(config.auth_url().contains("accounts.google.com"));
1748 }
1749
1750 #[test]
1753 fn guest_session() {
1754 let store = SessionStore::new();
1755 let session = store.create_guest();
1756 assert!(session.user_id.starts_with("guest_"));
1757 assert!(!session.token.is_empty());
1758
1759 let ctx = store.resolve(Some(&session.token));
1760 assert!(ctx.is_authenticated());
1761 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1762 }
1763
1764 #[test]
1765 fn upgrade_guest_to_real_user() {
1766 let store = SessionStore::new();
1767 let session = store.create_guest();
1768 assert!(session.user_id.starts_with("guest_"));
1769
1770 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1771 assert!(upgraded);
1772
1773 let ctx = store.resolve(Some(&session.token));
1774 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1775 }
1776
1777 #[test]
1778 fn upgrade_invalid_token_fails() {
1779 let store = SessionStore::new();
1780 let upgraded = store.upgrade("nonexistent-token", "user".into());
1781 assert!(!upgraded);
1782 }
1783
1784 #[test]
1785 fn guest_context() {
1786 let ctx = AuthContext::guest("guest_123".into());
1787 assert!(!ctx.is_authenticated());
1790 assert!(ctx.is_guest);
1791 assert!(!ctx.is_admin);
1792 assert_eq!(ctx.user_id, Some("guest_123".into()));
1793 assert!(!AuthMode::User.check(&ctx));
1794 assert!(AuthMode::Public.check(&ctx));
1795 }
1796
1797 #[test]
1798 fn oauth_token_urls() {
1799 let google = OAuthConfig {
1800 provider: "google".into(),
1801 client_id: "x".into(),
1802 client_secret: "x".into(),
1803 redirect_uri: "x".into(),
1804 };
1805 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1806 let github = OAuthConfig {
1807 provider: "github".into(),
1808 client_id: "x".into(),
1809 client_secret: "x".into(),
1810 redirect_uri: "x".into(),
1811 };
1812 assert_eq!(
1813 github.token_url(),
1814 "https://github.com/login/oauth/access_token"
1815 );
1816 let unknown = OAuthConfig {
1817 provider: "unknown".into(),
1818 client_id: "x".into(),
1819 client_secret: "x".into(),
1820 redirect_uri: "x".into(),
1821 };
1822 assert_eq!(unknown.token_url(), "");
1823 assert!(unknown.auth_url().is_empty());
1824 }
1825
1826 #[test]
1827 fn oauth_auth_url_github() {
1828 let config = OAuthConfig {
1829 provider: "github".into(),
1830 client_id: "gh-id".into(),
1831 client_secret: "gh-secret".into(),
1832 redirect_uri: "http://localhost/cb".into(),
1833 };
1834 assert!(config.auth_url().contains("github.com"));
1835 assert!(config.auth_url().contains("gh-id"));
1836 }
1837
1838 #[test]
1839 fn oauth_auth_url_with_state() {
1840 let config = OAuthConfig {
1841 provider: "google".into(),
1842 client_id: "test-id".into(),
1843 client_secret: "test-secret".into(),
1844 redirect_uri: "http://localhost/cb".into(),
1845 };
1846 let url = config.auth_url_with_state("random_state_123");
1847 assert!(url.contains("&state=random_state_123"));
1848 }
1849
1850 #[test]
1851 fn oauth_state_store_create_and_validate() {
1852 let store = OAuthStateStore::new();
1853 let state = store.create("google");
1854 assert!(store.validate(&state, "google"));
1855 assert!(!store.validate(&state, "google"));
1857 }
1858
1859 #[test]
1860 fn oauth_state_store_wrong_provider_rejected() {
1861 let store = OAuthStateStore::new();
1862 let state = store.create("google");
1863 assert!(!store.validate(&state, "github"));
1864 }
1865
1866 #[test]
1867 fn oauth_state_store_invalid_state_rejected() {
1868 let store = OAuthStateStore::new();
1869 assert!(!store.validate("nonexistent", "google"));
1870 }
1871}