1pub mod cookie;
2pub mod email;
3pub mod password;
4
5pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
24pub struct AuthContext {
25 pub user_id: Option<String>,
29 pub is_admin: bool,
31 #[serde(default, skip_serializing_if = "is_false")]
36 pub is_guest: bool,
37 pub roles: Vec<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tenant_id: Option<String>,
43}
44
45fn is_false(b: &bool) -> bool {
46 !b
47}
48
49impl AuthContext {
50 pub fn anonymous() -> Self {
52 Self {
53 user_id: None,
54 is_admin: false,
55 is_guest: false,
56 roles: Vec::new(),
57 tenant_id: None,
58 }
59 }
60
61 pub fn authenticated(user_id: String) -> Self {
63 Self {
64 user_id: Some(user_id),
65 is_admin: false,
66 is_guest: false,
67 roles: Vec::new(),
68 tenant_id: None,
69 }
70 }
71
72 pub fn guest(guest_id: String) -> Self {
77 Self {
78 user_id: Some(guest_id),
79 is_admin: false,
80 is_guest: true,
81 roles: Vec::new(),
82 tenant_id: None,
83 }
84 }
85
86 pub fn admin() -> Self {
88 Self {
89 user_id: Some("__admin__".into()),
90 is_admin: true,
91 is_guest: false,
92 roles: vec!["admin".into()],
93 tenant_id: None,
94 }
95 }
96
97 pub fn user(user_id: String) -> Self {
99 Self::authenticated(user_id)
100 }
101
102 pub fn tenant_id(&self) -> Option<&str> {
104 self.tenant_id.as_deref()
105 }
106
107 pub fn with_tenant(mut self, tenant_id: String) -> Self {
109 self.tenant_id = Some(tenant_id);
110 self
111 }
112
113 pub fn is_authenticated(&self) -> bool {
117 self.user_id.is_some() && !self.is_guest
118 }
119
120 pub fn has_role(&self, role: &str) -> bool {
122 self.is_admin || self.roles.iter().any(|r| r == role)
123 }
124
125 pub fn has_any_role(&self, roles: &[&str]) -> bool {
127 self.is_admin || roles.iter().any(|r| self.has_role(r))
128 }
129
130 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
132 self.roles = roles;
133 self
134 }
135}
136
137pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
147 if a.len() != b.len() {
148 return false;
149 }
150 let mut result: u8 = 0;
151 for (x, y) in a.iter().zip(b.iter()) {
152 result |= x ^ y;
153 }
154 result == 0
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum AuthMode {
164 Public,
166 User,
168}
169
170impl AuthMode {
171 #[allow(clippy::should_implement_trait)]
173 pub fn from_str(s: &str) -> Option<Self> {
174 match s {
175 "public" => Some(AuthMode::Public),
176 "user" => Some(AuthMode::User),
177 _ => None,
178 }
179 }
180
181 pub fn check(&self, ctx: &AuthContext) -> bool {
183 match self {
184 AuthMode::Public => true,
185 AuthMode::User => ctx.is_authenticated(),
186 }
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
196pub struct Session {
197 pub token: String,
198 pub user_id: String,
199 #[serde(default)]
201 pub expires_at: u64,
202 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub device: Option<String>,
205 #[serde(default)]
207 pub created_at: u64,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub tenant_id: Option<String>,
213}
214
215impl Session {
216 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
218
219 pub fn new(user_id: String) -> Self {
221 let now = now_secs();
222 Self {
223 token: generate_token(),
224 user_id,
225 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
226 device: None,
227 created_at: now,
228 tenant_id: None,
229 }
230 }
231
232 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
234 let now = now_secs();
235 Self {
236 token: generate_token(),
237 user_id,
238 expires_at: if lifetime_secs == 0 {
239 0
240 } else {
241 now.saturating_add(lifetime_secs)
242 },
243 device: None,
244 created_at: now,
245 tenant_id: None,
246 }
247 }
248
249 pub fn to_auth_context(&self) -> AuthContext {
252 let ctx = AuthContext::authenticated(self.user_id.clone());
253 match &self.tenant_id {
254 Some(t) => ctx.with_tenant(t.clone()),
255 None => ctx,
256 }
257 }
258
259 pub fn is_expired(&self) -> bool {
263 self.expires_at != 0 && now_secs() >= self.expires_at
264 }
265}
266
267fn now_secs() -> u64 {
268 use std::time::{SystemTime, UNIX_EPOCH};
269 SystemTime::now()
270 .duration_since(UNIX_EPOCH)
271 .unwrap_or_default()
272 .as_secs()
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct OAuthConfig {
281 pub provider: String,
282 pub client_id: String,
283 pub client_secret: String,
284 pub redirect_uri: String,
285}
286
287impl OAuthConfig {
288 pub fn auth_url(&self) -> String {
294 match self.provider.as_str() {
295 "google" => format!(
296 "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
297 self.client_id, self.redirect_uri
298 ),
299 "github" => format!(
300 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
301 self.client_id, self.redirect_uri
302 ),
303 _ => String::new(),
304 }
305 }
306
307 pub fn auth_url_with_state(&self, state: &str) -> String {
309 let base = self.auth_url();
310 if base.is_empty() {
311 return base;
312 }
313 format!("{}&state={}", base, state)
314 }
315
316 pub fn token_url(&self) -> &str {
318 match self.provider.as_str() {
319 "google" => "https://oauth2.googleapis.com/token",
320 "github" => "https://github.com/login/oauth/access_token",
321 _ => "",
322 }
323 }
324
325 pub fn userinfo_url(&self) -> &str {
327 match self.provider.as_str() {
328 "google" => "https://www.googleapis.com/oauth2/v3/userinfo",
329 "github" => "https://api.github.com/user",
330 _ => "",
331 }
332 }
333
334 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
341 let body = match self.provider.as_str() {
342 "google" => format!(
343 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
344 url_encode(&self.client_id),
345 url_encode(&self.client_secret),
346 url_encode(&self.redirect_uri)
347 ),
348 "github" => format!(
349 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
350 url_encode(&self.client_id),
351 url_encode(&self.client_secret),
352 url_encode(&self.redirect_uri)
353 ),
354 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
355 };
356
357 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
358 parse_token_response(&out)
359 }
360
361 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
365 let body = match self.provider.as_str() {
366 "google" => format!(
367 "code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
368 url_encode(&self.client_id),
369 url_encode(&self.client_secret),
370 url_encode(&self.redirect_uri)
371 ),
372 "github" => format!(
373 "code={code}&client_id={}&client_secret={}&redirect_uri={}",
374 url_encode(&self.client_id),
375 url_encode(&self.client_secret),
376 url_encode(&self.redirect_uri)
377 ),
378 _ => return Err(format!("unknown OAuth provider: {}", self.provider)),
379 };
380
381 let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
382 extract_access_token(&out)
383 }
384
385 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
391 let info = self.fetch_userinfo_full(access_token)?;
392 Ok((info.email, info.name))
393 }
394
395 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
401 let out = http_get_bearer(self.userinfo_url(), access_token)?;
402 let parsed: serde_json::Value =
403 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
404 match self.provider.as_str() {
405 "google" => {
406 let email = parsed
407 .get("email")
408 .and_then(|v| v.as_str())
409 .ok_or("no email in userinfo")?
410 .to_string();
411 let name = parsed
412 .get("name")
413 .and_then(|v| v.as_str())
414 .map(String::from);
415 let provider_account_id = parsed
416 .get("sub")
417 .and_then(|v| v.as_str())
418 .ok_or("no sub in userinfo")?
419 .to_string();
420 Ok(UserInfo {
421 provider: self.provider.clone(),
422 provider_account_id,
423 email,
424 name,
425 })
426 }
427 "github" => {
428 let name = parsed
429 .get("name")
430 .and_then(|v| v.as_str())
431 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
432 .map(String::from);
433 let email = parsed
434 .get("email")
435 .and_then(|v| v.as_str())
436 .map(String::from);
437 let email = email
440 .or_else(|| fetch_github_primary_email(access_token).ok())
441 .ok_or("no accessible email on GitHub account")?;
442 let provider_account_id = parsed
445 .get("id")
446 .map(|v| {
447 v.as_i64()
448 .map(|n| n.to_string())
449 .or_else(|| v.as_str().map(String::from))
450 .unwrap_or_default()
451 })
452 .filter(|s| !s.is_empty())
453 .ok_or("no id in userinfo")?;
454 Ok(UserInfo {
455 provider: self.provider.clone(),
456 provider_account_id,
457 email,
458 name,
459 })
460 }
461 _ => Err(format!("unknown provider: {}", self.provider)),
462 }
463 }
464}
465
466#[derive(Debug, Clone, PartialEq, Eq)]
471pub struct UserInfo {
472 pub provider: String,
473 pub provider_account_id: String,
474 pub email: String,
475 pub name: Option<String>,
476}
477
478#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct TokenSet {
483 pub access_token: String,
484 pub refresh_token: Option<String>,
485 pub id_token: Option<String>,
486 pub expires_at: Option<u64>,
490 pub scope: Option<String>,
491}
492
493fn parse_token_response(body: &str) -> Result<TokenSet, String> {
494 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
497 let mut map = serde_json::Map::new();
499 for pair in body.split('&') {
500 if let Some((k, v)) = pair.split_once('=') {
501 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
502 }
503 }
504 serde_json::Value::Object(map)
505 });
506
507 let access_token = json
508 .get("access_token")
509 .and_then(|v| v.as_str())
510 .ok_or_else(|| format!("no access_token in token response: {body}"))?
511 .to_string();
512 let refresh_token = json
513 .get("refresh_token")
514 .and_then(|v| v.as_str())
515 .map(String::from);
516 let id_token = json
517 .get("id_token")
518 .and_then(|v| v.as_str())
519 .map(String::from);
520 let expires_at = json
521 .get("expires_in")
522 .and_then(|v| {
523 v.as_u64()
524 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
525 })
526 .map(|secs| now_secs().saturating_add(secs));
527 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
528 Ok(TokenSet {
529 access_token,
530 refresh_token,
531 id_token,
532 expires_at,
533 scope,
534 })
535}
536
537fn url_encode(s: &str) -> String {
538 let mut out = String::with_capacity(s.len());
539 for b in s.bytes() {
540 match b {
541 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
542 out.push(b as char)
543 }
544 _ => out.push_str(&format!("%{b:02X}")),
545 }
546 }
547 out
548}
549
550const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
554
555fn ureq_agent() -> ureq::Agent {
556 ureq::AgentBuilder::new()
557 .timeout_connect(HTTP_TIMEOUT)
558 .timeout_read(HTTP_TIMEOUT)
559 .timeout_write(HTTP_TIMEOUT)
560 .user_agent("pylon/0.1")
561 .build()
562}
563
564fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
565 let agent = ureq_agent();
566 let mut req = agent
567 .post(url)
568 .set("Content-Type", "application/x-www-form-urlencoded");
569 if accept_json {
570 req = req.set("Accept", "application/json");
571 }
572 match req.send_string(body) {
573 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
574 Err(ureq::Error::Status(code, resp)) => {
575 let body = resp.into_string().unwrap_or_default();
576 Err(format!("HTTP {code}: {body}"))
577 }
578 Err(e) => Err(format!("HTTP error: {e}")),
579 }
580}
581
582fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
583 let agent = ureq_agent();
584 match agent
585 .get(url)
586 .set("Authorization", &format!("Bearer {token}"))
587 .set("Accept", "application/json")
588 .call()
589 {
590 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
591 Err(ureq::Error::Status(code, resp)) => {
592 let body = resp.into_string().unwrap_or_default();
593 Err(format!("HTTP {code}: {body}"))
594 }
595 Err(e) => Err(format!("HTTP error: {e}")),
596 }
597}
598
599fn fetch_github_primary_email(token: &str) -> Result<String, String> {
600 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
601 let emails: serde_json::Value =
602 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
603 emails
604 .as_array()
605 .and_then(|arr| {
606 arr.iter()
607 .find(|e| {
608 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
609 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
610 })
611 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
612 })
613 .ok_or_else(|| "no primary verified email on GitHub".into())
614}
615
616fn extract_access_token(body: &str) -> Result<String, String> {
617 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
618 if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
619 return Ok(t.to_string());
620 }
621 }
622 for pair in body.split('&') {
624 if let Some(val) = pair.strip_prefix("access_token=") {
625 return Ok(val.to_string());
626 }
627 }
628 Err(format!("no access_token in token response: {body}"))
629}
630
631pub struct OAuthRegistry {
633 providers: std::collections::HashMap<String, OAuthConfig>,
634}
635
636impl Default for OAuthRegistry {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642impl OAuthRegistry {
643 pub fn new() -> Self {
644 Self {
645 providers: std::collections::HashMap::new(),
646 }
647 }
648
649 pub fn register(&mut self, config: OAuthConfig) {
650 self.providers.insert(config.provider.clone(), config);
651 }
652
653 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
654 self.providers.get(provider)
655 }
656
657 pub fn from_env() -> Self {
660 let mut reg = Self::new();
661
662 if let (Ok(id), Ok(secret)) = (
664 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
665 std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
666 ) {
667 reg.register(OAuthConfig {
668 provider: "google".into(),
669 client_id: id,
670 client_secret: secret,
671 redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
672 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
673 });
674 }
675
676 if let (Ok(id), Ok(secret)) = (
678 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
679 std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
680 ) {
681 reg.register(OAuthConfig {
682 provider: "github".into(),
683 client_id: id,
684 client_secret: secret,
685 redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
686 .unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
687 });
688 }
689
690 reg
691 }
692}
693
694#[derive(Debug, Clone, PartialEq, Eq)]
704pub struct OAuthState {
705 pub provider: String,
706 pub callback_url: String,
709 pub error_callback_url: String,
714 pub expires_at: u64,
715}
716
717pub trait OAuthStateBackend: Send + Sync {
722 fn put(&self, token: &str, state: &OAuthState);
724 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
729}
730
731pub struct InMemoryOAuthBackend {
733 states: Mutex<HashMap<String, OAuthState>>,
734}
735
736impl InMemoryOAuthBackend {
737 pub fn new() -> Self {
738 Self {
739 states: Mutex::new(HashMap::new()),
740 }
741 }
742}
743
744impl Default for InMemoryOAuthBackend {
745 fn default() -> Self {
746 Self::new()
747 }
748}
749
750impl OAuthStateBackend for InMemoryOAuthBackend {
751 fn put(&self, token: &str, state: &OAuthState) {
752 self.states
753 .lock()
754 .unwrap()
755 .insert(token.to_string(), state.clone());
756 }
757 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
758 let mut s = self.states.lock().unwrap();
759 let entry = s.remove(token)?;
760 if entry.expires_at <= now_unix_secs {
761 return None;
762 }
763 Some(entry)
764 }
765}
766
767pub struct OAuthStateStore {
774 backend: Box<dyn OAuthStateBackend>,
775}
776
777impl Default for OAuthStateStore {
778 fn default() -> Self {
779 Self::new()
780 }
781}
782
783impl OAuthStateStore {
784 pub fn new() -> Self {
785 Self {
786 backend: Box::new(InMemoryOAuthBackend::new()),
787 }
788 }
789
790 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
791 Self { backend }
792 }
793
794 pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
802 use std::time::{SystemTime, UNIX_EPOCH};
803 let token = generate_token();
804 let now = SystemTime::now()
805 .duration_since(UNIX_EPOCH)
806 .unwrap_or_default()
807 .as_secs();
808 let state = OAuthState {
809 provider: provider.to_string(),
810 callback_url: callback_url.to_string(),
811 error_callback_url: error_callback_url.to_string(),
812 expires_at: now + 600,
813 };
814 self.backend.put(&token, &state);
815 token
816 }
817
818 pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
823 use std::time::{SystemTime, UNIX_EPOCH};
824 let now = SystemTime::now()
825 .duration_since(UNIX_EPOCH)
826 .unwrap_or_default()
827 .as_secs();
828 let entry = self.backend.take(state, now)?;
829 if entry.provider != expected_provider {
830 return None;
831 }
832 Some(entry)
833 }
834}
835
836pub fn validate_trusted_redirect(
853 url: &str,
854 trusted_origins: &[String],
855) -> Result<(), TrustedOriginError> {
856 if url.is_empty() {
857 return Err(TrustedOriginError::Empty);
858 }
859 if !url.starts_with("http://") && !url.starts_with("https://") {
862 return Err(TrustedOriginError::NotHttp);
863 }
864 let url_origin = origin_of(url);
865 if trusted_origins.iter().any(|t| t == &url_origin) {
866 Ok(())
867 } else {
868 Err(TrustedOriginError::NotTrusted { origin: url_origin })
869 }
870}
871
872#[derive(Debug, Clone, PartialEq, Eq)]
874pub enum TrustedOriginError {
875 Empty,
876 NotHttp,
877 NotTrusted { origin: String },
878}
879
880impl std::fmt::Display for TrustedOriginError {
881 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882 match self {
883 TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
884 TrustedOriginError::NotHttp => {
885 write!(f, "redirect URL must use http:// or https:// scheme")
886 }
887 TrustedOriginError::NotTrusted { origin } => write!(
888 f,
889 "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
890 ),
891 }
892 }
893}
894
895pub fn origin_of(url: &str) -> String {
900 let after_scheme = match url.find("://") {
901 Some(i) => i + 3,
902 None => return url.trim_end_matches('/').to_string(),
903 };
904 let rest = &url[after_scheme..];
905 let cut = rest
906 .find(|c: char| c == '/' || c == '?' || c == '#')
907 .unwrap_or(rest.len());
908 url[..after_scheme + cut].to_string()
909}
910
911pub trait MagicCodeBackend: Send + Sync {
924 fn put(&self, email: &str, code: &MagicCode);
926 fn get(&self, email: &str) -> Option<MagicCode>;
928 fn remove(&self, email: &str);
931 fn bump_attempts(&self, email: &str);
935 fn load_all(&self) -> Vec<MagicCode>;
938}
939
940pub struct InMemoryMagicCodeBackend {
943 codes: Mutex<HashMap<String, MagicCode>>,
944}
945
946impl InMemoryMagicCodeBackend {
947 pub fn new() -> Self {
948 Self {
949 codes: Mutex::new(HashMap::new()),
950 }
951 }
952}
953
954impl Default for InMemoryMagicCodeBackend {
955 fn default() -> Self {
956 Self::new()
957 }
958}
959
960impl MagicCodeBackend for InMemoryMagicCodeBackend {
961 fn put(&self, email: &str, code: &MagicCode) {
962 self.codes
963 .lock()
964 .unwrap()
965 .insert(email.to_string(), code.clone());
966 }
967 fn get(&self, email: &str) -> Option<MagicCode> {
968 self.codes.lock().unwrap().get(email).cloned()
969 }
970 fn remove(&self, email: &str) {
971 self.codes.lock().unwrap().remove(email);
972 }
973 fn bump_attempts(&self, email: &str) {
974 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
975 c.attempts = c.attempts.saturating_add(1);
976 }
977 }
978 fn load_all(&self) -> Vec<MagicCode> {
979 self.codes.lock().unwrap().values().cloned().collect()
980 }
981}
982
983pub struct MagicCodeStore {
988 cache: Mutex<HashMap<String, MagicCode>>,
989 backend: Box<dyn MagicCodeBackend>,
990}
991
992#[derive(Debug, Clone)]
993pub struct MagicCode {
994 pub email: String,
995 pub code: String,
996 pub expires_at: u64,
997 pub attempts: u32,
1000}
1001
1002const MAX_ATTEMPTS: u32 = 5;
1006
1007const CREATE_COOLDOWN_SECS: u64 = 60;
1010
1011#[derive(Debug, Clone, PartialEq, Eq)]
1012pub enum MagicCodeError {
1013 NotFound,
1015 TooManyAttempts,
1017 BadCode,
1019 Expired,
1021 Throttled { retry_after_secs: u64 },
1023}
1024
1025impl Default for MagicCodeStore {
1026 fn default() -> Self {
1027 Self::new()
1028 }
1029}
1030
1031impl MagicCodeStore {
1032 pub fn new() -> Self {
1033 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1034 }
1035
1036 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1041 let now = now_secs();
1042 let mut cache = HashMap::new();
1043 for c in backend.load_all() {
1044 if c.expires_at > now {
1045 cache.insert(c.email.clone(), c);
1046 }
1047 }
1048 Self {
1049 cache: Mutex::new(cache),
1050 backend,
1051 }
1052 }
1053
1054 pub fn create(&self, email: &str) -> String {
1057 self.try_create(email).unwrap_or_else(|_| String::new())
1060 }
1061
1062 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1065 let now = now_secs();
1066
1067 let mut codes = self.cache.lock().unwrap();
1068
1069 if let Some(existing) = codes.get(email) {
1073 if existing.expires_at > now {
1074 let created_at = existing.expires_at.saturating_sub(600);
1075 let age = now.saturating_sub(created_at);
1076 if age < CREATE_COOLDOWN_SECS {
1077 return Err(MagicCodeError::Throttled {
1078 retry_after_secs: CREATE_COOLDOWN_SECS - age,
1079 });
1080 }
1081 }
1082 }
1083
1084 let code = generate_magic_code();
1085 let mc = MagicCode {
1086 email: email.to_string(),
1087 code: code.clone(),
1088 expires_at: now + 600, attempts: 0,
1090 };
1091 codes.insert(email.to_string(), mc.clone());
1092 self.backend.put(email, &mc);
1096 Ok(code)
1097 }
1098
1099 pub fn verify(&self, email: &str, code: &str) -> bool {
1103 matches!(self.try_verify(email, code), Ok(()))
1104 }
1105
1106 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1113 self.cache
1114 .lock()
1115 .map(|m| m.values().cloned().collect())
1116 .unwrap_or_default()
1117 }
1118
1119 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1120 let now = now_secs();
1121 let mut codes = self.cache.lock().unwrap();
1122
1123 let mc = match codes.get_mut(email) {
1124 Some(m) => m,
1125 None => return Err(MagicCodeError::NotFound),
1126 };
1127
1128 if mc.attempts >= MAX_ATTEMPTS {
1129 return Err(MagicCodeError::TooManyAttempts);
1130 }
1131 if mc.expires_at <= now {
1132 codes.remove(email);
1133 self.backend.remove(email);
1134 return Err(MagicCodeError::Expired);
1135 }
1136
1137 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1138 if !ok {
1139 mc.attempts += 1;
1140 self.backend.bump_attempts(email);
1141 if mc.attempts >= MAX_ATTEMPTS {
1143 return Err(MagicCodeError::TooManyAttempts);
1144 }
1145 return Err(MagicCodeError::BadCode);
1146 }
1147
1148 codes.remove(email);
1150 self.backend.remove(email);
1151 Ok(())
1152 }
1153}
1154
1155fn hex_encode(bytes: &[u8]) -> String {
1160 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1161}
1162
1163fn generate_magic_code() -> String {
1165 use rand::Rng;
1166 let mut rng = rand::thread_rng();
1167 let code: u32 = rng.gen_range(0..1_000_000);
1168 format!("{:06}", code)
1169}
1170
1171fn generate_token() -> String {
1173 use rand::Rng;
1174 let mut rng = rand::thread_rng();
1175 let bytes: [u8; 32] = rng.gen();
1176 format!("pylon_{}", hex_encode(&bytes))
1177}
1178
1179use std::collections::HashMap;
1184use std::sync::Mutex;
1185
1186pub trait SessionBackend: Send + Sync {
1190 fn load_all(&self) -> Vec<Session>;
1191 fn save(&self, session: &Session);
1192 fn remove(&self, token: &str);
1193}
1194
1195pub struct SessionStore {
1203 sessions: Mutex<HashMap<String, Session>>,
1204 backend: Option<Box<dyn SessionBackend>>,
1205 default_lifetime_secs: u64,
1209}
1210
1211impl Default for SessionStore {
1212 fn default() -> Self {
1213 Self::new()
1214 }
1215}
1216
1217impl SessionStore {
1218 pub fn new() -> Self {
1219 Self {
1220 sessions: Mutex::new(HashMap::new()),
1221 backend: None,
1222 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1223 }
1224 }
1225
1226 pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1229 self.default_lifetime_secs = lifetime_secs;
1230 self
1231 }
1232
1233 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1237 let mut map = HashMap::new();
1238 for s in backend.load_all() {
1239 if !s.is_expired() {
1240 map.insert(s.token.clone(), s);
1241 }
1242 }
1243 Self {
1244 sessions: Mutex::new(map),
1245 backend: Some(backend),
1246 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1247 }
1248 }
1249
1250 pub fn create(&self, user_id: String) -> Session {
1254 let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1255 let mut sessions = self.sessions.lock().unwrap();
1256 sessions.insert(session.token.clone(), session.clone());
1257 if let Some(b) = &self.backend {
1258 b.save(&session);
1259 }
1260 session
1261 }
1262
1263 pub fn get(&self, token: &str) -> Option<Session> {
1265 let mut sessions = self.sessions.lock().unwrap();
1266 match sessions.get(token) {
1267 Some(s) if s.is_expired() => {
1268 sessions.remove(token);
1269 None
1270 }
1271 Some(s) => Some(s.clone()),
1272 None => None,
1273 }
1274 }
1275
1276 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1279 match token {
1280 Some(t) => match self.get(t) {
1281 Some(session) => session.to_auth_context(),
1282 None => AuthContext::anonymous(),
1283 },
1284 None => AuthContext::anonymous(),
1285 }
1286 }
1287
1288 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1292 let mut sessions = self.sessions.lock().unwrap();
1293 let old = sessions.remove(old_token)?;
1294 if let Some(b) = &self.backend {
1295 b.remove(old_token);
1296 }
1297 if old.is_expired() {
1298 return None;
1299 }
1300 let mut new = Session::new(old.user_id.clone());
1301 new.device = old.device.clone();
1302 sessions.insert(new.token.clone(), new.clone());
1303 if let Some(b) = &self.backend {
1304 b.save(&new);
1305 }
1306 Some(new)
1307 }
1308
1309 pub fn list_all_unfiltered(&self) -> Vec<Session> {
1314 self.sessions
1315 .lock()
1316 .map(|m| m.values().cloned().collect())
1317 .unwrap_or_default()
1318 }
1319
1320 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
1322 let sessions = self.sessions.lock().unwrap();
1323 sessions
1324 .values()
1325 .filter(|s| s.user_id == user_id && !s.is_expired())
1326 .cloned()
1327 .collect()
1328 }
1329
1330 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
1332 let mut sessions = self.sessions.lock().unwrap();
1333 let tokens: Vec<String> = sessions
1334 .iter()
1335 .filter_map(|(t, s)| {
1336 if s.user_id == user_id {
1337 Some(t.clone())
1338 } else {
1339 None
1340 }
1341 })
1342 .collect();
1343 let n = tokens.len();
1344 for t in &tokens {
1345 sessions.remove(t);
1346 if let Some(b) = &self.backend {
1347 b.remove(t);
1348 }
1349 }
1350 n
1351 }
1352
1353 pub fn sweep_expired(&self) -> usize {
1355 let mut sessions = self.sessions.lock().unwrap();
1356 let expired: Vec<String> = sessions
1357 .iter()
1358 .filter_map(|(t, s)| {
1359 if s.is_expired() {
1360 Some(t.clone())
1361 } else {
1362 None
1363 }
1364 })
1365 .collect();
1366 let n = expired.len();
1367 for t in &expired {
1368 sessions.remove(t);
1369 if let Some(b) = &self.backend {
1370 b.remove(t);
1371 }
1372 }
1373 n
1374 }
1375
1376 pub fn set_device(&self, token: &str, device: String) -> bool {
1378 let mut sessions = self.sessions.lock().unwrap();
1379 if let Some(s) = sessions.get_mut(token) {
1380 s.device = Some(device);
1381 if let Some(b) = &self.backend {
1382 b.save(s);
1383 }
1384 true
1385 } else {
1386 false
1387 }
1388 }
1389
1390 pub fn create_guest(&self) -> Session {
1392 use rand::Rng;
1393 let mut rng = rand::thread_rng();
1394 let bytes: [u8; 16] = rng.gen();
1395 let guest_id = format!("guest_{}", hex_encode(&bytes));
1396 self.create(guest_id)
1397 }
1398
1399 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
1401 let mut sessions = self.sessions.lock().unwrap();
1402 if let Some(session) = sessions.get_mut(token) {
1403 session.user_id = real_user_id;
1404 if let Some(b) = &self.backend {
1405 b.save(session);
1406 }
1407 true
1408 } else {
1409 false
1410 }
1411 }
1412
1413 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
1418 let mut sessions = self.sessions.lock().unwrap();
1419 if let Some(session) = sessions.get_mut(token) {
1420 session.tenant_id = tenant_id;
1421 if let Some(b) = &self.backend {
1422 b.save(session);
1423 }
1424 true
1425 } else {
1426 false
1427 }
1428 }
1429
1430 pub fn revoke(&self, token: &str) -> bool {
1432 let mut sessions = self.sessions.lock().unwrap();
1433 let removed = sessions.remove(token).is_some();
1434 if removed {
1435 if let Some(b) = &self.backend {
1436 b.remove(token);
1437 }
1438 }
1439 removed
1440 }
1441}
1442
1443#[derive(Debug, Clone, PartialEq, Eq)]
1469pub struct Account {
1470 pub id: String,
1471 pub user_id: String,
1472 pub provider_id: String,
1475 pub account_id: String,
1478 pub access_token: Option<String>,
1479 pub refresh_token: Option<String>,
1480 pub id_token: Option<String>,
1481 pub access_token_expires_at: Option<u64>,
1484 pub refresh_token_expires_at: Option<u64>,
1488 pub scope: Option<String>,
1489 pub password: Option<String>,
1493 pub created_at: u64,
1495 pub updated_at: u64,
1497}
1498
1499impl Account {
1500 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
1504 let now = now_secs();
1505 Self {
1506 id: generate_token(),
1507 user_id,
1508 provider_id: info.provider.clone(),
1509 account_id: info.provider_account_id.clone(),
1510 access_token: Some(tokens.access_token.clone()),
1511 refresh_token: tokens.refresh_token.clone(),
1512 id_token: tokens.id_token.clone(),
1513 access_token_expires_at: tokens.expires_at,
1514 refresh_token_expires_at: None,
1515 scope: tokens.scope.clone(),
1516 password: None,
1517 created_at: now,
1518 updated_at: now,
1519 }
1520 }
1521
1522 pub fn access_token_expired(&self) -> bool {
1527 match self.access_token_expires_at {
1528 Some(ts) => now_secs() >= ts,
1529 None => false,
1530 }
1531 }
1532}
1533
1534pub trait AccountBackend: Send + Sync {
1537 fn upsert(&self, account: &Account);
1541 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
1544 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
1549 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
1551 fn list_all(&self) -> Vec<Account>;
1556}
1557
1558pub struct InMemoryAccountBackend {
1562 accounts: Mutex<HashMap<(String, String), Account>>,
1566}
1567
1568impl InMemoryAccountBackend {
1569 pub fn new() -> Self {
1570 Self {
1571 accounts: Mutex::new(HashMap::new()),
1572 }
1573 }
1574}
1575
1576impl Default for InMemoryAccountBackend {
1577 fn default() -> Self {
1578 Self::new()
1579 }
1580}
1581
1582impl AccountBackend for InMemoryAccountBackend {
1583 fn upsert(&self, account: &Account) {
1584 let key = (account.provider_id.clone(), account.account_id.clone());
1585 self.accounts.lock().unwrap().insert(key, account.clone());
1586 }
1587 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1588 self.accounts
1589 .lock()
1590 .unwrap()
1591 .get(&(provider_id.to_string(), account_id.to_string()))
1592 .cloned()
1593 }
1594 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1595 self.accounts
1596 .lock()
1597 .unwrap()
1598 .values()
1599 .filter(|a| a.user_id == user_id)
1600 .cloned()
1601 .collect()
1602 }
1603 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1604 self.accounts
1605 .lock()
1606 .unwrap()
1607 .remove(&(provider_id.to_string(), account_id.to_string()))
1608 .is_some()
1609 }
1610 fn list_all(&self) -> Vec<Account> {
1611 self.accounts.lock().unwrap().values().cloned().collect()
1612 }
1613}
1614
1615pub struct AccountStore {
1618 backend: Box<dyn AccountBackend>,
1619}
1620
1621impl Default for AccountStore {
1622 fn default() -> Self {
1623 Self::new()
1624 }
1625}
1626
1627impl AccountStore {
1628 pub fn new() -> Self {
1629 Self {
1630 backend: Box::new(InMemoryAccountBackend::new()),
1631 }
1632 }
1633 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
1634 Self { backend }
1635 }
1636 pub fn upsert(&self, account: &Account) {
1637 self.backend.upsert(account);
1638 }
1639 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
1640 self.backend.find_by_provider(provider_id, account_id)
1641 }
1642 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
1643 self.backend.find_for_user(user_id)
1644 }
1645 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
1646 self.backend.unlink(provider_id, account_id)
1647 }
1648
1649 pub fn list_all_unfiltered(&self) -> Vec<Account> {
1663 self.backend.list_all()
1664 }
1665}
1666
1667#[cfg(test)]
1672mod tests {
1673 use super::*;
1674
1675 #[test]
1676 fn anonymous_context() {
1677 let ctx = AuthContext::anonymous();
1678 assert!(!ctx.is_authenticated());
1679 assert!(ctx.user_id.is_none());
1680 }
1681
1682 #[test]
1683 fn authenticated_context() {
1684 let ctx = AuthContext::authenticated("user-1".into());
1685 assert!(ctx.is_authenticated());
1686 assert_eq!(ctx.user_id, Some("user-1".into()));
1687 }
1688
1689 #[test]
1690 fn auth_mode_public_allows_anonymous() {
1691 let mode = AuthMode::Public;
1692 assert!(mode.check(&AuthContext::anonymous()));
1693 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1694 }
1695
1696 #[test]
1697 fn auth_mode_user_requires_authenticated() {
1698 let mode = AuthMode::User;
1699 assert!(!mode.check(&AuthContext::anonymous()));
1700 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
1701 }
1702
1703 #[test]
1704 fn auth_mode_from_str() {
1705 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
1706 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
1707 assert_eq!(AuthMode::from_str("admin"), None);
1708 }
1709
1710 #[test]
1711 fn session_store_create_and_get() {
1712 let store = SessionStore::new();
1713 let session = store.create("user-1".into());
1714 assert!(!session.token.is_empty());
1715 assert!(session.token.starts_with("pylon_"));
1716
1717 let retrieved = store.get(&session.token).unwrap();
1718 assert_eq!(retrieved.user_id, "user-1");
1719 }
1720
1721 #[test]
1722 fn session_store_resolve() {
1723 let store = SessionStore::new();
1724 let session = store.create("user-1".into());
1725
1726 let ctx = store.resolve(Some(&session.token));
1727 assert!(ctx.is_authenticated());
1728 assert_eq!(ctx.user_id, Some("user-1".into()));
1729
1730 let anon = store.resolve(None);
1731 assert!(!anon.is_authenticated());
1732
1733 let bad = store.resolve(Some("invalid-token"));
1734 assert!(!bad.is_authenticated());
1735 }
1736
1737 #[test]
1738 fn session_store_revoke() {
1739 let store = SessionStore::new();
1740 let session = store.create("user-1".into());
1741
1742 assert!(store.revoke(&session.token));
1743 assert!(store.get(&session.token).is_none());
1744 assert!(!store.revoke(&session.token)); }
1746
1747 #[test]
1748 fn session_to_auth_context() {
1749 let session = Session::new("user-42".into());
1750 let ctx = session.to_auth_context();
1751 assert_eq!(ctx.user_id, Some("user-42".into()));
1752 }
1753
1754 #[test]
1757 fn admin_context() {
1758 let ctx = AuthContext::admin();
1759 assert!(ctx.is_admin);
1760 assert!(ctx.is_authenticated());
1761 }
1762
1763 #[test]
1764 fn anonymous_not_admin() {
1765 let ctx = AuthContext::anonymous();
1766 assert!(!ctx.is_admin);
1767 }
1768
1769 #[test]
1770 fn authenticated_not_admin() {
1771 let ctx = AuthContext::authenticated("user-1".into());
1772 assert!(!ctx.is_admin);
1773 }
1774
1775 #[test]
1778 fn magic_code_create_and_verify() {
1779 let store = MagicCodeStore::new();
1780 let code = store.create("test@example.com");
1781 assert_eq!(code.len(), 6);
1782 assert!(store.verify("test@example.com", &code));
1783 }
1784
1785 #[test]
1786 fn magic_code_wrong_code_rejected() {
1787 let store = MagicCodeStore::new();
1788 store.create("test@example.com");
1789 assert!(!store.verify("test@example.com", "000000"));
1790 }
1791
1792 #[test]
1793 fn magic_code_wrong_email_rejected() {
1794 let store = MagicCodeStore::new();
1795 let code = store.create("test@example.com");
1796 assert!(!store.verify("other@example.com", &code));
1797 }
1798
1799 #[test]
1800 fn magic_code_consumed_after_verify() {
1801 let store = MagicCodeStore::new();
1802 let code = store.create("test@example.com");
1803 assert!(store.verify("test@example.com", &code));
1804 assert!(!store.verify("test@example.com", &code));
1806 }
1807
1808 #[test]
1809 fn magic_code_different_emails_independent() {
1810 let store = MagicCodeStore::new();
1811 let code1 = store.create("alice@example.com");
1812 let code2 = store.create("bob@example.com");
1813 assert!(store.verify("alice@example.com", &code1));
1815 assert!(store.verify("bob@example.com", &code2));
1816 }
1817
1818 #[test]
1821 fn constant_time_eq_equal() {
1822 assert!(constant_time_eq(b"hello", b"hello"));
1823 assert!(constant_time_eq(b"", b""));
1824 }
1825
1826 #[test]
1827 fn constant_time_eq_not_equal() {
1828 assert!(!constant_time_eq(b"hello", b"world"));
1829 assert!(!constant_time_eq(b"hello", b"hell"));
1830 assert!(!constant_time_eq(b"a", b"b"));
1831 }
1832
1833 #[test]
1836 fn generated_tokens_are_unique() {
1837 let t1 = generate_token();
1838 let t2 = generate_token();
1839 assert_ne!(t1, t2);
1840 assert!(t1.starts_with("pylon_"));
1841 assert!(t2.starts_with("pylon_"));
1842 assert_eq!(t1.len(), 6 + 64);
1844 }
1845
1846 #[test]
1849 fn oauth_registry_empty() {
1850 let reg = OAuthRegistry::new();
1851 assert!(reg.get("google").is_none());
1852 }
1853
1854 #[test]
1855 fn oauth_registry_register_and_get() {
1856 let mut reg = OAuthRegistry::new();
1857 reg.register(OAuthConfig {
1858 provider: "google".into(),
1859 client_id: "test-id".into(),
1860 client_secret: "test-secret".into(),
1861 redirect_uri: "http://localhost/callback".into(),
1862 });
1863 let config = reg.get("google").unwrap();
1864 assert_eq!(config.client_id, "test-id");
1865 assert!(config.auth_url().contains("accounts.google.com"));
1866 }
1867
1868 #[test]
1871 fn guest_session() {
1872 let store = SessionStore::new();
1873 let session = store.create_guest();
1874 assert!(session.user_id.starts_with("guest_"));
1875 assert!(!session.token.is_empty());
1876
1877 let ctx = store.resolve(Some(&session.token));
1878 assert!(ctx.is_authenticated());
1879 assert!(ctx.user_id.unwrap().starts_with("guest_"));
1880 }
1881
1882 #[test]
1883 fn upgrade_guest_to_real_user() {
1884 let store = SessionStore::new();
1885 let session = store.create_guest();
1886 assert!(session.user_id.starts_with("guest_"));
1887
1888 let upgraded = store.upgrade(&session.token, "real-user-123".into());
1889 assert!(upgraded);
1890
1891 let ctx = store.resolve(Some(&session.token));
1892 assert_eq!(ctx.user_id, Some("real-user-123".into()));
1893 }
1894
1895 #[test]
1896 fn upgrade_invalid_token_fails() {
1897 let store = SessionStore::new();
1898 let upgraded = store.upgrade("nonexistent-token", "user".into());
1899 assert!(!upgraded);
1900 }
1901
1902 #[test]
1903 fn guest_context() {
1904 let ctx = AuthContext::guest("guest_123".into());
1905 assert!(!ctx.is_authenticated());
1908 assert!(ctx.is_guest);
1909 assert!(!ctx.is_admin);
1910 assert_eq!(ctx.user_id, Some("guest_123".into()));
1911 assert!(!AuthMode::User.check(&ctx));
1912 assert!(AuthMode::Public.check(&ctx));
1913 }
1914
1915 #[test]
1916 fn oauth_token_urls() {
1917 let google = OAuthConfig {
1918 provider: "google".into(),
1919 client_id: "x".into(),
1920 client_secret: "x".into(),
1921 redirect_uri: "x".into(),
1922 };
1923 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
1924 let github = OAuthConfig {
1925 provider: "github".into(),
1926 client_id: "x".into(),
1927 client_secret: "x".into(),
1928 redirect_uri: "x".into(),
1929 };
1930 assert_eq!(
1931 github.token_url(),
1932 "https://github.com/login/oauth/access_token"
1933 );
1934 let unknown = OAuthConfig {
1935 provider: "unknown".into(),
1936 client_id: "x".into(),
1937 client_secret: "x".into(),
1938 redirect_uri: "x".into(),
1939 };
1940 assert_eq!(unknown.token_url(), "");
1941 assert!(unknown.auth_url().is_empty());
1942 }
1943
1944 #[test]
1945 fn oauth_auth_url_github() {
1946 let config = OAuthConfig {
1947 provider: "github".into(),
1948 client_id: "gh-id".into(),
1949 client_secret: "gh-secret".into(),
1950 redirect_uri: "http://localhost/cb".into(),
1951 };
1952 assert!(config.auth_url().contains("github.com"));
1953 assert!(config.auth_url().contains("gh-id"));
1954 }
1955
1956 #[test]
1957 fn oauth_auth_url_with_state() {
1958 let config = OAuthConfig {
1959 provider: "google".into(),
1960 client_id: "test-id".into(),
1961 client_secret: "test-secret".into(),
1962 redirect_uri: "http://localhost/cb".into(),
1963 };
1964 let url = config.auth_url_with_state("random_state_123");
1965 assert!(url.contains("&state=random_state_123"));
1966 }
1967
1968 #[test]
1969 fn oauth_state_store_create_and_validate() {
1970 let store = OAuthStateStore::new();
1971 let token = store.create("google", "https://app/cb", "https://app/login");
1972 let rec = store.validate(&token, "google").expect("valid first time");
1973 assert_eq!(rec.callback_url, "https://app/cb");
1974 assert_eq!(rec.error_callback_url, "https://app/login");
1975 assert!(store.validate(&token, "google").is_none());
1977 }
1978
1979 #[test]
1980 fn oauth_state_store_wrong_provider_rejected() {
1981 let store = OAuthStateStore::new();
1982 let token = store.create("google", "https://app/cb", "https://app/cb");
1983 assert!(store.validate(&token, "github").is_none());
1984 }
1985
1986 #[test]
1987 fn oauth_state_store_invalid_state_rejected() {
1988 let store = OAuthStateStore::new();
1989 assert!(store.validate("nonexistent", "google").is_none());
1990 }
1991
1992 #[test]
1993 fn validate_trusted_redirect_basics() {
1994 let trusted = vec!["http://localhost:3000".to_string()];
1995 assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
1996 assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
1997 assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
1998
1999 assert!(matches!(
2001 validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
2002 Err(TrustedOriginError::NotTrusted { .. })
2003 ));
2004 assert!(matches!(
2007 validate_trusted_redirect("javascript:alert(1)", &trusted),
2008 Err(TrustedOriginError::NotHttp)
2009 ));
2010 assert!(matches!(
2011 validate_trusted_redirect("", &trusted),
2012 Err(TrustedOriginError::Empty)
2013 ));
2014 }
2015}