1pub mod api_key;
2pub mod apple_jwt;
3pub mod captcha;
4pub mod cookie;
5pub mod email;
6pub mod jwt;
7pub mod oidc_provider;
8pub mod org;
9pub mod password;
10pub mod phone;
11pub mod provider;
12pub mod scim;
13pub mod siwe;
14pub mod stripe;
15pub mod totp;
16pub mod webauthn;
17
18pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
19
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub struct AuthContext {
38 pub user_id: Option<String>,
42 pub is_admin: bool,
44 #[serde(default, skip_serializing_if = "is_false")]
49 pub is_guest: bool,
50 pub roles: Vec<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub tenant_id: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
61 pub api_key_id: Option<String>,
62 #[serde(skip_serializing_if = "Option::is_none")]
65 pub api_key_scopes: Option<String>,
66}
67
68fn is_false(b: &bool) -> bool {
69 !b
70}
71
72impl AuthContext {
73 pub fn anonymous() -> Self {
75 Self {
76 user_id: None,
77 is_admin: false,
78 is_guest: false,
79 roles: Vec::new(),
80 tenant_id: None,
81 api_key_id: None,
82 api_key_scopes: None,
83 }
84 }
85
86 pub fn authenticated(user_id: String) -> Self {
88 Self {
89 user_id: Some(user_id),
90 is_admin: false,
91 is_guest: false,
92 roles: Vec::new(),
93 tenant_id: None,
94 api_key_id: None,
95 api_key_scopes: None,
96 }
97 }
98
99 pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
102 Self {
103 user_id: Some(user_id),
104 is_admin: false,
105 is_guest: false,
106 roles: Vec::new(),
107 tenant_id: None,
108 api_key_id: Some(key_id),
109 api_key_scopes: scopes,
110 }
111 }
112
113 pub fn is_api_key_auth(&self) -> bool {
116 self.api_key_id.is_some()
117 }
118
119 pub fn guest(guest_id: String) -> Self {
124 Self {
125 user_id: Some(guest_id),
126 is_admin: false,
127 is_guest: true,
128 roles: Vec::new(),
129 tenant_id: None,
130 api_key_id: None,
131 api_key_scopes: None,
132 }
133 }
134
135 pub fn admin() -> Self {
137 Self {
138 user_id: Some("__admin__".into()),
139 is_admin: true,
140 is_guest: false,
141 roles: vec!["admin".into()],
142 tenant_id: None,
143 api_key_id: None,
144 api_key_scopes: None,
145 }
146 }
147
148 pub fn user(user_id: String) -> Self {
150 Self::authenticated(user_id)
151 }
152
153 pub fn tenant_id(&self) -> Option<&str> {
155 self.tenant_id.as_deref()
156 }
157
158 pub fn with_tenant(mut self, tenant_id: String) -> Self {
160 self.tenant_id = Some(tenant_id);
161 self
162 }
163
164 pub fn is_authenticated(&self) -> bool {
168 self.user_id.is_some() && !self.is_guest
169 }
170
171 pub fn has_role(&self, role: &str) -> bool {
173 self.is_admin || self.roles.iter().any(|r| r == role)
174 }
175
176 pub fn has_any_role(&self, roles: &[&str]) -> bool {
178 self.is_admin || roles.iter().any(|r| self.has_role(r))
179 }
180
181 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
183 self.roles = roles;
184 self
185 }
186}
187
188pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
198 if a.len() != b.len() {
199 return false;
200 }
201 let mut result: u8 = 0;
202 for (x, y) in a.iter().zip(b.iter()) {
203 result |= x ^ y;
204 }
205 result == 0
206}
207
208#[derive(Debug, Clone, PartialEq, Eq)]
214pub enum AuthMode {
215 Public,
217 User,
219}
220
221impl AuthMode {
222 #[allow(clippy::should_implement_trait)]
224 pub fn from_str(s: &str) -> Option<Self> {
225 match s {
226 "public" => Some(AuthMode::Public),
227 "user" => Some(AuthMode::User),
228 _ => None,
229 }
230 }
231
232 pub fn check(&self, ctx: &AuthContext) -> bool {
234 match self {
235 AuthMode::Public => true,
236 AuthMode::User => ctx.is_authenticated(),
237 }
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
247pub struct Session {
248 pub token: String,
249 pub user_id: String,
250 #[serde(default)]
252 pub expires_at: u64,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub device: Option<String>,
256 #[serde(default)]
258 pub created_at: u64,
259 #[serde(default, skip_serializing_if = "Option::is_none")]
263 pub tenant_id: Option<String>,
264}
265
266impl Session {
267 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
269
270 pub fn new(user_id: String) -> Self {
272 let now = now_secs();
273 Self {
274 token: generate_token(),
275 user_id,
276 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
277 device: None,
278 created_at: now,
279 tenant_id: None,
280 }
281 }
282
283 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
285 let now = now_secs();
286 Self {
287 token: generate_token(),
288 user_id,
289 expires_at: if lifetime_secs == 0 {
290 0
291 } else {
292 now.saturating_add(lifetime_secs)
293 },
294 device: None,
295 created_at: now,
296 tenant_id: None,
297 }
298 }
299
300 pub fn to_auth_context(&self) -> AuthContext {
303 let ctx = AuthContext::authenticated(self.user_id.clone());
304 match &self.tenant_id {
305 Some(t) => ctx.with_tenant(t.clone()),
306 None => ctx,
307 }
308 }
309
310 pub fn is_expired(&self) -> bool {
314 self.expires_at != 0 && now_secs() >= self.expires_at
315 }
316}
317
318fn now_secs() -> u64 {
319 use std::time::{SystemTime, UNIX_EPOCH};
320 SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap_or_default()
323 .as_secs()
324}
325
326#[derive(Debug, Clone, Default, Serialize, Deserialize)]
331pub struct OAuthConfig {
332 pub provider: String,
333 pub client_id: String,
334 pub client_secret: String,
335 pub redirect_uri: String,
336 #[serde(default, skip_serializing_if = "Option::is_none")]
341 pub scopes_override: Option<String>,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
346 pub tenant: Option<String>,
347 #[serde(default, skip_serializing_if = "Option::is_none")]
350 pub apple: Option<provider::AppleConfig>,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
357 pub oidc_issuer: Option<String>,
358}
359
360impl OAuthConfig {
361 fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
366 if let Some(issuer) = self.oidc_issuer.as_deref() {
367 return provider::oidc_cache::resolve(issuer);
368 }
369 provider::find_spec(&self.provider)
370 .map(provider::ResolvedSpec::Static)
371 .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
372 }
373
374 fn provider_cfg(&self) -> provider::ProviderConfig {
377 provider::ProviderConfig {
378 provider: self.provider.clone(),
379 client_id: self.client_id.clone(),
380 client_secret: self.client_secret.clone(),
381 redirect_uri: self.redirect_uri.clone(),
382 scopes_override: self.scopes_override.clone(),
383 tenant: self.tenant.clone(),
384 apple: self.apple.clone(),
385 oidc_issuer: self.oidc_issuer.clone(),
386 }
387 }
388
389 pub fn auth_url(&self) -> String {
399 match self.build_auth_url(None) {
400 Ok(u) => u,
401 Err(_) => String::new(),
402 }
403 }
404
405 pub fn auth_url_with_state(&self, state: &str) -> String {
407 let base = self.auth_url();
408 if base.is_empty() {
409 return base;
410 }
411 format!("{}&state={}", base, url_encode(state))
412 }
413
414 pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
419 let spec = self.resolved_spec()?;
420 let pkce = if spec.requires_pkce() {
421 Some(generate_pkce())
422 } else {
423 None
424 };
425 let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
426 let mut url = self.build_auth_url(challenge)?;
427 if !state.is_empty() {
428 url.push_str(&format!("&state={}", url_encode(state)));
429 }
430 Ok((url, pkce.map(|p| p.code_verifier)))
431 }
432
433 fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
434 let spec = self.resolved_spec()?;
435 let cfg = self.provider_cfg();
436 let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
437 if auth.is_empty() {
438 return Err(format!(
439 "provider {} has no authorization endpoint",
440 self.provider
441 ));
442 }
443 let scopes_default = spec.scopes().to_string();
444 let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
445 let scopes_joined = scopes_raw
449 .split_whitespace()
450 .collect::<Vec<_>>()
451 .join(spec.scope_separator());
452
453 let mut url = format!(
454 "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
455 cid_param = spec.client_id_param(),
456 cid = url_encode(&self.client_id),
457 ruri = url_encode(&self.redirect_uri),
458 scope = url_encode(&scopes_joined),
459 );
460 if !spec.auth_query_extra().is_empty() {
461 url.push('&');
462 url.push_str(spec.auth_query_extra());
463 }
464 if let Some(challenge) = pkce_challenge {
465 url.push_str("&code_challenge=");
466 url.push_str(challenge);
467 url.push_str("&code_challenge_method=S256");
468 }
469 Ok(url)
470 }
471
472 pub fn token_url(&self) -> String {
474 match self.resolved_spec() {
475 Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
476 Err(_) => String::new(),
477 }
478 }
479
480 pub fn userinfo_url(&self) -> String {
482 match self.resolved_spec() {
483 Ok(spec) => match spec.userinfo_url() {
484 Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
485 None => String::new(),
486 },
487 Err(_) => String::new(),
488 }
489 }
490
491 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
497 self.exchange_code_full_pkce(code, None)
498 }
499
500 pub fn exchange_code_full_pkce(
501 &self,
502 code: &str,
503 code_verifier: Option<&str>,
504 ) -> Result<TokenSet, String> {
505 let spec = self.resolved_spec()?;
506 let cfg = self.provider_cfg();
507 let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
508 let pkce_field = code_verifier
509 .map(|v| format!("&code_verifier={}", url_encode(v)))
510 .unwrap_or_default();
511
512 let out = match spec.token_exchange() {
513 provider::TokenExchangeShape::Standard => {
514 let body = format!(
515 "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
516 code = url_encode(code),
517 cid_param = spec.client_id_param(),
518 cid = url_encode(&self.client_id),
519 secret = url_encode(&self.client_secret),
520 ruri = url_encode(&self.redirect_uri),
521 pkce = pkce_field,
522 );
523 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
524 }
525 provider::TokenExchangeShape::AppleJwt => {
526 let apple = self.apple.as_ref().ok_or(
527 "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
528 )?;
529 let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
530 let body = format!(
531 "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
532 code = url_encode(code),
533 cid = url_encode(&self.client_id),
534 secret = url_encode(&signed_secret),
535 ruri = url_encode(&self.redirect_uri),
536 pkce = pkce_field,
537 );
538 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
539 }
540 provider::TokenExchangeShape::BasicAuth => {
541 let body = format!(
542 "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
543 code = url_encode(code),
544 ruri = url_encode(&self.redirect_uri),
545 pkce = pkce_field,
546 );
547 http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
548 .map_err(sanitize_token_error)?
549 }
550 provider::TokenExchangeShape::JsonBody => {
551 let mut json = serde_json::Map::new();
552 json.insert("grant_type".into(), "authorization_code".into());
553 json.insert("code".into(), code.into());
554 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
555 json.insert("client_id".into(), self.client_id.clone().into());
556 json.insert("client_secret".into(), self.client_secret.clone().into());
557 if let Some(v) = code_verifier {
558 json.insert("code_verifier".into(), v.to_string().into());
559 }
560 let body = serde_json::Value::Object(json).to_string();
561 http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
562 }
563 provider::TokenExchangeShape::BasicAuthJsonBody => {
564 let mut json = serde_json::Map::new();
565 json.insert("grant_type".into(), "authorization_code".into());
566 json.insert("code".into(), code.into());
567 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
568 if let Some(v) = code_verifier {
569 json.insert("code_verifier".into(), v.to_string().into());
570 }
571 let body = serde_json::Value::Object(json).to_string();
572 http_post_json(
573 &token_url,
574 &body,
575 Some((&self.client_id, &self.client_secret)),
576 )
577 .map_err(sanitize_token_error)?
578 }
579 };
580 parse_token_response(&out)
581 }
582
583 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
587 Ok(self.exchange_code_full(code)?.access_token)
588 }
589
590 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
592 let info = self.fetch_userinfo_full(access_token)?;
593 Ok((info.email, info.name))
594 }
595
596 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
601 self.fetch_userinfo_with_id_token(access_token, None)
605 }
606
607 pub fn fetch_userinfo_with_id_token(
612 &self,
613 access_token: &str,
614 id_token: Option<&str>,
615 ) -> Result<UserInfo, String> {
616 let spec = self.resolved_spec()?;
617 let cfg = self.provider_cfg();
618
619 if matches!(spec.userinfo_parser(), provider::UserinfoParser::AppleIdToken) {
621 let token = id_token
622 .ok_or("apple login requires the id_token from the token response")?;
623 return parse_apple_id_token(token, &self.provider);
624 }
625
626 if matches!(spec.userinfo_parser(), provider::UserinfoParser::LinearGraphql) {
629 return fetch_linear_userinfo(&self.provider, access_token);
630 }
631
632 let url = match spec.userinfo_url() {
633 Some(u) => provider::resolve_endpoint(u, &cfg),
634 None => return Err(format!("provider {} has no userinfo endpoint", self.provider)),
635 };
636 let out = match spec.userinfo_method() {
637 provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
638 provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
639 }
640 .map_err(sanitize_token_error)?;
641 let parsed: serde_json::Value =
642 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
643
644 match spec.userinfo_parser() {
645 provider::UserinfoParser::Oidc => {
646 let email = parsed
647 .get("email")
648 .and_then(|v| v.as_str())
649 .ok_or("no email in userinfo")?
650 .to_string();
651 let name = parsed
652 .get("name")
653 .and_then(|v| v.as_str())
654 .map(String::from);
655 let provider_account_id = parsed
656 .get("sub")
657 .and_then(|v| v.as_str())
658 .ok_or("no sub in userinfo")?
659 .to_string();
660 Ok(UserInfo {
661 provider: self.provider.clone(),
662 provider_account_id,
663 email,
664 name,
665 })
666 }
667 provider::UserinfoParser::GitHub => {
668 let name = parsed
669 .get("name")
670 .and_then(|v| v.as_str())
671 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
672 .map(String::from);
673 let email = parsed
674 .get("email")
675 .and_then(|v| v.as_str())
676 .map(String::from);
677 let email = email
678 .or_else(|| fetch_github_primary_email(access_token).ok())
679 .ok_or("no accessible email on GitHub account")?;
680 let provider_account_id = parsed
681 .get("id")
682 .map(|v| {
683 v.as_i64()
684 .map(|n| n.to_string())
685 .or_else(|| v.as_str().map(String::from))
686 .unwrap_or_default()
687 })
688 .filter(|s| !s.is_empty())
689 .ok_or("no id in userinfo")?;
690 Ok(UserInfo {
691 provider: self.provider.clone(),
692 provider_account_id,
693 email,
694 name,
695 })
696 }
697 provider::UserinfoParser::Custom {
698 id_path,
699 email_path,
700 name_path,
701 } => {
702 let provider_account_id = json_pointer_string(&parsed, id_path)
703 .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
704 let raw_email = json_pointer_string(&parsed, email_path)
705 .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
706 let email = if !raw_email.contains('@') {
711 let domain = match self.provider.as_str() {
712 "twitter" => "x.invalid",
713 "reddit" => "reddit.invalid",
714 other => return Err(format!(
715 "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
716 )),
717 };
718 format!("{raw_email}@{domain}")
719 } else {
720 raw_email
721 };
722 let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
723 Ok(UserInfo {
724 provider: self.provider.clone(),
725 provider_account_id,
726 email,
727 name,
728 })
729 }
730 provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
731 provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
732 }
733 }
734}
735
736struct PkcePair {
739 code_verifier: String,
740 code_challenge: String,
741}
742
743fn generate_pkce() -> PkcePair {
747 use rand::RngCore;
748 let mut bytes = [0u8; 32];
749 rand::thread_rng().fill_bytes(&mut bytes);
750 let code_verifier = apple_jwt::base64_url(bytes);
751 use sha2::{Digest, Sha256};
752 let mut hasher = Sha256::new();
753 hasher.update(code_verifier.as_bytes());
754 let code_challenge = apple_jwt::base64_url(hasher.finalize());
755 PkcePair {
756 code_verifier,
757 code_challenge,
758 }
759}
760
761fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
784 let mut parts = id_token.split('.');
785 let _header = parts.next().ok_or("apple id_token: missing header")?;
786 let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
787 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
788 let claims_bytes = URL_SAFE_NO_PAD
789 .decode(claims_b64)
790 .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
791 let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
792 .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
793 let provider_account_id = claims
794 .get("sub")
795 .and_then(|v| v.as_str())
796 .ok_or("apple id_token: missing sub")?
797 .to_string();
798 let email = claims
799 .get("email")
800 .and_then(|v| v.as_str())
801 .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
802 .to_string();
803 Ok(UserInfo {
804 provider: provider.to_string(),
805 provider_account_id,
806 email,
807 name: None, })
809}
810
811fn sanitize_token_error(err: String) -> String {
823 const SENSITIVE: &[&str] = &[
824 "client_secret",
825 "code_verifier",
826 "client_assertion",
827 "refresh_token",
828 "access_token",
829 "id_token",
830 "code",
835 ];
836 let mut out = err;
837 for key in SENSITIVE {
838 out = redact_param_form(&out, key);
839 out = redact_param_json(&out, key);
840 }
841 out
842}
843
844fn redact_param_form(input: &str, key: &str) -> String {
848 let needle = format!("{key}=");
849 let mut out = String::with_capacity(input.len());
850 let mut i = 0;
851 while i < input.len() {
852 if input[i..].starts_with(&needle) {
853 out.push_str(&needle);
854 out.push_str("***");
855 i += needle.len();
856 while let Some((rel, ch)) = input[i..].char_indices().next() {
859 if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
860 i += rel;
861 break;
862 }
863 i += rel + ch.len_utf8();
864 }
865 } else {
866 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
868 out.push(ch);
869 i += ch.len_utf8();
870 }
871 }
872 out
873}
874
875fn redact_param_json(input: &str, key: &str) -> String {
878 let needle = format!("\"{key}\"");
879 let mut out = String::with_capacity(input.len());
880 let mut i = 0;
881 while i < input.len() {
882 if !input[i..].starts_with(&needle) {
883 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
884 out.push(ch);
885 i += ch.len_utf8();
886 continue;
887 }
888 let mut j = i + needle.len();
893 while let Some((_, ch)) = input[j..].char_indices().next() {
895 if !ch.is_whitespace() {
896 break;
897 }
898 j += ch.len_utf8();
899 }
900 if !input[j..].starts_with(':') {
901 out.push_str(&input[i..j]);
903 i = j;
904 continue;
905 }
906 j += 1;
907 while let Some((_, ch)) = input[j..].char_indices().next() {
908 if !ch.is_whitespace() {
909 break;
910 }
911 j += ch.len_utf8();
912 }
913 if !input[j..].starts_with('"') {
914 out.push_str(&input[i..j]);
915 i = j;
916 continue;
917 }
918 let value_start = j + 1;
919 let mut k = value_start;
921 let mut prev_backslash = false;
922 let mut closing: Option<usize> = None;
923 while k < input.len() {
924 let (_, ch) = input[k..].char_indices().next().expect("non-empty");
925 if ch == '"' && !prev_backslash {
926 closing = Some(k);
927 break;
928 }
929 prev_backslash = ch == '\\' && !prev_backslash;
930 k += ch.len_utf8();
931 }
932 match closing {
933 Some(end) => {
934 out.push_str(&input[i..value_start]);
935 out.push_str("***");
936 out.push('"');
937 i = end + 1;
938 }
939 None => {
940 out.push_str(&input[i..value_start]);
942 out.push_str("***");
943 i = input.len();
944 }
945 }
946 }
947 out
948}
949
950fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
955 let body = r#"{"query":"query { viewer { id email name } }"}"#;
956 let agent = ureq_agent();
957 let resp = agent
958 .post("https://api.linear.app/graphql")
959 .set("Authorization", &format!("Bearer {access_token}"))
960 .set("Content-Type", "application/json")
961 .set("Accept", "application/json")
962 .send_string(body)
963 .map_err(|e| format!("linear graphql: {e}"))?;
964 let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
965 let parsed: serde_json::Value = serde_json::from_str(&out)
966 .map_err(|e| format!("linear graphql not JSON: {e}"))?;
967 let viewer = parsed
968 .pointer("/data/viewer")
969 .ok_or("linear graphql: no /data/viewer")?;
970 let provider_account_id = viewer
971 .get("id")
972 .and_then(|v| v.as_str())
973 .ok_or("linear graphql: no id")?
974 .to_string();
975 let email = viewer
976 .get("email")
977 .and_then(|v| v.as_str())
978 .ok_or("linear graphql: no email")?
979 .to_string();
980 let name = viewer.get("name").and_then(|v| v.as_str()).map(String::from);
981 Ok(UserInfo {
982 provider: provider.to_string(),
983 provider_account_id,
984 email,
985 name,
986 })
987}
988
989fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
993 let node = v.pointer(path)?;
994 if let Some(s) = node.as_str() {
995 return Some(s.to_string());
996 }
997 if let Some(n) = node.as_i64() {
998 return Some(n.to_string());
999 }
1000 if let Some(n) = node.as_u64() {
1001 return Some(n.to_string());
1002 }
1003 None
1004}
1005
1006#[derive(Debug, Clone, PartialEq, Eq)]
1011pub struct UserInfo {
1012 pub provider: String,
1013 pub provider_account_id: String,
1014 pub email: String,
1015 pub name: Option<String>,
1016}
1017
1018#[derive(Debug, Clone, PartialEq, Eq)]
1022pub struct TokenSet {
1023 pub access_token: String,
1024 pub refresh_token: Option<String>,
1025 pub id_token: Option<String>,
1026 pub expires_at: Option<u64>,
1030 pub scope: Option<String>,
1031}
1032
1033fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1034 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1037 let mut map = serde_json::Map::new();
1039 for pair in body.split('&') {
1040 if let Some((k, v)) = pair.split_once('=') {
1041 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1042 }
1043 }
1044 serde_json::Value::Object(map)
1045 });
1046
1047 let access_token = json
1048 .get("access_token")
1049 .and_then(|v| v.as_str())
1050 .ok_or_else(|| format!("no access_token in token response: {body}"))?
1051 .to_string();
1052 let refresh_token = json
1053 .get("refresh_token")
1054 .and_then(|v| v.as_str())
1055 .map(String::from);
1056 let id_token = json
1057 .get("id_token")
1058 .and_then(|v| v.as_str())
1059 .map(String::from);
1060 let expires_at = json
1061 .get("expires_in")
1062 .and_then(|v| {
1063 v.as_u64()
1064 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1065 })
1066 .map(|secs| now_secs().saturating_add(secs));
1067 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1068 Ok(TokenSet {
1069 access_token,
1070 refresh_token,
1071 id_token,
1072 expires_at,
1073 scope,
1074 })
1075}
1076
1077fn url_encode(s: &str) -> String {
1078 let mut out = String::with_capacity(s.len());
1079 for b in s.bytes() {
1080 match b {
1081 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1082 out.push(b as char)
1083 }
1084 _ => out.push_str(&format!("%{b:02X}")),
1085 }
1086 }
1087 out
1088}
1089
1090const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1094
1095fn ureq_agent() -> ureq::Agent {
1096 ureq::AgentBuilder::new()
1097 .timeout_connect(HTTP_TIMEOUT)
1098 .timeout_read(HTTP_TIMEOUT)
1099 .timeout_write(HTTP_TIMEOUT)
1100 .user_agent("pylon/0.1")
1101 .build()
1102}
1103
1104fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1105 let agent = ureq_agent();
1106 let mut req = agent
1107 .post(url)
1108 .set("Content-Type", "application/x-www-form-urlencoded");
1109 if accept_json {
1110 req = req.set("Accept", "application/json");
1111 }
1112 match req.send_string(body) {
1113 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1114 Err(ureq::Error::Status(code, resp)) => {
1115 let body = resp.into_string().unwrap_or_default();
1116 Err(format!("HTTP {code}: {body}"))
1117 }
1118 Err(e) => Err(format!("HTTP error: {e}")),
1119 }
1120}
1121
1122fn http_post_form_basic(
1126 url: &str,
1127 body: &str,
1128 client_id: &str,
1129 client_secret: &str,
1130) -> Result<String, String> {
1131 use base64::{engine::general_purpose::STANDARD, Engine};
1132 let creds = format!("{client_id}:{client_secret}");
1133 let basic = STANDARD.encode(creds.as_bytes());
1134 let agent = ureq_agent();
1135 match agent
1136 .post(url)
1137 .set("Content-Type", "application/x-www-form-urlencoded")
1138 .set("Accept", "application/json")
1139 .set("Authorization", &format!("Basic {basic}"))
1140 .send_string(body)
1141 {
1142 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1143 Err(ureq::Error::Status(code, resp)) => {
1144 let body = resp.into_string().unwrap_or_default();
1145 Err(format!("HTTP {code}: {body}"))
1146 }
1147 Err(e) => Err(format!("HTTP error: {e}")),
1148 }
1149}
1150
1151fn http_post_json(
1155 url: &str,
1156 body: &str,
1157 basic_creds: Option<(&str, &str)>,
1158) -> Result<String, String> {
1159 let agent = ureq_agent();
1160 let mut req = agent
1161 .post(url)
1162 .set("Content-Type", "application/json")
1163 .set("Accept", "application/json");
1164 if let Some((id, secret)) = basic_creds {
1165 use base64::{engine::general_purpose::STANDARD, Engine};
1166 let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1167 req = req.set("Authorization", &format!("Basic {creds}"));
1168 }
1169 req = req.set("Notion-Version", "2022-06-28");
1172 match req.send_string(body) {
1173 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1174 Err(ureq::Error::Status(code, resp)) => {
1175 let body = resp.into_string().unwrap_or_default();
1176 Err(format!("HTTP {code}: {body}"))
1177 }
1178 Err(e) => Err(format!("HTTP error: {e}")),
1179 }
1180}
1181
1182fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1185 let agent = ureq_agent();
1186 match agent
1187 .post(url)
1188 .set("Authorization", &format!("Bearer {token}"))
1189 .set("Accept", "application/json")
1190 .call()
1191 {
1192 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1193 Err(ureq::Error::Status(code, resp)) => {
1194 let body = resp.into_string().unwrap_or_default();
1195 Err(format!("HTTP {code}: {body}"))
1196 }
1197 Err(e) => Err(format!("HTTP error: {e}")),
1198 }
1199}
1200
1201fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1202 let agent = ureq_agent();
1203 match agent
1204 .get(url)
1205 .set("Authorization", &format!("Bearer {token}"))
1206 .set("Accept", "application/json")
1207 .call()
1208 {
1209 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1210 Err(ureq::Error::Status(code, resp)) => {
1211 let body = resp.into_string().unwrap_or_default();
1212 Err(format!("HTTP {code}: {body}"))
1213 }
1214 Err(e) => Err(format!("HTTP error: {e}")),
1215 }
1216}
1217
1218fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1219 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1220 let emails: serde_json::Value =
1221 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1222 emails
1223 .as_array()
1224 .and_then(|arr| {
1225 arr.iter()
1226 .find(|e| {
1227 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1228 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1229 })
1230 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1231 })
1232 .ok_or_else(|| "no primary verified email on GitHub".into())
1233}
1234
1235pub struct OAuthRegistry {
1237 providers: std::collections::HashMap<String, OAuthConfig>,
1238}
1239
1240impl Default for OAuthRegistry {
1241 fn default() -> Self {
1242 Self::new()
1243 }
1244}
1245
1246impl OAuthRegistry {
1247 pub fn new() -> Self {
1248 Self {
1249 providers: std::collections::HashMap::new(),
1250 }
1251 }
1252
1253 pub fn register(&mut self, config: OAuthConfig) {
1254 self.providers.insert(config.provider.clone(), config);
1255 }
1256
1257 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1258 self.providers.get(provider)
1259 }
1260
1261 pub fn from_env() -> Self {
1274 let mut reg = Self::new();
1275
1276 for spec in provider::builtin::all() {
1277 let upper = spec.id.to_ascii_uppercase();
1278 let prefix = format!("PYLON_OAUTH_{upper}");
1279 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1280 Ok(v) => v,
1281 Err(_) => continue,
1282 };
1283 let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1284 Ok(v) => v,
1285 Err(_) if spec.id == "apple" => String::new(),
1287 Err(_) => continue,
1288 };
1289 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT")).unwrap_or_else(|_| {
1290 format!("http://localhost:3000/api/auth/callback/{}", spec.id)
1291 });
1292 let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1293 let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1294
1295 let apple = if spec.id == "apple" {
1296 match (
1297 std::env::var(format!("{prefix}_TEAM_ID")),
1298 std::env::var(format!("{prefix}_KEY_ID")),
1299 std::env::var(format!("{prefix}_PRIVATE_KEY")),
1300 ) {
1301 (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1302 team_id,
1303 key_id,
1304 private_key_pem,
1305 }),
1306 _ => continue, }
1308 } else {
1309 None
1310 };
1311
1312 reg.register(OAuthConfig {
1313 provider: spec.id.to_string(),
1314 client_id: id,
1315 client_secret: secret,
1316 redirect_uri,
1317 scopes_override,
1318 tenant,
1319 apple,
1320 oidc_issuer: None,
1321 });
1322 }
1323
1324 for (key, issuer) in std::env::vars() {
1326 let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1327 continue;
1328 };
1329 let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1330 continue;
1331 };
1332 let name = name_upper.to_ascii_lowercase();
1333 if provider::find_spec(&name).is_some() {
1334 continue; }
1336 let prefix = format!("PYLON_OAUTH_{name_upper}");
1337 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1338 Ok(v) => v,
1339 Err(_) => continue,
1340 };
1341 let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1342 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1343 .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1344 reg.register(OAuthConfig {
1345 provider: name,
1346 client_id: id,
1347 client_secret: secret,
1348 redirect_uri,
1349 scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1350 tenant: None,
1351 apple: None,
1352 oidc_issuer: Some(issuer),
1353 });
1354 }
1355
1356 reg
1357 }
1358
1359 pub fn ids(&self) -> impl Iterator<Item = &str> {
1363 self.providers.keys().map(|s| s.as_str())
1364 }
1365
1366 pub fn shared() -> &'static OAuthRegistry {
1373 static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1374 CELL.get_or_init(Self::from_env)
1375 }
1376}
1377
1378#[derive(Debug, Clone, PartialEq, Eq)]
1388pub struct OAuthState {
1389 pub provider: String,
1390 pub callback_url: String,
1393 pub error_callback_url: String,
1398 pub pkce_verifier: Option<String>,
1403 pub expires_at: u64,
1404}
1405
1406pub trait OAuthStateBackend: Send + Sync {
1411 fn put(&self, token: &str, state: &OAuthState);
1413 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1418}
1419
1420pub struct InMemoryOAuthBackend {
1422 states: Mutex<HashMap<String, OAuthState>>,
1423}
1424
1425impl InMemoryOAuthBackend {
1426 pub fn new() -> Self {
1427 Self {
1428 states: Mutex::new(HashMap::new()),
1429 }
1430 }
1431}
1432
1433impl Default for InMemoryOAuthBackend {
1434 fn default() -> Self {
1435 Self::new()
1436 }
1437}
1438
1439impl OAuthStateBackend for InMemoryOAuthBackend {
1440 fn put(&self, token: &str, state: &OAuthState) {
1441 self.states
1442 .lock()
1443 .unwrap()
1444 .insert(token.to_string(), state.clone());
1445 }
1446 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1447 let mut s = self.states.lock().unwrap();
1448 let entry = s.remove(token)?;
1449 if entry.expires_at <= now_unix_secs {
1450 return None;
1451 }
1452 Some(entry)
1453 }
1454}
1455
1456pub struct OAuthStateStore {
1463 backend: Box<dyn OAuthStateBackend>,
1464}
1465
1466impl Default for OAuthStateStore {
1467 fn default() -> Self {
1468 Self::new()
1469 }
1470}
1471
1472impl OAuthStateStore {
1473 pub fn new() -> Self {
1474 Self {
1475 backend: Box::new(InMemoryOAuthBackend::new()),
1476 }
1477 }
1478
1479 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1480 Self { backend }
1481 }
1482
1483 pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1491 self.create_with_pkce(provider, callback_url, error_callback_url, None)
1492 }
1493
1494 pub fn create_with_pkce(
1498 &self,
1499 provider: &str,
1500 callback_url: &str,
1501 error_callback_url: &str,
1502 pkce_verifier: Option<String>,
1503 ) -> String {
1504 use std::time::{SystemTime, UNIX_EPOCH};
1505 let token = generate_token();
1506 let now = SystemTime::now()
1507 .duration_since(UNIX_EPOCH)
1508 .unwrap_or_default()
1509 .as_secs();
1510 let state = OAuthState {
1511 provider: provider.to_string(),
1512 callback_url: callback_url.to_string(),
1513 error_callback_url: error_callback_url.to_string(),
1514 pkce_verifier,
1515 expires_at: now + 600,
1516 };
1517 self.backend.put(&token, &state);
1518 token
1519 }
1520
1521 pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1526 use std::time::{SystemTime, UNIX_EPOCH};
1527 let now = SystemTime::now()
1528 .duration_since(UNIX_EPOCH)
1529 .unwrap_or_default()
1530 .as_secs();
1531 let entry = self.backend.take(state, now)?;
1532 if entry.provider != expected_provider {
1533 return None;
1534 }
1535 Some(entry)
1536 }
1537}
1538
1539pub fn validate_trusted_redirect(
1556 url: &str,
1557 trusted_origins: &[String],
1558) -> Result<(), TrustedOriginError> {
1559 if url.is_empty() {
1560 return Err(TrustedOriginError::Empty);
1561 }
1562 if !url.starts_with("http://") && !url.starts_with("https://") {
1565 return Err(TrustedOriginError::NotHttp);
1566 }
1567 let url_origin = origin_of(url);
1568 if trusted_origins.iter().any(|t| t == &url_origin) {
1569 Ok(())
1570 } else {
1571 Err(TrustedOriginError::NotTrusted { origin: url_origin })
1572 }
1573}
1574
1575#[derive(Debug, Clone, PartialEq, Eq)]
1577pub enum TrustedOriginError {
1578 Empty,
1579 NotHttp,
1580 NotTrusted { origin: String },
1581}
1582
1583impl std::fmt::Display for TrustedOriginError {
1584 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1585 match self {
1586 TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1587 TrustedOriginError::NotHttp => {
1588 write!(f, "redirect URL must use http:// or https:// scheme")
1589 }
1590 TrustedOriginError::NotTrusted { origin } => write!(
1591 f,
1592 "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1593 ),
1594 }
1595 }
1596}
1597
1598pub fn origin_of(url: &str) -> String {
1603 let after_scheme = match url.find("://") {
1604 Some(i) => i + 3,
1605 None => return url.trim_end_matches('/').to_string(),
1606 };
1607 let rest = &url[after_scheme..];
1608 let cut = rest
1609 .find(|c: char| c == '/' || c == '?' || c == '#')
1610 .unwrap_or(rest.len());
1611 url[..after_scheme + cut].to_string()
1612}
1613
1614pub trait MagicCodeBackend: Send + Sync {
1627 fn put(&self, email: &str, code: &MagicCode);
1629 fn get(&self, email: &str) -> Option<MagicCode>;
1631 fn remove(&self, email: &str);
1634 fn bump_attempts(&self, email: &str);
1638 fn load_all(&self) -> Vec<MagicCode>;
1641}
1642
1643pub struct InMemoryMagicCodeBackend {
1646 codes: Mutex<HashMap<String, MagicCode>>,
1647}
1648
1649impl InMemoryMagicCodeBackend {
1650 pub fn new() -> Self {
1651 Self {
1652 codes: Mutex::new(HashMap::new()),
1653 }
1654 }
1655}
1656
1657impl Default for InMemoryMagicCodeBackend {
1658 fn default() -> Self {
1659 Self::new()
1660 }
1661}
1662
1663impl MagicCodeBackend for InMemoryMagicCodeBackend {
1664 fn put(&self, email: &str, code: &MagicCode) {
1665 self.codes
1666 .lock()
1667 .unwrap()
1668 .insert(email.to_string(), code.clone());
1669 }
1670 fn get(&self, email: &str) -> Option<MagicCode> {
1671 self.codes.lock().unwrap().get(email).cloned()
1672 }
1673 fn remove(&self, email: &str) {
1674 self.codes.lock().unwrap().remove(email);
1675 }
1676 fn bump_attempts(&self, email: &str) {
1677 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1678 c.attempts = c.attempts.saturating_add(1);
1679 }
1680 }
1681 fn load_all(&self) -> Vec<MagicCode> {
1682 self.codes.lock().unwrap().values().cloned().collect()
1683 }
1684}
1685
1686pub struct MagicCodeStore {
1691 cache: Mutex<HashMap<String, MagicCode>>,
1692 backend: Box<dyn MagicCodeBackend>,
1693}
1694
1695#[derive(Debug, Clone)]
1696pub struct MagicCode {
1697 pub email: String,
1698 pub code: String,
1699 pub expires_at: u64,
1700 pub attempts: u32,
1703}
1704
1705const MAX_ATTEMPTS: u32 = 5;
1709
1710const CREATE_COOLDOWN_SECS: u64 = 60;
1713
1714#[derive(Debug, Clone, PartialEq, Eq)]
1715pub enum MagicCodeError {
1716 NotFound,
1718 TooManyAttempts,
1720 BadCode,
1722 Expired,
1724 Throttled { retry_after_secs: u64 },
1726}
1727
1728impl Default for MagicCodeStore {
1729 fn default() -> Self {
1730 Self::new()
1731 }
1732}
1733
1734impl MagicCodeStore {
1735 pub fn new() -> Self {
1736 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1737 }
1738
1739 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1744 let now = now_secs();
1745 let mut cache = HashMap::new();
1746 for c in backend.load_all() {
1747 if c.expires_at > now {
1748 cache.insert(c.email.clone(), c);
1749 }
1750 }
1751 Self {
1752 cache: Mutex::new(cache),
1753 backend,
1754 }
1755 }
1756
1757 pub fn create(&self, email: &str) -> String {
1760 self.try_create(email).unwrap_or_else(|_| String::new())
1763 }
1764
1765 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1768 let now = now_secs();
1769
1770 let mut codes = self.cache.lock().unwrap();
1771
1772 if let Some(existing) = codes.get(email) {
1776 if existing.expires_at > now {
1777 let created_at = existing.expires_at.saturating_sub(600);
1778 let age = now.saturating_sub(created_at);
1779 if age < CREATE_COOLDOWN_SECS {
1780 return Err(MagicCodeError::Throttled {
1781 retry_after_secs: CREATE_COOLDOWN_SECS - age,
1782 });
1783 }
1784 }
1785 }
1786
1787 let code = generate_magic_code();
1788 let mc = MagicCode {
1789 email: email.to_string(),
1790 code: code.clone(),
1791 expires_at: now + 600, attempts: 0,
1793 };
1794 codes.insert(email.to_string(), mc.clone());
1795 self.backend.put(email, &mc);
1799 Ok(code)
1800 }
1801
1802 pub fn verify(&self, email: &str, code: &str) -> bool {
1806 matches!(self.try_verify(email, code), Ok(()))
1807 }
1808
1809 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1816 self.cache
1817 .lock()
1818 .map(|m| m.values().cloned().collect())
1819 .unwrap_or_default()
1820 }
1821
1822 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1823 let now = now_secs();
1824 let mut codes = self.cache.lock().unwrap();
1825
1826 let mc = match codes.get_mut(email) {
1827 Some(m) => m,
1828 None => return Err(MagicCodeError::NotFound),
1829 };
1830
1831 if mc.attempts >= MAX_ATTEMPTS {
1832 return Err(MagicCodeError::TooManyAttempts);
1833 }
1834 if mc.expires_at <= now {
1835 codes.remove(email);
1836 self.backend.remove(email);
1837 return Err(MagicCodeError::Expired);
1838 }
1839
1840 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1841 if !ok {
1842 mc.attempts += 1;
1843 self.backend.bump_attempts(email);
1844 if mc.attempts >= MAX_ATTEMPTS {
1846 return Err(MagicCodeError::TooManyAttempts);
1847 }
1848 return Err(MagicCodeError::BadCode);
1849 }
1850
1851 codes.remove(email);
1853 self.backend.remove(email);
1854 Ok(())
1855 }
1856}
1857
1858fn hex_encode(bytes: &[u8]) -> String {
1863 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1864}
1865
1866fn generate_magic_code() -> String {
1868 use rand::Rng;
1869 let mut rng = rand::thread_rng();
1870 let code: u32 = rng.gen_range(0..1_000_000);
1871 format!("{:06}", code)
1872}
1873
1874fn generate_token() -> String {
1876 use rand::Rng;
1877 let mut rng = rand::thread_rng();
1878 let bytes: [u8; 32] = rng.gen();
1879 format!("pylon_{}", hex_encode(&bytes))
1880}
1881
1882use std::collections::HashMap;
1887use std::sync::Mutex;
1888
1889pub trait SessionBackend: Send + Sync {
1893 fn load_all(&self) -> Vec<Session>;
1894 fn save(&self, session: &Session);
1895 fn remove(&self, token: &str);
1896}
1897
1898pub struct SessionStore {
1906 sessions: Mutex<HashMap<String, Session>>,
1907 backend: Option<Box<dyn SessionBackend>>,
1908 default_lifetime_secs: u64,
1912}
1913
1914impl Default for SessionStore {
1915 fn default() -> Self {
1916 Self::new()
1917 }
1918}
1919
1920impl SessionStore {
1921 pub fn new() -> Self {
1922 Self {
1923 sessions: Mutex::new(HashMap::new()),
1924 backend: None,
1925 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1926 }
1927 }
1928
1929 pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1932 self.default_lifetime_secs = lifetime_secs;
1933 self
1934 }
1935
1936 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1940 let mut map = HashMap::new();
1941 for s in backend.load_all() {
1942 if !s.is_expired() {
1943 map.insert(s.token.clone(), s);
1944 }
1945 }
1946 Self {
1947 sessions: Mutex::new(map),
1948 backend: Some(backend),
1949 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1950 }
1951 }
1952
1953 pub fn create(&self, user_id: String) -> Session {
1957 let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1958 let mut sessions = self.sessions.lock().unwrap();
1959 sessions.insert(session.token.clone(), session.clone());
1960 if let Some(b) = &self.backend {
1961 b.save(&session);
1962 }
1963 session
1964 }
1965
1966 pub fn get(&self, token: &str) -> Option<Session> {
1968 let mut sessions = self.sessions.lock().unwrap();
1969 match sessions.get(token) {
1970 Some(s) if s.is_expired() => {
1971 sessions.remove(token);
1972 None
1973 }
1974 Some(s) => Some(s.clone()),
1975 None => None,
1976 }
1977 }
1978
1979 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1982 match token {
1983 Some(t) => match self.get(t) {
1984 Some(session) => session.to_auth_context(),
1985 None => AuthContext::anonymous(),
1986 },
1987 None => AuthContext::anonymous(),
1988 }
1989 }
1990
1991 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1995 let mut sessions = self.sessions.lock().unwrap();
1996 let old = sessions.remove(old_token)?;
1997 if let Some(b) = &self.backend {
1998 b.remove(old_token);
1999 }
2000 if old.is_expired() {
2001 return None;
2002 }
2003 let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2009 new.device = old.device.clone();
2010 sessions.insert(new.token.clone(), new.clone());
2011 if let Some(b) = &self.backend {
2012 b.save(&new);
2013 }
2014 Some(new)
2015 }
2016
2017 pub fn list_all_unfiltered(&self) -> Vec<Session> {
2022 self.sessions
2023 .lock()
2024 .map(|m| m.values().cloned().collect())
2025 .unwrap_or_default()
2026 }
2027
2028 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2030 let sessions = self.sessions.lock().unwrap();
2031 sessions
2032 .values()
2033 .filter(|s| s.user_id == user_id && !s.is_expired())
2034 .cloned()
2035 .collect()
2036 }
2037
2038 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2040 let mut sessions = self.sessions.lock().unwrap();
2041 let tokens: Vec<String> = sessions
2042 .iter()
2043 .filter_map(|(t, s)| {
2044 if s.user_id == user_id {
2045 Some(t.clone())
2046 } else {
2047 None
2048 }
2049 })
2050 .collect();
2051 let n = tokens.len();
2052 for t in &tokens {
2053 sessions.remove(t);
2054 if let Some(b) = &self.backend {
2055 b.remove(t);
2056 }
2057 }
2058 n
2059 }
2060
2061 pub fn sweep_expired(&self) -> usize {
2063 let mut sessions = self.sessions.lock().unwrap();
2064 let expired: Vec<String> = sessions
2065 .iter()
2066 .filter_map(|(t, s)| {
2067 if s.is_expired() {
2068 Some(t.clone())
2069 } else {
2070 None
2071 }
2072 })
2073 .collect();
2074 let n = expired.len();
2075 for t in &expired {
2076 sessions.remove(t);
2077 if let Some(b) = &self.backend {
2078 b.remove(t);
2079 }
2080 }
2081 n
2082 }
2083
2084 pub fn set_device(&self, token: &str, device: String) -> bool {
2086 let mut sessions = self.sessions.lock().unwrap();
2087 if let Some(s) = sessions.get_mut(token) {
2088 s.device = Some(device);
2089 if let Some(b) = &self.backend {
2090 b.save(s);
2091 }
2092 true
2093 } else {
2094 false
2095 }
2096 }
2097
2098 pub fn create_guest(&self) -> Session {
2100 use rand::Rng;
2101 let mut rng = rand::thread_rng();
2102 let bytes: [u8; 16] = rng.gen();
2103 let guest_id = format!("guest_{}", hex_encode(&bytes));
2104 self.create(guest_id)
2105 }
2106
2107 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2109 let mut sessions = self.sessions.lock().unwrap();
2110 if let Some(session) = sessions.get_mut(token) {
2111 session.user_id = real_user_id;
2112 if let Some(b) = &self.backend {
2113 b.save(session);
2114 }
2115 true
2116 } else {
2117 false
2118 }
2119 }
2120
2121 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2126 let mut sessions = self.sessions.lock().unwrap();
2127 if let Some(session) = sessions.get_mut(token) {
2128 session.tenant_id = tenant_id;
2129 if let Some(b) = &self.backend {
2130 b.save(session);
2131 }
2132 true
2133 } else {
2134 false
2135 }
2136 }
2137
2138 pub fn revoke(&self, token: &str) -> bool {
2140 let mut sessions = self.sessions.lock().unwrap();
2141 let removed = sessions.remove(token).is_some();
2142 if removed {
2143 if let Some(b) = &self.backend {
2144 b.remove(token);
2145 }
2146 }
2147 removed
2148 }
2149}
2150
2151#[derive(Debug, Clone, PartialEq, Eq)]
2177pub struct Account {
2178 pub id: String,
2179 pub user_id: String,
2180 pub provider_id: String,
2183 pub account_id: String,
2186 pub access_token: Option<String>,
2187 pub refresh_token: Option<String>,
2188 pub id_token: Option<String>,
2189 pub access_token_expires_at: Option<u64>,
2192 pub refresh_token_expires_at: Option<u64>,
2196 pub scope: Option<String>,
2197 pub password: Option<String>,
2201 pub created_at: u64,
2203 pub updated_at: u64,
2205}
2206
2207impl Account {
2208 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2212 let now = now_secs();
2213 Self {
2214 id: generate_token(),
2215 user_id,
2216 provider_id: info.provider.clone(),
2217 account_id: info.provider_account_id.clone(),
2218 access_token: Some(tokens.access_token.clone()),
2219 refresh_token: tokens.refresh_token.clone(),
2220 id_token: tokens.id_token.clone(),
2221 access_token_expires_at: tokens.expires_at,
2222 refresh_token_expires_at: None,
2223 scope: tokens.scope.clone(),
2224 password: None,
2225 created_at: now,
2226 updated_at: now,
2227 }
2228 }
2229
2230 pub fn access_token_expired(&self) -> bool {
2235 match self.access_token_expires_at {
2236 Some(ts) => now_secs() >= ts,
2237 None => false,
2238 }
2239 }
2240}
2241
2242pub trait AccountBackend: Send + Sync {
2245 fn upsert(&self, account: &Account);
2249 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2252 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2257 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2259 fn delete_for_user(&self, user_id: &str) -> usize {
2264 let accounts = self.find_for_user(user_id);
2265 let n = accounts.len();
2266 for a in accounts {
2267 self.unlink(&a.provider_id, &a.account_id);
2268 }
2269 n
2270 }
2271 fn list_all(&self) -> Vec<Account>;
2276}
2277
2278pub struct InMemoryAccountBackend {
2282 accounts: Mutex<HashMap<(String, String), Account>>,
2286}
2287
2288impl InMemoryAccountBackend {
2289 pub fn new() -> Self {
2290 Self {
2291 accounts: Mutex::new(HashMap::new()),
2292 }
2293 }
2294}
2295
2296impl Default for InMemoryAccountBackend {
2297 fn default() -> Self {
2298 Self::new()
2299 }
2300}
2301
2302impl AccountBackend for InMemoryAccountBackend {
2303 fn upsert(&self, account: &Account) {
2304 let key = (account.provider_id.clone(), account.account_id.clone());
2305 self.accounts.lock().unwrap().insert(key, account.clone());
2306 }
2307 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2308 self.accounts
2309 .lock()
2310 .unwrap()
2311 .get(&(provider_id.to_string(), account_id.to_string()))
2312 .cloned()
2313 }
2314 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2315 self.accounts
2316 .lock()
2317 .unwrap()
2318 .values()
2319 .filter(|a| a.user_id == user_id)
2320 .cloned()
2321 .collect()
2322 }
2323 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2324 self.accounts
2325 .lock()
2326 .unwrap()
2327 .remove(&(provider_id.to_string(), account_id.to_string()))
2328 .is_some()
2329 }
2330 fn list_all(&self) -> Vec<Account> {
2331 self.accounts.lock().unwrap().values().cloned().collect()
2332 }
2333}
2334
2335pub struct AccountStore {
2338 backend: Box<dyn AccountBackend>,
2339}
2340
2341impl Default for AccountStore {
2342 fn default() -> Self {
2343 Self::new()
2344 }
2345}
2346
2347impl AccountStore {
2348 pub fn new() -> Self {
2349 Self {
2350 backend: Box::new(InMemoryAccountBackend::new()),
2351 }
2352 }
2353 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2354 Self { backend }
2355 }
2356 pub fn upsert(&self, account: &Account) {
2357 self.backend.upsert(account);
2358 }
2359 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2360 self.backend.find_by_provider(provider_id, account_id)
2361 }
2362 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2363 self.backend.find_for_user(user_id)
2364 }
2365 pub fn delete_for_user(&self, user_id: &str) -> usize {
2366 self.backend.delete_for_user(user_id)
2367 }
2368
2369 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2370 self.backend.unlink(provider_id, account_id)
2371 }
2372
2373 pub fn list_all_unfiltered(&self) -> Vec<Account> {
2387 self.backend.list_all()
2388 }
2389}
2390
2391#[cfg(test)]
2396mod tests {
2397 use super::*;
2398
2399 #[test]
2400 fn anonymous_context() {
2401 let ctx = AuthContext::anonymous();
2402 assert!(!ctx.is_authenticated());
2403 assert!(ctx.user_id.is_none());
2404 }
2405
2406 #[test]
2407 fn authenticated_context() {
2408 let ctx = AuthContext::authenticated("user-1".into());
2409 assert!(ctx.is_authenticated());
2410 assert_eq!(ctx.user_id, Some("user-1".into()));
2411 }
2412
2413 #[test]
2414 fn from_api_key_carries_scope_metadata() {
2415 let ctx = AuthContext::from_api_key(
2416 "user-1".into(),
2417 "key_abc".into(),
2418 Some("read,write".into()),
2419 );
2420 assert!(ctx.is_authenticated());
2421 assert!(ctx.is_api_key_auth());
2422 assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2423 assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2424 assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2425 }
2426
2427 #[test]
2428 fn session_auth_is_not_api_key_auth() {
2429 let ctx = AuthContext::authenticated("user-1".into());
2430 assert!(!ctx.is_api_key_auth());
2431 assert!(ctx.api_key_id.is_none());
2432 }
2433
2434 #[test]
2435 fn auth_mode_public_allows_anonymous() {
2436 let mode = AuthMode::Public;
2437 assert!(mode.check(&AuthContext::anonymous()));
2438 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2439 }
2440
2441 #[test]
2442 fn auth_mode_user_requires_authenticated() {
2443 let mode = AuthMode::User;
2444 assert!(!mode.check(&AuthContext::anonymous()));
2445 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2446 }
2447
2448 #[test]
2449 fn auth_mode_from_str() {
2450 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2451 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2452 assert_eq!(AuthMode::from_str("admin"), None);
2453 }
2454
2455 #[test]
2456 fn session_store_create_and_get() {
2457 let store = SessionStore::new();
2458 let session = store.create("user-1".into());
2459 assert!(!session.token.is_empty());
2460 assert!(session.token.starts_with("pylon_"));
2461
2462 let retrieved = store.get(&session.token).unwrap();
2463 assert_eq!(retrieved.user_id, "user-1");
2464 }
2465
2466 #[test]
2467 fn session_store_resolve() {
2468 let store = SessionStore::new();
2469 let session = store.create("user-1".into());
2470
2471 let ctx = store.resolve(Some(&session.token));
2472 assert!(ctx.is_authenticated());
2473 assert_eq!(ctx.user_id, Some("user-1".into()));
2474
2475 let anon = store.resolve(None);
2476 assert!(!anon.is_authenticated());
2477
2478 let bad = store.resolve(Some("invalid-token"));
2479 assert!(!bad.is_authenticated());
2480 }
2481
2482 #[test]
2483 fn session_store_revoke() {
2484 let store = SessionStore::new();
2485 let session = store.create("user-1".into());
2486
2487 assert!(store.revoke(&session.token));
2488 assert!(store.get(&session.token).is_none());
2489 assert!(!store.revoke(&session.token)); }
2491
2492 #[test]
2493 fn session_to_auth_context() {
2494 let session = Session::new("user-42".into());
2495 let ctx = session.to_auth_context();
2496 assert_eq!(ctx.user_id, Some("user-42".into()));
2497 }
2498
2499 #[test]
2502 fn admin_context() {
2503 let ctx = AuthContext::admin();
2504 assert!(ctx.is_admin);
2505 assert!(ctx.is_authenticated());
2506 }
2507
2508 #[test]
2509 fn anonymous_not_admin() {
2510 let ctx = AuthContext::anonymous();
2511 assert!(!ctx.is_admin);
2512 }
2513
2514 #[test]
2515 fn authenticated_not_admin() {
2516 let ctx = AuthContext::authenticated("user-1".into());
2517 assert!(!ctx.is_admin);
2518 }
2519
2520 #[test]
2523 fn magic_code_create_and_verify() {
2524 let store = MagicCodeStore::new();
2525 let code = store.create("test@example.com");
2526 assert_eq!(code.len(), 6);
2527 assert!(store.verify("test@example.com", &code));
2528 }
2529
2530 #[test]
2531 fn magic_code_wrong_code_rejected() {
2532 let store = MagicCodeStore::new();
2533 store.create("test@example.com");
2534 assert!(!store.verify("test@example.com", "000000"));
2535 }
2536
2537 #[test]
2538 fn magic_code_wrong_email_rejected() {
2539 let store = MagicCodeStore::new();
2540 let code = store.create("test@example.com");
2541 assert!(!store.verify("other@example.com", &code));
2542 }
2543
2544 #[test]
2545 fn magic_code_consumed_after_verify() {
2546 let store = MagicCodeStore::new();
2547 let code = store.create("test@example.com");
2548 assert!(store.verify("test@example.com", &code));
2549 assert!(!store.verify("test@example.com", &code));
2551 }
2552
2553 #[test]
2554 fn magic_code_different_emails_independent() {
2555 let store = MagicCodeStore::new();
2556 let code1 = store.create("alice@example.com");
2557 let code2 = store.create("bob@example.com");
2558 assert!(store.verify("alice@example.com", &code1));
2560 assert!(store.verify("bob@example.com", &code2));
2561 }
2562
2563 #[test]
2566 fn constant_time_eq_equal() {
2567 assert!(constant_time_eq(b"hello", b"hello"));
2568 assert!(constant_time_eq(b"", b""));
2569 }
2570
2571 #[test]
2572 fn constant_time_eq_not_equal() {
2573 assert!(!constant_time_eq(b"hello", b"world"));
2574 assert!(!constant_time_eq(b"hello", b"hell"));
2575 assert!(!constant_time_eq(b"a", b"b"));
2576 }
2577
2578 #[test]
2581 fn generated_tokens_are_unique() {
2582 let t1 = generate_token();
2583 let t2 = generate_token();
2584 assert_ne!(t1, t2);
2585 assert!(t1.starts_with("pylon_"));
2586 assert!(t2.starts_with("pylon_"));
2587 assert_eq!(t1.len(), 6 + 64);
2589 }
2590
2591 #[test]
2594 fn oauth_registry_empty() {
2595 let reg = OAuthRegistry::new();
2596 assert!(reg.get("google").is_none());
2597 }
2598
2599 #[test]
2600 fn oauth_registry_register_and_get() {
2601 let mut reg = OAuthRegistry::new();
2602 reg.register(OAuthConfig {
2603 provider: "google".into(),
2604 client_id: "test-id".into(),
2605 client_secret: "test-secret".into(),
2606 redirect_uri: "http://localhost/callback".into(),
2607 ..Default::default()
2608 });
2609 let config = reg.get("google").unwrap();
2610 assert_eq!(config.client_id, "test-id");
2611 assert!(config.auth_url().contains("accounts.google.com"));
2612 }
2613
2614 #[test]
2622 fn every_builtin_provider_routes_through_oauth_config() {
2623 for spec in provider::builtin::all() {
2624 let cfg = OAuthConfig {
2625 provider: spec.id.into(),
2626 client_id: "cid".into(),
2627 client_secret: "csecret".into(),
2628 redirect_uri: "https://app/cb".into(),
2629 tenant: if spec.id == "microsoft" {
2630 Some("contoso".into())
2631 } else {
2632 None
2633 },
2634 apple: if spec.id == "apple" {
2635 Some(provider::AppleConfig {
2636 team_id: "T".into(),
2637 key_id: "K".into(),
2638 private_key_pem: "no".into(),
2639 })
2640 } else {
2641 None
2642 },
2643 ..Default::default()
2644 };
2645 let auth = cfg.auth_url();
2646 assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2647 let expected_param = format!("{}=cid", spec.client_id_param);
2649 assert!(
2650 auth.contains(&expected_param),
2651 "{}: missing {}; got auth_url: {}",
2652 spec.id,
2653 expected_param,
2654 auth,
2655 );
2656 assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2657 if spec.id == "apple" {
2659 assert!(
2660 auth.contains("response_mode=form_post"),
2661 "apple auth_url must include response_mode=form_post; got {auth}"
2662 );
2663 }
2664 }
2665 }
2666
2667 #[test]
2670 fn microsoft_tenant_placeholder_resolves() {
2671 let cfg = OAuthConfig {
2672 provider: "microsoft".into(),
2673 client_id: "id".into(),
2674 client_secret: "secret".into(),
2675 redirect_uri: "https://app/cb".into(),
2676 tenant: Some("contoso.onmicrosoft.com".into()),
2677 ..Default::default()
2678 };
2679 assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2680 assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2681 }
2682
2683 #[test]
2685 fn microsoft_default_tenant_common() {
2686 let cfg = OAuthConfig {
2687 provider: "microsoft".into(),
2688 client_id: "id".into(),
2689 client_secret: "secret".into(),
2690 redirect_uri: "https://app/cb".into(),
2691 ..Default::default()
2692 };
2693 assert!(cfg.auth_url().contains("/common/"));
2694 assert!(cfg.token_url().contains("/common/"));
2695 }
2696
2697 #[test]
2700 fn scopes_override_replaces_spec_default() {
2701 let cfg = OAuthConfig {
2702 provider: "github".into(),
2703 client_id: "id".into(),
2704 client_secret: "secret".into(),
2705 redirect_uri: "https://app/cb".into(),
2706 scopes_override: Some("repo user:email".into()),
2707 ..Default::default()
2708 };
2709 let auth = cfg.auth_url();
2710 assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2712 }
2713
2714 #[test]
2719 fn apple_exchange_requires_apple_config() {
2720 let cfg = OAuthConfig {
2721 provider: "apple".into(),
2722 client_id: "com.example.app".into(),
2723 client_secret: String::new(),
2724 redirect_uri: "https://app/cb".into(),
2725 apple: None, ..Default::default()
2727 };
2728 let err = cfg.exchange_code_full("x").unwrap_err();
2729 assert!(err.contains("apple provider requires"), "got: {err}");
2730 }
2731
2732 #[test]
2737 fn oidc_issuer_uses_discovered_endpoints() {
2738 let issuer = "https://acme.test.invalid";
2739 provider::oidc_cache::insert_for_test(
2740 issuer,
2741 provider::DiscoveredSpec {
2742 auth_url: "https://acme.test.invalid/authorize".into(),
2743 token_url: "https://acme.test.invalid/oauth/token".into(),
2744 userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2745 scopes: "openid email profile".into(),
2746 userinfo_parser: provider::UserinfoParser::Oidc,
2747 token_exchange: provider::TokenExchangeShape::Standard,
2748 },
2749 );
2750 let cfg = OAuthConfig {
2751 provider: "auth0".into(), client_id: "id".into(),
2753 client_secret: "secret".into(),
2754 redirect_uri: "https://app/cb".into(),
2755 oidc_issuer: Some(issuer.into()),
2756 ..Default::default()
2757 };
2758 assert!(cfg.auth_url().starts_with("https://acme.test.invalid/authorize?"));
2759 assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2760 assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2761 }
2762
2763 #[test]
2769 fn apple_auth_url_includes_form_post() {
2770 let cfg = OAuthConfig {
2771 provider: "apple".into(),
2772 client_id: "com.example.app".into(),
2773 client_secret: String::new(),
2774 redirect_uri: "https://app/cb".into(),
2775 apple: Some(provider::AppleConfig {
2776 team_id: "T".into(),
2777 key_id: "K".into(),
2778 private_key_pem: "no".into(),
2779 }),
2780 ..Default::default()
2781 };
2782 let auth = cfg.auth_url();
2783 assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2784 assert_eq!(cfg.userinfo_url(), "");
2786 }
2787
2788 #[test]
2793 fn apple_id_token_decode_extracts_identity() {
2794 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2796 use base64::Engine;
2797 let claims = serde_json::json!({
2798 "iss": "https://appleid.apple.com",
2799 "sub": "001234.abc.def",
2800 "aud": "com.example.app",
2801 "email": "user@privaterelay.appleid.com",
2802 "email_verified": "true",
2803 });
2804 let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
2805 .encode(claims.to_string().as_bytes());
2806 let id_token = format!("{header}.{claims_b64}.signature_ignored");
2807
2808 let cfg = OAuthConfig {
2809 provider: "apple".into(),
2810 client_id: "com.example.app".into(),
2811 client_secret: String::new(),
2812 redirect_uri: "https://app/cb".into(),
2813 apple: Some(provider::AppleConfig {
2814 team_id: "T".into(),
2815 key_id: "K".into(),
2816 private_key_pem: "no".into(),
2817 }),
2818 ..Default::default()
2819 };
2820 let info = cfg
2821 .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2822 .expect("apple id_token decode");
2823 assert_eq!(info.provider_account_id, "001234.abc.def");
2824 assert_eq!(info.email, "user@privaterelay.appleid.com");
2825
2826 let err = cfg.fetch_userinfo_full("token").unwrap_err();
2829 assert!(err.contains("apple login requires"), "got: {err}");
2830 }
2831
2832 #[test]
2836 fn twitter_auth_url_includes_pkce() {
2837 let cfg = OAuthConfig {
2838 provider: "twitter".into(),
2839 client_id: "tw_client".into(),
2840 client_secret: "tw_secret".into(),
2841 redirect_uri: "https://app/cb".into(),
2842 ..Default::default()
2843 };
2844 let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2845 let v = verifier.expect("twitter must produce verifier");
2846 assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2847 assert!(url.contains("code_challenge="), "got: {url}");
2848 assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2849
2850 let google = OAuthConfig {
2852 provider: "google".into(),
2853 client_id: "g".into(),
2854 client_secret: "g".into(),
2855 redirect_uri: "https://app/cb".into(),
2856 ..Default::default()
2857 };
2858 let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2859 assert!(gverifier.is_none(), "google should not add PKCE");
2860 assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2861 }
2862
2863 #[test]
2866 fn tiktok_uses_client_key_and_comma_scopes() {
2867 let cfg = OAuthConfig {
2868 provider: "tiktok".into(),
2869 client_id: "tk_client".into(),
2870 client_secret: "tk_secret".into(),
2871 redirect_uri: "https://app/cb".into(),
2872 scopes_override: Some("user.info.basic video.list".into()),
2873 ..Default::default()
2874 };
2875 let auth = cfg.auth_url();
2876 assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2877 assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2879 assert!(!auth.contains("user.info.basic%20video.list"), "got: {auth}");
2881 }
2882
2883 #[test]
2887 fn token_exchange_url_encodes_code() {
2888 let raw = "code+with/special=chars";
2894 let encoded = url_encode(raw);
2895 assert!(!encoded.contains('+'));
2896 assert!(!encoded.contains('/'));
2897 assert!(!encoded.contains('='));
2898 assert!(encoded.contains("%2B"));
2899 assert!(encoded.contains("%2F"));
2900 assert!(encoded.contains("%3D"));
2901 }
2902
2903 #[test]
2907 fn sanitize_token_error_redacts_secrets() {
2908 let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2909 let scrubbed = sanitize_token_error(raw.into());
2910 assert!(!scrubbed.contains("sk_real_secret_value"));
2911 assert!(!scrubbed.contains("verifierxyz"));
2912 assert!(scrubbed.contains("client_secret=***"));
2913 assert!(scrubbed.contains("code_verifier=***"));
2914 assert!(scrubbed.contains("invalid_grant"));
2916 assert!(scrubbed.contains("hint=check%20your%20code"));
2917 }
2918
2919 #[test]
2922 fn sanitize_token_error_redacts_json_secrets() {
2923 let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2924 let scrubbed = sanitize_token_error(raw.into());
2925 assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2926 assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2927 assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2928 assert!(scrubbed.contains(r#""client_secret":"***""#), "got: {scrubbed}");
2929 assert!(scrubbed.contains(r#""refresh_token":"***""#), "got: {scrubbed}");
2930 assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2931 assert!(scrubbed.contains("invalid_grant"));
2932 }
2933
2934 #[test]
2939 fn sanitize_token_error_handles_utf8() {
2940 let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2941 let scrubbed = sanitize_token_error(raw.into());
2942 assert!(scrubbed.contains("⚠️"), "non-ASCII chars must survive: {scrubbed}");
2943 assert!(!scrubbed.contains("sk_x"));
2944 assert!(scrubbed.contains("client_secret=***"));
2945 }
2946
2947 #[test]
2952 fn oidc_discovery_picks_token_auth_method() {
2953 let json_post = r#"{
2954 "issuer": "https://acme.test/",
2955 "authorization_endpoint": "https://acme.test/auth",
2956 "token_endpoint": "https://acme.test/token",
2957 "token_endpoint_auth_methods_supported": ["client_secret_post"]
2958 }"#;
2959 let spec = provider::OidcDiscoveryDoc::parse(json_post).unwrap().into_spec();
2960 assert!(matches!(
2961 spec.token_exchange,
2962 provider::TokenExchangeShape::Standard
2963 ));
2964
2965 let json_default = r#"{
2967 "issuer": "https://acme.test/",
2968 "authorization_endpoint": "https://acme.test/auth",
2969 "token_endpoint": "https://acme.test/token"
2970 }"#;
2971 let spec = provider::OidcDiscoveryDoc::parse(json_default)
2972 .unwrap()
2973 .into_spec();
2974 assert!(matches!(
2975 spec.token_exchange,
2976 provider::TokenExchangeShape::BasicAuth
2977 ));
2978 }
2979
2980 #[test]
2983 fn oidc_discovery_rejects_incomplete_doc() {
2984 let json = r#"{
2986 "issuer": "https://acme.test/",
2987 "authorization_endpoint": "https://acme.test/auth"
2988 }"#;
2989 let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
2990 assert!(err.contains("token_endpoint"), "got: {err}");
2991 }
2992
2993 #[test]
2997 fn from_env_picks_up_discord() {
2998 let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3001 let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3002 std::env::set_var(key_id, "discord-test-id");
3006 std::env::set_var(key_secret, "discord-test-secret");
3007
3008 let reg = OAuthRegistry::from_env();
3009 let discord = reg.get("discord").expect("discord registered");
3010 assert_eq!(discord.client_id, "discord-test-id");
3011 assert!(discord.auth_url().contains("discord.com"));
3012
3013 std::env::remove_var(key_id);
3014 std::env::remove_var(key_secret);
3015 }
3016
3017 #[test]
3020 fn guest_session() {
3021 let store = SessionStore::new();
3022 let session = store.create_guest();
3023 assert!(session.user_id.starts_with("guest_"));
3024 assert!(!session.token.is_empty());
3025
3026 let ctx = store.resolve(Some(&session.token));
3027 assert!(ctx.is_authenticated());
3028 assert!(ctx.user_id.unwrap().starts_with("guest_"));
3029 }
3030
3031 #[test]
3032 fn upgrade_guest_to_real_user() {
3033 let store = SessionStore::new();
3034 let session = store.create_guest();
3035 assert!(session.user_id.starts_with("guest_"));
3036
3037 let upgraded = store.upgrade(&session.token, "real-user-123".into());
3038 assert!(upgraded);
3039
3040 let ctx = store.resolve(Some(&session.token));
3041 assert_eq!(ctx.user_id, Some("real-user-123".into()));
3042 }
3043
3044 #[test]
3045 fn upgrade_invalid_token_fails() {
3046 let store = SessionStore::new();
3047 let upgraded = store.upgrade("nonexistent-token", "user".into());
3048 assert!(!upgraded);
3049 }
3050
3051 #[test]
3052 fn guest_context() {
3053 let ctx = AuthContext::guest("guest_123".into());
3054 assert!(!ctx.is_authenticated());
3057 assert!(ctx.is_guest);
3058 assert!(!ctx.is_admin);
3059 assert_eq!(ctx.user_id, Some("guest_123".into()));
3060 assert!(!AuthMode::User.check(&ctx));
3061 assert!(AuthMode::Public.check(&ctx));
3062 }
3063
3064 #[test]
3065 fn oauth_token_urls() {
3066 let google = OAuthConfig {
3067 provider: "google".into(),
3068 client_id: "x".into(),
3069 client_secret: "x".into(),
3070 redirect_uri: "x".into(),
3071 ..Default::default()
3072 };
3073 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3074 let github = OAuthConfig {
3075 provider: "github".into(),
3076 client_id: "x".into(),
3077 client_secret: "x".into(),
3078 redirect_uri: "x".into(),
3079 ..Default::default()
3080 };
3081 assert_eq!(
3082 github.token_url(),
3083 "https://github.com/login/oauth/access_token"
3084 );
3085 let unknown = OAuthConfig {
3086 provider: "unknown".into(),
3087 client_id: "x".into(),
3088 client_secret: "x".into(),
3089 redirect_uri: "x".into(),
3090 ..Default::default()
3091 };
3092 assert_eq!(unknown.token_url(), "");
3093 assert!(unknown.auth_url().is_empty());
3094 }
3095
3096 #[test]
3097 fn oauth_auth_url_github() {
3098 let config = OAuthConfig {
3099 provider: "github".into(),
3100 client_id: "gh-id".into(),
3101 client_secret: "gh-secret".into(),
3102 redirect_uri: "http://localhost/cb".into(),
3103 ..Default::default()
3104 };
3105 assert!(config.auth_url().contains("github.com"));
3106 assert!(config.auth_url().contains("gh-id"));
3107 }
3108
3109 #[test]
3110 fn oauth_auth_url_with_state() {
3111 let config = OAuthConfig {
3112 provider: "google".into(),
3113 client_id: "test-id".into(),
3114 client_secret: "test-secret".into(),
3115 redirect_uri: "http://localhost/cb".into(),
3116 ..Default::default()
3117 };
3118 let url = config.auth_url_with_state("random_state_123");
3119 assert!(url.contains("&state=random_state_123"));
3120 }
3121
3122 #[test]
3123 fn oauth_state_store_create_and_validate() {
3124 let store = OAuthStateStore::new();
3125 let token = store.create("google", "https://app/cb", "https://app/login");
3126 let rec = store.validate(&token, "google").expect("valid first time");
3127 assert_eq!(rec.callback_url, "https://app/cb");
3128 assert_eq!(rec.error_callback_url, "https://app/login");
3129 assert!(store.validate(&token, "google").is_none());
3131 }
3132
3133 #[test]
3134 fn oauth_state_store_wrong_provider_rejected() {
3135 let store = OAuthStateStore::new();
3136 let token = store.create("google", "https://app/cb", "https://app/cb");
3137 assert!(store.validate(&token, "github").is_none());
3138 }
3139
3140 #[test]
3141 fn oauth_state_store_invalid_state_rejected() {
3142 let store = OAuthStateStore::new();
3143 assert!(store.validate("nonexistent", "google").is_none());
3144 }
3145
3146 #[test]
3147 fn validate_trusted_redirect_basics() {
3148 let trusted = vec!["http://localhost:3000".to_string()];
3149 assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3150 assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3151 assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3152
3153 assert!(matches!(
3155 validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3156 Err(TrustedOriginError::NotTrusted { .. })
3157 ));
3158 assert!(matches!(
3161 validate_trusted_redirect("javascript:alert(1)", &trusted),
3162 Err(TrustedOriginError::NotHttp)
3163 ));
3164 assert!(matches!(
3165 validate_trusted_redirect("", &trusted),
3166 Err(TrustedOriginError::Empty)
3167 ));
3168 }
3169}