1pub mod api_key;
2pub mod apple_jwt;
3pub mod captcha;
4pub mod cookie;
5pub mod email;
6pub mod jwt;
7pub mod oidc_provider;
8pub mod org;
9pub mod password;
10pub mod phone;
11pub mod provider;
12pub mod rate_limit;
13pub mod scim;
14pub mod siwe;
15pub mod stripe;
16pub mod totp;
17pub mod verification;
18pub mod webauthn;
19
20pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
21
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
39pub struct AuthContext {
40 pub user_id: Option<String>,
44 pub is_admin: bool,
46 #[serde(default, skip_serializing_if = "is_false")]
51 pub is_guest: bool,
52 pub roles: Vec<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
57 pub tenant_id: Option<String>,
58 #[serde(skip_serializing_if = "Option::is_none")]
63 pub api_key_id: Option<String>,
64 #[serde(skip_serializing_if = "Option::is_none")]
67 pub api_key_scopes: Option<String>,
68}
69
70fn is_false(b: &bool) -> bool {
71 !b
72}
73
74impl AuthContext {
75 pub fn anonymous() -> Self {
77 Self {
78 user_id: None,
79 is_admin: false,
80 is_guest: false,
81 roles: Vec::new(),
82 tenant_id: None,
83 api_key_id: None,
84 api_key_scopes: None,
85 }
86 }
87
88 pub fn authenticated(user_id: String) -> Self {
90 Self {
91 user_id: Some(user_id),
92 is_admin: false,
93 is_guest: false,
94 roles: Vec::new(),
95 tenant_id: None,
96 api_key_id: None,
97 api_key_scopes: None,
98 }
99 }
100
101 pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
104 Self {
105 user_id: Some(user_id),
106 is_admin: false,
107 is_guest: false,
108 roles: Vec::new(),
109 tenant_id: None,
110 api_key_id: Some(key_id),
111 api_key_scopes: scopes,
112 }
113 }
114
115 pub fn is_api_key_auth(&self) -> bool {
118 self.api_key_id.is_some()
119 }
120
121 pub fn guest(guest_id: String) -> Self {
126 Self {
127 user_id: Some(guest_id),
128 is_admin: false,
129 is_guest: true,
130 roles: Vec::new(),
131 tenant_id: None,
132 api_key_id: None,
133 api_key_scopes: None,
134 }
135 }
136
137 pub fn admin() -> Self {
139 Self {
140 user_id: Some("__admin__".into()),
141 is_admin: true,
142 is_guest: false,
143 roles: vec!["admin".into()],
144 tenant_id: None,
145 api_key_id: None,
146 api_key_scopes: None,
147 }
148 }
149
150 pub fn user(user_id: String) -> Self {
152 Self::authenticated(user_id)
153 }
154
155 pub fn tenant_id(&self) -> Option<&str> {
157 self.tenant_id.as_deref()
158 }
159
160 pub fn with_tenant(mut self, tenant_id: String) -> Self {
162 self.tenant_id = Some(tenant_id);
163 self
164 }
165
166 pub fn is_authenticated(&self) -> bool {
170 self.user_id.is_some() && !self.is_guest
171 }
172
173 pub fn has_role(&self, role: &str) -> bool {
175 self.is_admin || self.roles.iter().any(|r| r == role)
176 }
177
178 pub fn has_any_role(&self, roles: &[&str]) -> bool {
180 self.is_admin || roles.iter().any(|r| self.has_role(r))
181 }
182
183 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
185 self.roles = roles;
186 self
187 }
188}
189
190pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
200 if a.len() != b.len() {
201 return false;
202 }
203 let mut result: u8 = 0;
204 for (x, y) in a.iter().zip(b.iter()) {
205 result |= x ^ y;
206 }
207 result == 0
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
216pub enum AuthMode {
217 Public,
219 User,
221}
222
223impl AuthMode {
224 #[allow(clippy::should_implement_trait)]
226 pub fn from_str(s: &str) -> Option<Self> {
227 match s {
228 "public" => Some(AuthMode::Public),
229 "user" => Some(AuthMode::User),
230 _ => None,
231 }
232 }
233
234 pub fn check(&self, ctx: &AuthContext) -> bool {
236 match self {
237 AuthMode::Public => true,
238 AuthMode::User => ctx.is_authenticated(),
239 }
240 }
241}
242
243#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
249pub struct Session {
250 pub token: String,
251 pub user_id: String,
252 #[serde(default)]
254 pub expires_at: u64,
255 #[serde(default, skip_serializing_if = "Option::is_none")]
257 pub device: Option<String>,
258 #[serde(default)]
260 pub created_at: u64,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
265 pub tenant_id: Option<String>,
266}
267
268impl Session {
269 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
271
272 pub fn new(user_id: String) -> Self {
274 let now = now_secs();
275 Self {
276 token: generate_token(),
277 user_id,
278 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
279 device: None,
280 created_at: now,
281 tenant_id: None,
282 }
283 }
284
285 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
287 let now = now_secs();
288 Self {
289 token: generate_token(),
290 user_id,
291 expires_at: if lifetime_secs == 0 {
292 0
293 } else {
294 now.saturating_add(lifetime_secs)
295 },
296 device: None,
297 created_at: now,
298 tenant_id: None,
299 }
300 }
301
302 pub fn to_auth_context(&self) -> AuthContext {
305 let ctx = AuthContext::authenticated(self.user_id.clone());
306 match &self.tenant_id {
307 Some(t) => ctx.with_tenant(t.clone()),
308 None => ctx,
309 }
310 }
311
312 pub fn is_expired(&self) -> bool {
316 self.expires_at != 0 && now_secs() >= self.expires_at
317 }
318}
319
320fn now_secs() -> u64 {
321 use std::time::{SystemTime, UNIX_EPOCH};
322 SystemTime::now()
323 .duration_since(UNIX_EPOCH)
324 .unwrap_or_default()
325 .as_secs()
326}
327
328#[derive(Debug, Clone, Default, Serialize, Deserialize)]
333pub struct OAuthConfig {
334 pub provider: String,
335 pub client_id: String,
336 pub client_secret: String,
337 pub redirect_uri: String,
338 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub scopes_override: Option<String>,
344 #[serde(default, skip_serializing_if = "Option::is_none")]
348 pub tenant: Option<String>,
349 #[serde(default, skip_serializing_if = "Option::is_none")]
352 pub apple: Option<provider::AppleConfig>,
353 #[serde(default, skip_serializing_if = "Option::is_none")]
359 pub oidc_issuer: Option<String>,
360}
361
362impl OAuthConfig {
363 fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
368 if let Some(issuer) = self.oidc_issuer.as_deref() {
369 return provider::oidc_cache::resolve(issuer);
370 }
371 provider::find_spec(&self.provider)
372 .map(provider::ResolvedSpec::Static)
373 .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
374 }
375
376 fn provider_cfg(&self) -> provider::ProviderConfig {
379 provider::ProviderConfig {
380 provider: self.provider.clone(),
381 client_id: self.client_id.clone(),
382 client_secret: self.client_secret.clone(),
383 redirect_uri: self.redirect_uri.clone(),
384 scopes_override: self.scopes_override.clone(),
385 tenant: self.tenant.clone(),
386 apple: self.apple.clone(),
387 oidc_issuer: self.oidc_issuer.clone(),
388 }
389 }
390
391 pub fn auth_url(&self) -> String {
401 match self.build_auth_url(None) {
402 Ok(u) => u,
403 Err(_) => String::new(),
404 }
405 }
406
407 pub fn auth_url_with_state(&self, state: &str) -> String {
409 let base = self.auth_url();
410 if base.is_empty() {
411 return base;
412 }
413 format!("{}&state={}", base, url_encode(state))
414 }
415
416 pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
421 let spec = self.resolved_spec()?;
422 let pkce = if spec.requires_pkce() {
423 Some(generate_pkce())
424 } else {
425 None
426 };
427 let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
428 let mut url = self.build_auth_url(challenge)?;
429 if !state.is_empty() {
430 url.push_str(&format!("&state={}", url_encode(state)));
431 }
432 Ok((url, pkce.map(|p| p.code_verifier)))
433 }
434
435 fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
436 let spec = self.resolved_spec()?;
437 let cfg = self.provider_cfg();
438 let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
439 if auth.is_empty() {
440 return Err(format!(
441 "provider {} has no authorization endpoint",
442 self.provider
443 ));
444 }
445 let scopes_default = spec.scopes().to_string();
446 let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
447 let scopes_joined = scopes_raw
451 .split_whitespace()
452 .collect::<Vec<_>>()
453 .join(spec.scope_separator());
454
455 let mut url = format!(
456 "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
457 cid_param = spec.client_id_param(),
458 cid = url_encode(&self.client_id),
459 ruri = url_encode(&self.redirect_uri),
460 scope = url_encode(&scopes_joined),
461 );
462 if !spec.auth_query_extra().is_empty() {
463 url.push('&');
464 url.push_str(spec.auth_query_extra());
465 }
466 if let Some(challenge) = pkce_challenge {
467 url.push_str("&code_challenge=");
468 url.push_str(challenge);
469 url.push_str("&code_challenge_method=S256");
470 }
471 Ok(url)
472 }
473
474 pub fn token_url(&self) -> String {
476 match self.resolved_spec() {
477 Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
478 Err(_) => String::new(),
479 }
480 }
481
482 pub fn userinfo_url(&self) -> String {
484 match self.resolved_spec() {
485 Ok(spec) => match spec.userinfo_url() {
486 Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
487 None => String::new(),
488 },
489 Err(_) => String::new(),
490 }
491 }
492
493 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
499 self.exchange_code_full_pkce(code, None)
500 }
501
502 pub fn exchange_code_full_pkce(
503 &self,
504 code: &str,
505 code_verifier: Option<&str>,
506 ) -> Result<TokenSet, String> {
507 let spec = self.resolved_spec()?;
508 let cfg = self.provider_cfg();
509 let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
510 let pkce_field = code_verifier
511 .map(|v| format!("&code_verifier={}", url_encode(v)))
512 .unwrap_or_default();
513
514 let out = match spec.token_exchange() {
515 provider::TokenExchangeShape::Standard => {
516 let body = format!(
517 "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
518 code = url_encode(code),
519 cid_param = spec.client_id_param(),
520 cid = url_encode(&self.client_id),
521 secret = url_encode(&self.client_secret),
522 ruri = url_encode(&self.redirect_uri),
523 pkce = pkce_field,
524 );
525 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
526 }
527 provider::TokenExchangeShape::AppleJwt => {
528 let apple = self.apple.as_ref().ok_or(
529 "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
530 )?;
531 let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
532 let body = format!(
533 "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
534 code = url_encode(code),
535 cid = url_encode(&self.client_id),
536 secret = url_encode(&signed_secret),
537 ruri = url_encode(&self.redirect_uri),
538 pkce = pkce_field,
539 );
540 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
541 }
542 provider::TokenExchangeShape::BasicAuth => {
543 let body = format!(
544 "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
545 code = url_encode(code),
546 ruri = url_encode(&self.redirect_uri),
547 pkce = pkce_field,
548 );
549 http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
550 .map_err(sanitize_token_error)?
551 }
552 provider::TokenExchangeShape::JsonBody => {
553 let mut json = serde_json::Map::new();
554 json.insert("grant_type".into(), "authorization_code".into());
555 json.insert("code".into(), code.into());
556 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
557 json.insert("client_id".into(), self.client_id.clone().into());
558 json.insert("client_secret".into(), self.client_secret.clone().into());
559 if let Some(v) = code_verifier {
560 json.insert("code_verifier".into(), v.to_string().into());
561 }
562 let body = serde_json::Value::Object(json).to_string();
563 http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
564 }
565 provider::TokenExchangeShape::BasicAuthJsonBody => {
566 let mut json = serde_json::Map::new();
567 json.insert("grant_type".into(), "authorization_code".into());
568 json.insert("code".into(), code.into());
569 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
570 if let Some(v) = code_verifier {
571 json.insert("code_verifier".into(), v.to_string().into());
572 }
573 let body = serde_json::Value::Object(json).to_string();
574 http_post_json(
575 &token_url,
576 &body,
577 Some((&self.client_id, &self.client_secret)),
578 )
579 .map_err(sanitize_token_error)?
580 }
581 };
582 parse_token_response(&out)
583 }
584
585 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
589 Ok(self.exchange_code_full(code)?.access_token)
590 }
591
592 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
594 let info = self.fetch_userinfo_full(access_token)?;
595 Ok((info.email, info.name))
596 }
597
598 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
603 self.fetch_userinfo_with_id_token(access_token, None)
607 }
608
609 pub fn fetch_userinfo_with_id_token(
614 &self,
615 access_token: &str,
616 id_token: Option<&str>,
617 ) -> Result<UserInfo, String> {
618 let spec = self.resolved_spec()?;
619 let cfg = self.provider_cfg();
620
621 if matches!(spec.userinfo_parser(), provider::UserinfoParser::AppleIdToken) {
623 let token = id_token
624 .ok_or("apple login requires the id_token from the token response")?;
625 return parse_apple_id_token(token, &self.provider);
626 }
627
628 if matches!(spec.userinfo_parser(), provider::UserinfoParser::LinearGraphql) {
631 return fetch_linear_userinfo(&self.provider, access_token);
632 }
633
634 let url = match spec.userinfo_url() {
635 Some(u) => provider::resolve_endpoint(u, &cfg),
636 None => return Err(format!("provider {} has no userinfo endpoint", self.provider)),
637 };
638 let out = match spec.userinfo_method() {
639 provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
640 provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
641 }
642 .map_err(sanitize_token_error)?;
643 let parsed: serde_json::Value =
644 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
645
646 match spec.userinfo_parser() {
647 provider::UserinfoParser::Oidc => {
648 let email = parsed
649 .get("email")
650 .and_then(|v| v.as_str())
651 .ok_or("no email in userinfo")?
652 .to_string();
653 let name = parsed
654 .get("name")
655 .and_then(|v| v.as_str())
656 .map(String::from);
657 let provider_account_id = parsed
658 .get("sub")
659 .and_then(|v| v.as_str())
660 .ok_or("no sub in userinfo")?
661 .to_string();
662 Ok(UserInfo {
663 provider: self.provider.clone(),
664 provider_account_id,
665 email,
666 name,
667 })
668 }
669 provider::UserinfoParser::GitHub => {
670 let name = parsed
671 .get("name")
672 .and_then(|v| v.as_str())
673 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
674 .map(String::from);
675 let email = parsed
676 .get("email")
677 .and_then(|v| v.as_str())
678 .map(String::from);
679 let email = email
680 .or_else(|| fetch_github_primary_email(access_token).ok())
681 .ok_or("no accessible email on GitHub account")?;
682 let provider_account_id = parsed
683 .get("id")
684 .map(|v| {
685 v.as_i64()
686 .map(|n| n.to_string())
687 .or_else(|| v.as_str().map(String::from))
688 .unwrap_or_default()
689 })
690 .filter(|s| !s.is_empty())
691 .ok_or("no id in userinfo")?;
692 Ok(UserInfo {
693 provider: self.provider.clone(),
694 provider_account_id,
695 email,
696 name,
697 })
698 }
699 provider::UserinfoParser::Custom {
700 id_path,
701 email_path,
702 name_path,
703 } => {
704 let provider_account_id = json_pointer_string(&parsed, id_path)
705 .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
706 let raw_email = json_pointer_string(&parsed, email_path)
707 .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
708 let email = if !raw_email.contains('@') {
713 let domain = match self.provider.as_str() {
714 "twitter" => "x.invalid",
715 "reddit" => "reddit.invalid",
716 other => return Err(format!(
717 "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
718 )),
719 };
720 format!("{raw_email}@{domain}")
721 } else {
722 raw_email
723 };
724 let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
725 Ok(UserInfo {
726 provider: self.provider.clone(),
727 provider_account_id,
728 email,
729 name,
730 })
731 }
732 provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
733 provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
734 }
735 }
736}
737
738struct PkcePair {
741 code_verifier: String,
742 code_challenge: String,
743}
744
745fn generate_pkce() -> PkcePair {
749 use rand::RngCore;
750 let mut bytes = [0u8; 32];
751 rand::thread_rng().fill_bytes(&mut bytes);
752 let code_verifier = apple_jwt::base64_url(bytes);
753 use sha2::{Digest, Sha256};
754 let mut hasher = Sha256::new();
755 hasher.update(code_verifier.as_bytes());
756 let code_challenge = apple_jwt::base64_url(hasher.finalize());
757 PkcePair {
758 code_verifier,
759 code_challenge,
760 }
761}
762
763fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
786 let mut parts = id_token.split('.');
787 let _header = parts.next().ok_or("apple id_token: missing header")?;
788 let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
789 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
790 let claims_bytes = URL_SAFE_NO_PAD
791 .decode(claims_b64)
792 .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
793 let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
794 .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
795 let provider_account_id = claims
796 .get("sub")
797 .and_then(|v| v.as_str())
798 .ok_or("apple id_token: missing sub")?
799 .to_string();
800 let email = claims
801 .get("email")
802 .and_then(|v| v.as_str())
803 .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
804 .to_string();
805 Ok(UserInfo {
806 provider: provider.to_string(),
807 provider_account_id,
808 email,
809 name: None, })
811}
812
813fn sanitize_token_error(err: String) -> String {
825 const SENSITIVE: &[&str] = &[
826 "client_secret",
827 "code_verifier",
828 "client_assertion",
829 "refresh_token",
830 "access_token",
831 "id_token",
832 "code",
837 ];
838 let mut out = err;
839 for key in SENSITIVE {
840 out = redact_param_form(&out, key);
841 out = redact_param_json(&out, key);
842 }
843 out
844}
845
846fn redact_param_form(input: &str, key: &str) -> String {
850 let needle = format!("{key}=");
851 let mut out = String::with_capacity(input.len());
852 let mut i = 0;
853 while i < input.len() {
854 if input[i..].starts_with(&needle) {
855 out.push_str(&needle);
856 out.push_str("***");
857 i += needle.len();
858 while let Some((rel, ch)) = input[i..].char_indices().next() {
861 if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
862 i += rel;
863 break;
864 }
865 i += rel + ch.len_utf8();
866 }
867 } else {
868 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
870 out.push(ch);
871 i += ch.len_utf8();
872 }
873 }
874 out
875}
876
877fn redact_param_json(input: &str, key: &str) -> String {
880 let needle = format!("\"{key}\"");
881 let mut out = String::with_capacity(input.len());
882 let mut i = 0;
883 while i < input.len() {
884 if !input[i..].starts_with(&needle) {
885 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
886 out.push(ch);
887 i += ch.len_utf8();
888 continue;
889 }
890 let mut j = i + needle.len();
895 while let Some((_, ch)) = input[j..].char_indices().next() {
897 if !ch.is_whitespace() {
898 break;
899 }
900 j += ch.len_utf8();
901 }
902 if !input[j..].starts_with(':') {
903 out.push_str(&input[i..j]);
905 i = j;
906 continue;
907 }
908 j += 1;
909 while let Some((_, ch)) = input[j..].char_indices().next() {
910 if !ch.is_whitespace() {
911 break;
912 }
913 j += ch.len_utf8();
914 }
915 if !input[j..].starts_with('"') {
916 out.push_str(&input[i..j]);
917 i = j;
918 continue;
919 }
920 let value_start = j + 1;
921 let mut k = value_start;
923 let mut prev_backslash = false;
924 let mut closing: Option<usize> = None;
925 while k < input.len() {
926 let (_, ch) = input[k..].char_indices().next().expect("non-empty");
927 if ch == '"' && !prev_backslash {
928 closing = Some(k);
929 break;
930 }
931 prev_backslash = ch == '\\' && !prev_backslash;
932 k += ch.len_utf8();
933 }
934 match closing {
935 Some(end) => {
936 out.push_str(&input[i..value_start]);
937 out.push_str("***");
938 out.push('"');
939 i = end + 1;
940 }
941 None => {
942 out.push_str(&input[i..value_start]);
944 out.push_str("***");
945 i = input.len();
946 }
947 }
948 }
949 out
950}
951
952fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
957 let body = r#"{"query":"query { viewer { id email name } }"}"#;
958 let agent = ureq_agent();
959 let resp = agent
960 .post("https://api.linear.app/graphql")
961 .set("Authorization", &format!("Bearer {access_token}"))
962 .set("Content-Type", "application/json")
963 .set("Accept", "application/json")
964 .send_string(body)
965 .map_err(|e| format!("linear graphql: {e}"))?;
966 let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
967 let parsed: serde_json::Value = serde_json::from_str(&out)
968 .map_err(|e| format!("linear graphql not JSON: {e}"))?;
969 let viewer = parsed
970 .pointer("/data/viewer")
971 .ok_or("linear graphql: no /data/viewer")?;
972 let provider_account_id = viewer
973 .get("id")
974 .and_then(|v| v.as_str())
975 .ok_or("linear graphql: no id")?
976 .to_string();
977 let email = viewer
978 .get("email")
979 .and_then(|v| v.as_str())
980 .ok_or("linear graphql: no email")?
981 .to_string();
982 let name = viewer.get("name").and_then(|v| v.as_str()).map(String::from);
983 Ok(UserInfo {
984 provider: provider.to_string(),
985 provider_account_id,
986 email,
987 name,
988 })
989}
990
991fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
995 let node = v.pointer(path)?;
996 if let Some(s) = node.as_str() {
997 return Some(s.to_string());
998 }
999 if let Some(n) = node.as_i64() {
1000 return Some(n.to_string());
1001 }
1002 if let Some(n) = node.as_u64() {
1003 return Some(n.to_string());
1004 }
1005 None
1006}
1007
1008#[derive(Debug, Clone, PartialEq, Eq)]
1013pub struct UserInfo {
1014 pub provider: String,
1015 pub provider_account_id: String,
1016 pub email: String,
1017 pub name: Option<String>,
1018}
1019
1020#[derive(Debug, Clone, PartialEq, Eq)]
1024pub struct TokenSet {
1025 pub access_token: String,
1026 pub refresh_token: Option<String>,
1027 pub id_token: Option<String>,
1028 pub expires_at: Option<u64>,
1032 pub scope: Option<String>,
1033}
1034
1035fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1036 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1039 let mut map = serde_json::Map::new();
1041 for pair in body.split('&') {
1042 if let Some((k, v)) = pair.split_once('=') {
1043 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1044 }
1045 }
1046 serde_json::Value::Object(map)
1047 });
1048
1049 let access_token = json
1050 .get("access_token")
1051 .and_then(|v| v.as_str())
1052 .ok_or_else(|| format!("no access_token in token response: {body}"))?
1053 .to_string();
1054 let refresh_token = json
1055 .get("refresh_token")
1056 .and_then(|v| v.as_str())
1057 .map(String::from);
1058 let id_token = json
1059 .get("id_token")
1060 .and_then(|v| v.as_str())
1061 .map(String::from);
1062 let expires_at = json
1063 .get("expires_in")
1064 .and_then(|v| {
1065 v.as_u64()
1066 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1067 })
1068 .map(|secs| now_secs().saturating_add(secs));
1069 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1070 Ok(TokenSet {
1071 access_token,
1072 refresh_token,
1073 id_token,
1074 expires_at,
1075 scope,
1076 })
1077}
1078
1079fn url_encode(s: &str) -> String {
1080 let mut out = String::with_capacity(s.len());
1081 for b in s.bytes() {
1082 match b {
1083 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1084 out.push(b as char)
1085 }
1086 _ => out.push_str(&format!("%{b:02X}")),
1087 }
1088 }
1089 out
1090}
1091
1092const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1096
1097fn ureq_agent() -> ureq::Agent {
1098 ureq::AgentBuilder::new()
1099 .timeout_connect(HTTP_TIMEOUT)
1100 .timeout_read(HTTP_TIMEOUT)
1101 .timeout_write(HTTP_TIMEOUT)
1102 .user_agent("pylon/0.1")
1103 .build()
1104}
1105
1106fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1107 let agent = ureq_agent();
1108 let mut req = agent
1109 .post(url)
1110 .set("Content-Type", "application/x-www-form-urlencoded");
1111 if accept_json {
1112 req = req.set("Accept", "application/json");
1113 }
1114 match req.send_string(body) {
1115 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1116 Err(ureq::Error::Status(code, resp)) => {
1117 let body = resp.into_string().unwrap_or_default();
1118 Err(format!("HTTP {code}: {body}"))
1119 }
1120 Err(e) => Err(format!("HTTP error: {e}")),
1121 }
1122}
1123
1124fn http_post_form_basic(
1128 url: &str,
1129 body: &str,
1130 client_id: &str,
1131 client_secret: &str,
1132) -> Result<String, String> {
1133 use base64::{engine::general_purpose::STANDARD, Engine};
1134 let creds = format!("{client_id}:{client_secret}");
1135 let basic = STANDARD.encode(creds.as_bytes());
1136 let agent = ureq_agent();
1137 match agent
1138 .post(url)
1139 .set("Content-Type", "application/x-www-form-urlencoded")
1140 .set("Accept", "application/json")
1141 .set("Authorization", &format!("Basic {basic}"))
1142 .send_string(body)
1143 {
1144 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1145 Err(ureq::Error::Status(code, resp)) => {
1146 let body = resp.into_string().unwrap_or_default();
1147 Err(format!("HTTP {code}: {body}"))
1148 }
1149 Err(e) => Err(format!("HTTP error: {e}")),
1150 }
1151}
1152
1153fn http_post_json(
1157 url: &str,
1158 body: &str,
1159 basic_creds: Option<(&str, &str)>,
1160) -> Result<String, String> {
1161 let agent = ureq_agent();
1162 let mut req = agent
1163 .post(url)
1164 .set("Content-Type", "application/json")
1165 .set("Accept", "application/json");
1166 if let Some((id, secret)) = basic_creds {
1167 use base64::{engine::general_purpose::STANDARD, Engine};
1168 let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1169 req = req.set("Authorization", &format!("Basic {creds}"));
1170 }
1171 req = req.set("Notion-Version", "2022-06-28");
1174 match req.send_string(body) {
1175 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1176 Err(ureq::Error::Status(code, resp)) => {
1177 let body = resp.into_string().unwrap_or_default();
1178 Err(format!("HTTP {code}: {body}"))
1179 }
1180 Err(e) => Err(format!("HTTP error: {e}")),
1181 }
1182}
1183
1184fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1187 let agent = ureq_agent();
1188 match agent
1189 .post(url)
1190 .set("Authorization", &format!("Bearer {token}"))
1191 .set("Accept", "application/json")
1192 .call()
1193 {
1194 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1195 Err(ureq::Error::Status(code, resp)) => {
1196 let body = resp.into_string().unwrap_or_default();
1197 Err(format!("HTTP {code}: {body}"))
1198 }
1199 Err(e) => Err(format!("HTTP error: {e}")),
1200 }
1201}
1202
1203fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1204 let agent = ureq_agent();
1205 match agent
1206 .get(url)
1207 .set("Authorization", &format!("Bearer {token}"))
1208 .set("Accept", "application/json")
1209 .call()
1210 {
1211 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1212 Err(ureq::Error::Status(code, resp)) => {
1213 let body = resp.into_string().unwrap_or_default();
1214 Err(format!("HTTP {code}: {body}"))
1215 }
1216 Err(e) => Err(format!("HTTP error: {e}")),
1217 }
1218}
1219
1220fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1221 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1222 let emails: serde_json::Value =
1223 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1224 emails
1225 .as_array()
1226 .and_then(|arr| {
1227 arr.iter()
1228 .find(|e| {
1229 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1230 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1231 })
1232 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1233 })
1234 .ok_or_else(|| "no primary verified email on GitHub".into())
1235}
1236
1237pub struct OAuthRegistry {
1239 providers: std::collections::HashMap<String, OAuthConfig>,
1240}
1241
1242impl Default for OAuthRegistry {
1243 fn default() -> Self {
1244 Self::new()
1245 }
1246}
1247
1248impl OAuthRegistry {
1249 pub fn new() -> Self {
1250 Self {
1251 providers: std::collections::HashMap::new(),
1252 }
1253 }
1254
1255 pub fn register(&mut self, config: OAuthConfig) {
1256 self.providers.insert(config.provider.clone(), config);
1257 }
1258
1259 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1260 self.providers.get(provider)
1261 }
1262
1263 pub fn from_env() -> Self {
1276 let mut reg = Self::new();
1277
1278 for spec in provider::builtin::all() {
1279 let upper = spec.id.to_ascii_uppercase();
1280 let prefix = format!("PYLON_OAUTH_{upper}");
1281 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1282 Ok(v) => v,
1283 Err(_) => continue,
1284 };
1285 let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1286 Ok(v) => v,
1287 Err(_) if spec.id == "apple" => String::new(),
1289 Err(_) => continue,
1290 };
1291 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT")).unwrap_or_else(|_| {
1292 format!("http://localhost:3000/api/auth/callback/{}", spec.id)
1293 });
1294 let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1295 let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1296
1297 let apple = if spec.id == "apple" {
1298 match (
1299 std::env::var(format!("{prefix}_TEAM_ID")),
1300 std::env::var(format!("{prefix}_KEY_ID")),
1301 std::env::var(format!("{prefix}_PRIVATE_KEY")),
1302 ) {
1303 (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1304 team_id,
1305 key_id,
1306 private_key_pem,
1307 }),
1308 _ => continue, }
1310 } else {
1311 None
1312 };
1313
1314 reg.register(OAuthConfig {
1315 provider: spec.id.to_string(),
1316 client_id: id,
1317 client_secret: secret,
1318 redirect_uri,
1319 scopes_override,
1320 tenant,
1321 apple,
1322 oidc_issuer: None,
1323 });
1324 }
1325
1326 for (key, issuer) in std::env::vars() {
1328 let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1329 continue;
1330 };
1331 let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1332 continue;
1333 };
1334 let name = name_upper.to_ascii_lowercase();
1335 if provider::find_spec(&name).is_some() {
1336 continue; }
1338 let prefix = format!("PYLON_OAUTH_{name_upper}");
1339 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1340 Ok(v) => v,
1341 Err(_) => continue,
1342 };
1343 let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1344 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1345 .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1346 reg.register(OAuthConfig {
1347 provider: name,
1348 client_id: id,
1349 client_secret: secret,
1350 redirect_uri,
1351 scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1352 tenant: None,
1353 apple: None,
1354 oidc_issuer: Some(issuer),
1355 });
1356 }
1357
1358 reg
1359 }
1360
1361 pub fn ids(&self) -> impl Iterator<Item = &str> {
1365 self.providers.keys().map(|s| s.as_str())
1366 }
1367
1368 pub fn shared() -> &'static OAuthRegistry {
1375 static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1376 CELL.get_or_init(Self::from_env)
1377 }
1378}
1379
1380#[derive(Debug, Clone, PartialEq, Eq)]
1390pub struct OAuthState {
1391 pub provider: String,
1392 pub callback_url: String,
1395 pub error_callback_url: String,
1400 pub pkce_verifier: Option<String>,
1405 pub expires_at: u64,
1406}
1407
1408pub trait OAuthStateBackend: Send + Sync {
1413 fn put(&self, token: &str, state: &OAuthState);
1415 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1420}
1421
1422pub struct InMemoryOAuthBackend {
1424 states: Mutex<HashMap<String, OAuthState>>,
1425}
1426
1427impl InMemoryOAuthBackend {
1428 pub fn new() -> Self {
1429 Self {
1430 states: Mutex::new(HashMap::new()),
1431 }
1432 }
1433}
1434
1435impl Default for InMemoryOAuthBackend {
1436 fn default() -> Self {
1437 Self::new()
1438 }
1439}
1440
1441impl OAuthStateBackend for InMemoryOAuthBackend {
1442 fn put(&self, token: &str, state: &OAuthState) {
1443 self.states
1444 .lock()
1445 .unwrap()
1446 .insert(token.to_string(), state.clone());
1447 }
1448 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1449 let mut s = self.states.lock().unwrap();
1450 let entry = s.remove(token)?;
1451 if entry.expires_at <= now_unix_secs {
1452 return None;
1453 }
1454 Some(entry)
1455 }
1456}
1457
1458pub struct OAuthStateStore {
1465 backend: Box<dyn OAuthStateBackend>,
1466}
1467
1468impl Default for OAuthStateStore {
1469 fn default() -> Self {
1470 Self::new()
1471 }
1472}
1473
1474impl OAuthStateStore {
1475 pub fn new() -> Self {
1476 Self {
1477 backend: Box::new(InMemoryOAuthBackend::new()),
1478 }
1479 }
1480
1481 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1482 Self { backend }
1483 }
1484
1485 pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1493 self.create_with_pkce(provider, callback_url, error_callback_url, None)
1494 }
1495
1496 pub fn create_with_pkce(
1500 &self,
1501 provider: &str,
1502 callback_url: &str,
1503 error_callback_url: &str,
1504 pkce_verifier: Option<String>,
1505 ) -> String {
1506 use std::time::{SystemTime, UNIX_EPOCH};
1507 let token = generate_token();
1508 let now = SystemTime::now()
1509 .duration_since(UNIX_EPOCH)
1510 .unwrap_or_default()
1511 .as_secs();
1512 let state = OAuthState {
1513 provider: provider.to_string(),
1514 callback_url: callback_url.to_string(),
1515 error_callback_url: error_callback_url.to_string(),
1516 pkce_verifier,
1517 expires_at: now + 600,
1518 };
1519 self.backend.put(&token, &state);
1520 token
1521 }
1522
1523 pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1528 use std::time::{SystemTime, UNIX_EPOCH};
1529 let now = SystemTime::now()
1530 .duration_since(UNIX_EPOCH)
1531 .unwrap_or_default()
1532 .as_secs();
1533 let entry = self.backend.take(state, now)?;
1534 if entry.provider != expected_provider {
1535 return None;
1536 }
1537 Some(entry)
1538 }
1539}
1540
1541pub fn validate_trusted_redirect(
1558 url: &str,
1559 trusted_origins: &[String],
1560) -> Result<(), TrustedOriginError> {
1561 if url.is_empty() {
1562 return Err(TrustedOriginError::Empty);
1563 }
1564 if !url.starts_with("http://") && !url.starts_with("https://") {
1567 return Err(TrustedOriginError::NotHttp);
1568 }
1569 let url_origin = origin_of(url);
1570 if trusted_origins.iter().any(|t| t == &url_origin) {
1571 Ok(())
1572 } else {
1573 Err(TrustedOriginError::NotTrusted { origin: url_origin })
1574 }
1575}
1576
1577#[derive(Debug, Clone, PartialEq, Eq)]
1579pub enum TrustedOriginError {
1580 Empty,
1581 NotHttp,
1582 NotTrusted { origin: String },
1583}
1584
1585impl std::fmt::Display for TrustedOriginError {
1586 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1587 match self {
1588 TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1589 TrustedOriginError::NotHttp => {
1590 write!(f, "redirect URL must use http:// or https:// scheme")
1591 }
1592 TrustedOriginError::NotTrusted { origin } => write!(
1593 f,
1594 "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1595 ),
1596 }
1597 }
1598}
1599
1600pub fn origin_of(url: &str) -> String {
1605 let after_scheme = match url.find("://") {
1606 Some(i) => i + 3,
1607 None => return url.trim_end_matches('/').to_string(),
1608 };
1609 let rest = &url[after_scheme..];
1610 let cut = rest
1611 .find(|c: char| c == '/' || c == '?' || c == '#')
1612 .unwrap_or(rest.len());
1613 url[..after_scheme + cut].to_string()
1614}
1615
1616pub trait MagicCodeBackend: Send + Sync {
1629 fn put(&self, email: &str, code: &MagicCode);
1631 fn get(&self, email: &str) -> Option<MagicCode>;
1633 fn remove(&self, email: &str);
1636 fn bump_attempts(&self, email: &str);
1640 fn load_all(&self) -> Vec<MagicCode>;
1643}
1644
1645pub struct InMemoryMagicCodeBackend {
1648 codes: Mutex<HashMap<String, MagicCode>>,
1649}
1650
1651impl InMemoryMagicCodeBackend {
1652 pub fn new() -> Self {
1653 Self {
1654 codes: Mutex::new(HashMap::new()),
1655 }
1656 }
1657}
1658
1659impl Default for InMemoryMagicCodeBackend {
1660 fn default() -> Self {
1661 Self::new()
1662 }
1663}
1664
1665impl MagicCodeBackend for InMemoryMagicCodeBackend {
1666 fn put(&self, email: &str, code: &MagicCode) {
1667 self.codes
1668 .lock()
1669 .unwrap()
1670 .insert(email.to_string(), code.clone());
1671 }
1672 fn get(&self, email: &str) -> Option<MagicCode> {
1673 self.codes.lock().unwrap().get(email).cloned()
1674 }
1675 fn remove(&self, email: &str) {
1676 self.codes.lock().unwrap().remove(email);
1677 }
1678 fn bump_attempts(&self, email: &str) {
1679 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1680 c.attempts = c.attempts.saturating_add(1);
1681 }
1682 }
1683 fn load_all(&self) -> Vec<MagicCode> {
1684 self.codes.lock().unwrap().values().cloned().collect()
1685 }
1686}
1687
1688pub struct MagicCodeStore {
1693 cache: Mutex<HashMap<String, MagicCode>>,
1694 backend: Box<dyn MagicCodeBackend>,
1695}
1696
1697#[derive(Debug, Clone)]
1698pub struct MagicCode {
1699 pub email: String,
1700 pub code: String,
1701 pub expires_at: u64,
1702 pub attempts: u32,
1705}
1706
1707const MAX_ATTEMPTS: u32 = 5;
1711
1712const CREATE_COOLDOWN_SECS: u64 = 60;
1715
1716#[derive(Debug, Clone, PartialEq, Eq)]
1717pub enum MagicCodeError {
1718 NotFound,
1720 TooManyAttempts,
1722 BadCode,
1724 Expired,
1726 Throttled { retry_after_secs: u64 },
1728}
1729
1730impl Default for MagicCodeStore {
1731 fn default() -> Self {
1732 Self::new()
1733 }
1734}
1735
1736impl MagicCodeStore {
1737 pub fn new() -> Self {
1738 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1739 }
1740
1741 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1746 let now = now_secs();
1747 let mut cache = HashMap::new();
1748 for c in backend.load_all() {
1749 if c.expires_at > now {
1750 cache.insert(c.email.clone(), c);
1751 }
1752 }
1753 Self {
1754 cache: Mutex::new(cache),
1755 backend,
1756 }
1757 }
1758
1759 pub fn create(&self, email: &str) -> String {
1762 self.try_create(email).unwrap_or_else(|_| String::new())
1765 }
1766
1767 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1770 let now = now_secs();
1771
1772 let mut codes = self.cache.lock().unwrap();
1773
1774 if let Some(existing) = codes.get(email) {
1778 if existing.expires_at > now {
1779 let created_at = existing.expires_at.saturating_sub(600);
1780 let age = now.saturating_sub(created_at);
1781 if age < CREATE_COOLDOWN_SECS {
1782 return Err(MagicCodeError::Throttled {
1783 retry_after_secs: CREATE_COOLDOWN_SECS - age,
1784 });
1785 }
1786 }
1787 }
1788
1789 let code = generate_magic_code();
1790 let mc = MagicCode {
1791 email: email.to_string(),
1792 code: code.clone(),
1793 expires_at: now + 600, attempts: 0,
1795 };
1796 codes.insert(email.to_string(), mc.clone());
1797 self.backend.put(email, &mc);
1801 Ok(code)
1802 }
1803
1804 pub fn verify(&self, email: &str, code: &str) -> bool {
1808 matches!(self.try_verify(email, code), Ok(()))
1809 }
1810
1811 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1818 self.cache
1819 .lock()
1820 .map(|m| m.values().cloned().collect())
1821 .unwrap_or_default()
1822 }
1823
1824 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1825 let now = now_secs();
1826 let mut codes = self.cache.lock().unwrap();
1827
1828 let mc = match codes.get_mut(email) {
1829 Some(m) => m,
1830 None => return Err(MagicCodeError::NotFound),
1831 };
1832
1833 if mc.attempts >= MAX_ATTEMPTS {
1834 return Err(MagicCodeError::TooManyAttempts);
1835 }
1836 if mc.expires_at <= now {
1837 codes.remove(email);
1838 self.backend.remove(email);
1839 return Err(MagicCodeError::Expired);
1840 }
1841
1842 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1843 if !ok {
1844 mc.attempts += 1;
1845 self.backend.bump_attempts(email);
1846 if mc.attempts >= MAX_ATTEMPTS {
1848 return Err(MagicCodeError::TooManyAttempts);
1849 }
1850 return Err(MagicCodeError::BadCode);
1851 }
1852
1853 codes.remove(email);
1855 self.backend.remove(email);
1856 Ok(())
1857 }
1858}
1859
1860fn hex_encode(bytes: &[u8]) -> String {
1865 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1866}
1867
1868fn generate_magic_code() -> String {
1870 use rand::Rng;
1871 let mut rng = rand::thread_rng();
1872 let code: u32 = rng.gen_range(0..1_000_000);
1873 format!("{:06}", code)
1874}
1875
1876fn generate_token() -> String {
1878 use rand::Rng;
1879 let mut rng = rand::thread_rng();
1880 let bytes: [u8; 32] = rng.gen();
1881 format!("pylon_{}", hex_encode(&bytes))
1882}
1883
1884use std::collections::HashMap;
1889use std::sync::Mutex;
1890
1891pub trait SessionBackend: Send + Sync {
1895 fn load_all(&self) -> Vec<Session>;
1896 fn save(&self, session: &Session);
1897 fn remove(&self, token: &str);
1898}
1899
1900pub struct SessionStore {
1908 sessions: Mutex<HashMap<String, Session>>,
1909 backend: Option<Box<dyn SessionBackend>>,
1910 default_lifetime_secs: u64,
1914}
1915
1916impl Default for SessionStore {
1917 fn default() -> Self {
1918 Self::new()
1919 }
1920}
1921
1922impl SessionStore {
1923 pub fn new() -> Self {
1924 Self {
1925 sessions: Mutex::new(HashMap::new()),
1926 backend: None,
1927 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1928 }
1929 }
1930
1931 pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1934 self.default_lifetime_secs = lifetime_secs;
1935 self
1936 }
1937
1938 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1942 let mut map = HashMap::new();
1943 for s in backend.load_all() {
1944 if !s.is_expired() {
1945 map.insert(s.token.clone(), s);
1946 }
1947 }
1948 Self {
1949 sessions: Mutex::new(map),
1950 backend: Some(backend),
1951 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1952 }
1953 }
1954
1955 pub fn create(&self, user_id: String) -> Session {
1959 let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1960 let mut sessions = self.sessions.lock().unwrap();
1961 sessions.insert(session.token.clone(), session.clone());
1962 if let Some(b) = &self.backend {
1963 b.save(&session);
1964 }
1965 session
1966 }
1967
1968 pub fn get(&self, token: &str) -> Option<Session> {
1970 let mut sessions = self.sessions.lock().unwrap();
1971 match sessions.get(token) {
1972 Some(s) if s.is_expired() => {
1973 sessions.remove(token);
1974 None
1975 }
1976 Some(s) => Some(s.clone()),
1977 None => None,
1978 }
1979 }
1980
1981 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
1984 match token {
1985 Some(t) => match self.get(t) {
1986 Some(session) => session.to_auth_context(),
1987 None => AuthContext::anonymous(),
1988 },
1989 None => AuthContext::anonymous(),
1990 }
1991 }
1992
1993 pub fn refresh(&self, old_token: &str) -> Option<Session> {
1997 let mut sessions = self.sessions.lock().unwrap();
1998 let old = sessions.remove(old_token)?;
1999 if let Some(b) = &self.backend {
2000 b.remove(old_token);
2001 }
2002 if old.is_expired() {
2003 return None;
2004 }
2005 let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2011 new.device = old.device.clone();
2012 sessions.insert(new.token.clone(), new.clone());
2013 if let Some(b) = &self.backend {
2014 b.save(&new);
2015 }
2016 Some(new)
2017 }
2018
2019 pub fn list_all_unfiltered(&self) -> Vec<Session> {
2024 self.sessions
2025 .lock()
2026 .map(|m| m.values().cloned().collect())
2027 .unwrap_or_default()
2028 }
2029
2030 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2032 let sessions = self.sessions.lock().unwrap();
2033 sessions
2034 .values()
2035 .filter(|s| s.user_id == user_id && !s.is_expired())
2036 .cloned()
2037 .collect()
2038 }
2039
2040 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2042 let mut sessions = self.sessions.lock().unwrap();
2043 let tokens: Vec<String> = sessions
2044 .iter()
2045 .filter_map(|(t, s)| {
2046 if s.user_id == user_id {
2047 Some(t.clone())
2048 } else {
2049 None
2050 }
2051 })
2052 .collect();
2053 let n = tokens.len();
2054 for t in &tokens {
2055 sessions.remove(t);
2056 if let Some(b) = &self.backend {
2057 b.remove(t);
2058 }
2059 }
2060 n
2061 }
2062
2063 pub fn sweep_expired(&self) -> usize {
2065 let mut sessions = self.sessions.lock().unwrap();
2066 let expired: Vec<String> = sessions
2067 .iter()
2068 .filter_map(|(t, s)| {
2069 if s.is_expired() {
2070 Some(t.clone())
2071 } else {
2072 None
2073 }
2074 })
2075 .collect();
2076 let n = expired.len();
2077 for t in &expired {
2078 sessions.remove(t);
2079 if let Some(b) = &self.backend {
2080 b.remove(t);
2081 }
2082 }
2083 n
2084 }
2085
2086 pub fn set_device(&self, token: &str, device: String) -> bool {
2088 let mut sessions = self.sessions.lock().unwrap();
2089 if let Some(s) = sessions.get_mut(token) {
2090 s.device = Some(device);
2091 if let Some(b) = &self.backend {
2092 b.save(s);
2093 }
2094 true
2095 } else {
2096 false
2097 }
2098 }
2099
2100 pub fn create_guest(&self) -> Session {
2102 use rand::Rng;
2103 let mut rng = rand::thread_rng();
2104 let bytes: [u8; 16] = rng.gen();
2105 let guest_id = format!("guest_{}", hex_encode(&bytes));
2106 self.create(guest_id)
2107 }
2108
2109 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2111 let mut sessions = self.sessions.lock().unwrap();
2112 if let Some(session) = sessions.get_mut(token) {
2113 session.user_id = real_user_id;
2114 if let Some(b) = &self.backend {
2115 b.save(session);
2116 }
2117 true
2118 } else {
2119 false
2120 }
2121 }
2122
2123 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2128 let mut sessions = self.sessions.lock().unwrap();
2129 if let Some(session) = sessions.get_mut(token) {
2130 session.tenant_id = tenant_id;
2131 if let Some(b) = &self.backend {
2132 b.save(session);
2133 }
2134 true
2135 } else {
2136 false
2137 }
2138 }
2139
2140 pub fn revoke(&self, token: &str) -> bool {
2142 let mut sessions = self.sessions.lock().unwrap();
2143 let removed = sessions.remove(token).is_some();
2144 if removed {
2145 if let Some(b) = &self.backend {
2146 b.remove(token);
2147 }
2148 }
2149 removed
2150 }
2151}
2152
2153#[derive(Debug, Clone, PartialEq, Eq)]
2179pub struct Account {
2180 pub id: String,
2181 pub user_id: String,
2182 pub provider_id: String,
2185 pub account_id: String,
2188 pub access_token: Option<String>,
2189 pub refresh_token: Option<String>,
2190 pub id_token: Option<String>,
2191 pub access_token_expires_at: Option<u64>,
2194 pub refresh_token_expires_at: Option<u64>,
2198 pub scope: Option<String>,
2199 pub password: Option<String>,
2203 pub created_at: u64,
2205 pub updated_at: u64,
2207}
2208
2209impl Account {
2210 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2214 let now = now_secs();
2215 Self {
2216 id: generate_token(),
2217 user_id,
2218 provider_id: info.provider.clone(),
2219 account_id: info.provider_account_id.clone(),
2220 access_token: Some(tokens.access_token.clone()),
2221 refresh_token: tokens.refresh_token.clone(),
2222 id_token: tokens.id_token.clone(),
2223 access_token_expires_at: tokens.expires_at,
2224 refresh_token_expires_at: None,
2225 scope: tokens.scope.clone(),
2226 password: None,
2227 created_at: now,
2228 updated_at: now,
2229 }
2230 }
2231
2232 pub fn access_token_expired(&self) -> bool {
2237 match self.access_token_expires_at {
2238 Some(ts) => now_secs() >= ts,
2239 None => false,
2240 }
2241 }
2242}
2243
2244pub trait AccountBackend: Send + Sync {
2247 fn upsert(&self, account: &Account);
2251 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2254 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2259 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2261 fn delete_for_user(&self, user_id: &str) -> usize {
2266 let accounts = self.find_for_user(user_id);
2267 let n = accounts.len();
2268 for a in accounts {
2269 self.unlink(&a.provider_id, &a.account_id);
2270 }
2271 n
2272 }
2273 fn list_all(&self) -> Vec<Account>;
2278}
2279
2280pub struct InMemoryAccountBackend {
2284 accounts: Mutex<HashMap<(String, String), Account>>,
2288}
2289
2290impl InMemoryAccountBackend {
2291 pub fn new() -> Self {
2292 Self {
2293 accounts: Mutex::new(HashMap::new()),
2294 }
2295 }
2296}
2297
2298impl Default for InMemoryAccountBackend {
2299 fn default() -> Self {
2300 Self::new()
2301 }
2302}
2303
2304impl AccountBackend for InMemoryAccountBackend {
2305 fn upsert(&self, account: &Account) {
2306 let key = (account.provider_id.clone(), account.account_id.clone());
2307 self.accounts.lock().unwrap().insert(key, account.clone());
2308 }
2309 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2310 self.accounts
2311 .lock()
2312 .unwrap()
2313 .get(&(provider_id.to_string(), account_id.to_string()))
2314 .cloned()
2315 }
2316 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2317 self.accounts
2318 .lock()
2319 .unwrap()
2320 .values()
2321 .filter(|a| a.user_id == user_id)
2322 .cloned()
2323 .collect()
2324 }
2325 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2326 self.accounts
2327 .lock()
2328 .unwrap()
2329 .remove(&(provider_id.to_string(), account_id.to_string()))
2330 .is_some()
2331 }
2332 fn list_all(&self) -> Vec<Account> {
2333 self.accounts.lock().unwrap().values().cloned().collect()
2334 }
2335}
2336
2337pub struct AccountStore {
2340 backend: Box<dyn AccountBackend>,
2341}
2342
2343impl Default for AccountStore {
2344 fn default() -> Self {
2345 Self::new()
2346 }
2347}
2348
2349impl AccountStore {
2350 pub fn new() -> Self {
2351 Self {
2352 backend: Box::new(InMemoryAccountBackend::new()),
2353 }
2354 }
2355 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2356 Self { backend }
2357 }
2358 pub fn upsert(&self, account: &Account) {
2359 self.backend.upsert(account);
2360 }
2361 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2362 self.backend.find_by_provider(provider_id, account_id)
2363 }
2364 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2365 self.backend.find_for_user(user_id)
2366 }
2367 pub fn delete_for_user(&self, user_id: &str) -> usize {
2368 self.backend.delete_for_user(user_id)
2369 }
2370
2371 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2372 self.backend.unlink(provider_id, account_id)
2373 }
2374
2375 pub fn list_all_unfiltered(&self) -> Vec<Account> {
2389 self.backend.list_all()
2390 }
2391}
2392
2393#[cfg(test)]
2398mod tests {
2399 use super::*;
2400
2401 #[test]
2402 fn anonymous_context() {
2403 let ctx = AuthContext::anonymous();
2404 assert!(!ctx.is_authenticated());
2405 assert!(ctx.user_id.is_none());
2406 }
2407
2408 #[test]
2409 fn authenticated_context() {
2410 let ctx = AuthContext::authenticated("user-1".into());
2411 assert!(ctx.is_authenticated());
2412 assert_eq!(ctx.user_id, Some("user-1".into()));
2413 }
2414
2415 #[test]
2416 fn from_api_key_carries_scope_metadata() {
2417 let ctx = AuthContext::from_api_key(
2418 "user-1".into(),
2419 "key_abc".into(),
2420 Some("read,write".into()),
2421 );
2422 assert!(ctx.is_authenticated());
2423 assert!(ctx.is_api_key_auth());
2424 assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2425 assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2426 assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2427 }
2428
2429 #[test]
2430 fn session_auth_is_not_api_key_auth() {
2431 let ctx = AuthContext::authenticated("user-1".into());
2432 assert!(!ctx.is_api_key_auth());
2433 assert!(ctx.api_key_id.is_none());
2434 }
2435
2436 #[test]
2437 fn auth_mode_public_allows_anonymous() {
2438 let mode = AuthMode::Public;
2439 assert!(mode.check(&AuthContext::anonymous()));
2440 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2441 }
2442
2443 #[test]
2444 fn auth_mode_user_requires_authenticated() {
2445 let mode = AuthMode::User;
2446 assert!(!mode.check(&AuthContext::anonymous()));
2447 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2448 }
2449
2450 #[test]
2451 fn auth_mode_from_str() {
2452 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2453 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2454 assert_eq!(AuthMode::from_str("admin"), None);
2455 }
2456
2457 #[test]
2458 fn session_store_create_and_get() {
2459 let store = SessionStore::new();
2460 let session = store.create("user-1".into());
2461 assert!(!session.token.is_empty());
2462 assert!(session.token.starts_with("pylon_"));
2463
2464 let retrieved = store.get(&session.token).unwrap();
2465 assert_eq!(retrieved.user_id, "user-1");
2466 }
2467
2468 #[test]
2469 fn session_store_resolve() {
2470 let store = SessionStore::new();
2471 let session = store.create("user-1".into());
2472
2473 let ctx = store.resolve(Some(&session.token));
2474 assert!(ctx.is_authenticated());
2475 assert_eq!(ctx.user_id, Some("user-1".into()));
2476
2477 let anon = store.resolve(None);
2478 assert!(!anon.is_authenticated());
2479
2480 let bad = store.resolve(Some("invalid-token"));
2481 assert!(!bad.is_authenticated());
2482 }
2483
2484 #[test]
2485 fn session_store_revoke() {
2486 let store = SessionStore::new();
2487 let session = store.create("user-1".into());
2488
2489 assert!(store.revoke(&session.token));
2490 assert!(store.get(&session.token).is_none());
2491 assert!(!store.revoke(&session.token)); }
2493
2494 #[test]
2495 fn session_to_auth_context() {
2496 let session = Session::new("user-42".into());
2497 let ctx = session.to_auth_context();
2498 assert_eq!(ctx.user_id, Some("user-42".into()));
2499 }
2500
2501 #[test]
2504 fn admin_context() {
2505 let ctx = AuthContext::admin();
2506 assert!(ctx.is_admin);
2507 assert!(ctx.is_authenticated());
2508 }
2509
2510 #[test]
2511 fn anonymous_not_admin() {
2512 let ctx = AuthContext::anonymous();
2513 assert!(!ctx.is_admin);
2514 }
2515
2516 #[test]
2517 fn authenticated_not_admin() {
2518 let ctx = AuthContext::authenticated("user-1".into());
2519 assert!(!ctx.is_admin);
2520 }
2521
2522 #[test]
2525 fn magic_code_create_and_verify() {
2526 let store = MagicCodeStore::new();
2527 let code = store.create("test@example.com");
2528 assert_eq!(code.len(), 6);
2529 assert!(store.verify("test@example.com", &code));
2530 }
2531
2532 #[test]
2533 fn magic_code_wrong_code_rejected() {
2534 let store = MagicCodeStore::new();
2535 store.create("test@example.com");
2536 assert!(!store.verify("test@example.com", "000000"));
2537 }
2538
2539 #[test]
2540 fn magic_code_wrong_email_rejected() {
2541 let store = MagicCodeStore::new();
2542 let code = store.create("test@example.com");
2543 assert!(!store.verify("other@example.com", &code));
2544 }
2545
2546 #[test]
2547 fn magic_code_consumed_after_verify() {
2548 let store = MagicCodeStore::new();
2549 let code = store.create("test@example.com");
2550 assert!(store.verify("test@example.com", &code));
2551 assert!(!store.verify("test@example.com", &code));
2553 }
2554
2555 #[test]
2556 fn magic_code_different_emails_independent() {
2557 let store = MagicCodeStore::new();
2558 let code1 = store.create("alice@example.com");
2559 let code2 = store.create("bob@example.com");
2560 assert!(store.verify("alice@example.com", &code1));
2562 assert!(store.verify("bob@example.com", &code2));
2563 }
2564
2565 #[test]
2568 fn constant_time_eq_equal() {
2569 assert!(constant_time_eq(b"hello", b"hello"));
2570 assert!(constant_time_eq(b"", b""));
2571 }
2572
2573 #[test]
2574 fn constant_time_eq_not_equal() {
2575 assert!(!constant_time_eq(b"hello", b"world"));
2576 assert!(!constant_time_eq(b"hello", b"hell"));
2577 assert!(!constant_time_eq(b"a", b"b"));
2578 }
2579
2580 #[test]
2583 fn generated_tokens_are_unique() {
2584 let t1 = generate_token();
2585 let t2 = generate_token();
2586 assert_ne!(t1, t2);
2587 assert!(t1.starts_with("pylon_"));
2588 assert!(t2.starts_with("pylon_"));
2589 assert_eq!(t1.len(), 6 + 64);
2591 }
2592
2593 #[test]
2596 fn oauth_registry_empty() {
2597 let reg = OAuthRegistry::new();
2598 assert!(reg.get("google").is_none());
2599 }
2600
2601 #[test]
2602 fn oauth_registry_register_and_get() {
2603 let mut reg = OAuthRegistry::new();
2604 reg.register(OAuthConfig {
2605 provider: "google".into(),
2606 client_id: "test-id".into(),
2607 client_secret: "test-secret".into(),
2608 redirect_uri: "http://localhost/callback".into(),
2609 ..Default::default()
2610 });
2611 let config = reg.get("google").unwrap();
2612 assert_eq!(config.client_id, "test-id");
2613 assert!(config.auth_url().contains("accounts.google.com"));
2614 }
2615
2616 #[test]
2624 fn every_builtin_provider_routes_through_oauth_config() {
2625 for spec in provider::builtin::all() {
2626 let cfg = OAuthConfig {
2627 provider: spec.id.into(),
2628 client_id: "cid".into(),
2629 client_secret: "csecret".into(),
2630 redirect_uri: "https://app/cb".into(),
2631 tenant: if spec.id == "microsoft" {
2632 Some("contoso".into())
2633 } else {
2634 None
2635 },
2636 apple: if spec.id == "apple" {
2637 Some(provider::AppleConfig {
2638 team_id: "T".into(),
2639 key_id: "K".into(),
2640 private_key_pem: "no".into(),
2641 })
2642 } else {
2643 None
2644 },
2645 ..Default::default()
2646 };
2647 let auth = cfg.auth_url();
2648 assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2649 let expected_param = format!("{}=cid", spec.client_id_param);
2651 assert!(
2652 auth.contains(&expected_param),
2653 "{}: missing {}; got auth_url: {}",
2654 spec.id,
2655 expected_param,
2656 auth,
2657 );
2658 assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2659 if spec.id == "apple" {
2661 assert!(
2662 auth.contains("response_mode=form_post"),
2663 "apple auth_url must include response_mode=form_post; got {auth}"
2664 );
2665 }
2666 }
2667 }
2668
2669 #[test]
2672 fn microsoft_tenant_placeholder_resolves() {
2673 let cfg = OAuthConfig {
2674 provider: "microsoft".into(),
2675 client_id: "id".into(),
2676 client_secret: "secret".into(),
2677 redirect_uri: "https://app/cb".into(),
2678 tenant: Some("contoso.onmicrosoft.com".into()),
2679 ..Default::default()
2680 };
2681 assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2682 assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2683 }
2684
2685 #[test]
2687 fn microsoft_default_tenant_common() {
2688 let cfg = OAuthConfig {
2689 provider: "microsoft".into(),
2690 client_id: "id".into(),
2691 client_secret: "secret".into(),
2692 redirect_uri: "https://app/cb".into(),
2693 ..Default::default()
2694 };
2695 assert!(cfg.auth_url().contains("/common/"));
2696 assert!(cfg.token_url().contains("/common/"));
2697 }
2698
2699 #[test]
2702 fn scopes_override_replaces_spec_default() {
2703 let cfg = OAuthConfig {
2704 provider: "github".into(),
2705 client_id: "id".into(),
2706 client_secret: "secret".into(),
2707 redirect_uri: "https://app/cb".into(),
2708 scopes_override: Some("repo user:email".into()),
2709 ..Default::default()
2710 };
2711 let auth = cfg.auth_url();
2712 assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2714 }
2715
2716 #[test]
2721 fn apple_exchange_requires_apple_config() {
2722 let cfg = OAuthConfig {
2723 provider: "apple".into(),
2724 client_id: "com.example.app".into(),
2725 client_secret: String::new(),
2726 redirect_uri: "https://app/cb".into(),
2727 apple: None, ..Default::default()
2729 };
2730 let err = cfg.exchange_code_full("x").unwrap_err();
2731 assert!(err.contains("apple provider requires"), "got: {err}");
2732 }
2733
2734 #[test]
2739 fn oidc_issuer_uses_discovered_endpoints() {
2740 let issuer = "https://acme.test.invalid";
2741 provider::oidc_cache::insert_for_test(
2742 issuer,
2743 provider::DiscoveredSpec {
2744 auth_url: "https://acme.test.invalid/authorize".into(),
2745 token_url: "https://acme.test.invalid/oauth/token".into(),
2746 userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2747 scopes: "openid email profile".into(),
2748 userinfo_parser: provider::UserinfoParser::Oidc,
2749 token_exchange: provider::TokenExchangeShape::Standard,
2750 },
2751 );
2752 let cfg = OAuthConfig {
2753 provider: "auth0".into(), client_id: "id".into(),
2755 client_secret: "secret".into(),
2756 redirect_uri: "https://app/cb".into(),
2757 oidc_issuer: Some(issuer.into()),
2758 ..Default::default()
2759 };
2760 assert!(cfg.auth_url().starts_with("https://acme.test.invalid/authorize?"));
2761 assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2762 assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2763 }
2764
2765 #[test]
2771 fn apple_auth_url_includes_form_post() {
2772 let cfg = OAuthConfig {
2773 provider: "apple".into(),
2774 client_id: "com.example.app".into(),
2775 client_secret: String::new(),
2776 redirect_uri: "https://app/cb".into(),
2777 apple: Some(provider::AppleConfig {
2778 team_id: "T".into(),
2779 key_id: "K".into(),
2780 private_key_pem: "no".into(),
2781 }),
2782 ..Default::default()
2783 };
2784 let auth = cfg.auth_url();
2785 assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2786 assert_eq!(cfg.userinfo_url(), "");
2788 }
2789
2790 #[test]
2795 fn apple_id_token_decode_extracts_identity() {
2796 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2798 use base64::Engine;
2799 let claims = serde_json::json!({
2800 "iss": "https://appleid.apple.com",
2801 "sub": "001234.abc.def",
2802 "aud": "com.example.app",
2803 "email": "user@privaterelay.appleid.com",
2804 "email_verified": "true",
2805 });
2806 let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
2807 .encode(claims.to_string().as_bytes());
2808 let id_token = format!("{header}.{claims_b64}.signature_ignored");
2809
2810 let cfg = OAuthConfig {
2811 provider: "apple".into(),
2812 client_id: "com.example.app".into(),
2813 client_secret: String::new(),
2814 redirect_uri: "https://app/cb".into(),
2815 apple: Some(provider::AppleConfig {
2816 team_id: "T".into(),
2817 key_id: "K".into(),
2818 private_key_pem: "no".into(),
2819 }),
2820 ..Default::default()
2821 };
2822 let info = cfg
2823 .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2824 .expect("apple id_token decode");
2825 assert_eq!(info.provider_account_id, "001234.abc.def");
2826 assert_eq!(info.email, "user@privaterelay.appleid.com");
2827
2828 let err = cfg.fetch_userinfo_full("token").unwrap_err();
2831 assert!(err.contains("apple login requires"), "got: {err}");
2832 }
2833
2834 #[test]
2838 fn twitter_auth_url_includes_pkce() {
2839 let cfg = OAuthConfig {
2840 provider: "twitter".into(),
2841 client_id: "tw_client".into(),
2842 client_secret: "tw_secret".into(),
2843 redirect_uri: "https://app/cb".into(),
2844 ..Default::default()
2845 };
2846 let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2847 let v = verifier.expect("twitter must produce verifier");
2848 assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2849 assert!(url.contains("code_challenge="), "got: {url}");
2850 assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2851
2852 let google = OAuthConfig {
2854 provider: "google".into(),
2855 client_id: "g".into(),
2856 client_secret: "g".into(),
2857 redirect_uri: "https://app/cb".into(),
2858 ..Default::default()
2859 };
2860 let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2861 assert!(gverifier.is_none(), "google should not add PKCE");
2862 assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2863 }
2864
2865 #[test]
2868 fn tiktok_uses_client_key_and_comma_scopes() {
2869 let cfg = OAuthConfig {
2870 provider: "tiktok".into(),
2871 client_id: "tk_client".into(),
2872 client_secret: "tk_secret".into(),
2873 redirect_uri: "https://app/cb".into(),
2874 scopes_override: Some("user.info.basic video.list".into()),
2875 ..Default::default()
2876 };
2877 let auth = cfg.auth_url();
2878 assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2879 assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2881 assert!(!auth.contains("user.info.basic%20video.list"), "got: {auth}");
2883 }
2884
2885 #[test]
2889 fn token_exchange_url_encodes_code() {
2890 let raw = "code+with/special=chars";
2896 let encoded = url_encode(raw);
2897 assert!(!encoded.contains('+'));
2898 assert!(!encoded.contains('/'));
2899 assert!(!encoded.contains('='));
2900 assert!(encoded.contains("%2B"));
2901 assert!(encoded.contains("%2F"));
2902 assert!(encoded.contains("%3D"));
2903 }
2904
2905 #[test]
2909 fn sanitize_token_error_redacts_secrets() {
2910 let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2911 let scrubbed = sanitize_token_error(raw.into());
2912 assert!(!scrubbed.contains("sk_real_secret_value"));
2913 assert!(!scrubbed.contains("verifierxyz"));
2914 assert!(scrubbed.contains("client_secret=***"));
2915 assert!(scrubbed.contains("code_verifier=***"));
2916 assert!(scrubbed.contains("invalid_grant"));
2918 assert!(scrubbed.contains("hint=check%20your%20code"));
2919 }
2920
2921 #[test]
2924 fn sanitize_token_error_redacts_json_secrets() {
2925 let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2926 let scrubbed = sanitize_token_error(raw.into());
2927 assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2928 assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2929 assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2930 assert!(scrubbed.contains(r#""client_secret":"***""#), "got: {scrubbed}");
2931 assert!(scrubbed.contains(r#""refresh_token":"***""#), "got: {scrubbed}");
2932 assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2933 assert!(scrubbed.contains("invalid_grant"));
2934 }
2935
2936 #[test]
2941 fn sanitize_token_error_handles_utf8() {
2942 let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2943 let scrubbed = sanitize_token_error(raw.into());
2944 assert!(scrubbed.contains("⚠️"), "non-ASCII chars must survive: {scrubbed}");
2945 assert!(!scrubbed.contains("sk_x"));
2946 assert!(scrubbed.contains("client_secret=***"));
2947 }
2948
2949 #[test]
2954 fn oidc_discovery_picks_token_auth_method() {
2955 let json_post = r#"{
2956 "issuer": "https://acme.test/",
2957 "authorization_endpoint": "https://acme.test/auth",
2958 "token_endpoint": "https://acme.test/token",
2959 "token_endpoint_auth_methods_supported": ["client_secret_post"]
2960 }"#;
2961 let spec = provider::OidcDiscoveryDoc::parse(json_post).unwrap().into_spec();
2962 assert!(matches!(
2963 spec.token_exchange,
2964 provider::TokenExchangeShape::Standard
2965 ));
2966
2967 let json_default = r#"{
2969 "issuer": "https://acme.test/",
2970 "authorization_endpoint": "https://acme.test/auth",
2971 "token_endpoint": "https://acme.test/token"
2972 }"#;
2973 let spec = provider::OidcDiscoveryDoc::parse(json_default)
2974 .unwrap()
2975 .into_spec();
2976 assert!(matches!(
2977 spec.token_exchange,
2978 provider::TokenExchangeShape::BasicAuth
2979 ));
2980 }
2981
2982 #[test]
2985 fn oidc_discovery_rejects_incomplete_doc() {
2986 let json = r#"{
2988 "issuer": "https://acme.test/",
2989 "authorization_endpoint": "https://acme.test/auth"
2990 }"#;
2991 let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
2992 assert!(err.contains("token_endpoint"), "got: {err}");
2993 }
2994
2995 #[test]
2999 fn from_env_picks_up_discord() {
3000 let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3003 let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3004 std::env::set_var(key_id, "discord-test-id");
3008 std::env::set_var(key_secret, "discord-test-secret");
3009
3010 let reg = OAuthRegistry::from_env();
3011 let discord = reg.get("discord").expect("discord registered");
3012 assert_eq!(discord.client_id, "discord-test-id");
3013 assert!(discord.auth_url().contains("discord.com"));
3014
3015 std::env::remove_var(key_id);
3016 std::env::remove_var(key_secret);
3017 }
3018
3019 #[test]
3022 fn guest_session() {
3023 let store = SessionStore::new();
3024 let session = store.create_guest();
3025 assert!(session.user_id.starts_with("guest_"));
3026 assert!(!session.token.is_empty());
3027
3028 let ctx = store.resolve(Some(&session.token));
3029 assert!(ctx.is_authenticated());
3030 assert!(ctx.user_id.unwrap().starts_with("guest_"));
3031 }
3032
3033 #[test]
3034 fn upgrade_guest_to_real_user() {
3035 let store = SessionStore::new();
3036 let session = store.create_guest();
3037 assert!(session.user_id.starts_with("guest_"));
3038
3039 let upgraded = store.upgrade(&session.token, "real-user-123".into());
3040 assert!(upgraded);
3041
3042 let ctx = store.resolve(Some(&session.token));
3043 assert_eq!(ctx.user_id, Some("real-user-123".into()));
3044 }
3045
3046 #[test]
3047 fn upgrade_invalid_token_fails() {
3048 let store = SessionStore::new();
3049 let upgraded = store.upgrade("nonexistent-token", "user".into());
3050 assert!(!upgraded);
3051 }
3052
3053 #[test]
3054 fn guest_context() {
3055 let ctx = AuthContext::guest("guest_123".into());
3056 assert!(!ctx.is_authenticated());
3059 assert!(ctx.is_guest);
3060 assert!(!ctx.is_admin);
3061 assert_eq!(ctx.user_id, Some("guest_123".into()));
3062 assert!(!AuthMode::User.check(&ctx));
3063 assert!(AuthMode::Public.check(&ctx));
3064 }
3065
3066 #[test]
3067 fn oauth_token_urls() {
3068 let google = OAuthConfig {
3069 provider: "google".into(),
3070 client_id: "x".into(),
3071 client_secret: "x".into(),
3072 redirect_uri: "x".into(),
3073 ..Default::default()
3074 };
3075 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3076 let github = OAuthConfig {
3077 provider: "github".into(),
3078 client_id: "x".into(),
3079 client_secret: "x".into(),
3080 redirect_uri: "x".into(),
3081 ..Default::default()
3082 };
3083 assert_eq!(
3084 github.token_url(),
3085 "https://github.com/login/oauth/access_token"
3086 );
3087 let unknown = OAuthConfig {
3088 provider: "unknown".into(),
3089 client_id: "x".into(),
3090 client_secret: "x".into(),
3091 redirect_uri: "x".into(),
3092 ..Default::default()
3093 };
3094 assert_eq!(unknown.token_url(), "");
3095 assert!(unknown.auth_url().is_empty());
3096 }
3097
3098 #[test]
3099 fn oauth_auth_url_github() {
3100 let config = OAuthConfig {
3101 provider: "github".into(),
3102 client_id: "gh-id".into(),
3103 client_secret: "gh-secret".into(),
3104 redirect_uri: "http://localhost/cb".into(),
3105 ..Default::default()
3106 };
3107 assert!(config.auth_url().contains("github.com"));
3108 assert!(config.auth_url().contains("gh-id"));
3109 }
3110
3111 #[test]
3112 fn oauth_auth_url_with_state() {
3113 let config = OAuthConfig {
3114 provider: "google".into(),
3115 client_id: "test-id".into(),
3116 client_secret: "test-secret".into(),
3117 redirect_uri: "http://localhost/cb".into(),
3118 ..Default::default()
3119 };
3120 let url = config.auth_url_with_state("random_state_123");
3121 assert!(url.contains("&state=random_state_123"));
3122 }
3123
3124 #[test]
3125 fn oauth_state_store_create_and_validate() {
3126 let store = OAuthStateStore::new();
3127 let token = store.create("google", "https://app/cb", "https://app/login");
3128 let rec = store.validate(&token, "google").expect("valid first time");
3129 assert_eq!(rec.callback_url, "https://app/cb");
3130 assert_eq!(rec.error_callback_url, "https://app/login");
3131 assert!(store.validate(&token, "google").is_none());
3133 }
3134
3135 #[test]
3136 fn oauth_state_store_wrong_provider_rejected() {
3137 let store = OAuthStateStore::new();
3138 let token = store.create("google", "https://app/cb", "https://app/cb");
3139 assert!(store.validate(&token, "github").is_none());
3140 }
3141
3142 #[test]
3143 fn oauth_state_store_invalid_state_rejected() {
3144 let store = OAuthStateStore::new();
3145 assert!(store.validate("nonexistent", "google").is_none());
3146 }
3147
3148 #[test]
3149 fn validate_trusted_redirect_basics() {
3150 let trusted = vec!["http://localhost:3000".to_string()];
3151 assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3152 assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3153 assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3154
3155 assert!(matches!(
3157 validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3158 Err(TrustedOriginError::NotTrusted { .. })
3159 ));
3160 assert!(matches!(
3163 validate_trusted_redirect("javascript:alert(1)", &trusted),
3164 Err(TrustedOriginError::NotHttp)
3165 ));
3166 assert!(matches!(
3167 validate_trusted_redirect("", &trusted),
3168 Err(TrustedOriginError::Empty)
3169 ));
3170 }
3171}