1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25 pub user_id: Option<String>,
29 pub is_admin: bool,
31 #[serde(default, skip_serializing_if = "is_false")]
36 pub is_guest: bool,
37 pub roles: Vec<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46 !b
47}
48
49impl AuthContext {
50 pub fn anonymous() -> Self {
52 Self {
53 user_id: None,
54 is_admin: false,
55 is_guest: false,
56 roles: Vec::new(),
57 tenant_id: None,
58 }
59 }
60
61 pub fn authenticated(user_id: String) -> Self {
63 Self {
64 user_id: Some(user_id),
65 is_admin: false,
66 is_guest: false,
67 roles: Vec::new(),
68 tenant_id: None,
69 }
70 }
71
72 pub fn guest(guest_id: String) -> Self {
77 Self {
78 user_id: Some(guest_id),
79 is_admin: false,
80 is_guest: true,
81 roles: Vec::new(),
82 tenant_id: None,
83 }
84 }
85
86 pub fn admin() -> Self {
88 Self {
89 user_id: Some("__admin__".into()),
90 is_admin: true,
91 is_guest: false,
92 roles: vec!["admin".into()],
93 tenant_id: None,
94 }
95 }
96
97 pub fn user(user_id: String) -> Self {
99 Self::authenticated(user_id)
100 }
101
102 pub fn tenant_id(&self) -> Option<&str> {
104 self.tenant_id.as_deref()
105 }
106
107 pub fn with_tenant(mut self, tenant_id: String) -> Self {
109 self.tenant_id = Some(tenant_id);
110 self
111 }
112
113 pub fn is_authenticated(&self) -> bool {
117 self.user_id.is_some() && !self.is_guest
118 }
119
120 pub fn has_role(&self, role: &str) -> bool {
122 self.is_admin || self.roles.iter().any(|r| r == role)
123 }
124
125 pub fn has_any_role(&self, roles: &[&str]) -> bool {
127 self.is_admin || roles.iter().any(|r| self.has_role(r))
128 }
129
130 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132 self.roles = roles;
133 self
134 }
135}
136
137pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147 if a.len() != b.len() {
148 return false;
149 }
150 let mut result: u8 = 0;
151 for (x, y) in a.iter().zip(b.iter()) {
152 result |= x ^ y;
153 }
154 result == 0
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164 Public,
166 User,
168}
169
170impl AuthMode {
171 #[allow(clippy::should_implement_trait)]
173 pub fn from_str(s: &str) -> Option<Self> {
174 match s {
175 "public" => Some(AuthMode::Public),
176 "user" => Some(AuthMode::User),
177 _ => None,
178 }
179 }
180
181 pub fn check(&self, ctx: &AuthContext) -> bool {
183 match self {
184 AuthMode::Public => true,
185 AuthMode::User => ctx.is_authenticated(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197 pub token: String,
198 pub user_id: String,
199 #[serde(default)]
201 pub expires_at: u64,
202 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub device: Option<String>,
205 #[serde(default)]
207 pub created_at: u64,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub tenant_id: Option<String>,
213}
214
215impl Session {
216 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219 pub fn new(user_id: String) -> Self {
221 let now = now_secs();
222 Self {
223 token: generate_token(),
224 user_id,
225 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226 device: None,
227 created_at: now,
228 tenant_id: None,
229 }
230 }
231
232 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234 let now = now_secs();
235 Self {
236 token: generate_token(),
237 user_id,
238 expires_at: if lifetime_secs == 0 {
239 0
240 } else {
241 now.saturating_add(lifetime_secs)
242 },
243 device: None,
244 created_at: now,
245 tenant_id: None,
246 }
247 }
248
249 pub fn to_auth_context(&self) -> AuthContext {
252 let ctx = AuthContext::authenticated(self.user_id.clone());
253 match &self.tenant_id {
254 Some(t) => ctx.with_tenant(t.clone()),
255 None => ctx,
256 }
257 }
258
259 pub fn is_expired(&self) -> bool {
263 self.expires_at != 0 && now_secs() >= self.expires_at
264 }
265}
266
267fn now_secs() -> u64 {
268 use std::time::{SystemTime, UNIX_EPOCH};
269 SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281 pub provider: String,
282 pub client_id: String,
283 pub client_secret: String,
284 pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288 pub fn auth_url(&self) -> String {
294 match self.provider.as_str() {
295 "google" => format!(
296 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297 self.client_id, self.redirect_uri
298 ),
299 "github" => format!(
300 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301 self.client_id, self.redirect_uri
302 ),
303 _ => String::new(),
304 }
305 }
306
307 pub fn auth_url_with_state(&self, state: &str) -> String {
309 let base = self.auth_url();
310 if base.is_empty() {
311 return base;
312 }
313 format!("{}&state={}", base, state)
314 }
315
316 pub fn token_url(&self) -> &str {
318 match self.provider.as_str() {
319 "google" => "https://oauth2.googleapis.com/token",
320 "github" => "https://github.com/login/oauth/access_token",
321 _ => "",
322 }
323 }
324
325 pub fn userinfo_url(&self) -> &str {
327 match self.provider.as_str() {
328 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329 "github" => "https://api.github.com/user",
330 _ => "",
331 }
332 }
333
334 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
340 let body = match self.provider.as_str() {
341 "google" => format!(
342 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
343 url_encode(&self.client_id),
344 url_encode(&self.client_secret),
345 url_encode(&self.redirect_uri)
346 ),
347 "github" => format!(
348 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
349 url_encode(&self.client_id),
350 url_encode(&self.client_secret),
351 url_encode(&self.redirect_uri)
352 ),
353 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
354 };
355
356 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
357 extract_access_token(&out)
358 }
359
360 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
363 let out = http_get_bearer(self.userinfo_url(), access_token)?;
364 let parsed: serde_json::Value =
365 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
366 match self.provider.as_str() {
367 "google" => {
368 let email = parsed
369 .get("email")
370 .and_then(|v| v.as_str())
371 .ok_or("no email in userinfo")?
372 .to_string();
373 let name = parsed
374 .get("name")
375 .and_then(|v| v.as_str())
376 .map(String::from);
377 Ok((email, name))
378 }
379 "github" => {
380 let name = parsed
381 .get("name")
382 .and_then(|v| v.as_str())
383 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
384 .map(String::from);
385 let email = parsed
386 .get("email")
387 .and_then(|v| v.as_str())
388 .map(String::from);
389 let email = email
392 .or_else(|| fetch_github_primary_email(access_token).ok())
393 .ok_or("no accessible email on GitHub account")?;
394 Ok((email, name))
395 }
396 _ => Err(format!("unknown provider: {}", self.provider)),
397 }
398 }
399}
400
401fn url_encode(s: &str) -> String {
402 let mut out = String::with_capacity(s.len());
403 for b in s.bytes() {
404 match b {
405 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
406 out.push(b as char)
407 }
408 _ => out.push_str(&format!("%{b:02X}")),
409 }
410 }
411 out
412}
413
414const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
418
419fn ureq_agent() -> ureq::Agent {
420 ureq::AgentBuilder::new()
421 .timeout_connect(HTTP_TIMEOUT)
422 .timeout_read(HTTP_TIMEOUT)
423 .timeout_write(HTTP_TIMEOUT)
424 .user_agent("pylon/0.1")
425 .build()
426}
427
428fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
429 let agent = ureq_agent();
430 let mut req = agent
431 .post(url)
432 .set("Content-Type", "application/x-www-form-urlencoded");
433 if accept_json {
434 req = req.set("Accept", "application/json");
435 }
436 match req.send_string(body) {
437 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
438 Err(ureq::Error::Status(code, resp)) => {
439 let body = resp.into_string().unwrap_or_default();
440 Err(format!("HTTP {code}: {body}"))
441 }
442 Err(e) => Err(format!("HTTP error: {e}")),
443 }
444}
445
446fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
447 let agent = ureq_agent();
448 match agent
449 .get(url)
450 .set("Authorization", &format!("Bearer {token}"))
451 .set("Accept", "application/json")
452 .call()
453 {
454 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
455 Err(ureq::Error::Status(code, resp)) => {
456 let body = resp.into_string().unwrap_or_default();
457 Err(format!("HTTP {code}: {body}"))
458 }
459 Err(e) => Err(format!("HTTP error: {e}")),
460 }
461}
462
463fn fetch_github_primary_email(token: &str) -> Result<String, String> {
464 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
465 let emails: serde_json::Value =
466 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
467 emails
468 .as_array()
469 .and_then(|arr| {
470 arr.iter()
471 .find(|e| {
472 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
473 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
474 })
475 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
476 })
477 .ok_or_else(|| "no primary verified email on GitHub".into())
478}
479
480fn extract_access_token(body: &str) -> Result<String, String> {
481 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
482 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
483 return Ok(t.to_string());
484 }
485 }
486 for pair in body.split('&') {
488 if let Some(val) = pair.strip_prefix("access_token=") {
489 return Ok(val.to_string());
490 }
491 }
492 Err(format!("no access_token in token response: {body}"))
493}
494
495pub struct OAuthRegistry {
497 providers: std::collections::HashMap<String, OAuthConfig>,
498}
499
500impl Default for OAuthRegistry {
501 fn default() -> Self {
502 Self::new()
503 }
504}
505
506impl OAuthRegistry {
507 pub fn new() -> Self {
508 Self {
509 providers: std::collections::HashMap::new(),
510 }
511 }
512
513 pub fn register(&mut self, config: OAuthConfig) {
514 self.providers.insert(config.provider.clone(), config);
515 }
516
517 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
518 self.providers.get(provider)
519 }
520
521 pub fn from_env() -> Self {
524 let mut reg = Self::new();
525
526 if let (Ok(id), Ok(secret)) = (
528 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
529 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
530 ) {
531 reg.register(OAuthConfig {
532 provider: "google".into(),
533 client_id: id,
534 client_secret: secret,
535 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
536 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
537 });
538 }
539
540 if let (Ok(id), Ok(secret)) = (
542 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
543 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
544 ) {
545 reg.register(OAuthConfig {
546 provider: "github".into(),
547 client_id: id,
548 client_secret: secret,
549 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
550 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
551 });
552 }
553
554 reg
555 }
556}
557
558pub trait OAuthStateBackend: Send + Sync {
567 fn put(&self, token: &str, provider: &str, expires_at: u64);
568 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
572}
573
574pub struct InMemoryOAuthBackend {
576 states: Mutex<HashMap<String, OAuthState>>,
577}
578
579impl InMemoryOAuthBackend {
580 pub fn new() -> Self {
581 Self {
582 states: Mutex::new(HashMap::new()),
583 }
584 }
585}
586
587impl Default for InMemoryOAuthBackend {
588 fn default() -> Self {
589 Self::new()
590 }
591}
592
593impl OAuthStateBackend for InMemoryOAuthBackend {
594 fn put(&self, token: &str, provider: &str, expires_at: u64) {
595 self.states.lock().unwrap().insert(
596 token.to_string(),
597 OAuthState {
598 provider: provider.to_string(),
599 expires_at,
600 },
601 );
602 }
603 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
604 let mut s = self.states.lock().unwrap();
605 let entry = s.remove(token)?;
606 if entry.expires_at <= now_unix_secs {
607 return None;
608 }
609 Some(entry.provider)
610 }
611}
612
613pub struct OAuthStateStore {
619 backend: Box<dyn OAuthStateBackend>,
620}
621
622pub struct OAuthState {
623 pub provider: String,
624 pub expires_at: u64,
625}
626
627impl Default for OAuthStateStore {
628 fn default() -> Self {
629 Self::new()
630 }
631}
632
633impl OAuthStateStore {
634 pub fn new() -> Self {
635 Self {
636 backend: Box::new(InMemoryOAuthBackend::new()),
637 }
638 }
639
640 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
641 Self { backend }
642 }
643
644 pub fn create(&self, provider: &str) -> String {
646 use std::time::{SystemTime, UNIX_EPOCH};
647 let token = generate_token();
648 let now = SystemTime::now()
649 .duration_since(UNIX_EPOCH)
650 .unwrap_or_default()
651 .as_secs();
652 self.backend.put(&token, provider, now + 600);
653 token
654 }
655
656 pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
660 use std::time::{SystemTime, UNIX_EPOCH};
661 let now = SystemTime::now()
662 .duration_since(UNIX_EPOCH)
663 .unwrap_or_default()
664 .as_secs();
665 match self.backend.take(state, now) {
666 Some(provider) => provider == expected_provider,
667 None => false,
668 }
669 }
670}
671
672pub struct MagicCodeStore {
678 codes: Mutex<HashMap<String, MagicCode>>,
679}
680
681#[derive(Debug, Clone)]
682pub struct MagicCode {
683 pub email: String,
684 pub code: String,
685 pub expires_at: u64,
686 pub attempts: u32,
689}
690
691const MAX_ATTEMPTS: u32 = 5;
695
696const CREATE_COOLDOWN_SECS: u64 = 60;
699
700#[derive(Debug, Clone, PartialEq, Eq)]
701pub enum MagicCodeError {
702 NotFound,
704 TooManyAttempts,
706 BadCode,
708 Expired,
710 Throttled { retry_after_secs: u64 },
712}
713
714impl Default for MagicCodeStore {
715 fn default() -> Self {
716 Self::new()
717 }
718}
719
720impl MagicCodeStore {
721 pub fn new() -> Self {
722 Self {
723 codes: Mutex::new(HashMap::new()),
724 }
725 }
726
727 pub fn create(&self, email: &str) -> String {
730 self.try_create(email).unwrap_or_else(|_| String::new())
733 }
734
735 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
738 let now = now_secs();
739
740 let mut codes = self.codes.lock().unwrap();
741
742 if let Some(existing) = codes.get(email) {
746 if existing.expires_at > now {
747 let created_at = existing.expires_at.saturating_sub(600);
748 let age = now.saturating_sub(created_at);
749 if age < CREATE_COOLDOWN_SECS {
750 return Err(MagicCodeError::Throttled {
751 retry_after_secs: CREATE_COOLDOWN_SECS - age,
752 });
753 }
754 }
755 }
756
757 let code = generate_magic_code();
758 let mc = MagicCode {
759 email: email.to_string(),
760 code: code.clone(),
761 expires_at: now + 600, attempts: 0,
763 };
764 codes.insert(email.to_string(), mc);
765 Ok(code)
766 }
767
768 pub fn verify(&self, email: &str, code: &str) -> bool {
772 matches!(self.try_verify(email, code), Ok(()))
773 }
774
775 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
779 let now = now_secs();
780 let mut codes = self.codes.lock().unwrap();
781
782 let mc = match codes.get_mut(email) {
783 Some(m) => m,
784 None => return Err(MagicCodeError::NotFound),
785 };
786
787 if mc.attempts >= MAX_ATTEMPTS {
788 return Err(MagicCodeError::TooManyAttempts);
789 }
790 if mc.expires_at <= now {
791 codes.remove(email);
792 return Err(MagicCodeError::Expired);
793 }
794
795 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
796 if !ok {
797 mc.attempts += 1;
798 if mc.attempts >= MAX_ATTEMPTS {
800 return Err(MagicCodeError::TooManyAttempts);
801 }
802 return Err(MagicCodeError::BadCode);
803 }
804
805 codes.remove(email);
807 Ok(())
808 }
809}
810
811fn hex_encode(bytes: &[u8]) -> String {
816 bytes.iter().map(|b| format!("{:02x}", b)).collect()
817}
818
819fn generate_magic_code() -> String {
821 use rand::Rng;
822 let mut rng = rand::thread_rng();
823 let code: u32 = rng.gen_range(0..1_000_000);
824 format!("{:06}", code)
825}
826
827fn generate_token() -> String {
829 use rand::Rng;
830 let mut rng = rand::thread_rng();
831 let bytes: [u8; 32] = rng.gen();
832 format!("pylon_{}", hex_encode(&bytes))
833}
834
835use std::collections::HashMap;
840use std::sync::Mutex;
841
842pub trait SessionBackend: Send + Sync {
846 fn load_all(&self) -> Vec<Session>;
847 fn save(&self, session: &Session);
848 fn remove(&self, token: &str);
849}
850
851pub struct SessionStore {
859 sessions: Mutex<HashMap<String, Session>>,
860 backend: Option<Box<dyn SessionBackend>>,
861}
862
863impl Default for SessionStore {
864 fn default() -> Self {
865 Self::new()
866 }
867}
868
869impl SessionStore {
870 pub fn new() -> Self {
871 Self {
872 sessions: Mutex::new(HashMap::new()),
873 backend: None,
874 }
875 }
876
877 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
881 let mut map = HashMap::new();
882 for s in backend.load_all() {
883 if !s.is_expired() {
884 map.insert(s.token.clone(), s);
885 }
886 }
887 Self {
888 sessions: Mutex::new(map),
889 backend: Some(backend),
890 }
891 }
892
893 pub fn create(&self, user_id: String) -> Session {
895 let session = Session::new(user_id);
896 let mut sessions = self.sessions.lock().unwrap();
897 sessions.insert(session.token.clone(), session.clone());
898 if let Some(b) = &self.backend {
899 b.save(&session);
900 }
901 session
902 }
903
904 pub fn get(&self, token: &str) -> Option<Session> {
906 let mut sessions = self.sessions.lock().unwrap();
907 match sessions.get(token) {
908 Some(s) if s.is_expired() => {
909 sessions.remove(token);
910 None
911 }
912 Some(s) => Some(s.clone()),
913 None => None,
914 }
915 }
916
917 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
920 match token {
921 Some(t) => match self.get(t) {
922 Some(session) => session.to_auth_context(),
923 None => AuthContext::anonymous(),
924 },
925 None => AuthContext::anonymous(),
926 }
927 }
928
929 pub fn refresh(&self, old_token: &str) -> Option<Session> {
933 let mut sessions = self.sessions.lock().unwrap();
934 let old = sessions.remove(old_token)?;
935 if let Some(b) = &self.backend {
936 b.remove(old_token);
937 }
938 if old.is_expired() {
939 return None;
940 }
941 let mut new = Session::new(old.user_id.clone());
942 new.device = old.device.clone();
943 sessions.insert(new.token.clone(), new.clone());
944 if let Some(b) = &self.backend {
945 b.save(&new);
946 }
947 Some(new)
948 }
949
950 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
952 let sessions = self.sessions.lock().unwrap();
953 sessions
954 .values()
955 .filter(|s| s.user_id == user_id && !s.is_expired())
956 .cloned()
957 .collect()
958 }
959
960 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
962 let mut sessions = self.sessions.lock().unwrap();
963 let tokens: Vec<String> = sessions
964 .iter()
965 .filter_map(|(t, s)| {
966 if s.user_id == user_id {
967 Some(t.clone())
968 } else {
969 None
970 }
971 })
972 .collect();
973 let n = tokens.len();
974 for t in &tokens {
975 sessions.remove(t);
976 if let Some(b) = &self.backend {
977 b.remove(t);
978 }
979 }
980 n
981 }
982
983 pub fn sweep_expired(&self) -> usize {
985 let mut sessions = self.sessions.lock().unwrap();
986 let expired: Vec<String> = sessions
987 .iter()
988 .filter_map(|(t, s)| {
989 if s.is_expired() {
990 Some(t.clone())
991 } else {
992 None
993 }
994 })
995 .collect();
996 let n = expired.len();
997 for t in &expired {
998 sessions.remove(t);
999 if let Some(b) = &self.backend {
1000 b.remove(t);
1001 }
1002 }
1003 n
1004 }
1005
1006 pub fn set_device(&self, token: &str, device: String) -> bool {
1008 let mut sessions = self.sessions.lock().unwrap();
1009 if let Some(s) = sessions.get_mut(token) {
1010 s.device = Some(device);
1011 if let Some(b) = &self.backend {
1012 b.save(s);
1013 }
1014 true
1015 } else {
1016 false
1017 }
1018 }
1019
1020 pub fn create_guest(&self) -> Session {
1022 use rand::Rng;
1023 let mut rng = rand::thread_rng();
1024 let bytes: [u8; 16] = rng.gen();
1025 let guest_id = format!("guest_{}", hex_encode(&bytes));
1026 self.create(guest_id)
1027 }
1028
1029 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1031 let mut sessions = self.sessions.lock().unwrap();
1032 if let Some(session) = sessions.get_mut(token) {
1033 session.user_id = real_user_id;
1034 if let Some(b) = &self.backend {
1035 b.save(session);
1036 }
1037 true
1038 } else {
1039 false
1040 }
1041 }
1042
1043 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1048 let mut sessions = self.sessions.lock().unwrap();
1049 if let Some(session) = sessions.get_mut(token) {
1050 session.tenant_id = tenant_id;
1051 if let Some(b) = &self.backend {
1052 b.save(session);
1053 }
1054 true
1055 } else {
1056 false
1057 }
1058 }
1059
1060 pub fn revoke(&self, token: &str) -> bool {
1062 let mut sessions = self.sessions.lock().unwrap();
1063 let removed = sessions.remove(token).is_some();
1064 if removed {
1065 if let Some(b) = &self.backend {
1066 b.remove(token);
1067 }
1068 }
1069 removed
1070 }
1071}
1072
1073#[cfg(test)]
1078mod tests {
1079 use super::*;
1080
1081 #[test]
1082 fn anonymous_context() {
1083 let ctx = AuthContext::anonymous();
1084 assert!(!ctx.is_authenticated());
1085 assert!(ctx.user_id.is_none());
1086 }
1087
1088 #[test]
1089 fn authenticated_context() {
1090 let ctx = AuthContext::authenticated("user-1".into());
1091 assert!(ctx.is_authenticated());
1092 assert_eq!(ctx.user_id, Some("user-1".into()));
1093 }
1094
1095 #[test]
1096 fn auth_mode_public_allows_anonymous() {
1097 let mode = AuthMode::Public;
1098 assert!(mode.check(&AuthContext::anonymous()));
1099 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1100 }
1101
1102 #[test]
1103 fn auth_mode_user_requires_authenticated() {
1104 let mode = AuthMode::User;
1105 assert!(!mode.check(&AuthContext::anonymous()));
1106 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1107 }
1108
1109 #[test]
1110 fn auth_mode_from_str() {
1111 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1112 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1113 assert_eq!(AuthMode::from_str("admin"), None);
1114 }
1115
1116 #[test]
1117 fn session_store_create_and_get() {
1118 let store = SessionStore::new();
1119 let session = store.create("user-1".into());
1120 assert!(!session.token.is_empty());
1121 assert!(session.token.starts_with("pylon_"));
1122
1123 let retrieved = store.get(&session.token).unwrap();
1124 assert_eq!(retrieved.user_id, "user-1");
1125 }
1126
1127 #[test]
1128 fn session_store_resolve() {
1129 let store = SessionStore::new();
1130 let session = store.create("user-1".into());
1131
1132 let ctx = store.resolve(Some(&session.token));
1133 assert!(ctx.is_authenticated());
1134 assert_eq!(ctx.user_id, Some("user-1".into()));
1135
1136 let anon = store.resolve(None);
1137 assert!(!anon.is_authenticated());
1138
1139 let bad = store.resolve(Some("invalid-token"));
1140 assert!(!bad.is_authenticated());
1141 }
1142
1143 #[test]
1144 fn session_store_revoke() {
1145 let store = SessionStore::new();
1146 let session = store.create("user-1".into());
1147
1148 assert!(store.revoke(&session.token));
1149 assert!(store.get(&session.token).is_none());
1150 assert!(!store.revoke(&session.token)); }
1152
1153 #[test]
1154 fn session_to_auth_context() {
1155 let session = Session::new("user-42".into());
1156 let ctx = session.to_auth_context();
1157 assert_eq!(ctx.user_id, Some("user-42".into()));
1158 }
1159
1160 #[test]
1163 fn admin_context() {
1164 let ctx = AuthContext::admin();
1165 assert!(ctx.is_admin);
1166 assert!(ctx.is_authenticated());
1167 }
1168
1169 #[test]
1170 fn anonymous_not_admin() {
1171 let ctx = AuthContext::anonymous();
1172 assert!(!ctx.is_admin);
1173 }
1174
1175 #[test]
1176 fn authenticated_not_admin() {
1177 let ctx = AuthContext::authenticated("user-1".into());
1178 assert!(!ctx.is_admin);
1179 }
1180
1181 #[test]
1184 fn magic_code_create_and_verify() {
1185 let store = MagicCodeStore::new();
1186 let code = store.create("test@example.com");
1187 assert_eq!(code.len(), 6);
1188 assert!(store.verify("test@example.com", &code));
1189 }
1190
1191 #[test]
1192 fn magic_code_wrong_code_rejected() {
1193 let store = MagicCodeStore::new();
1194 store.create("test@example.com");
1195 assert!(!store.verify("test@example.com", "000000"));
1196 }
1197
1198 #[test]
1199 fn magic_code_wrong_email_rejected() {
1200 let store = MagicCodeStore::new();
1201 let code = store.create("test@example.com");
1202 assert!(!store.verify("other@example.com", &code));
1203 }
1204
1205 #[test]
1206 fn magic_code_consumed_after_verify() {
1207 let store = MagicCodeStore::new();
1208 let code = store.create("test@example.com");
1209 assert!(store.verify("test@example.com", &code));
1210 assert!(!store.verify("test@example.com", &code));
1212 }
1213
1214 #[test]
1215 fn magic_code_different_emails_independent() {
1216 let store = MagicCodeStore::new();
1217 let code1 = store.create("alice@example.com");
1218 let code2 = store.create("bob@example.com");
1219 assert!(store.verify("alice@example.com", &code1));
1221 assert!(store.verify("bob@example.com", &code2));
1222 }
1223
1224 #[test]
1227 fn constant_time_eq_equal() {
1228 assert!(constant_time_eq(b"hello", b"hello"));
1229 assert!(constant_time_eq(b"", b""));
1230 }
1231
1232 #[test]
1233 fn constant_time_eq_not_equal() {
1234 assert!(!constant_time_eq(b"hello", b"world"));
1235 assert!(!constant_time_eq(b"hello", b"hell"));
1236 assert!(!constant_time_eq(b"a", b"b"));
1237 }
1238
1239 #[test]
1242 fn generated_tokens_are_unique() {
1243 let t1 = generate_token();
1244 let t2 = generate_token();
1245 assert_ne!(t1, t2);
1246 assert!(t1.starts_with("pylon_"));
1247 assert!(t2.starts_with("pylon_"));
1248 assert_eq!(t1.len(), 6 + 64);
1250 }
1251
1252 #[test]
1255 fn oauth_registry_empty() {
1256 let reg = OAuthRegistry::new();
1257 assert!(reg.get("google").is_none());
1258 }
1259
1260 #[test]
1261 fn oauth_registry_register_and_get() {
1262 let mut reg = OAuthRegistry::new();
1263 reg.register(OAuthConfig {
1264 provider: "google".into(),
1265 client_id: "test-id".into(),
1266 client_secret: "test-secret".into(),
1267 redirect_uri: "http://localhost/callback".into(),
1268 });
1269 let config = reg.get("google").unwrap();
1270 assert_eq!(config.client_id, "test-id");
1271 assert!(config.auth_url().contains("accounts.google.com"));
1272 }
1273
1274 #[test]
1277 fn guest_session() {
1278 let store = SessionStore::new();
1279 let session = store.create_guest();
1280 assert!(session.user_id.starts_with("guest_"));
1281 assert!(!session.token.is_empty());
1282
1283 let ctx = store.resolve(Some(&session.token));
1284 assert!(ctx.is_authenticated());
1285 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1286 }
1287
1288 #[test]
1289 fn upgrade_guest_to_real_user() {
1290 let store = SessionStore::new();
1291 let session = store.create_guest();
1292 assert!(session.user_id.starts_with("guest_"));
1293
1294 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1295 assert!(upgraded);
1296
1297 let ctx = store.resolve(Some(&session.token));
1298 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1299 }
1300
1301 #[test]
1302 fn upgrade_invalid_token_fails() {
1303 let store = SessionStore::new();
1304 let upgraded = store.upgrade("nonexistent-token", "user".into());
1305 assert!(!upgraded);
1306 }
1307
1308 #[test]
1309 fn guest_context() {
1310 let ctx = AuthContext::guest("guest_123".into());
1311 assert!(!ctx.is_authenticated());
1314 assert!(ctx.is_guest);
1315 assert!(!ctx.is_admin);
1316 assert_eq!(ctx.user_id, Some("guest_123".into()));
1317 assert!(!AuthMode::User.check(&ctx));
1318 assert!(AuthMode::Public.check(&ctx));
1319 }
1320
1321 #[test]
1322 fn oauth_token_urls() {
1323 let google = OAuthConfig {
1324 provider: "google".into(),
1325 client_id: "x".into(),
1326 client_secret: "x".into(),
1327 redirect_uri: "x".into(),
1328 };
1329 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1330 let github = OAuthConfig {
1331 provider: "github".into(),
1332 client_id: "x".into(),
1333 client_secret: "x".into(),
1334 redirect_uri: "x".into(),
1335 };
1336 assert_eq!(
1337 github.token_url(),
1338 "https://github.com/login/oauth/access_token"
1339 );
1340 let unknown = OAuthConfig {
1341 provider: "unknown".into(),
1342 client_id: "x".into(),
1343 client_secret: "x".into(),
1344 redirect_uri: "x".into(),
1345 };
1346 assert_eq!(unknown.token_url(), "");
1347 assert!(unknown.auth_url().is_empty());
1348 }
1349
1350 #[test]
1351 fn oauth_auth_url_github() {
1352 let config = OAuthConfig {
1353 provider: "github".into(),
1354 client_id: "gh-id".into(),
1355 client_secret: "gh-secret".into(),
1356 redirect_uri: "http://localhost/cb".into(),
1357 };
1358 assert!(config.auth_url().contains("github.com"));
1359 assert!(config.auth_url().contains("gh-id"));
1360 }
1361
1362 #[test]
1363 fn oauth_auth_url_with_state() {
1364 let config = OAuthConfig {
1365 provider: "google".into(),
1366 client_id: "test-id".into(),
1367 client_secret: "test-secret".into(),
1368 redirect_uri: "http://localhost/cb".into(),
1369 };
1370 let url = config.auth_url_with_state("random_state_123");
1371 assert!(url.contains("&state=random_state_123"));
1372 }
1373
1374 #[test]
1375 fn oauth_state_store_create_and_validate() {
1376 let store = OAuthStateStore::new();
1377 let state = store.create("google");
1378 assert!(store.validate(&state, "google"));
1379 assert!(!store.validate(&state, "google"));
1381 }
1382
1383 #[test]
1384 fn oauth_state_store_wrong_provider_rejected() {
1385 let store = OAuthStateStore::new();
1386 let state = store.create("google");
1387 assert!(!store.validate(&state, "github"));
1388 }
1389
1390 #[test]
1391 fn oauth_state_store_invalid_state_rejected() {
1392 let store = OAuthStateStore::new();
1393 assert!(!store.validate("nonexistent", "google"));
1394 }
1395}