1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25 pub user_id: Option<String>,
29 pub is_admin: bool,
31 #[serde(default, skip_serializing_if = "is_false")]
36 pub is_guest: bool,
37 pub roles: Vec<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46 !b
47}
48
49impl AuthContext {
50 pub fn anonymous() -> Self {
52 Self {
53 user_id: None,
54 is_admin: false,
55 is_guest: false,
56 roles: Vec::new(),
57 tenant_id: None,
58 }
59 }
60
61 pub fn authenticated(user_id: String) -> Self {
63 Self {
64 user_id: Some(user_id),
65 is_admin: false,
66 is_guest: false,
67 roles: Vec::new(),
68 tenant_id: None,
69 }
70 }
71
72 pub fn guest(guest_id: String) -> Self {
77 Self {
78 user_id: Some(guest_id),
79 is_admin: false,
80 is_guest: true,
81 roles: Vec::new(),
82 tenant_id: None,
83 }
84 }
85
86 pub fn admin() -> Self {
88 Self {
89 user_id: Some("__admin__".into()),
90 is_admin: true,
91 is_guest: false,
92 roles: vec!["admin".into()],
93 tenant_id: None,
94 }
95 }
96
97 pub fn user(user_id: String) -> Self {
99 Self::authenticated(user_id)
100 }
101
102 pub fn tenant_id(&self) -> Option<&str> {
104 self.tenant_id.as_deref()
105 }
106
107 pub fn with_tenant(mut self, tenant_id: String) -> Self {
109 self.tenant_id = Some(tenant_id);
110 self
111 }
112
113 pub fn is_authenticated(&self) -> bool {
117 self.user_id.is_some() && !self.is_guest
118 }
119
120 pub fn has_role(&self, role: &str) -> bool {
122 self.is_admin || self.roles.iter().any(|r| r == role)
123 }
124
125 pub fn has_any_role(&self, roles: &[&str]) -> bool {
127 self.is_admin || roles.iter().any(|r| self.has_role(r))
128 }
129
130 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132 self.roles = roles;
133 self
134 }
135}
136
137pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147 if a.len() != b.len() {
148 return false;
149 }
150 let mut result: u8 = 0;
151 for (x, y) in a.iter().zip(b.iter()) {
152 result |= x ^ y;
153 }
154 result == 0
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164 Public,
166 User,
168}
169
170impl AuthMode {
171 #[allow(clippy::should_implement_trait)]
173 pub fn from_str(s: &str) -> Option<Self> {
174 match s {
175 "public" => Some(AuthMode::Public),
176 "user" => Some(AuthMode::User),
177 _ => None,
178 }
179 }
180
181 pub fn check(&self, ctx: &AuthContext) -> bool {
183 match self {
184 AuthMode::Public => true,
185 AuthMode::User => ctx.is_authenticated(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197 pub token: String,
198 pub user_id: String,
199 #[serde(default)]
201 pub expires_at: u64,
202 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub device: Option<String>,
205 #[serde(default)]
207 pub created_at: u64,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub tenant_id: Option<String>,
213}
214
215impl Session {
216 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219 pub fn new(user_id: String) -> Self {
221 let now = now_secs();
222 Self {
223 token: generate_token(),
224 user_id,
225 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226 device: None,
227 created_at: now,
228 tenant_id: None,
229 }
230 }
231
232 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234 let now = now_secs();
235 Self {
236 token: generate_token(),
237 user_id,
238 expires_at: if lifetime_secs == 0 {
239 0
240 } else {
241 now.saturating_add(lifetime_secs)
242 },
243 device: None,
244 created_at: now,
245 tenant_id: None,
246 }
247 }
248
249 pub fn to_auth_context(&self) -> AuthContext {
252 let ctx = AuthContext::authenticated(self.user_id.clone());
253 match &self.tenant_id {
254 Some(t) => ctx.with_tenant(t.clone()),
255 None => ctx,
256 }
257 }
258
259 pub fn is_expired(&self) -> bool {
263 self.expires_at != 0 && now_secs() >= self.expires_at
264 }
265}
266
267fn now_secs() -> u64 {
268 use std::time::{SystemTime, UNIX_EPOCH};
269 SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281 pub provider: String,
282 pub client_id: String,
283 pub client_secret: String,
284 pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288 pub fn auth_url(&self) -> String {
294 match self.provider.as_str() {
295 "google" => format!(
296 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297 self.client_id, self.redirect_uri
298 ),
299 "github" => format!(
300 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301 self.client_id, self.redirect_uri
302 ),
303 _ => String::new(),
304 }
305 }
306
307 pub fn auth_url_with_state(&self, state: &str) -> String {
309 let base = self.auth_url();
310 if base.is_empty() {
311 return base;
312 }
313 format!("{}&state={}", base, state)
314 }
315
316 pub fn token_url(&self) -> &str {
318 match self.provider.as_str() {
319 "google" => "https://oauth2.googleapis.com/token",
320 "github" => "https://github.com/login/oauth/access_token",
321 _ => "",
322 }
323 }
324
325 pub fn userinfo_url(&self) -> &str {
327 match self.provider.as_str() {
328 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329 "github" => "https://api.github.com/user",
330 _ => "",
331 }
332 }
333
334 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341 let body = match self.provider.as_str() {
342 "google" => format!(
343 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344 url_encode(&self.client_id),
345 url_encode(&self.client_secret),
346 url_encode(&self.redirect_uri)
347 ),
348 "github" => format!(
349 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350 url_encode(&self.client_id),
351 url_encode(&self.client_secret),
352 url_encode(&self.redirect_uri)
353 ),
354 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355 };
356
357 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358 parse_token_response(&out)
359 }
360
361 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365 let body = match self.provider.as_str() {
366 "google" => format!(
367 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368 url_encode(&self.client_id),
369 url_encode(&self.client_secret),
370 url_encode(&self.redirect_uri)
371 ),
372 "github" => format!(
373 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374 url_encode(&self.client_id),
375 url_encode(&self.client_secret),
376 url_encode(&self.redirect_uri)
377 ),
378 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379 };
380
381 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382 extract_access_token(&out)
383 }
384
385 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391 let info = self.fetch_userinfo_full(access_token)?;
392 Ok((info.email, info.name))
393 }
394
395 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401 let out = http_get_bearer(self.userinfo_url(), access_token)?;
402 let parsed: serde_json::Value =
403 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404 match self.provider.as_str() {
405 "google" => {
406 let email = parsed
407 .get("email")
408 .and_then(|v| v.as_str())
409 .ok_or("no email in userinfo")?
410 .to_string();
411 let name = parsed
412 .get("name")
413 .and_then(|v| v.as_str())
414 .map(String::from);
415 let provider_account_id = parsed
416 .get("sub")
417 .and_then(|v| v.as_str())
418 .ok_or("no sub in userinfo")?
419 .to_string();
420 Ok(UserInfo {
421 provider: self.provider.clone(),
422 provider_account_id,
423 email,
424 name,
425 })
426 }
427 "github" => {
428 let name = parsed
429 .get("name")
430 .and_then(|v| v.as_str())
431 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432 .map(String::from);
433 let email = parsed
434 .get("email")
435 .and_then(|v| v.as_str())
436 .map(String::from);
437 let email = email
440 .or_else(|| fetch_github_primary_email(access_token).ok())
441 .ok_or("no accessible email on GitHub account")?;
442 let provider_account_id = parsed
445 .get("id")
446 .map(|v| {
447 v.as_i64()
448 .map(|n| n.to_string())
449 .or_else(|| v.as_str().map(String::from))
450 .unwrap_or_default()
451 })
452 .filter(|s| !s.is_empty())
453 .ok_or("no id in userinfo")?;
454 Ok(UserInfo {
455 provider: self.provider.clone(),
456 provider_account_id,
457 email,
458 name,
459 })
460 }
461 _ => Err(format!("unknown provider: {}", self.provider)),
462 }
463 }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472 pub provider: String,
473 pub provider_account_id: String,
474 pub email: String,
475 pub name: Option<String>,
476}
477
478#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483 pub access_token: String,
484 pub refresh_token: Option<String>,
485 pub id_token: Option<String>,
486 pub expires_at: Option<u64>,
490 pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497 let mut map = serde_json::Map::new();
499 for pair in body.split('&') {
500 if let Some((k, v)) = pair.split_once('=') {
501 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502 }
503 }
504 serde_json::Value::Object(map)
505 });
506
507 let access_token = json
508 .get("access_token")
509 .and_then(|v| v.as_str())
510 .ok_or_else(|| format!("no access_token in token response: {body}"))?
511 .to_string();
512 let refresh_token = json
513 .get("refresh_token")
514 .and_then(|v| v.as_str())
515 .map(String::from);
516 let id_token = json
517 .get("id_token")
518 .and_then(|v| v.as_str())
519 .map(String::from);
520 let expires_at = json
521 .get("expires_in")
522 .and_then(|v| {
523 v.as_u64()
524 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525 })
526 .map(|secs| now_secs().saturating_add(secs));
527 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528 Ok(TokenSet {
529 access_token,
530 refresh_token,
531 id_token,
532 expires_at,
533 scope,
534 })
535}
536
537fn url_encode(s: &str) -> String {
538 let mut out = String::with_capacity(s.len());
539 for b in s.bytes() {
540 match b {
541 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542 out.push(b as char)
543 }
544 _ => out.push_str(&format!("%{b:02X}")),
545 }
546 }
547 out
548}
549
550const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556 ureq::AgentBuilder::new()
557 .timeout_connect(HTTP_TIMEOUT)
558 .timeout_read(HTTP_TIMEOUT)
559 .timeout_write(HTTP_TIMEOUT)
560 .user_agent("pylon/0.1")
561 .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565 let agent = ureq_agent();
566 let mut req = agent
567 .post(url)
568 .set("Content-Type", "application/x-www-form-urlencoded");
569 if accept_json {
570 req = req.set("Accept", "application/json");
571 }
572 match req.send_string(body) {
573 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574 Err(ureq::Error::Status(code, resp)) => {
575 let body = resp.into_string().unwrap_or_default();
576 Err(format!("HTTP {code}: {body}"))
577 }
578 Err(e) => Err(format!("HTTP error: {e}")),
579 }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583 let agent = ureq_agent();
584 match agent
585 .get(url)
586 .set("Authorization", &format!("Bearer {token}"))
587 .set("Accept", "application/json")
588 .call()
589 {
590 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591 Err(ureq::Error::Status(code, resp)) => {
592 let body = resp.into_string().unwrap_or_default();
593 Err(format!("HTTP {code}: {body}"))
594 }
595 Err(e) => Err(format!("HTTP error: {e}")),
596 }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601 let emails: serde_json::Value =
602 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603 emails
604 .as_array()
605 .and_then(|arr| {
606 arr.iter()
607 .find(|e| {
608 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610 })
611 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612 })
613 .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619 return Ok(t.to_string());
620 }
621 }
622 for pair in body.split('&') {
624 if let Some(val) = pair.strip_prefix("access_token=") {
625 return Ok(val.to_string());
626 }
627 }
628 Err(format!("no access_token in token response: {body}"))
629}
630
631pub struct OAuthRegistry {
633 providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642impl OAuthRegistry {
643 pub fn new() -> Self {
644 Self {
645 providers: std::collections::HashMap::new(),
646 }
647 }
648
649 pub fn register(&mut self, config: OAuthConfig) {
650 self.providers.insert(config.provider.clone(), config);
651 }
652
653 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654 self.providers.get(provider)
655 }
656
657 pub fn from_env() -> Self {
660 let mut reg = Self::new();
661
662 if let (Ok(id), Ok(secret)) = (
664 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666 ) {
667 reg.register(OAuthConfig {
668 provider: "google".into(),
669 client_id: id,
670 client_secret: secret,
671 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673 });
674 }
675
676 if let (Ok(id), Ok(secret)) = (
678 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680 ) {
681 reg.register(OAuthConfig {
682 provider: "github".into(),
683 client_id: id,
684 client_secret: secret,
685 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687 });
688 }
689
690 reg
691 }
692}
693
694#[derive(Debug, Clone, PartialEq, Eq)]
704pub struct OAuthState {
705 pub provider: String,
706 pub callback_url: String,
709 pub error_callback_url: String,
714 pub expires_at: u64,
715}
716
717pub trait OAuthStateBackend: Send + Sync {
722 fn put(&self, token: &str, state: &OAuthState);
724 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
729}
730
731pub struct InMemoryOAuthBackend {
733 states: Mutex<HashMap<String, OAuthState>>,
734}
735
736impl InMemoryOAuthBackend {
737 pub fn new() -> Self {
738 Self {
739 states: Mutex::new(HashMap::new()),
740 }
741 }
742}
743
744impl Default for InMemoryOAuthBackend {
745 fn default() -> Self {
746 Self::new()
747 }
748}
749
750impl OAuthStateBackend for InMemoryOAuthBackend {
751 fn put(&self, token: &str, state: &OAuthState) {
752 self.states
753 .lock()
754 .unwrap()
755 .insert(token.to_string(), state.clone());
756 }
757 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
758 let mut s = self.states.lock().unwrap();
759 let entry = s.remove(token)?;
760 if entry.expires_at <= now_unix_secs {
761 return None;
762 }
763 Some(entry)
764 }
765}
766
767pub struct OAuthStateStore {
774 backend: Box<dyn OAuthStateBackend>,
775}
776
777impl Default for OAuthStateStore {
778 fn default() -> Self {
779 Self::new()
780 }
781}
782
783impl OAuthStateStore {
784 pub fn new() -> Self {
785 Self {
786 backend: Box::new(InMemoryOAuthBackend::new()),
787 }
788 }
789
790 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
791 Self { backend }
792 }
793
794 pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
802 use std::time::{SystemTime, UNIX_EPOCH};
803 let token = generate_token();
804 let now = SystemTime::now()
805 .duration_since(UNIX_EPOCH)
806 .unwrap_or_default()
807 .as_secs();
808 let state = OAuthState {
809 provider: provider.to_string(),
810 callback_url: callback_url.to_string(),
811 error_callback_url: error_callback_url.to_string(),
812 expires_at: now + 600,
813 };
814 self.backend.put(&token, &state);
815 token
816 }
817
818 pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
823 use std::time::{SystemTime, UNIX_EPOCH};
824 let now = SystemTime::now()
825 .duration_since(UNIX_EPOCH)
826 .unwrap_or_default()
827 .as_secs();
828 let entry = self.backend.take(state, now)?;
829 if entry.provider != expected_provider {
830 return None;
831 }
832 Some(entry)
833 }
834}
835
836pub fn validate_trusted_redirect(
853 url: &str,
854 trusted_origins: &[String],
855) -> Result<(), TrustedOriginError> {
856 if url.is_empty() {
857 return Err(TrustedOriginError::Empty);
858 }
859 if !url.starts_with("http://") && !url.starts_with("https://") {
862 return Err(TrustedOriginError::NotHttp);
863 }
864 let url_origin = origin_of(url);
865 if trusted_origins.iter().any(|t| t == &url_origin) {
866 Ok(())
867 } else {
868 Err(TrustedOriginError::NotTrusted { origin: url_origin })
869 }
870}
871
872#[derive(Debug, Clone, PartialEq, Eq)]
874pub enum TrustedOriginError {
875 Empty,
876 NotHttp,
877 NotTrusted { origin: String },
878}
879
880impl std::fmt::Display for TrustedOriginError {
881 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882 match self {
883 TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
884 TrustedOriginError::NotHttp => {
885 write!(f, "redirect URL must use http:// or https:// scheme")
886 }
887 TrustedOriginError::NotTrusted { origin } => write!(
888 f,
889 "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
890 ),
891 }
892 }
893}
894
895pub fn origin_of(url: &str) -> String {
900 let after_scheme = match url.find("://") {
901 Some(i) => i + 3,
902 None => return url.trim_end_matches('/').to_string(),
903 };
904 let rest = &url[after_scheme..];
905 let cut = rest
906 .find(|c: char| c == '/' || c == '?' || c == '#')
907 .unwrap_or(rest.len());
908 url[..after_scheme + cut].to_string()
909}
910
911pub trait MagicCodeBackend: Send + Sync {
924 fn put(&self, email: &str, code: &MagicCode);
926 fn get(&self, email: &str) -> Option<MagicCode>;
928 fn remove(&self, email: &str);
931 fn bump_attempts(&self, email: &str);
935 fn load_all(&self) -> Vec<MagicCode>;
938}
939
940pub struct InMemoryMagicCodeBackend {
943 codes: Mutex<HashMap<String, MagicCode>>,
944}
945
946impl InMemoryMagicCodeBackend {
947 pub fn new() -> Self {
948 Self {
949 codes: Mutex::new(HashMap::new()),
950 }
951 }
952}
953
954impl Default for InMemoryMagicCodeBackend {
955 fn default() -> Self {
956 Self::new()
957 }
958}
959
960impl MagicCodeBackend for InMemoryMagicCodeBackend {
961 fn put(&self, email: &str, code: &MagicCode) {
962 self.codes
963 .lock()
964 .unwrap()
965 .insert(email.to_string(), code.clone());
966 }
967 fn get(&self, email: &str) -> Option<MagicCode> {
968 self.codes.lock().unwrap().get(email).cloned()
969 }
970 fn remove(&self, email: &str) {
971 self.codes.lock().unwrap().remove(email);
972 }
973 fn bump_attempts(&self, email: &str) {
974 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
975 c.attempts = c.attempts.saturating_add(1);
976 }
977 }
978 fn load_all(&self) -> Vec<MagicCode> {
979 self.codes.lock().unwrap().values().cloned().collect()
980 }
981}
982
983pub struct MagicCodeStore {
988 cache: Mutex<HashMap<String, MagicCode>>,
989 backend: Box<dyn MagicCodeBackend>,
990}
991
992#[derive(Debug, Clone)]
993pub struct MagicCode {
994 pub email: String,
995 pub code: String,
996 pub expires_at: u64,
997 pub attempts: u32,
1000}
1001
1002const MAX_ATTEMPTS: u32 = 5;
1006
1007const CREATE_COOLDOWN_SECS: u64 = 60;
1010
1011#[derive(Debug, Clone, PartialEq, Eq)]
1012pub enum MagicCodeError {
1013 NotFound,
1015 TooManyAttempts,
1017 BadCode,
1019 Expired,
1021 Throttled { retry_after_secs: u64 },
1023}
1024
1025impl Default for MagicCodeStore {
1026 fn default() -> Self {
1027 Self::new()
1028 }
1029}
1030
1031impl MagicCodeStore {
1032 pub fn new() -> Self {
1033 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1034 }
1035
1036 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1041 let now = now_secs();
1042 let mut cache = HashMap::new();
1043 for c in backend.load_all() {
1044 if c.expires_at > now {
1045 cache.insert(c.email.clone(), c);
1046 }
1047 }
1048 Self {
1049 cache: Mutex::new(cache),
1050 backend,
1051 }
1052 }
1053
1054 pub fn create(&self, email: &str) -> String {
1057 self.try_create(email).unwrap_or_else(|_| String::new())
1060 }
1061
1062 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1065 let now = now_secs();
1066
1067 let mut codes = self.cache.lock().unwrap();
1068
1069 if let Some(existing) = codes.get(email) {
1073 if existing.expires_at > now {
1074 let created_at = existing.expires_at.saturating_sub(600);
1075 let age = now.saturating_sub(created_at);
1076 if age < CREATE_COOLDOWN_SECS {
1077 return Err(MagicCodeError::Throttled {
1078 retry_after_secs: CREATE_COOLDOWN_SECS - age,
1079 });
1080 }
1081 }
1082 }
1083
1084 let code = generate_magic_code();
1085 let mc = MagicCode {
1086 email: email.to_string(),
1087 code: code.clone(),
1088 expires_at: now + 600, attempts: 0,
1090 };
1091 codes.insert(email.to_string(), mc.clone());
1092 self.backend.put(email, &mc);
1096 Ok(code)
1097 }
1098
1099 pub fn verify(&self, email: &str, code: &str) -> bool {
1103 matches!(self.try_verify(email, code), Ok(()))
1104 }
1105
1106 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1113 self.cache
1114 .lock()
1115 .map(|m| m.values().cloned().collect())
1116 .unwrap_or_default()
1117 }
1118
1119 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1120 let now = now_secs();
1121 let mut codes = self.cache.lock().unwrap();
1122
1123 let mc = match codes.get_mut(email) {
1124 Some(m) => m,
1125 None => return Err(MagicCodeError::NotFound),
1126 };
1127
1128 if mc.attempts >= MAX_ATTEMPTS {
1129 return Err(MagicCodeError::TooManyAttempts);
1130 }
1131 if mc.expires_at <= now {
1132 codes.remove(email);
1133 self.backend.remove(email);
1134 return Err(MagicCodeError::Expired);
1135 }
1136
1137 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1138 if !ok {
1139 mc.attempts += 1;
1140 self.backend.bump_attempts(email);
1141 if mc.attempts >= MAX_ATTEMPTS {
1143 return Err(MagicCodeError::TooManyAttempts);
1144 }
1145 return Err(MagicCodeError::BadCode);
1146 }
1147
1148 codes.remove(email);
1150 self.backend.remove(email);
1151 Ok(())
1152 }
1153}
1154
1155fn hex_encode(bytes: &[u8]) -> String {
1160 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1161}
1162
1163fn generate_magic_code() -> String {
1165 use rand::Rng;
1166 let mut rng = rand::thread_rng();
1167 let code: u32 = rng.gen_range(0..1_000_000);
1168 format!("{:06}", code)
1169}
1170
1171fn generate_token() -> String {
1173 use rand::Rng;
1174 let mut rng = rand::thread_rng();
1175 let bytes: [u8; 32] = rng.gen();
1176 format!("pylon_{}", hex_encode(&bytes))
1177}
1178
1179use std::collections::HashMap;
1184use std::sync::Mutex;
1185
1186pub trait SessionBackend: Send + Sync {
1190 fn load_all(&self) -> Vec<Session>;
1191 fn save(&self, session: &Session);
1192 fn remove(&self, token: &str);
1193}
1194
1195pub struct SessionStore {
1203 sessions: Mutex<HashMap<String, Session>>,
1204 backend: Option<Box<dyn SessionBackend>>,
1205 default_lifetime_secs: u64,
1209}
1210
1211impl Default for SessionStore {
1212 fn default() -> Self {
1213 Self::new()
1214 }
1215}
1216
1217impl SessionStore {
1218 pub fn new() -> Self {
1219 Self {
1220 sessions: Mutex::new(HashMap::new()),
1221 backend: None,
1222 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1223 }
1224 }
1225
1226 pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1229 self.default_lifetime_secs = lifetime_secs;
1230 self
1231 }
1232
1233 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1237 let mut map = HashMap::new();
1238 for s in backend.load_all() {
1239 if !s.is_expired() {
1240 map.insert(s.token.clone(), s);
1241 }
1242 }
1243 Self {
1244 sessions: Mutex::new(map),
1245 backend: Some(backend),
1246 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1247 }
1248 }
1249
1250 pub fn create(&self, user_id: String) -> Session {
1254 let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1255 let mut sessions = self.sessions.lock().unwrap();
1256 sessions.insert(session.token.clone(), session.clone());
1257 if let Some(b) = &self.backend {
1258 b.save(&session);
1259 }
1260 session
1261 }
1262
1263 pub fn get(&self, token: &str) -> Option<Session> {
1265 let mut sessions = self.sessions.lock().unwrap();
1266 match sessions.get(token) {
1267 Some(s) if s.is_expired() => {
1268 sessions.remove(token);
1269 None
1270 }
1271 Some(s) => Some(s.clone()),
1272 None => None,
1273 }
1274 }
1275
1276 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1279 match token {
1280 Some(t) => match self.get(t) {
1281 Some(session) => session.to_auth_context(),
1282 None => AuthContext::anonymous(),
1283 },
1284 None => AuthContext::anonymous(),
1285 }
1286 }
1287
1288 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1292 let mut sessions = self.sessions.lock().unwrap();
1293 let old = sessions.remove(old_token)?;
1294 if let Some(b) = &self.backend {
1295 b.remove(old_token);
1296 }
1297 if old.is_expired() {
1298 return None;
1299 }
1300 let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
1306 new.device = old.device.clone();
1307 sessions.insert(new.token.clone(), new.clone());
1308 if let Some(b) = &self.backend {
1309 b.save(&new);
1310 }
1311 Some(new)
1312 }
1313
1314 pub fn list_all_unfiltered(&self) -> Vec<Session> {
1319 self.sessions
1320 .lock()
1321 .map(|m| m.values().cloned().collect())
1322 .unwrap_or_default()
1323 }
1324
1325 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1327 let sessions = self.sessions.lock().unwrap();
1328 sessions
1329 .values()
1330 .filter(|s| s.user_id == user_id && !s.is_expired())
1331 .cloned()
1332 .collect()
1333 }
1334
1335 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1337 let mut sessions = self.sessions.lock().unwrap();
1338 let tokens: Vec<String> = sessions
1339 .iter()
1340 .filter_map(|(t, s)| {
1341 if s.user_id == user_id {
1342 Some(t.clone())
1343 } else {
1344 None
1345 }
1346 })
1347 .collect();
1348 let n = tokens.len();
1349 for t in &tokens {
1350 sessions.remove(t);
1351 if let Some(b) = &self.backend {
1352 b.remove(t);
1353 }
1354 }
1355 n
1356 }
1357
1358 pub fn sweep_expired(&self) -> usize {
1360 let mut sessions = self.sessions.lock().unwrap();
1361 let expired: Vec<String> = sessions
1362 .iter()
1363 .filter_map(|(t, s)| {
1364 if s.is_expired() {
1365 Some(t.clone())
1366 } else {
1367 None
1368 }
1369 })
1370 .collect();
1371 let n = expired.len();
1372 for t in &expired {
1373 sessions.remove(t);
1374 if let Some(b) = &self.backend {
1375 b.remove(t);
1376 }
1377 }
1378 n
1379 }
1380
1381 pub fn set_device(&self, token: &str, device: String) -> bool {
1383 let mut sessions = self.sessions.lock().unwrap();
1384 if let Some(s) = sessions.get_mut(token) {
1385 s.device = Some(device);
1386 if let Some(b) = &self.backend {
1387 b.save(s);
1388 }
1389 true
1390 } else {
1391 false
1392 }
1393 }
1394
1395 pub fn create_guest(&self) -> Session {
1397 use rand::Rng;
1398 let mut rng = rand::thread_rng();
1399 let bytes: [u8; 16] = rng.gen();
1400 let guest_id = format!("guest_{}", hex_encode(&bytes));
1401 self.create(guest_id)
1402 }
1403
1404 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1406 let mut sessions = self.sessions.lock().unwrap();
1407 if let Some(session) = sessions.get_mut(token) {
1408 session.user_id = real_user_id;
1409 if let Some(b) = &self.backend {
1410 b.save(session);
1411 }
1412 true
1413 } else {
1414 false
1415 }
1416 }
1417
1418 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1423 let mut sessions = self.sessions.lock().unwrap();
1424 if let Some(session) = sessions.get_mut(token) {
1425 session.tenant_id = tenant_id;
1426 if let Some(b) = &self.backend {
1427 b.save(session);
1428 }
1429 true
1430 } else {
1431 false
1432 }
1433 }
1434
1435 pub fn revoke(&self, token: &str) -> bool {
1437 let mut sessions = self.sessions.lock().unwrap();
1438 let removed = sessions.remove(token).is_some();
1439 if removed {
1440 if let Some(b) = &self.backend {
1441 b.remove(token);
1442 }
1443 }
1444 removed
1445 }
1446}
1447
1448#[derive(Debug, Clone, PartialEq, Eq)]
1474pub struct Account {
1475 pub id: String,
1476 pub user_id: String,
1477 pub provider_id: String,
1480 pub account_id: String,
1483 pub access_token: Option<String>,
1484 pub refresh_token: Option<String>,
1485 pub id_token: Option<String>,
1486 pub access_token_expires_at: Option<u64>,
1489 pub refresh_token_expires_at: Option<u64>,
1493 pub scope: Option<String>,
1494 pub password: Option<String>,
1498 pub created_at: u64,
1500 pub updated_at: u64,
1502}
1503
1504impl Account {
1505 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1509 let now = now_secs();
1510 Self {
1511 id: generate_token(),
1512 user_id,
1513 provider_id: info.provider.clone(),
1514 account_id: info.provider_account_id.clone(),
1515 access_token: Some(tokens.access_token.clone()),
1516 refresh_token: tokens.refresh_token.clone(),
1517 id_token: tokens.id_token.clone(),
1518 access_token_expires_at: tokens.expires_at,
1519 refresh_token_expires_at: None,
1520 scope: tokens.scope.clone(),
1521 password: None,
1522 created_at: now,
1523 updated_at: now,
1524 }
1525 }
1526
1527 pub fn access_token_expired(&self) -> bool {
1532 match self.access_token_expires_at {
1533 Some(ts) => now_secs() >= ts,
1534 None => false,
1535 }
1536 }
1537}
1538
1539pub trait AccountBackend: Send + Sync {
1542 fn upsert(&self, account: &Account);
1546 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1549 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1554 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1556 fn list_all(&self) -> Vec<Account>;
1561}
1562
1563pub struct InMemoryAccountBackend {
1567 accounts: Mutex<HashMap<(String, String), Account>>,
1571}
1572
1573impl InMemoryAccountBackend {
1574 pub fn new() -> Self {
1575 Self {
1576 accounts: Mutex::new(HashMap::new()),
1577 }
1578 }
1579}
1580
1581impl Default for InMemoryAccountBackend {
1582 fn default() -> Self {
1583 Self::new()
1584 }
1585}
1586
1587impl AccountBackend for InMemoryAccountBackend {
1588 fn upsert(&self, account: &Account) {
1589 let key = (account.provider_id.clone(), account.account_id.clone());
1590 self.accounts.lock().unwrap().insert(key, account.clone());
1591 }
1592 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1593 self.accounts
1594 .lock()
1595 .unwrap()
1596 .get(&(provider_id.to_string(), account_id.to_string()))
1597 .cloned()
1598 }
1599 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1600 self.accounts
1601 .lock()
1602 .unwrap()
1603 .values()
1604 .filter(|a| a.user_id == user_id)
1605 .cloned()
1606 .collect()
1607 }
1608 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1609 self.accounts
1610 .lock()
1611 .unwrap()
1612 .remove(&(provider_id.to_string(), account_id.to_string()))
1613 .is_some()
1614 }
1615 fn list_all(&self) -> Vec<Account> {
1616 self.accounts.lock().unwrap().values().cloned().collect()
1617 }
1618}
1619
1620pub struct AccountStore {
1623 backend: Box<dyn AccountBackend>,
1624}
1625
1626impl Default for AccountStore {
1627 fn default() -> Self {
1628 Self::new()
1629 }
1630}
1631
1632impl AccountStore {
1633 pub fn new() -> Self {
1634 Self {
1635 backend: Box::new(InMemoryAccountBackend::new()),
1636 }
1637 }
1638 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1639 Self { backend }
1640 }
1641 pub fn upsert(&self, account: &Account) {
1642 self.backend.upsert(account);
1643 }
1644 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1645 self.backend.find_by_provider(provider_id, account_id)
1646 }
1647 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1648 self.backend.find_for_user(user_id)
1649 }
1650 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1651 self.backend.unlink(provider_id, account_id)
1652 }
1653
1654 pub fn list_all_unfiltered(&self) -> Vec<Account> {
1668 self.backend.list_all()
1669 }
1670}
1671
1672#[cfg(test)]
1677mod tests {
1678 use super::*;
1679
1680 #[test]
1681 fn anonymous_context() {
1682 let ctx = AuthContext::anonymous();
1683 assert!(!ctx.is_authenticated());
1684 assert!(ctx.user_id.is_none());
1685 }
1686
1687 #[test]
1688 fn authenticated_context() {
1689 let ctx = AuthContext::authenticated("user-1".into());
1690 assert!(ctx.is_authenticated());
1691 assert_eq!(ctx.user_id, Some("user-1".into()));
1692 }
1693
1694 #[test]
1695 fn auth_mode_public_allows_anonymous() {
1696 let mode = AuthMode::Public;
1697 assert!(mode.check(&AuthContext::anonymous()));
1698 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1699 }
1700
1701 #[test]
1702 fn auth_mode_user_requires_authenticated() {
1703 let mode = AuthMode::User;
1704 assert!(!mode.check(&AuthContext::anonymous()));
1705 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1706 }
1707
1708 #[test]
1709 fn auth_mode_from_str() {
1710 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1711 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1712 assert_eq!(AuthMode::from_str("admin"), None);
1713 }
1714
1715 #[test]
1716 fn session_store_create_and_get() {
1717 let store = SessionStore::new();
1718 let session = store.create("user-1".into());
1719 assert!(!session.token.is_empty());
1720 assert!(session.token.starts_with("pylon_"));
1721
1722 let retrieved = store.get(&session.token).unwrap();
1723 assert_eq!(retrieved.user_id, "user-1");
1724 }
1725
1726 #[test]
1727 fn session_store_resolve() {
1728 let store = SessionStore::new();
1729 let session = store.create("user-1".into());
1730
1731 let ctx = store.resolve(Some(&session.token));
1732 assert!(ctx.is_authenticated());
1733 assert_eq!(ctx.user_id, Some("user-1".into()));
1734
1735 let anon = store.resolve(None);
1736 assert!(!anon.is_authenticated());
1737
1738 let bad = store.resolve(Some("invalid-token"));
1739 assert!(!bad.is_authenticated());
1740 }
1741
1742 #[test]
1743 fn session_store_revoke() {
1744 let store = SessionStore::new();
1745 let session = store.create("user-1".into());
1746
1747 assert!(store.revoke(&session.token));
1748 assert!(store.get(&session.token).is_none());
1749 assert!(!store.revoke(&session.token)); }
1751
1752 #[test]
1753 fn session_to_auth_context() {
1754 let session = Session::new("user-42".into());
1755 let ctx = session.to_auth_context();
1756 assert_eq!(ctx.user_id, Some("user-42".into()));
1757 }
1758
1759 #[test]
1762 fn admin_context() {
1763 let ctx = AuthContext::admin();
1764 assert!(ctx.is_admin);
1765 assert!(ctx.is_authenticated());
1766 }
1767
1768 #[test]
1769 fn anonymous_not_admin() {
1770 let ctx = AuthContext::anonymous();
1771 assert!(!ctx.is_admin);
1772 }
1773
1774 #[test]
1775 fn authenticated_not_admin() {
1776 let ctx = AuthContext::authenticated("user-1".into());
1777 assert!(!ctx.is_admin);
1778 }
1779
1780 #[test]
1783 fn magic_code_create_and_verify() {
1784 let store = MagicCodeStore::new();
1785 let code = store.create("test@example.com");
1786 assert_eq!(code.len(), 6);
1787 assert!(store.verify("test@example.com", &code));
1788 }
1789
1790 #[test]
1791 fn magic_code_wrong_code_rejected() {
1792 let store = MagicCodeStore::new();
1793 store.create("test@example.com");
1794 assert!(!store.verify("test@example.com", "000000"));
1795 }
1796
1797 #[test]
1798 fn magic_code_wrong_email_rejected() {
1799 let store = MagicCodeStore::new();
1800 let code = store.create("test@example.com");
1801 assert!(!store.verify("other@example.com", &code));
1802 }
1803
1804 #[test]
1805 fn magic_code_consumed_after_verify() {
1806 let store = MagicCodeStore::new();
1807 let code = store.create("test@example.com");
1808 assert!(store.verify("test@example.com", &code));
1809 assert!(!store.verify("test@example.com", &code));
1811 }
1812
1813 #[test]
1814 fn magic_code_different_emails_independent() {
1815 let store = MagicCodeStore::new();
1816 let code1 = store.create("alice@example.com");
1817 let code2 = store.create("bob@example.com");
1818 assert!(store.verify("alice@example.com", &code1));
1820 assert!(store.verify("bob@example.com", &code2));
1821 }
1822
1823 #[test]
1826 fn constant_time_eq_equal() {
1827 assert!(constant_time_eq(b"hello", b"hello"));
1828 assert!(constant_time_eq(b"", b""));
1829 }
1830
1831 #[test]
1832 fn constant_time_eq_not_equal() {
1833 assert!(!constant_time_eq(b"hello", b"world"));
1834 assert!(!constant_time_eq(b"hello", b"hell"));
1835 assert!(!constant_time_eq(b"a", b"b"));
1836 }
1837
1838 #[test]
1841 fn generated_tokens_are_unique() {
1842 let t1 = generate_token();
1843 let t2 = generate_token();
1844 assert_ne!(t1, t2);
1845 assert!(t1.starts_with("pylon_"));
1846 assert!(t2.starts_with("pylon_"));
1847 assert_eq!(t1.len(), 6 + 64);
1849 }
1850
1851 #[test]
1854 fn oauth_registry_empty() {
1855 let reg = OAuthRegistry::new();
1856 assert!(reg.get("google").is_none());
1857 }
1858
1859 #[test]
1860 fn oauth_registry_register_and_get() {
1861 let mut reg = OAuthRegistry::new();
1862 reg.register(OAuthConfig {
1863 provider: "google".into(),
1864 client_id: "test-id".into(),
1865 client_secret: "test-secret".into(),
1866 redirect_uri: "http://localhost/callback".into(),
1867 });
1868 let config = reg.get("google").unwrap();
1869 assert_eq!(config.client_id, "test-id");
1870 assert!(config.auth_url().contains("accounts.google.com"));
1871 }
1872
1873 #[test]
1876 fn guest_session() {
1877 let store = SessionStore::new();
1878 let session = store.create_guest();
1879 assert!(session.user_id.starts_with("guest_"));
1880 assert!(!session.token.is_empty());
1881
1882 let ctx = store.resolve(Some(&session.token));
1883 assert!(ctx.is_authenticated());
1884 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1885 }
1886
1887 #[test]
1888 fn upgrade_guest_to_real_user() {
1889 let store = SessionStore::new();
1890 let session = store.create_guest();
1891 assert!(session.user_id.starts_with("guest_"));
1892
1893 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1894 assert!(upgraded);
1895
1896 let ctx = store.resolve(Some(&session.token));
1897 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1898 }
1899
1900 #[test]
1901 fn upgrade_invalid_token_fails() {
1902 let store = SessionStore::new();
1903 let upgraded = store.upgrade("nonexistent-token", "user".into());
1904 assert!(!upgraded);
1905 }
1906
1907 #[test]
1908 fn guest_context() {
1909 let ctx = AuthContext::guest("guest_123".into());
1910 assert!(!ctx.is_authenticated());
1913 assert!(ctx.is_guest);
1914 assert!(!ctx.is_admin);
1915 assert_eq!(ctx.user_id, Some("guest_123".into()));
1916 assert!(!AuthMode::User.check(&ctx));
1917 assert!(AuthMode::Public.check(&ctx));
1918 }
1919
1920 #[test]
1921 fn oauth_token_urls() {
1922 let google = OAuthConfig {
1923 provider: "google".into(),
1924 client_id: "x".into(),
1925 client_secret: "x".into(),
1926 redirect_uri: "x".into(),
1927 };
1928 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1929 let github = OAuthConfig {
1930 provider: "github".into(),
1931 client_id: "x".into(),
1932 client_secret: "x".into(),
1933 redirect_uri: "x".into(),
1934 };
1935 assert_eq!(
1936 github.token_url(),
1937 "https://github.com/login/oauth/access_token"
1938 );
1939 let unknown = OAuthConfig {
1940 provider: "unknown".into(),
1941 client_id: "x".into(),
1942 client_secret: "x".into(),
1943 redirect_uri: "x".into(),
1944 };
1945 assert_eq!(unknown.token_url(), "");
1946 assert!(unknown.auth_url().is_empty());
1947 }
1948
1949 #[test]
1950 fn oauth_auth_url_github() {
1951 let config = OAuthConfig {
1952 provider: "github".into(),
1953 client_id: "gh-id".into(),
1954 client_secret: "gh-secret".into(),
1955 redirect_uri: "http://localhost/cb".into(),
1956 };
1957 assert!(config.auth_url().contains("github.com"));
1958 assert!(config.auth_url().contains("gh-id"));
1959 }
1960
1961 #[test]
1962 fn oauth_auth_url_with_state() {
1963 let config = OAuthConfig {
1964 provider: "google".into(),
1965 client_id: "test-id".into(),
1966 client_secret: "test-secret".into(),
1967 redirect_uri: "http://localhost/cb".into(),
1968 };
1969 let url = config.auth_url_with_state("random_state_123");
1970 assert!(url.contains("&state=random_state_123"));
1971 }
1972
1973 #[test]
1974 fn oauth_state_store_create_and_validate() {
1975 let store = OAuthStateStore::new();
1976 let token = store.create("google", "https://app/cb", "https://app/login");
1977 let rec = store.validate(&token, "google").expect("valid first time");
1978 assert_eq!(rec.callback_url, "https://app/cb");
1979 assert_eq!(rec.error_callback_url, "https://app/login");
1980 assert!(store.validate(&token, "google").is_none());
1982 }
1983
1984 #[test]
1985 fn oauth_state_store_wrong_provider_rejected() {
1986 let store = OAuthStateStore::new();
1987 let token = store.create("google", "https://app/cb", "https://app/cb");
1988 assert!(store.validate(&token, "github").is_none());
1989 }
1990
1991 #[test]
1992 fn oauth_state_store_invalid_state_rejected() {
1993 let store = OAuthStateStore::new();
1994 assert!(store.validate("nonexistent", "google").is_none());
1995 }
1996
1997 #[test]
1998 fn validate_trusted_redirect_basics() {
1999 let trusted = vec!["http://localhost:3000".to_string()];
2000 assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
2001 assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
2002 assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
2003
2004 assert!(matches!(
2006 validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
2007 Err(TrustedOriginError::NotTrusted { .. })
2008 ));
2009 assert!(matches!(
2012 validate_trusted_redirect("javascript:alert(1)", &trusted),
2013 Err(TrustedOriginError::NotHttp)
2014 ));
2015 assert!(matches!(
2016 validate_trusted_redirect("", &trusted),
2017 Err(TrustedOriginError::Empty)
2018 ));
2019 }
2020}