1pub mod email;
2pub mod password;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
21pub struct AuthContext {
22 pub user_id: Option<String>,
24 pub is_admin: bool,
26 pub roles: Vec<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
31 pub tenant_id: Option<String>,
32}
33
34impl AuthContext {
35 pub fn anonymous() -> Self {
37 Self {
38 user_id: None,
39 is_admin: false,
40 roles: Vec::new(),
41 tenant_id: None,
42 }
43 }
44
45 pub fn authenticated(user_id: String) -> Self {
47 Self {
48 user_id: Some(user_id),
49 is_admin: false,
50 roles: Vec::new(),
51 tenant_id: None,
52 }
53 }
54
55 pub fn guest(guest_id: String) -> Self {
57 Self {
58 user_id: Some(guest_id),
59 is_admin: false,
60 roles: Vec::new(),
61 tenant_id: None,
62 }
63 }
64
65 pub fn admin() -> Self {
67 Self {
68 user_id: Some("__admin__".into()),
69 is_admin: true,
70 roles: vec!["admin".into()],
71 tenant_id: None,
72 }
73 }
74
75 pub fn user(user_id: String) -> Self {
77 Self::authenticated(user_id)
78 }
79
80 pub fn tenant_id(&self) -> Option<&str> {
82 self.tenant_id.as_deref()
83 }
84
85 pub fn with_tenant(mut self, tenant_id: String) -> Self {
87 self.tenant_id = Some(tenant_id);
88 self
89 }
90
91 pub fn is_authenticated(&self) -> bool {
93 self.user_id.is_some()
94 }
95
96 pub fn has_role(&self, role: &str) -> bool {
98 self.is_admin || self.roles.iter().any(|r| r == role)
99 }
100
101 pub fn has_any_role(&self, roles: &[&str]) -> bool {
103 self.is_admin || roles.iter().any(|r| self.has_role(r))
104 }
105
106 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
108 self.roles = roles;
109 self
110 }
111}
112
113pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
123 if a.len() != b.len() {
124 return false;
125 }
126 let mut result: u8 = 0;
127 for (x, y) in a.iter().zip(b.iter()) {
128 result |= x ^ y;
129 }
130 result == 0
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
139pub enum AuthMode {
140 Public,
142 User,
144}
145
146impl AuthMode {
147 #[allow(clippy::should_implement_trait)]
149 pub fn from_str(s: &str) -> Option<Self> {
150 match s {
151 "public" => Some(AuthMode::Public),
152 "user" => Some(AuthMode::User),
153 _ => None,
154 }
155 }
156
157 pub fn check(&self, ctx: &AuthContext) -> bool {
159 match self {
160 AuthMode::Public => true,
161 AuthMode::User => ctx.is_authenticated(),
162 }
163 }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
172pub struct Session {
173 pub token: String,
174 pub user_id: String,
175 #[serde(default)]
177 pub expires_at: u64,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub device: Option<String>,
181 #[serde(default)]
183 pub created_at: u64,
184 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub tenant_id: Option<String>,
189}
190
191impl Session {
192 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
194
195 pub fn new(user_id: String) -> Self {
197 let now = now_secs();
198 Self {
199 token: generate_token(),
200 user_id,
201 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
202 device: None,
203 created_at: now,
204 tenant_id: None,
205 }
206 }
207
208 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
210 let now = now_secs();
211 Self {
212 token: generate_token(),
213 user_id,
214 expires_at: if lifetime_secs == 0 {
215 0
216 } else {
217 now.saturating_add(lifetime_secs)
218 },
219 device: None,
220 created_at: now,
221 tenant_id: None,
222 }
223 }
224
225 pub fn to_auth_context(&self) -> AuthContext {
228 let ctx = AuthContext::authenticated(self.user_id.clone());
229 match &self.tenant_id {
230 Some(t) => ctx.with_tenant(t.clone()),
231 None => ctx,
232 }
233 }
234
235 pub fn is_expired(&self) -> bool {
237 self.expires_at != 0 && now_secs() > self.expires_at
238 }
239}
240
241fn now_secs() -> u64 {
242 use std::time::{SystemTime, UNIX_EPOCH};
243 SystemTime::now()
244 .duration_since(UNIX_EPOCH)
245 .unwrap_or_default()
246 .as_secs()
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct OAuthConfig {
255 pub provider: String,
256 pub client_id: String,
257 pub client_secret: String,
258 pub redirect_uri: String,
259}
260
261impl OAuthConfig {
262 pub fn auth_url(&self) -> String {
268 match self.provider.as_str() {
269 "google" => format!(
270 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
271 self.client_id, self.redirect_uri
272 ),
273 "github" => format!(
274 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
275 self.client_id, self.redirect_uri
276 ),
277 _ => String::new(),
278 }
279 }
280
281 pub fn auth_url_with_state(&self, state: &str) -> String {
283 let base = self.auth_url();
284 if base.is_empty() {
285 return base;
286 }
287 format!("{}&state={}", base, state)
288 }
289
290 pub fn token_url(&self) -> &str {
292 match self.provider.as_str() {
293 "google" => "https://oauth2.googleapis.com/token",
294 "github" => "https://github.com/login/oauth/access_token",
295 _ => "",
296 }
297 }
298
299 pub fn userinfo_url(&self) -> &str {
301 match self.provider.as_str() {
302 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
303 "github" => "https://api.github.com/user",
304 _ => "",
305 }
306 }
307
308 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
314 let body = match self.provider.as_str() {
315 "google" => format!(
316 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
317 url_encode(&self.client_id),
318 url_encode(&self.client_secret),
319 url_encode(&self.redirect_uri)
320 ),
321 "github" => format!(
322 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
323 url_encode(&self.client_id),
324 url_encode(&self.client_secret),
325 url_encode(&self.redirect_uri)
326 ),
327 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
328 };
329
330 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
331 extract_access_token(&out)
332 }
333
334 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
337 let out = http_get_bearer(self.userinfo_url(), access_token)?;
338 let parsed: serde_json::Value =
339 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
340 match self.provider.as_str() {
341 "google" => {
342 let email = parsed
343 .get("email")
344 .and_then(|v| v.as_str())
345 .ok_or("no email in userinfo")?
346 .to_string();
347 let name = parsed
348 .get("name")
349 .and_then(|v| v.as_str())
350 .map(String::from);
351 Ok((email, name))
352 }
353 "github" => {
354 let name = parsed
355 .get("name")
356 .and_then(|v| v.as_str())
357 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
358 .map(String::from);
359 let email = parsed
360 .get("email")
361 .and_then(|v| v.as_str())
362 .map(String::from);
363 let email = email
366 .or_else(|| fetch_github_primary_email(access_token).ok())
367 .ok_or("no accessible email on GitHub account")?;
368 Ok((email, name))
369 }
370 _ => Err(format!("unknown provider: {}", self.provider)),
371 }
372 }
373}
374
375fn url_encode(s: &str) -> String {
376 let mut out = String::with_capacity(s.len());
377 for b in s.bytes() {
378 match b {
379 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
380 out.push(b as char)
381 }
382 _ => out.push_str(&format!("%{b:02X}")),
383 }
384 }
385 out
386}
387
388const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
392
393fn ureq_agent() -> ureq::Agent {
394 ureq::AgentBuilder::new()
395 .timeout_connect(HTTP_TIMEOUT)
396 .timeout_read(HTTP_TIMEOUT)
397 .timeout_write(HTTP_TIMEOUT)
398 .user_agent("pylon/0.1")
399 .build()
400}
401
402fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
403 let agent = ureq_agent();
404 let mut req = agent
405 .post(url)
406 .set("Content-Type", "application/x-www-form-urlencoded");
407 if accept_json {
408 req = req.set("Accept", "application/json");
409 }
410 match req.send_string(body) {
411 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
412 Err(ureq::Error::Status(code, resp)) => {
413 let body = resp.into_string().unwrap_or_default();
414 Err(format!("HTTP {code}: {body}"))
415 }
416 Err(e) => Err(format!("HTTP error: {e}")),
417 }
418}
419
420fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
421 let agent = ureq_agent();
422 match agent
423 .get(url)
424 .set("Authorization", &format!("Bearer {token}"))
425 .set("Accept", "application/json")
426 .call()
427 {
428 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
429 Err(ureq::Error::Status(code, resp)) => {
430 let body = resp.into_string().unwrap_or_default();
431 Err(format!("HTTP {code}: {body}"))
432 }
433 Err(e) => Err(format!("HTTP error: {e}")),
434 }
435}
436
437fn fetch_github_primary_email(token: &str) -> Result<String, String> {
438 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
439 let emails: serde_json::Value =
440 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
441 emails
442 .as_array()
443 .and_then(|arr| {
444 arr.iter()
445 .find(|e| {
446 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
447 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
448 })
449 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
450 })
451 .ok_or_else(|| "no primary verified email on GitHub".into())
452}
453
454fn extract_access_token(body: &str) -> Result<String, String> {
455 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
456 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
457 return Ok(t.to_string());
458 }
459 }
460 for pair in body.split('&') {
462 if let Some(val) = pair.strip_prefix("access_token=") {
463 return Ok(val.to_string());
464 }
465 }
466 Err(format!("no access_token in token response: {body}"))
467}
468
469pub struct OAuthRegistry {
471 providers: std::collections::HashMap<String, OAuthConfig>,
472}
473
474impl Default for OAuthRegistry {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl OAuthRegistry {
481 pub fn new() -> Self {
482 Self {
483 providers: std::collections::HashMap::new(),
484 }
485 }
486
487 pub fn register(&mut self, config: OAuthConfig) {
488 self.providers.insert(config.provider.clone(), config);
489 }
490
491 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
492 self.providers.get(provider)
493 }
494
495 pub fn from_env() -> Self {
498 let mut reg = Self::new();
499
500 if let (Ok(id), Ok(secret)) = (
502 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
503 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
504 ) {
505 reg.register(OAuthConfig {
506 provider: "google".into(),
507 client_id: id,
508 client_secret: secret,
509 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
510 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
511 });
512 }
513
514 if let (Ok(id), Ok(secret)) = (
516 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
517 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
518 ) {
519 reg.register(OAuthConfig {
520 provider: "github".into(),
521 client_id: id,
522 client_secret: secret,
523 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
524 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
525 });
526 }
527
528 reg
529 }
530}
531
532pub trait OAuthStateBackend: Send + Sync {
541 fn put(&self, token: &str, provider: &str, expires_at: u64);
542 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String>;
546}
547
548pub struct InMemoryOAuthBackend {
550 states: Mutex<HashMap<String, OAuthState>>,
551}
552
553impl InMemoryOAuthBackend {
554 pub fn new() -> Self {
555 Self {
556 states: Mutex::new(HashMap::new()),
557 }
558 }
559}
560
561impl Default for InMemoryOAuthBackend {
562 fn default() -> Self {
563 Self::new()
564 }
565}
566
567impl OAuthStateBackend for InMemoryOAuthBackend {
568 fn put(&self, token: &str, provider: &str, expires_at: u64) {
569 self.states.lock().unwrap().insert(
570 token.to_string(),
571 OAuthState {
572 provider: provider.to_string(),
573 expires_at,
574 },
575 );
576 }
577 fn take(&self, token: &str, now_unix_secs: u64) -> Option<String> {
578 let mut s = self.states.lock().unwrap();
579 let entry = s.remove(token)?;
580 if entry.expires_at <= now_unix_secs {
581 return None;
582 }
583 Some(entry.provider)
584 }
585}
586
587pub struct OAuthStateStore {
593 backend: Box<dyn OAuthStateBackend>,
594}
595
596pub struct OAuthState {
597 pub provider: String,
598 pub expires_at: u64,
599}
600
601impl Default for OAuthStateStore {
602 fn default() -> Self {
603 Self::new()
604 }
605}
606
607impl OAuthStateStore {
608 pub fn new() -> Self {
609 Self {
610 backend: Box::new(InMemoryOAuthBackend::new()),
611 }
612 }
613
614 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
615 Self { backend }
616 }
617
618 pub fn create(&self, provider: &str) -> String {
620 use std::time::{SystemTime, UNIX_EPOCH};
621 let token = generate_token();
622 let now = SystemTime::now()
623 .duration_since(UNIX_EPOCH)
624 .unwrap_or_default()
625 .as_secs();
626 self.backend.put(&token, provider, now + 600);
627 token
628 }
629
630 pub fn validate(&self, state: &str, expected_provider: &str) -> bool {
634 use std::time::{SystemTime, UNIX_EPOCH};
635 let now = SystemTime::now()
636 .duration_since(UNIX_EPOCH)
637 .unwrap_or_default()
638 .as_secs();
639 match self.backend.take(state, now) {
640 Some(provider) => provider == expected_provider,
641 None => false,
642 }
643 }
644}
645
646pub struct MagicCodeStore {
652 codes: Mutex<HashMap<String, MagicCode>>,
653}
654
655#[derive(Debug, Clone)]
656pub struct MagicCode {
657 pub email: String,
658 pub code: String,
659 pub expires_at: u64,
660 pub attempts: u32,
663}
664
665const MAX_ATTEMPTS: u32 = 5;
669
670const CREATE_COOLDOWN_SECS: u64 = 60;
673
674#[derive(Debug, Clone, PartialEq, Eq)]
675pub enum MagicCodeError {
676 NotFound,
678 TooManyAttempts,
680 BadCode,
682 Expired,
684 Throttled { retry_after_secs: u64 },
686}
687
688impl Default for MagicCodeStore {
689 fn default() -> Self {
690 Self::new()
691 }
692}
693
694impl MagicCodeStore {
695 pub fn new() -> Self {
696 Self {
697 codes: Mutex::new(HashMap::new()),
698 }
699 }
700
701 pub fn create(&self, email: &str) -> String {
704 self.try_create(email).unwrap_or_else(|_| String::new())
707 }
708
709 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
712 let now = now_secs();
713
714 let mut codes = self.codes.lock().unwrap();
715
716 if let Some(existing) = codes.get(email) {
720 if existing.expires_at > now {
721 let created_at = existing.expires_at.saturating_sub(600);
722 let age = now.saturating_sub(created_at);
723 if age < CREATE_COOLDOWN_SECS {
724 return Err(MagicCodeError::Throttled {
725 retry_after_secs: CREATE_COOLDOWN_SECS - age,
726 });
727 }
728 }
729 }
730
731 let code = generate_magic_code();
732 let mc = MagicCode {
733 email: email.to_string(),
734 code: code.clone(),
735 expires_at: now + 600, attempts: 0,
737 };
738 codes.insert(email.to_string(), mc);
739 Ok(code)
740 }
741
742 pub fn verify(&self, email: &str, code: &str) -> bool {
746 matches!(self.try_verify(email, code), Ok(()))
747 }
748
749 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
753 let now = now_secs();
754 let mut codes = self.codes.lock().unwrap();
755
756 let mc = match codes.get_mut(email) {
757 Some(m) => m,
758 None => return Err(MagicCodeError::NotFound),
759 };
760
761 if mc.attempts >= MAX_ATTEMPTS {
762 return Err(MagicCodeError::TooManyAttempts);
763 }
764 if mc.expires_at <= now {
765 codes.remove(email);
766 return Err(MagicCodeError::Expired);
767 }
768
769 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
770 if !ok {
771 mc.attempts += 1;
772 if mc.attempts >= MAX_ATTEMPTS {
774 return Err(MagicCodeError::TooManyAttempts);
775 }
776 return Err(MagicCodeError::BadCode);
777 }
778
779 codes.remove(email);
781 Ok(())
782 }
783}
784
785fn hex_encode(bytes: &[u8]) -> String {
790 bytes.iter().map(|b| format!("{:02x}", b)).collect()
791}
792
793fn generate_magic_code() -> String {
795 use rand::Rng;
796 let mut rng = rand::thread_rng();
797 let code: u32 = rng.gen_range(0..1_000_000);
798 format!("{:06}", code)
799}
800
801fn generate_token() -> String {
803 use rand::Rng;
804 let mut rng = rand::thread_rng();
805 let bytes: [u8; 32] = rng.gen();
806 format!("pylon_{}", hex_encode(&bytes))
807}
808
809use std::collections::HashMap;
814use std::sync::Mutex;
815
816pub trait SessionBackend: Send + Sync {
820 fn load_all(&self) -> Vec<Session>;
821 fn save(&self, session: &Session);
822 fn remove(&self, token: &str);
823}
824
825pub struct SessionStore {
833 sessions: Mutex<HashMap<String, Session>>,
834 backend: Option<Box<dyn SessionBackend>>,
835}
836
837impl Default for SessionStore {
838 fn default() -> Self {
839 Self::new()
840 }
841}
842
843impl SessionStore {
844 pub fn new() -> Self {
845 Self {
846 sessions: Mutex::new(HashMap::new()),
847 backend: None,
848 }
849 }
850
851 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
855 let mut map = HashMap::new();
856 for s in backend.load_all() {
857 if !s.is_expired() {
858 map.insert(s.token.clone(), s);
859 }
860 }
861 Self {
862 sessions: Mutex::new(map),
863 backend: Some(backend),
864 }
865 }
866
867 pub fn create(&self, user_id: String) -> Session {
869 let session = Session::new(user_id);
870 let mut sessions = self.sessions.lock().unwrap();
871 sessions.insert(session.token.clone(), session.clone());
872 if let Some(b) = &self.backend {
873 b.save(&session);
874 }
875 session
876 }
877
878 pub fn get(&self, token: &str) -> Option<Session> {
880 let mut sessions = self.sessions.lock().unwrap();
881 match sessions.get(token) {
882 Some(s) if s.is_expired() => {
883 sessions.remove(token);
884 None
885 }
886 Some(s) => Some(s.clone()),
887 None => None,
888 }
889 }
890
891 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
894 match token {
895 Some(t) => match self.get(t) {
896 Some(session) => session.to_auth_context(),
897 None => AuthContext::anonymous(),
898 },
899 None => AuthContext::anonymous(),
900 }
901 }
902
903 pub fn refresh(&self, old_token: &str) -> Option<Session> {
907 let mut sessions = self.sessions.lock().unwrap();
908 let old = sessions.remove(old_token)?;
909 if let Some(b) = &self.backend {
910 b.remove(old_token);
911 }
912 if old.is_expired() {
913 return None;
914 }
915 let mut new = Session::new(old.user_id.clone());
916 new.device = old.device.clone();
917 sessions.insert(new.token.clone(), new.clone());
918 if let Some(b) = &self.backend {
919 b.save(&new);
920 }
921 Some(new)
922 }
923
924 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
926 let sessions = self.sessions.lock().unwrap();
927 sessions
928 .values()
929 .filter(|s| s.user_id == user_id && !s.is_expired())
930 .cloned()
931 .collect()
932 }
933
934 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
936 let mut sessions = self.sessions.lock().unwrap();
937 let tokens: Vec<String> = sessions
938 .iter()
939 .filter_map(|(t, s)| {
940 if s.user_id == user_id {
941 Some(t.clone())
942 } else {
943 None
944 }
945 })
946 .collect();
947 let n = tokens.len();
948 for t in &tokens {
949 sessions.remove(t);
950 if let Some(b) = &self.backend {
951 b.remove(t);
952 }
953 }
954 n
955 }
956
957 pub fn sweep_expired(&self) -> usize {
959 let mut sessions = self.sessions.lock().unwrap();
960 let expired: Vec<String> = sessions
961 .iter()
962 .filter_map(|(t, s)| {
963 if s.is_expired() {
964 Some(t.clone())
965 } else {
966 None
967 }
968 })
969 .collect();
970 let n = expired.len();
971 for t in &expired {
972 sessions.remove(t);
973 if let Some(b) = &self.backend {
974 b.remove(t);
975 }
976 }
977 n
978 }
979
980 pub fn set_device(&self, token: &str, device: String) -> bool {
982 let mut sessions = self.sessions.lock().unwrap();
983 if let Some(s) = sessions.get_mut(token) {
984 s.device = Some(device);
985 if let Some(b) = &self.backend {
986 b.save(s);
987 }
988 true
989 } else {
990 false
991 }
992 }
993
994 pub fn create_guest(&self) -> Session {
996 use rand::Rng;
997 let mut rng = rand::thread_rng();
998 let bytes: [u8; 16] = rng.gen();
999 let guest_id = format!("guest_{}", hex_encode(&bytes));
1000 self.create(guest_id)
1001 }
1002
1003 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1005 let mut sessions = self.sessions.lock().unwrap();
1006 if let Some(session) = sessions.get_mut(token) {
1007 session.user_id = real_user_id;
1008 if let Some(b) = &self.backend {
1009 b.save(session);
1010 }
1011 true
1012 } else {
1013 false
1014 }
1015 }
1016
1017 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1022 let mut sessions = self.sessions.lock().unwrap();
1023 if let Some(session) = sessions.get_mut(token) {
1024 session.tenant_id = tenant_id;
1025 if let Some(b) = &self.backend {
1026 b.save(session);
1027 }
1028 true
1029 } else {
1030 false
1031 }
1032 }
1033
1034 pub fn revoke(&self, token: &str) -> bool {
1036 let mut sessions = self.sessions.lock().unwrap();
1037 let removed = sessions.remove(token).is_some();
1038 if removed {
1039 if let Some(b) = &self.backend {
1040 b.remove(token);
1041 }
1042 }
1043 removed
1044 }
1045}
1046
1047#[cfg(test)]
1052mod tests {
1053 use super::*;
1054
1055 #[test]
1056 fn anonymous_context() {
1057 let ctx = AuthContext::anonymous();
1058 assert!(!ctx.is_authenticated());
1059 assert!(ctx.user_id.is_none());
1060 }
1061
1062 #[test]
1063 fn authenticated_context() {
1064 let ctx = AuthContext::authenticated("user-1".into());
1065 assert!(ctx.is_authenticated());
1066 assert_eq!(ctx.user_id, Some("user-1".into()));
1067 }
1068
1069 #[test]
1070 fn auth_mode_public_allows_anonymous() {
1071 let mode = AuthMode::Public;
1072 assert!(mode.check(&AuthContext::anonymous()));
1073 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1074 }
1075
1076 #[test]
1077 fn auth_mode_user_requires_authenticated() {
1078 let mode = AuthMode::User;
1079 assert!(!mode.check(&AuthContext::anonymous()));
1080 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1081 }
1082
1083 #[test]
1084 fn auth_mode_from_str() {
1085 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1086 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1087 assert_eq!(AuthMode::from_str("admin"), None);
1088 }
1089
1090 #[test]
1091 fn session_store_create_and_get() {
1092 let store = SessionStore::new();
1093 let session = store.create("user-1".into());
1094 assert!(!session.token.is_empty());
1095 assert!(session.token.starts_with("pylon_"));
1096
1097 let retrieved = store.get(&session.token).unwrap();
1098 assert_eq!(retrieved.user_id, "user-1");
1099 }
1100
1101 #[test]
1102 fn session_store_resolve() {
1103 let store = SessionStore::new();
1104 let session = store.create("user-1".into());
1105
1106 let ctx = store.resolve(Some(&session.token));
1107 assert!(ctx.is_authenticated());
1108 assert_eq!(ctx.user_id, Some("user-1".into()));
1109
1110 let anon = store.resolve(None);
1111 assert!(!anon.is_authenticated());
1112
1113 let bad = store.resolve(Some("invalid-token"));
1114 assert!(!bad.is_authenticated());
1115 }
1116
1117 #[test]
1118 fn session_store_revoke() {
1119 let store = SessionStore::new();
1120 let session = store.create("user-1".into());
1121
1122 assert!(store.revoke(&session.token));
1123 assert!(store.get(&session.token).is_none());
1124 assert!(!store.revoke(&session.token)); }
1126
1127 #[test]
1128 fn session_to_auth_context() {
1129 let session = Session::new("user-42".into());
1130 let ctx = session.to_auth_context();
1131 assert_eq!(ctx.user_id, Some("user-42".into()));
1132 }
1133
1134 #[test]
1137 fn admin_context() {
1138 let ctx = AuthContext::admin();
1139 assert!(ctx.is_admin);
1140 assert!(ctx.is_authenticated());
1141 }
1142
1143 #[test]
1144 fn anonymous_not_admin() {
1145 let ctx = AuthContext::anonymous();
1146 assert!(!ctx.is_admin);
1147 }
1148
1149 #[test]
1150 fn authenticated_not_admin() {
1151 let ctx = AuthContext::authenticated("user-1".into());
1152 assert!(!ctx.is_admin);
1153 }
1154
1155 #[test]
1158 fn magic_code_create_and_verify() {
1159 let store = MagicCodeStore::new();
1160 let code = store.create("test@example.com");
1161 assert_eq!(code.len(), 6);
1162 assert!(store.verify("test@example.com", &code));
1163 }
1164
1165 #[test]
1166 fn magic_code_wrong_code_rejected() {
1167 let store = MagicCodeStore::new();
1168 store.create("test@example.com");
1169 assert!(!store.verify("test@example.com", "000000"));
1170 }
1171
1172 #[test]
1173 fn magic_code_wrong_email_rejected() {
1174 let store = MagicCodeStore::new();
1175 let code = store.create("test@example.com");
1176 assert!(!store.verify("other@example.com", &code));
1177 }
1178
1179 #[test]
1180 fn magic_code_consumed_after_verify() {
1181 let store = MagicCodeStore::new();
1182 let code = store.create("test@example.com");
1183 assert!(store.verify("test@example.com", &code));
1184 assert!(!store.verify("test@example.com", &code));
1186 }
1187
1188 #[test]
1189 fn magic_code_different_emails_independent() {
1190 let store = MagicCodeStore::new();
1191 let code1 = store.create("alice@example.com");
1192 let code2 = store.create("bob@example.com");
1193 assert!(store.verify("alice@example.com", &code1));
1195 assert!(store.verify("bob@example.com", &code2));
1196 }
1197
1198 #[test]
1201 fn constant_time_eq_equal() {
1202 assert!(constant_time_eq(b"hello", b"hello"));
1203 assert!(constant_time_eq(b"", b""));
1204 }
1205
1206 #[test]
1207 fn constant_time_eq_not_equal() {
1208 assert!(!constant_time_eq(b"hello", b"world"));
1209 assert!(!constant_time_eq(b"hello", b"hell"));
1210 assert!(!constant_time_eq(b"a", b"b"));
1211 }
1212
1213 #[test]
1216 fn generated_tokens_are_unique() {
1217 let t1 = generate_token();
1218 let t2 = generate_token();
1219 assert_ne!(t1, t2);
1220 assert!(t1.starts_with("pylon_"));
1221 assert!(t2.starts_with("pylon_"));
1222 assert_eq!(t1.len(), 6 + 64);
1224 }
1225
1226 #[test]
1229 fn oauth_registry_empty() {
1230 let reg = OAuthRegistry::new();
1231 assert!(reg.get("google").is_none());
1232 }
1233
1234 #[test]
1235 fn oauth_registry_register_and_get() {
1236 let mut reg = OAuthRegistry::new();
1237 reg.register(OAuthConfig {
1238 provider: "google".into(),
1239 client_id: "test-id".into(),
1240 client_secret: "test-secret".into(),
1241 redirect_uri: "http://localhost/callback".into(),
1242 });
1243 let config = reg.get("google").unwrap();
1244 assert_eq!(config.client_id, "test-id");
1245 assert!(config.auth_url().contains("accounts.google.com"));
1246 }
1247
1248 #[test]
1251 fn guest_session() {
1252 let store = SessionStore::new();
1253 let session = store.create_guest();
1254 assert!(session.user_id.starts_with("guest_"));
1255 assert!(!session.token.is_empty());
1256
1257 let ctx = store.resolve(Some(&session.token));
1258 assert!(ctx.is_authenticated());
1259 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1260 }
1261
1262 #[test]
1263 fn upgrade_guest_to_real_user() {
1264 let store = SessionStore::new();
1265 let session = store.create_guest();
1266 assert!(session.user_id.starts_with("guest_"));
1267
1268 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1269 assert!(upgraded);
1270
1271 let ctx = store.resolve(Some(&session.token));
1272 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1273 }
1274
1275 #[test]
1276 fn upgrade_invalid_token_fails() {
1277 let store = SessionStore::new();
1278 let upgraded = store.upgrade("nonexistent-token", "user".into());
1279 assert!(!upgraded);
1280 }
1281
1282 #[test]
1283 fn guest_context() {
1284 let ctx = AuthContext::guest("guest_123".into());
1285 assert!(ctx.is_authenticated());
1286 assert!(!ctx.is_admin);
1287 assert_eq!(ctx.user_id, Some("guest_123".into()));
1288 }
1289
1290 #[test]
1291 fn oauth_token_urls() {
1292 let google = OAuthConfig {
1293 provider: "google".into(),
1294 client_id: "x".into(),
1295 client_secret: "x".into(),
1296 redirect_uri: "x".into(),
1297 };
1298 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1299 let github = OAuthConfig {
1300 provider: "github".into(),
1301 client_id: "x".into(),
1302 client_secret: "x".into(),
1303 redirect_uri: "x".into(),
1304 };
1305 assert_eq!(
1306 github.token_url(),
1307 "https://github.com/login/oauth/access_token"
1308 );
1309 let unknown = OAuthConfig {
1310 provider: "unknown".into(),
1311 client_id: "x".into(),
1312 client_secret: "x".into(),
1313 redirect_uri: "x".into(),
1314 };
1315 assert_eq!(unknown.token_url(), "");
1316 assert!(unknown.auth_url().is_empty());
1317 }
1318
1319 #[test]
1320 fn oauth_auth_url_github() {
1321 let config = OAuthConfig {
1322 provider: "github".into(),
1323 client_id: "gh-id".into(),
1324 client_secret: "gh-secret".into(),
1325 redirect_uri: "http://localhost/cb".into(),
1326 };
1327 assert!(config.auth_url().contains("github.com"));
1328 assert!(config.auth_url().contains("gh-id"));
1329 }
1330
1331 #[test]
1332 fn oauth_auth_url_with_state() {
1333 let config = OAuthConfig {
1334 provider: "google".into(),
1335 client_id: "test-id".into(),
1336 client_secret: "test-secret".into(),
1337 redirect_uri: "http://localhost/cb".into(),
1338 };
1339 let url = config.auth_url_with_state("random_state_123");
1340 assert!(url.contains("&state=random_state_123"));
1341 }
1342
1343 #[test]
1344 fn oauth_state_store_create_and_validate() {
1345 let store = OAuthStateStore::new();
1346 let state = store.create("google");
1347 assert!(store.validate(&state, "google"));
1348 assert!(!store.validate(&state, "google"));
1350 }
1351
1352 #[test]
1353 fn oauth_state_store_wrong_provider_rejected() {
1354 let store = OAuthStateStore::new();
1355 let state = store.create("google");
1356 assert!(!store.validate(&state, "github"));
1357 }
1358
1359 #[test]
1360 fn oauth_state_store_invalid_state_rejected() {
1361 let store = OAuthStateStore::new();
1362 assert!(!store.validate("nonexistent", "google"));
1363 }
1364}