1pub mod api_key;
2pub mod apple_jwt;
3pub mod audit;
4pub mod captcha;
5pub mod cookie;
6pub mod device;
7pub mod email;
8pub mod email_templates;
9pub mod jwt;
10pub mod oidc_provider;
11pub mod org;
12pub mod password;
13pub mod phone;
14pub mod provider;
15pub mod rate_limit;
16pub mod scim;
17pub mod siwe;
18pub mod stripe;
19pub mod totp;
20pub mod verification;
21pub mod webauthn;
22
23pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
24
25use serde::{Deserialize, Serialize};
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
42pub struct AuthContext {
43 pub user_id: Option<String>,
47 pub is_admin: bool,
49 #[serde(default, skip_serializing_if = "is_false")]
54 pub is_guest: bool,
55 pub roles: Vec<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
60 pub tenant_id: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
66 pub api_key_id: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
70 pub api_key_scopes: Option<String>,
71}
72
73fn is_false(b: &bool) -> bool {
74 !b
75}
76
77impl AuthContext {
78 pub fn anonymous() -> Self {
80 Self {
81 user_id: None,
82 is_admin: false,
83 is_guest: false,
84 roles: Vec::new(),
85 tenant_id: None,
86 api_key_id: None,
87 api_key_scopes: None,
88 }
89 }
90
91 pub fn authenticated(user_id: String) -> Self {
93 Self {
94 user_id: Some(user_id),
95 is_admin: false,
96 is_guest: false,
97 roles: Vec::new(),
98 tenant_id: None,
99 api_key_id: None,
100 api_key_scopes: None,
101 }
102 }
103
104 pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
107 Self {
108 user_id: Some(user_id),
109 is_admin: false,
110 is_guest: false,
111 roles: Vec::new(),
112 tenant_id: None,
113 api_key_id: Some(key_id),
114 api_key_scopes: scopes,
115 }
116 }
117
118 pub fn is_api_key_auth(&self) -> bool {
121 self.api_key_id.is_some()
122 }
123
124 pub fn guest(guest_id: String) -> Self {
129 Self {
130 user_id: Some(guest_id),
131 is_admin: false,
132 is_guest: true,
133 roles: Vec::new(),
134 tenant_id: None,
135 api_key_id: None,
136 api_key_scopes: None,
137 }
138 }
139
140 pub fn admin() -> Self {
142 Self {
143 user_id: Some("__admin__".into()),
144 is_admin: true,
145 is_guest: false,
146 roles: vec!["admin".into()],
147 tenant_id: None,
148 api_key_id: None,
149 api_key_scopes: None,
150 }
151 }
152
153 pub fn user(user_id: String) -> Self {
155 Self::authenticated(user_id)
156 }
157
158 pub fn tenant_id(&self) -> Option<&str> {
160 self.tenant_id.as_deref()
161 }
162
163 pub fn with_tenant(mut self, tenant_id: String) -> Self {
165 self.tenant_id = Some(tenant_id);
166 self
167 }
168
169 pub fn is_authenticated(&self) -> bool {
173 self.user_id.is_some() && !self.is_guest
174 }
175
176 pub fn has_role(&self, role: &str) -> bool {
178 self.is_admin || self.roles.iter().any(|r| r == role)
179 }
180
181 pub fn has_any_role(&self, roles: &[&str]) -> bool {
183 self.is_admin || roles.iter().any(|r| self.has_role(r))
184 }
185
186 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
188 self.roles = roles;
189 self
190 }
191}
192
193pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
203 if a.len() != b.len() {
204 return false;
205 }
206 let mut result: u8 = 0;
207 for (x, y) in a.iter().zip(b.iter()) {
208 result |= x ^ y;
209 }
210 result == 0
211}
212
213#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum AuthMode {
220 Public,
222 User,
224}
225
226impl AuthMode {
227 #[allow(clippy::should_implement_trait)]
229 pub fn from_str(s: &str) -> Option<Self> {
230 match s {
231 "public" => Some(AuthMode::Public),
232 "user" => Some(AuthMode::User),
233 _ => None,
234 }
235 }
236
237 pub fn check(&self, ctx: &AuthContext) -> bool {
239 match self {
240 AuthMode::Public => true,
241 AuthMode::User => ctx.is_authenticated(),
242 }
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
252pub struct Session {
253 pub token: String,
254 pub user_id: String,
255 #[serde(default)]
257 pub expires_at: u64,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
260 pub device: Option<String>,
261 #[serde(default)]
263 pub created_at: u64,
264 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub tenant_id: Option<String>,
269}
270
271impl Session {
272 pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
274
275 pub fn new(user_id: String) -> Self {
277 let now = now_secs();
278 Self {
279 token: generate_token(),
280 user_id,
281 expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
282 device: None,
283 created_at: now,
284 tenant_id: None,
285 }
286 }
287
288 pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
290 let now = now_secs();
291 Self {
292 token: generate_token(),
293 user_id,
294 expires_at: if lifetime_secs == 0 {
295 0
296 } else {
297 now.saturating_add(lifetime_secs)
298 },
299 device: None,
300 created_at: now,
301 tenant_id: None,
302 }
303 }
304
305 pub fn to_auth_context(&self) -> AuthContext {
308 let ctx = AuthContext::authenticated(self.user_id.clone());
309 match &self.tenant_id {
310 Some(t) => ctx.with_tenant(t.clone()),
311 None => ctx,
312 }
313 }
314
315 pub fn is_expired(&self) -> bool {
319 self.expires_at != 0 && now_secs() >= self.expires_at
320 }
321}
322
323fn now_secs() -> u64 {
324 use std::time::{SystemTime, UNIX_EPOCH};
325 SystemTime::now()
326 .duration_since(UNIX_EPOCH)
327 .unwrap_or_default()
328 .as_secs()
329}
330
331#[derive(Debug, Clone, Default, Serialize, Deserialize)]
336pub struct OAuthConfig {
337 pub provider: String,
338 pub client_id: String,
339 pub client_secret: String,
340 pub redirect_uri: String,
341 #[serde(default, skip_serializing_if = "Option::is_none")]
346 pub scopes_override: Option<String>,
347 #[serde(default, skip_serializing_if = "Option::is_none")]
351 pub tenant: Option<String>,
352 #[serde(default, skip_serializing_if = "Option::is_none")]
355 pub apple: Option<provider::AppleConfig>,
356 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub oidc_issuer: Option<String>,
363}
364
365impl OAuthConfig {
366 fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
371 if let Some(issuer) = self.oidc_issuer.as_deref() {
372 return provider::oidc_cache::resolve(issuer);
373 }
374 provider::find_spec(&self.provider)
375 .map(provider::ResolvedSpec::Static)
376 .ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
377 }
378
379 fn provider_cfg(&self) -> provider::ProviderConfig {
382 provider::ProviderConfig {
383 provider: self.provider.clone(),
384 client_id: self.client_id.clone(),
385 client_secret: self.client_secret.clone(),
386 redirect_uri: self.redirect_uri.clone(),
387 scopes_override: self.scopes_override.clone(),
388 tenant: self.tenant.clone(),
389 apple: self.apple.clone(),
390 oidc_issuer: self.oidc_issuer.clone(),
391 }
392 }
393
394 pub fn auth_url(&self) -> String {
404 match self.build_auth_url(None) {
405 Ok(u) => u,
406 Err(_) => String::new(),
407 }
408 }
409
410 pub fn auth_url_with_state(&self, state: &str) -> String {
412 let base = self.auth_url();
413 if base.is_empty() {
414 return base;
415 }
416 format!("{}&state={}", base, url_encode(state))
417 }
418
419 pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
424 let spec = self.resolved_spec()?;
425 let pkce = if spec.requires_pkce() {
426 Some(generate_pkce())
427 } else {
428 None
429 };
430 let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
431 let mut url = self.build_auth_url(challenge)?;
432 if !state.is_empty() {
433 url.push_str(&format!("&state={}", url_encode(state)));
434 }
435 Ok((url, pkce.map(|p| p.code_verifier)))
436 }
437
438 fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
439 let spec = self.resolved_spec()?;
440 let cfg = self.provider_cfg();
441 let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
442 if auth.is_empty() {
443 return Err(format!(
444 "provider {} has no authorization endpoint",
445 self.provider
446 ));
447 }
448 let scopes_default = spec.scopes().to_string();
449 let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
450 let scopes_joined = scopes_raw
454 .split_whitespace()
455 .collect::<Vec<_>>()
456 .join(spec.scope_separator());
457
458 let mut url = format!(
459 "{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
460 cid_param = spec.client_id_param(),
461 cid = url_encode(&self.client_id),
462 ruri = url_encode(&self.redirect_uri),
463 scope = url_encode(&scopes_joined),
464 );
465 if !spec.auth_query_extra().is_empty() {
466 url.push('&');
467 url.push_str(spec.auth_query_extra());
468 }
469 if let Some(challenge) = pkce_challenge {
470 url.push_str("&code_challenge=");
471 url.push_str(challenge);
472 url.push_str("&code_challenge_method=S256");
473 }
474 Ok(url)
475 }
476
477 pub fn token_url(&self) -> String {
479 match self.resolved_spec() {
480 Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
481 Err(_) => String::new(),
482 }
483 }
484
485 pub fn userinfo_url(&self) -> String {
487 match self.resolved_spec() {
488 Ok(spec) => match spec.userinfo_url() {
489 Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
490 None => String::new(),
491 },
492 Err(_) => String::new(),
493 }
494 }
495
496 pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
502 self.exchange_code_full_pkce(code, None)
503 }
504
505 pub fn exchange_code_full_pkce(
506 &self,
507 code: &str,
508 code_verifier: Option<&str>,
509 ) -> Result<TokenSet, String> {
510 let spec = self.resolved_spec()?;
511 let cfg = self.provider_cfg();
512 let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
513 let pkce_field = code_verifier
514 .map(|v| format!("&code_verifier={}", url_encode(v)))
515 .unwrap_or_default();
516
517 let out = match spec.token_exchange() {
518 provider::TokenExchangeShape::Standard => {
519 let body = format!(
520 "code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
521 code = url_encode(code),
522 cid_param = spec.client_id_param(),
523 cid = url_encode(&self.client_id),
524 secret = url_encode(&self.client_secret),
525 ruri = url_encode(&self.redirect_uri),
526 pkce = pkce_field,
527 );
528 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
529 }
530 provider::TokenExchangeShape::AppleJwt => {
531 let apple = self.apple.as_ref().ok_or(
532 "apple provider requires `apple` config (team_id, key_id, private_key_pem)",
533 )?;
534 let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
535 let body = format!(
536 "code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
537 code = url_encode(code),
538 cid = url_encode(&self.client_id),
539 secret = url_encode(&signed_secret),
540 ruri = url_encode(&self.redirect_uri),
541 pkce = pkce_field,
542 );
543 http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
544 }
545 provider::TokenExchangeShape::BasicAuth => {
546 let body = format!(
547 "code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
548 code = url_encode(code),
549 ruri = url_encode(&self.redirect_uri),
550 pkce = pkce_field,
551 );
552 http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
553 .map_err(sanitize_token_error)?
554 }
555 provider::TokenExchangeShape::JsonBody => {
556 let mut json = serde_json::Map::new();
557 json.insert("grant_type".into(), "authorization_code".into());
558 json.insert("code".into(), code.into());
559 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
560 json.insert("client_id".into(), self.client_id.clone().into());
561 json.insert("client_secret".into(), self.client_secret.clone().into());
562 if let Some(v) = code_verifier {
563 json.insert("code_verifier".into(), v.to_string().into());
564 }
565 let body = serde_json::Value::Object(json).to_string();
566 http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
567 }
568 provider::TokenExchangeShape::BasicAuthJsonBody => {
569 let mut json = serde_json::Map::new();
570 json.insert("grant_type".into(), "authorization_code".into());
571 json.insert("code".into(), code.into());
572 json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
573 if let Some(v) = code_verifier {
574 json.insert("code_verifier".into(), v.to_string().into());
575 }
576 let body = serde_json::Value::Object(json).to_string();
577 http_post_json(
578 &token_url,
579 &body,
580 Some((&self.client_id, &self.client_secret)),
581 )
582 .map_err(sanitize_token_error)?
583 }
584 };
585 parse_token_response(&out)
586 }
587
588 pub fn exchange_code(&self, code: &str) -> Result<String, String> {
592 Ok(self.exchange_code_full(code)?.access_token)
593 }
594
595 pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
597 let info = self.fetch_userinfo_full(access_token)?;
598 Ok((info.email, info.name))
599 }
600
601 pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
606 self.fetch_userinfo_with_id_token(access_token, None)
610 }
611
612 pub fn fetch_userinfo_with_id_token(
617 &self,
618 access_token: &str,
619 id_token: Option<&str>,
620 ) -> Result<UserInfo, String> {
621 let spec = self.resolved_spec()?;
622 let cfg = self.provider_cfg();
623
624 if matches!(
626 spec.userinfo_parser(),
627 provider::UserinfoParser::AppleIdToken
628 ) {
629 let token =
630 id_token.ok_or("apple login requires the id_token from the token response")?;
631 return parse_apple_id_token(token, &self.provider);
632 }
633
634 if matches!(
637 spec.userinfo_parser(),
638 provider::UserinfoParser::LinearGraphql
639 ) {
640 return fetch_linear_userinfo(&self.provider, access_token);
641 }
642
643 let url = match spec.userinfo_url() {
644 Some(u) => provider::resolve_endpoint(u, &cfg),
645 None => {
646 return Err(format!(
647 "provider {} has no userinfo endpoint",
648 self.provider
649 ))
650 }
651 };
652 let out = match spec.userinfo_method() {
653 provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
654 provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
655 }
656 .map_err(sanitize_token_error)?;
657 let parsed: serde_json::Value =
658 serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
659
660 match spec.userinfo_parser() {
661 provider::UserinfoParser::Oidc => {
662 let email = parsed
663 .get("email")
664 .and_then(|v| v.as_str())
665 .ok_or("no email in userinfo")?
666 .to_string();
667 let name = parsed
668 .get("name")
669 .and_then(|v| v.as_str())
670 .map(String::from);
671 let provider_account_id = parsed
672 .get("sub")
673 .and_then(|v| v.as_str())
674 .ok_or("no sub in userinfo")?
675 .to_string();
676 Ok(UserInfo {
677 provider: self.provider.clone(),
678 provider_account_id,
679 email,
680 name,
681 })
682 }
683 provider::UserinfoParser::GitHub => {
684 let name = parsed
685 .get("name")
686 .and_then(|v| v.as_str())
687 .or_else(|| parsed.get("login").and_then(|v| v.as_str()))
688 .map(String::from);
689 let email = parsed
690 .get("email")
691 .and_then(|v| v.as_str())
692 .map(String::from);
693 let email = email
694 .or_else(|| fetch_github_primary_email(access_token).ok())
695 .ok_or("no accessible email on GitHub account")?;
696 let provider_account_id = parsed
697 .get("id")
698 .map(|v| {
699 v.as_i64()
700 .map(|n| n.to_string())
701 .or_else(|| v.as_str().map(String::from))
702 .unwrap_or_default()
703 })
704 .filter(|s| !s.is_empty())
705 .ok_or("no id in userinfo")?;
706 Ok(UserInfo {
707 provider: self.provider.clone(),
708 provider_account_id,
709 email,
710 name,
711 })
712 }
713 provider::UserinfoParser::Custom {
714 id_path,
715 email_path,
716 name_path,
717 } => {
718 let provider_account_id = json_pointer_string(&parsed, id_path)
719 .ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
720 let raw_email = json_pointer_string(&parsed, email_path)
721 .ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
722 let email = if !raw_email.contains('@') {
727 let domain = match self.provider.as_str() {
728 "twitter" => "x.invalid",
729 "reddit" => "reddit.invalid",
730 other => return Err(format!(
731 "{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
732 )),
733 };
734 format!("{raw_email}@{domain}")
735 } else {
736 raw_email
737 };
738 let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
739 Ok(UserInfo {
740 provider: self.provider.clone(),
741 provider_account_id,
742 email,
743 name,
744 })
745 }
746 provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
747 provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
748 }
749 }
750}
751
752struct PkcePair {
755 code_verifier: String,
756 code_challenge: String,
757}
758
759fn generate_pkce() -> PkcePair {
763 use rand::RngCore;
764 let mut bytes = [0u8; 32];
765 rand::thread_rng().fill_bytes(&mut bytes);
766 let code_verifier = apple_jwt::base64_url(bytes);
767 use sha2::{Digest, Sha256};
768 let mut hasher = Sha256::new();
769 hasher.update(code_verifier.as_bytes());
770 let code_challenge = apple_jwt::base64_url(hasher.finalize());
771 PkcePair {
772 code_verifier,
773 code_challenge,
774 }
775}
776
777fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
800 let mut parts = id_token.split('.');
801 let _header = parts.next().ok_or("apple id_token: missing header")?;
802 let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
803 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
804 let claims_bytes = URL_SAFE_NO_PAD
805 .decode(claims_b64)
806 .map_err(|e| format!("apple id_token claims not base64: {e}"))?;
807 let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
808 .map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
809 let provider_account_id = claims
810 .get("sub")
811 .and_then(|v| v.as_str())
812 .ok_or("apple id_token: missing sub")?
813 .to_string();
814 let email = claims
815 .get("email")
816 .and_then(|v| v.as_str())
817 .ok_or("apple id_token: missing email (was the `email` scope requested?)")?
818 .to_string();
819 Ok(UserInfo {
820 provider: provider.to_string(),
821 provider_account_id,
822 email,
823 name: None, })
825}
826
827fn sanitize_token_error(err: String) -> String {
839 const SENSITIVE: &[&str] = &[
840 "client_secret",
841 "code_verifier",
842 "client_assertion",
843 "refresh_token",
844 "access_token",
845 "id_token",
846 "code",
851 ];
852 let mut out = err;
853 for key in SENSITIVE {
854 out = redact_param_form(&out, key);
855 out = redact_param_json(&out, key);
856 }
857 out
858}
859
860fn redact_param_form(input: &str, key: &str) -> String {
864 let needle = format!("{key}=");
865 let mut out = String::with_capacity(input.len());
866 let mut i = 0;
867 while i < input.len() {
868 if input[i..].starts_with(&needle) {
869 out.push_str(&needle);
870 out.push_str("***");
871 i += needle.len();
872 while let Some((rel, ch)) = input[i..].char_indices().next() {
875 if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
876 i += rel;
877 break;
878 }
879 i += rel + ch.len_utf8();
880 }
881 } else {
882 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
884 out.push(ch);
885 i += ch.len_utf8();
886 }
887 }
888 out
889}
890
891fn redact_param_json(input: &str, key: &str) -> String {
894 let needle = format!("\"{key}\"");
895 let mut out = String::with_capacity(input.len());
896 let mut i = 0;
897 while i < input.len() {
898 if !input[i..].starts_with(&needle) {
899 let (_, ch) = input[i..].char_indices().next().expect("non-empty");
900 out.push(ch);
901 i += ch.len_utf8();
902 continue;
903 }
904 let mut j = i + needle.len();
909 while let Some((_, ch)) = input[j..].char_indices().next() {
911 if !ch.is_whitespace() {
912 break;
913 }
914 j += ch.len_utf8();
915 }
916 if !input[j..].starts_with(':') {
917 out.push_str(&input[i..j]);
919 i = j;
920 continue;
921 }
922 j += 1;
923 while let Some((_, ch)) = input[j..].char_indices().next() {
924 if !ch.is_whitespace() {
925 break;
926 }
927 j += ch.len_utf8();
928 }
929 if !input[j..].starts_with('"') {
930 out.push_str(&input[i..j]);
931 i = j;
932 continue;
933 }
934 let value_start = j + 1;
935 let mut k = value_start;
937 let mut prev_backslash = false;
938 let mut closing: Option<usize> = None;
939 while k < input.len() {
940 let (_, ch) = input[k..].char_indices().next().expect("non-empty");
941 if ch == '"' && !prev_backslash {
942 closing = Some(k);
943 break;
944 }
945 prev_backslash = ch == '\\' && !prev_backslash;
946 k += ch.len_utf8();
947 }
948 match closing {
949 Some(end) => {
950 out.push_str(&input[i..value_start]);
951 out.push_str("***");
952 out.push('"');
953 i = end + 1;
954 }
955 None => {
956 out.push_str(&input[i..value_start]);
958 out.push_str("***");
959 i = input.len();
960 }
961 }
962 }
963 out
964}
965
966fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
971 let body = r#"{"query":"query { viewer { id email name } }"}"#;
972 let agent = ureq_agent();
973 let resp = agent
974 .post("https://api.linear.app/graphql")
975 .set("Authorization", &format!("Bearer {access_token}"))
976 .set("Content-Type", "application/json")
977 .set("Accept", "application/json")
978 .send_string(body)
979 .map_err(|e| format!("linear graphql: {e}"))?;
980 let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
981 let parsed: serde_json::Value =
982 serde_json::from_str(&out).map_err(|e| format!("linear graphql not JSON: {e}"))?;
983 let viewer = parsed
984 .pointer("/data/viewer")
985 .ok_or("linear graphql: no /data/viewer")?;
986 let provider_account_id = viewer
987 .get("id")
988 .and_then(|v| v.as_str())
989 .ok_or("linear graphql: no id")?
990 .to_string();
991 let email = viewer
992 .get("email")
993 .and_then(|v| v.as_str())
994 .ok_or("linear graphql: no email")?
995 .to_string();
996 let name = viewer
997 .get("name")
998 .and_then(|v| v.as_str())
999 .map(String::from);
1000 Ok(UserInfo {
1001 provider: provider.to_string(),
1002 provider_account_id,
1003 email,
1004 name,
1005 })
1006}
1007
1008fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
1012 let node = v.pointer(path)?;
1013 if let Some(s) = node.as_str() {
1014 return Some(s.to_string());
1015 }
1016 if let Some(n) = node.as_i64() {
1017 return Some(n.to_string());
1018 }
1019 if let Some(n) = node.as_u64() {
1020 return Some(n.to_string());
1021 }
1022 None
1023}
1024
1025#[derive(Debug, Clone, PartialEq, Eq)]
1030pub struct UserInfo {
1031 pub provider: String,
1032 pub provider_account_id: String,
1033 pub email: String,
1034 pub name: Option<String>,
1035}
1036
1037#[derive(Debug, Clone, PartialEq, Eq)]
1041pub struct TokenSet {
1042 pub access_token: String,
1043 pub refresh_token: Option<String>,
1044 pub id_token: Option<String>,
1045 pub expires_at: Option<u64>,
1049 pub scope: Option<String>,
1050}
1051
1052fn parse_token_response(body: &str) -> Result<TokenSet, String> {
1053 let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
1056 let mut map = serde_json::Map::new();
1058 for pair in body.split('&') {
1059 if let Some((k, v)) = pair.split_once('=') {
1060 map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
1061 }
1062 }
1063 serde_json::Value::Object(map)
1064 });
1065
1066 let access_token = json
1067 .get("access_token")
1068 .and_then(|v| v.as_str())
1069 .ok_or_else(|| format!("no access_token in token response: {body}"))?
1070 .to_string();
1071 let refresh_token = json
1072 .get("refresh_token")
1073 .and_then(|v| v.as_str())
1074 .map(String::from);
1075 let id_token = json
1076 .get("id_token")
1077 .and_then(|v| v.as_str())
1078 .map(String::from);
1079 let expires_at = json
1080 .get("expires_in")
1081 .and_then(|v| {
1082 v.as_u64()
1083 .or_else(|| v.as_str().and_then(|s| s.parse().ok()))
1084 })
1085 .map(|secs| now_secs().saturating_add(secs));
1086 let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
1087 Ok(TokenSet {
1088 access_token,
1089 refresh_token,
1090 id_token,
1091 expires_at,
1092 scope,
1093 })
1094}
1095
1096fn url_encode(s: &str) -> String {
1097 let mut out = String::with_capacity(s.len());
1098 for b in s.bytes() {
1099 match b {
1100 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1101 out.push(b as char)
1102 }
1103 _ => out.push_str(&format!("%{b:02X}")),
1104 }
1105 }
1106 out
1107}
1108
1109const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
1113
1114fn ureq_agent() -> ureq::Agent {
1115 ureq::AgentBuilder::new()
1116 .timeout_connect(HTTP_TIMEOUT)
1117 .timeout_read(HTTP_TIMEOUT)
1118 .timeout_write(HTTP_TIMEOUT)
1119 .user_agent("pylon/0.1")
1120 .build()
1121}
1122
1123fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
1124 let agent = ureq_agent();
1125 let mut req = agent
1126 .post(url)
1127 .set("Content-Type", "application/x-www-form-urlencoded");
1128 if accept_json {
1129 req = req.set("Accept", "application/json");
1130 }
1131 match req.send_string(body) {
1132 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1133 Err(ureq::Error::Status(code, resp)) => {
1134 let body = resp.into_string().unwrap_or_default();
1135 Err(format!("HTTP {code}: {body}"))
1136 }
1137 Err(e) => Err(format!("HTTP error: {e}")),
1138 }
1139}
1140
1141fn http_post_form_basic(
1145 url: &str,
1146 body: &str,
1147 client_id: &str,
1148 client_secret: &str,
1149) -> Result<String, String> {
1150 use base64::{engine::general_purpose::STANDARD, Engine};
1151 let creds = format!("{client_id}:{client_secret}");
1152 let basic = STANDARD.encode(creds.as_bytes());
1153 let agent = ureq_agent();
1154 match agent
1155 .post(url)
1156 .set("Content-Type", "application/x-www-form-urlencoded")
1157 .set("Accept", "application/json")
1158 .set("Authorization", &format!("Basic {basic}"))
1159 .send_string(body)
1160 {
1161 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1162 Err(ureq::Error::Status(code, resp)) => {
1163 let body = resp.into_string().unwrap_or_default();
1164 Err(format!("HTTP {code}: {body}"))
1165 }
1166 Err(e) => Err(format!("HTTP error: {e}")),
1167 }
1168}
1169
1170fn http_post_json(
1174 url: &str,
1175 body: &str,
1176 basic_creds: Option<(&str, &str)>,
1177) -> Result<String, String> {
1178 let agent = ureq_agent();
1179 let mut req = agent
1180 .post(url)
1181 .set("Content-Type", "application/json")
1182 .set("Accept", "application/json");
1183 if let Some((id, secret)) = basic_creds {
1184 use base64::{engine::general_purpose::STANDARD, Engine};
1185 let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
1186 req = req.set("Authorization", &format!("Basic {creds}"));
1187 }
1188 req = req.set("Notion-Version", "2022-06-28");
1191 match req.send_string(body) {
1192 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1193 Err(ureq::Error::Status(code, resp)) => {
1194 let body = resp.into_string().unwrap_or_default();
1195 Err(format!("HTTP {code}: {body}"))
1196 }
1197 Err(e) => Err(format!("HTTP error: {e}")),
1198 }
1199}
1200
1201fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
1204 let agent = ureq_agent();
1205 match agent
1206 .post(url)
1207 .set("Authorization", &format!("Bearer {token}"))
1208 .set("Accept", "application/json")
1209 .call()
1210 {
1211 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1212 Err(ureq::Error::Status(code, resp)) => {
1213 let body = resp.into_string().unwrap_or_default();
1214 Err(format!("HTTP {code}: {body}"))
1215 }
1216 Err(e) => Err(format!("HTTP error: {e}")),
1217 }
1218}
1219
1220fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
1221 let agent = ureq_agent();
1222 match agent
1223 .get(url)
1224 .set("Authorization", &format!("Bearer {token}"))
1225 .set("Accept", "application/json")
1226 .call()
1227 {
1228 Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
1229 Err(ureq::Error::Status(code, resp)) => {
1230 let body = resp.into_string().unwrap_or_default();
1231 Err(format!("HTTP {code}: {body}"))
1232 }
1233 Err(e) => Err(format!("HTTP error: {e}")),
1234 }
1235}
1236
1237fn fetch_github_primary_email(token: &str) -> Result<String, String> {
1238 let out = http_get_bearer("https://api.github.com/user/emails", token)?;
1239 let emails: serde_json::Value =
1240 serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
1241 emails
1242 .as_array()
1243 .and_then(|arr| {
1244 arr.iter()
1245 .find(|e| {
1246 e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
1247 && e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
1248 })
1249 .and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
1250 })
1251 .ok_or_else(|| "no primary verified email on GitHub".into())
1252}
1253
1254pub struct OAuthRegistry {
1256 providers: std::collections::HashMap<String, OAuthConfig>,
1257}
1258
1259impl Default for OAuthRegistry {
1260 fn default() -> Self {
1261 Self::new()
1262 }
1263}
1264
1265impl OAuthRegistry {
1266 pub fn new() -> Self {
1267 Self {
1268 providers: std::collections::HashMap::new(),
1269 }
1270 }
1271
1272 pub fn register(&mut self, config: OAuthConfig) {
1273 self.providers.insert(config.provider.clone(), config);
1274 }
1275
1276 pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
1277 self.providers.get(provider)
1278 }
1279
1280 pub fn from_env() -> Self {
1293 let mut reg = Self::new();
1294
1295 for spec in provider::builtin::all() {
1296 let upper = spec.id.to_ascii_uppercase();
1297 let prefix = format!("PYLON_OAUTH_{upper}");
1298 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1299 Ok(v) => v,
1300 Err(_) => continue,
1301 };
1302 let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
1303 Ok(v) => v,
1304 Err(_) if spec.id == "apple" => String::new(),
1306 Err(_) => continue,
1307 };
1308 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1309 .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{}", spec.id));
1310 let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
1311 let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
1312
1313 let apple = if spec.id == "apple" {
1314 match (
1315 std::env::var(format!("{prefix}_TEAM_ID")),
1316 std::env::var(format!("{prefix}_KEY_ID")),
1317 std::env::var(format!("{prefix}_PRIVATE_KEY")),
1318 ) {
1319 (Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
1320 team_id,
1321 key_id,
1322 private_key_pem,
1323 }),
1324 _ => continue, }
1326 } else {
1327 None
1328 };
1329
1330 reg.register(OAuthConfig {
1331 provider: spec.id.to_string(),
1332 client_id: id,
1333 client_secret: secret,
1334 redirect_uri,
1335 scopes_override,
1336 tenant,
1337 apple,
1338 oidc_issuer: None,
1339 });
1340 }
1341
1342 for (key, issuer) in std::env::vars() {
1344 let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
1345 continue;
1346 };
1347 let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
1348 continue;
1349 };
1350 let name = name_upper.to_ascii_lowercase();
1351 if provider::find_spec(&name).is_some() {
1352 continue; }
1354 let prefix = format!("PYLON_OAUTH_{name_upper}");
1355 let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
1356 Ok(v) => v,
1357 Err(_) => continue,
1358 };
1359 let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
1360 let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
1361 .unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
1362 reg.register(OAuthConfig {
1363 provider: name,
1364 client_id: id,
1365 client_secret: secret,
1366 redirect_uri,
1367 scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
1368 tenant: None,
1369 apple: None,
1370 oidc_issuer: Some(issuer),
1371 });
1372 }
1373
1374 reg
1375 }
1376
1377 pub fn ids(&self) -> impl Iterator<Item = &str> {
1381 self.providers.keys().map(|s| s.as_str())
1382 }
1383
1384 pub fn shared() -> &'static OAuthRegistry {
1391 static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
1392 CELL.get_or_init(Self::from_env)
1393 }
1394}
1395
1396#[derive(Debug, Clone, PartialEq, Eq)]
1406pub struct OAuthState {
1407 pub provider: String,
1408 pub callback_url: String,
1411 pub error_callback_url: String,
1416 pub pkce_verifier: Option<String>,
1421 pub expires_at: u64,
1422}
1423
1424pub trait OAuthStateBackend: Send + Sync {
1429 fn put(&self, token: &str, state: &OAuthState);
1431 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
1436}
1437
1438pub struct InMemoryOAuthBackend {
1440 states: Mutex<HashMap<String, OAuthState>>,
1441}
1442
1443impl InMemoryOAuthBackend {
1444 pub fn new() -> Self {
1445 Self {
1446 states: Mutex::new(HashMap::new()),
1447 }
1448 }
1449}
1450
1451impl Default for InMemoryOAuthBackend {
1452 fn default() -> Self {
1453 Self::new()
1454 }
1455}
1456
1457impl OAuthStateBackend for InMemoryOAuthBackend {
1458 fn put(&self, token: &str, state: &OAuthState) {
1459 self.states
1460 .lock()
1461 .unwrap()
1462 .insert(token.to_string(), state.clone());
1463 }
1464 fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
1465 let mut s = self.states.lock().unwrap();
1466 let entry = s.remove(token)?;
1467 if entry.expires_at <= now_unix_secs {
1468 return None;
1469 }
1470 Some(entry)
1471 }
1472}
1473
1474pub struct OAuthStateStore {
1481 backend: Box<dyn OAuthStateBackend>,
1482}
1483
1484impl Default for OAuthStateStore {
1485 fn default() -> Self {
1486 Self::new()
1487 }
1488}
1489
1490impl OAuthStateStore {
1491 pub fn new() -> Self {
1492 Self {
1493 backend: Box::new(InMemoryOAuthBackend::new()),
1494 }
1495 }
1496
1497 pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
1498 Self { backend }
1499 }
1500
1501 pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
1509 self.create_with_pkce(provider, callback_url, error_callback_url, None)
1510 }
1511
1512 pub fn create_with_pkce(
1516 &self,
1517 provider: &str,
1518 callback_url: &str,
1519 error_callback_url: &str,
1520 pkce_verifier: Option<String>,
1521 ) -> String {
1522 use std::time::{SystemTime, UNIX_EPOCH};
1523 let token = generate_token();
1524 let now = SystemTime::now()
1525 .duration_since(UNIX_EPOCH)
1526 .unwrap_or_default()
1527 .as_secs();
1528 let state = OAuthState {
1529 provider: provider.to_string(),
1530 callback_url: callback_url.to_string(),
1531 error_callback_url: error_callback_url.to_string(),
1532 pkce_verifier,
1533 expires_at: now + 600,
1534 };
1535 self.backend.put(&token, &state);
1536 token
1537 }
1538
1539 pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
1544 use std::time::{SystemTime, UNIX_EPOCH};
1545 let now = SystemTime::now()
1546 .duration_since(UNIX_EPOCH)
1547 .unwrap_or_default()
1548 .as_secs();
1549 let entry = self.backend.take(state, now)?;
1550 if entry.provider != expected_provider {
1551 return None;
1552 }
1553 Some(entry)
1554 }
1555}
1556
1557pub fn validate_trusted_redirect(
1574 url: &str,
1575 trusted_origins: &[String],
1576) -> Result<(), TrustedOriginError> {
1577 if url.is_empty() {
1578 return Err(TrustedOriginError::Empty);
1579 }
1580 if !url.starts_with("http://") && !url.starts_with("https://") {
1583 return Err(TrustedOriginError::NotHttp);
1584 }
1585 let url_origin = origin_of(url);
1586 if trusted_origins.iter().any(|t| t == &url_origin) {
1587 Ok(())
1588 } else {
1589 Err(TrustedOriginError::NotTrusted { origin: url_origin })
1590 }
1591}
1592
1593#[derive(Debug, Clone, PartialEq, Eq)]
1595pub enum TrustedOriginError {
1596 Empty,
1597 NotHttp,
1598 NotTrusted { origin: String },
1599}
1600
1601impl std::fmt::Display for TrustedOriginError {
1602 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1603 match self {
1604 TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
1605 TrustedOriginError::NotHttp => {
1606 write!(f, "redirect URL must use http:// or https:// scheme")
1607 }
1608 TrustedOriginError::NotTrusted { origin } => write!(
1609 f,
1610 "redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
1611 ),
1612 }
1613 }
1614}
1615
1616pub fn origin_of(url: &str) -> String {
1621 let after_scheme = match url.find("://") {
1622 Some(i) => i + 3,
1623 None => return url.trim_end_matches('/').to_string(),
1624 };
1625 let rest = &url[after_scheme..];
1626 let cut = rest
1627 .find(|c: char| c == '/' || c == '?' || c == '#')
1628 .unwrap_or(rest.len());
1629 url[..after_scheme + cut].to_string()
1630}
1631
1632pub trait MagicCodeBackend: Send + Sync {
1645 fn put(&self, email: &str, code: &MagicCode);
1647 fn get(&self, email: &str) -> Option<MagicCode>;
1649 fn remove(&self, email: &str);
1652 fn bump_attempts(&self, email: &str);
1656 fn load_all(&self) -> Vec<MagicCode>;
1659}
1660
1661pub struct InMemoryMagicCodeBackend {
1664 codes: Mutex<HashMap<String, MagicCode>>,
1665}
1666
1667impl InMemoryMagicCodeBackend {
1668 pub fn new() -> Self {
1669 Self {
1670 codes: Mutex::new(HashMap::new()),
1671 }
1672 }
1673}
1674
1675impl Default for InMemoryMagicCodeBackend {
1676 fn default() -> Self {
1677 Self::new()
1678 }
1679}
1680
1681impl MagicCodeBackend for InMemoryMagicCodeBackend {
1682 fn put(&self, email: &str, code: &MagicCode) {
1683 self.codes
1684 .lock()
1685 .unwrap()
1686 .insert(email.to_string(), code.clone());
1687 }
1688 fn get(&self, email: &str) -> Option<MagicCode> {
1689 self.codes.lock().unwrap().get(email).cloned()
1690 }
1691 fn remove(&self, email: &str) {
1692 self.codes.lock().unwrap().remove(email);
1693 }
1694 fn bump_attempts(&self, email: &str) {
1695 if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
1696 c.attempts = c.attempts.saturating_add(1);
1697 }
1698 }
1699 fn load_all(&self) -> Vec<MagicCode> {
1700 self.codes.lock().unwrap().values().cloned().collect()
1701 }
1702}
1703
1704pub struct MagicCodeStore {
1709 cache: Mutex<HashMap<String, MagicCode>>,
1710 backend: Box<dyn MagicCodeBackend>,
1711}
1712
1713#[derive(Debug, Clone)]
1714pub struct MagicCode {
1715 pub email: String,
1716 pub code: String,
1717 pub expires_at: u64,
1718 pub attempts: u32,
1721}
1722
1723const MAX_ATTEMPTS: u32 = 5;
1727
1728const CREATE_COOLDOWN_SECS: u64 = 60;
1731
1732#[derive(Debug, Clone, PartialEq, Eq)]
1733pub enum MagicCodeError {
1734 NotFound,
1736 TooManyAttempts,
1738 BadCode,
1740 Expired,
1742 Throttled { retry_after_secs: u64 },
1744}
1745
1746impl Default for MagicCodeStore {
1747 fn default() -> Self {
1748 Self::new()
1749 }
1750}
1751
1752impl MagicCodeStore {
1753 pub fn new() -> Self {
1754 Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
1755 }
1756
1757 pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
1762 let now = now_secs();
1763 let mut cache = HashMap::new();
1764 for c in backend.load_all() {
1765 if c.expires_at > now {
1766 cache.insert(c.email.clone(), c);
1767 }
1768 }
1769 Self {
1770 cache: Mutex::new(cache),
1771 backend,
1772 }
1773 }
1774
1775 pub fn create(&self, email: &str) -> String {
1778 self.try_create(email).unwrap_or_else(|_| String::new())
1781 }
1782
1783 pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
1786 let now = now_secs();
1787
1788 let mut codes = self.cache.lock().unwrap();
1789
1790 if let Some(existing) = codes.get(email) {
1794 if existing.expires_at > now {
1795 let created_at = existing.expires_at.saturating_sub(600);
1796 let age = now.saturating_sub(created_at);
1797 if age < CREATE_COOLDOWN_SECS {
1798 return Err(MagicCodeError::Throttled {
1799 retry_after_secs: CREATE_COOLDOWN_SECS - age,
1800 });
1801 }
1802 }
1803 }
1804
1805 let code = generate_magic_code();
1806 let mc = MagicCode {
1807 email: email.to_string(),
1808 code: code.clone(),
1809 expires_at: now + 600, attempts: 0,
1811 };
1812 codes.insert(email.to_string(), mc.clone());
1813 self.backend.put(email, &mc);
1817 Ok(code)
1818 }
1819
1820 pub fn verify(&self, email: &str, code: &str) -> bool {
1824 matches!(self.try_verify(email, code), Ok(()))
1825 }
1826
1827 pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
1834 self.cache
1835 .lock()
1836 .map(|m| m.values().cloned().collect())
1837 .unwrap_or_default()
1838 }
1839
1840 pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
1841 let now = now_secs();
1842 let mut codes = self.cache.lock().unwrap();
1843
1844 let mc = match codes.get_mut(email) {
1845 Some(m) => m,
1846 None => return Err(MagicCodeError::NotFound),
1847 };
1848
1849 if mc.attempts >= MAX_ATTEMPTS {
1850 return Err(MagicCodeError::TooManyAttempts);
1851 }
1852 if mc.expires_at <= now {
1853 codes.remove(email);
1854 self.backend.remove(email);
1855 return Err(MagicCodeError::Expired);
1856 }
1857
1858 let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
1859 if !ok {
1860 mc.attempts += 1;
1861 self.backend.bump_attempts(email);
1862 if mc.attempts >= MAX_ATTEMPTS {
1864 return Err(MagicCodeError::TooManyAttempts);
1865 }
1866 return Err(MagicCodeError::BadCode);
1867 }
1868
1869 codes.remove(email);
1871 self.backend.remove(email);
1872 Ok(())
1873 }
1874}
1875
1876fn hex_encode(bytes: &[u8]) -> String {
1881 bytes.iter().map(|b| format!("{:02x}", b)).collect()
1882}
1883
1884fn generate_magic_code() -> String {
1886 use rand::Rng;
1887 let mut rng = rand::thread_rng();
1888 let code: u32 = rng.gen_range(0..1_000_000);
1889 format!("{:06}", code)
1890}
1891
1892fn generate_token() -> String {
1894 use rand::Rng;
1895 let mut rng = rand::thread_rng();
1896 let bytes: [u8; 32] = rng.gen();
1897 format!("pylon_{}", hex_encode(&bytes))
1898}
1899
1900use std::collections::HashMap;
1905use std::sync::Mutex;
1906
1907pub trait SessionBackend: Send + Sync {
1911 fn load_all(&self) -> Vec<Session>;
1912 fn save(&self, session: &Session);
1913 fn remove(&self, token: &str);
1914}
1915
1916pub struct SessionStore {
1924 sessions: Mutex<HashMap<String, Session>>,
1925 backend: Option<Box<dyn SessionBackend>>,
1926 default_lifetime_secs: u64,
1930}
1931
1932impl Default for SessionStore {
1933 fn default() -> Self {
1934 Self::new()
1935 }
1936}
1937
1938impl SessionStore {
1939 pub fn new() -> Self {
1940 Self {
1941 sessions: Mutex::new(HashMap::new()),
1942 backend: None,
1943 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1944 }
1945 }
1946
1947 pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
1950 self.default_lifetime_secs = lifetime_secs;
1951 self
1952 }
1953
1954 pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
1958 let mut map = HashMap::new();
1959 for s in backend.load_all() {
1960 if !s.is_expired() {
1961 map.insert(s.token.clone(), s);
1962 }
1963 }
1964 Self {
1965 sessions: Mutex::new(map),
1966 backend: Some(backend),
1967 default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
1968 }
1969 }
1970
1971 pub fn create(&self, user_id: String) -> Session {
1975 self.create_with_device(user_id, None)
1976 }
1977
1978 pub fn create_with_device(&self, user_id: String, device: Option<String>) -> Session {
1984 let mut session = Session::with_lifetime(user_id, self.default_lifetime_secs);
1985 session.device = device;
1986 let mut sessions = self.sessions.lock().unwrap();
1987 sessions.insert(session.token.clone(), session.clone());
1988 if let Some(b) = &self.backend {
1989 b.save(&session);
1990 }
1991 session
1992 }
1993
1994 pub fn get(&self, token: &str) -> Option<Session> {
1996 let mut sessions = self.sessions.lock().unwrap();
1997 match sessions.get(token) {
1998 Some(s) if s.is_expired() => {
1999 sessions.remove(token);
2000 None
2001 }
2002 Some(s) => Some(s.clone()),
2003 None => None,
2004 }
2005 }
2006
2007 pub fn resolve(&self, token: Option<&str>) -> AuthContext {
2010 match token {
2011 Some(t) => match self.get(t) {
2012 Some(session) => session.to_auth_context(),
2013 None => AuthContext::anonymous(),
2014 },
2015 None => AuthContext::anonymous(),
2016 }
2017 }
2018
2019 pub fn refresh(&self, old_token: &str) -> Option<Session> {
2023 let mut sessions = self.sessions.lock().unwrap();
2024 let old = sessions.remove(old_token)?;
2025 if let Some(b) = &self.backend {
2026 b.remove(old_token);
2027 }
2028 if old.is_expired() {
2029 return None;
2030 }
2031 let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
2037 new.device = old.device.clone();
2038 sessions.insert(new.token.clone(), new.clone());
2039 if let Some(b) = &self.backend {
2040 b.save(&new);
2041 }
2042 Some(new)
2043 }
2044
2045 pub fn list_all_unfiltered(&self) -> Vec<Session> {
2050 self.sessions
2051 .lock()
2052 .map(|m| m.values().cloned().collect())
2053 .unwrap_or_default()
2054 }
2055
2056 pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
2058 let sessions = self.sessions.lock().unwrap();
2059 sessions
2060 .values()
2061 .filter(|s| s.user_id == user_id && !s.is_expired())
2062 .cloned()
2063 .collect()
2064 }
2065
2066 pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
2068 let mut sessions = self.sessions.lock().unwrap();
2069 let tokens: Vec<String> = sessions
2070 .iter()
2071 .filter_map(|(t, s)| {
2072 if s.user_id == user_id {
2073 Some(t.clone())
2074 } else {
2075 None
2076 }
2077 })
2078 .collect();
2079 let n = tokens.len();
2080 for t in &tokens {
2081 sessions.remove(t);
2082 if let Some(b) = &self.backend {
2083 b.remove(t);
2084 }
2085 }
2086 n
2087 }
2088
2089 pub fn sweep_expired(&self) -> usize {
2091 let mut sessions = self.sessions.lock().unwrap();
2092 let expired: Vec<String> = sessions
2093 .iter()
2094 .filter_map(|(t, s)| {
2095 if s.is_expired() {
2096 Some(t.clone())
2097 } else {
2098 None
2099 }
2100 })
2101 .collect();
2102 let n = expired.len();
2103 for t in &expired {
2104 sessions.remove(t);
2105 if let Some(b) = &self.backend {
2106 b.remove(t);
2107 }
2108 }
2109 n
2110 }
2111
2112 pub fn set_device(&self, token: &str, device: String) -> bool {
2114 let mut sessions = self.sessions.lock().unwrap();
2115 if let Some(s) = sessions.get_mut(token) {
2116 s.device = Some(device);
2117 if let Some(b) = &self.backend {
2118 b.save(s);
2119 }
2120 true
2121 } else {
2122 false
2123 }
2124 }
2125
2126 pub fn create_guest(&self) -> Session {
2128 use rand::Rng;
2129 let mut rng = rand::thread_rng();
2130 let bytes: [u8; 16] = rng.gen();
2131 let guest_id = format!("guest_{}", hex_encode(&bytes));
2132 self.create(guest_id)
2133 }
2134
2135 pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
2137 let mut sessions = self.sessions.lock().unwrap();
2138 if let Some(session) = sessions.get_mut(token) {
2139 session.user_id = real_user_id;
2140 if let Some(b) = &self.backend {
2141 b.save(session);
2142 }
2143 true
2144 } else {
2145 false
2146 }
2147 }
2148
2149 pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
2154 let mut sessions = self.sessions.lock().unwrap();
2155 if let Some(session) = sessions.get_mut(token) {
2156 session.tenant_id = tenant_id;
2157 if let Some(b) = &self.backend {
2158 b.save(session);
2159 }
2160 true
2161 } else {
2162 false
2163 }
2164 }
2165
2166 pub fn revoke(&self, token: &str) -> bool {
2168 let mut sessions = self.sessions.lock().unwrap();
2169 let removed = sessions.remove(token).is_some();
2170 if removed {
2171 if let Some(b) = &self.backend {
2172 b.remove(token);
2173 }
2174 }
2175 removed
2176 }
2177}
2178
2179#[derive(Debug, Clone, PartialEq, Eq)]
2205pub struct Account {
2206 pub id: String,
2207 pub user_id: String,
2208 pub provider_id: String,
2211 pub account_id: String,
2214 pub access_token: Option<String>,
2215 pub refresh_token: Option<String>,
2216 pub id_token: Option<String>,
2217 pub access_token_expires_at: Option<u64>,
2220 pub refresh_token_expires_at: Option<u64>,
2224 pub scope: Option<String>,
2225 pub password: Option<String>,
2229 pub created_at: u64,
2231 pub updated_at: u64,
2233}
2234
2235impl Account {
2236 pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
2240 let now = now_secs();
2241 Self {
2242 id: generate_token(),
2243 user_id,
2244 provider_id: info.provider.clone(),
2245 account_id: info.provider_account_id.clone(),
2246 access_token: Some(tokens.access_token.clone()),
2247 refresh_token: tokens.refresh_token.clone(),
2248 id_token: tokens.id_token.clone(),
2249 access_token_expires_at: tokens.expires_at,
2250 refresh_token_expires_at: None,
2251 scope: tokens.scope.clone(),
2252 password: None,
2253 created_at: now,
2254 updated_at: now,
2255 }
2256 }
2257
2258 pub fn access_token_expired(&self) -> bool {
2263 match self.access_token_expires_at {
2264 Some(ts) => now_secs() >= ts,
2265 None => false,
2266 }
2267 }
2268}
2269
2270pub trait AccountBackend: Send + Sync {
2273 fn upsert(&self, account: &Account);
2277 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
2280 fn find_for_user(&self, user_id: &str) -> Vec<Account>;
2285 fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
2287 fn delete_for_user(&self, user_id: &str) -> usize {
2292 let accounts = self.find_for_user(user_id);
2293 let n = accounts.len();
2294 for a in accounts {
2295 self.unlink(&a.provider_id, &a.account_id);
2296 }
2297 n
2298 }
2299 fn list_all(&self) -> Vec<Account>;
2304}
2305
2306pub struct InMemoryAccountBackend {
2310 accounts: Mutex<HashMap<(String, String), Account>>,
2314}
2315
2316impl InMemoryAccountBackend {
2317 pub fn new() -> Self {
2318 Self {
2319 accounts: Mutex::new(HashMap::new()),
2320 }
2321 }
2322}
2323
2324impl Default for InMemoryAccountBackend {
2325 fn default() -> Self {
2326 Self::new()
2327 }
2328}
2329
2330impl AccountBackend for InMemoryAccountBackend {
2331 fn upsert(&self, account: &Account) {
2332 let key = (account.provider_id.clone(), account.account_id.clone());
2333 self.accounts.lock().unwrap().insert(key, account.clone());
2334 }
2335 fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2336 self.accounts
2337 .lock()
2338 .unwrap()
2339 .get(&(provider_id.to_string(), account_id.to_string()))
2340 .cloned()
2341 }
2342 fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2343 self.accounts
2344 .lock()
2345 .unwrap()
2346 .values()
2347 .filter(|a| a.user_id == user_id)
2348 .cloned()
2349 .collect()
2350 }
2351 fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2352 self.accounts
2353 .lock()
2354 .unwrap()
2355 .remove(&(provider_id.to_string(), account_id.to_string()))
2356 .is_some()
2357 }
2358 fn list_all(&self) -> Vec<Account> {
2359 self.accounts.lock().unwrap().values().cloned().collect()
2360 }
2361}
2362
2363pub struct AccountStore {
2366 backend: Box<dyn AccountBackend>,
2367}
2368
2369impl Default for AccountStore {
2370 fn default() -> Self {
2371 Self::new()
2372 }
2373}
2374
2375impl AccountStore {
2376 pub fn new() -> Self {
2377 Self {
2378 backend: Box::new(InMemoryAccountBackend::new()),
2379 }
2380 }
2381 pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
2382 Self { backend }
2383 }
2384 pub fn upsert(&self, account: &Account) {
2385 self.backend.upsert(account);
2386 }
2387 pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
2388 self.backend.find_by_provider(provider_id, account_id)
2389 }
2390 pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
2391 self.backend.find_for_user(user_id)
2392 }
2393 pub fn delete_for_user(&self, user_id: &str) -> usize {
2394 self.backend.delete_for_user(user_id)
2395 }
2396
2397 pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
2398 self.backend.unlink(provider_id, account_id)
2399 }
2400
2401 pub fn list_all_unfiltered(&self) -> Vec<Account> {
2415 self.backend.list_all()
2416 }
2417}
2418
2419#[cfg(test)]
2424mod tests {
2425 use super::*;
2426
2427 #[test]
2428 fn anonymous_context() {
2429 let ctx = AuthContext::anonymous();
2430 assert!(!ctx.is_authenticated());
2431 assert!(ctx.user_id.is_none());
2432 }
2433
2434 #[test]
2435 fn authenticated_context() {
2436 let ctx = AuthContext::authenticated("user-1".into());
2437 assert!(ctx.is_authenticated());
2438 assert_eq!(ctx.user_id, Some("user-1".into()));
2439 }
2440
2441 #[test]
2442 fn from_api_key_carries_scope_metadata() {
2443 let ctx =
2444 AuthContext::from_api_key("user-1".into(), "key_abc".into(), Some("read,write".into()));
2445 assert!(ctx.is_authenticated());
2446 assert!(ctx.is_api_key_auth());
2447 assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
2448 assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
2449 assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
2450 }
2451
2452 #[test]
2453 fn session_auth_is_not_api_key_auth() {
2454 let ctx = AuthContext::authenticated("user-1".into());
2455 assert!(!ctx.is_api_key_auth());
2456 assert!(ctx.api_key_id.is_none());
2457 }
2458
2459 #[test]
2460 fn auth_mode_public_allows_anonymous() {
2461 let mode = AuthMode::Public;
2462 assert!(mode.check(&AuthContext::anonymous()));
2463 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2464 }
2465
2466 #[test]
2467 fn auth_mode_user_requires_authenticated() {
2468 let mode = AuthMode::User;
2469 assert!(!mode.check(&AuthContext::anonymous()));
2470 assert!(mode.check(&AuthContext::authenticated("user-1".into())));
2471 }
2472
2473 #[test]
2474 fn auth_mode_from_str() {
2475 assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
2476 assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
2477 assert_eq!(AuthMode::from_str("admin"), None);
2478 }
2479
2480 #[test]
2481 fn session_store_create_and_get() {
2482 let store = SessionStore::new();
2483 let session = store.create("user-1".into());
2484 assert!(!session.token.is_empty());
2485 assert!(session.token.starts_with("pylon_"));
2486
2487 let retrieved = store.get(&session.token).unwrap();
2488 assert_eq!(retrieved.user_id, "user-1");
2489 }
2490
2491 #[test]
2492 fn session_store_resolve() {
2493 let store = SessionStore::new();
2494 let session = store.create("user-1".into());
2495
2496 let ctx = store.resolve(Some(&session.token));
2497 assert!(ctx.is_authenticated());
2498 assert_eq!(ctx.user_id, Some("user-1".into()));
2499
2500 let anon = store.resolve(None);
2501 assert!(!anon.is_authenticated());
2502
2503 let bad = store.resolve(Some("invalid-token"));
2504 assert!(!bad.is_authenticated());
2505 }
2506
2507 #[test]
2508 fn session_store_revoke() {
2509 let store = SessionStore::new();
2510 let session = store.create("user-1".into());
2511
2512 assert!(store.revoke(&session.token));
2513 assert!(store.get(&session.token).is_none());
2514 assert!(!store.revoke(&session.token)); }
2516
2517 #[test]
2518 fn session_to_auth_context() {
2519 let session = Session::new("user-42".into());
2520 let ctx = session.to_auth_context();
2521 assert_eq!(ctx.user_id, Some("user-42".into()));
2522 }
2523
2524 #[test]
2527 fn admin_context() {
2528 let ctx = AuthContext::admin();
2529 assert!(ctx.is_admin);
2530 assert!(ctx.is_authenticated());
2531 }
2532
2533 #[test]
2534 fn anonymous_not_admin() {
2535 let ctx = AuthContext::anonymous();
2536 assert!(!ctx.is_admin);
2537 }
2538
2539 #[test]
2540 fn authenticated_not_admin() {
2541 let ctx = AuthContext::authenticated("user-1".into());
2542 assert!(!ctx.is_admin);
2543 }
2544
2545 #[test]
2548 fn magic_code_create_and_verify() {
2549 let store = MagicCodeStore::new();
2550 let code = store.create("test@example.com");
2551 assert_eq!(code.len(), 6);
2552 assert!(store.verify("test@example.com", &code));
2553 }
2554
2555 #[test]
2556 fn magic_code_wrong_code_rejected() {
2557 let store = MagicCodeStore::new();
2558 store.create("test@example.com");
2559 assert!(!store.verify("test@example.com", "000000"));
2560 }
2561
2562 #[test]
2563 fn magic_code_wrong_email_rejected() {
2564 let store = MagicCodeStore::new();
2565 let code = store.create("test@example.com");
2566 assert!(!store.verify("other@example.com", &code));
2567 }
2568
2569 #[test]
2570 fn magic_code_consumed_after_verify() {
2571 let store = MagicCodeStore::new();
2572 let code = store.create("test@example.com");
2573 assert!(store.verify("test@example.com", &code));
2574 assert!(!store.verify("test@example.com", &code));
2576 }
2577
2578 #[test]
2579 fn magic_code_different_emails_independent() {
2580 let store = MagicCodeStore::new();
2581 let code1 = store.create("alice@example.com");
2582 let code2 = store.create("bob@example.com");
2583 assert!(store.verify("alice@example.com", &code1));
2585 assert!(store.verify("bob@example.com", &code2));
2586 }
2587
2588 #[test]
2591 fn constant_time_eq_equal() {
2592 assert!(constant_time_eq(b"hello", b"hello"));
2593 assert!(constant_time_eq(b"", b""));
2594 }
2595
2596 #[test]
2597 fn constant_time_eq_not_equal() {
2598 assert!(!constant_time_eq(b"hello", b"world"));
2599 assert!(!constant_time_eq(b"hello", b"hell"));
2600 assert!(!constant_time_eq(b"a", b"b"));
2601 }
2602
2603 #[test]
2606 fn generated_tokens_are_unique() {
2607 let t1 = generate_token();
2608 let t2 = generate_token();
2609 assert_ne!(t1, t2);
2610 assert!(t1.starts_with("pylon_"));
2611 assert!(t2.starts_with("pylon_"));
2612 assert_eq!(t1.len(), 6 + 64);
2614 }
2615
2616 #[test]
2619 fn oauth_registry_empty() {
2620 let reg = OAuthRegistry::new();
2621 assert!(reg.get("google").is_none());
2622 }
2623
2624 #[test]
2625 fn oauth_registry_register_and_get() {
2626 let mut reg = OAuthRegistry::new();
2627 reg.register(OAuthConfig {
2628 provider: "google".into(),
2629 client_id: "test-id".into(),
2630 client_secret: "test-secret".into(),
2631 redirect_uri: "http://localhost/callback".into(),
2632 ..Default::default()
2633 });
2634 let config = reg.get("google").unwrap();
2635 assert_eq!(config.client_id, "test-id");
2636 assert!(config.auth_url().contains("accounts.google.com"));
2637 }
2638
2639 #[test]
2647 fn every_builtin_provider_routes_through_oauth_config() {
2648 for spec in provider::builtin::all() {
2649 let cfg = OAuthConfig {
2650 provider: spec.id.into(),
2651 client_id: "cid".into(),
2652 client_secret: "csecret".into(),
2653 redirect_uri: "https://app/cb".into(),
2654 tenant: if spec.id == "microsoft" {
2655 Some("contoso".into())
2656 } else {
2657 None
2658 },
2659 apple: if spec.id == "apple" {
2660 Some(provider::AppleConfig {
2661 team_id: "T".into(),
2662 key_id: "K".into(),
2663 private_key_pem: "no".into(),
2664 })
2665 } else {
2666 None
2667 },
2668 ..Default::default()
2669 };
2670 let auth = cfg.auth_url();
2671 assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
2672 let expected_param = format!("{}=cid", spec.client_id_param);
2674 assert!(
2675 auth.contains(&expected_param),
2676 "{}: missing {}; got auth_url: {}",
2677 spec.id,
2678 expected_param,
2679 auth,
2680 );
2681 assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
2682 if spec.id == "apple" {
2684 assert!(
2685 auth.contains("response_mode=form_post"),
2686 "apple auth_url must include response_mode=form_post; got {auth}"
2687 );
2688 }
2689 }
2690 }
2691
2692 #[test]
2695 fn microsoft_tenant_placeholder_resolves() {
2696 let cfg = OAuthConfig {
2697 provider: "microsoft".into(),
2698 client_id: "id".into(),
2699 client_secret: "secret".into(),
2700 redirect_uri: "https://app/cb".into(),
2701 tenant: Some("contoso.onmicrosoft.com".into()),
2702 ..Default::default()
2703 };
2704 assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
2705 assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
2706 }
2707
2708 #[test]
2710 fn microsoft_default_tenant_common() {
2711 let cfg = OAuthConfig {
2712 provider: "microsoft".into(),
2713 client_id: "id".into(),
2714 client_secret: "secret".into(),
2715 redirect_uri: "https://app/cb".into(),
2716 ..Default::default()
2717 };
2718 assert!(cfg.auth_url().contains("/common/"));
2719 assert!(cfg.token_url().contains("/common/"));
2720 }
2721
2722 #[test]
2725 fn scopes_override_replaces_spec_default() {
2726 let cfg = OAuthConfig {
2727 provider: "github".into(),
2728 client_id: "id".into(),
2729 client_secret: "secret".into(),
2730 redirect_uri: "https://app/cb".into(),
2731 scopes_override: Some("repo user:email".into()),
2732 ..Default::default()
2733 };
2734 let auth = cfg.auth_url();
2735 assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
2737 }
2738
2739 #[test]
2744 fn apple_exchange_requires_apple_config() {
2745 let cfg = OAuthConfig {
2746 provider: "apple".into(),
2747 client_id: "com.example.app".into(),
2748 client_secret: String::new(),
2749 redirect_uri: "https://app/cb".into(),
2750 apple: None, ..Default::default()
2752 };
2753 let err = cfg.exchange_code_full("x").unwrap_err();
2754 assert!(err.contains("apple provider requires"), "got: {err}");
2755 }
2756
2757 #[test]
2762 fn oidc_issuer_uses_discovered_endpoints() {
2763 let issuer = "https://acme.test.invalid";
2764 provider::oidc_cache::insert_for_test(
2765 issuer,
2766 provider::DiscoveredSpec {
2767 auth_url: "https://acme.test.invalid/authorize".into(),
2768 token_url: "https://acme.test.invalid/oauth/token".into(),
2769 userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
2770 scopes: "openid email profile".into(),
2771 userinfo_parser: provider::UserinfoParser::Oidc,
2772 token_exchange: provider::TokenExchangeShape::Standard,
2773 },
2774 );
2775 let cfg = OAuthConfig {
2776 provider: "auth0".into(), client_id: "id".into(),
2778 client_secret: "secret".into(),
2779 redirect_uri: "https://app/cb".into(),
2780 oidc_issuer: Some(issuer.into()),
2781 ..Default::default()
2782 };
2783 assert!(cfg
2784 .auth_url()
2785 .starts_with("https://acme.test.invalid/authorize?"));
2786 assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
2787 assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
2788 }
2789
2790 #[test]
2796 fn apple_auth_url_includes_form_post() {
2797 let cfg = OAuthConfig {
2798 provider: "apple".into(),
2799 client_id: "com.example.app".into(),
2800 client_secret: String::new(),
2801 redirect_uri: "https://app/cb".into(),
2802 apple: Some(provider::AppleConfig {
2803 team_id: "T".into(),
2804 key_id: "K".into(),
2805 private_key_pem: "no".into(),
2806 }),
2807 ..Default::default()
2808 };
2809 let auth = cfg.auth_url();
2810 assert!(auth.contains("response_mode=form_post"), "got: {auth}");
2811 assert_eq!(cfg.userinfo_url(), "");
2813 }
2814
2815 #[test]
2820 fn apple_id_token_decode_extracts_identity() {
2821 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
2823 use base64::Engine;
2824 let claims = serde_json::json!({
2825 "iss": "https://appleid.apple.com",
2826 "sub": "001234.abc.def",
2827 "aud": "com.example.app",
2828 "email": "user@privaterelay.appleid.com",
2829 "email_verified": "true",
2830 });
2831 let claims_b64 =
2832 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims.to_string().as_bytes());
2833 let id_token = format!("{header}.{claims_b64}.signature_ignored");
2834
2835 let cfg = OAuthConfig {
2836 provider: "apple".into(),
2837 client_id: "com.example.app".into(),
2838 client_secret: String::new(),
2839 redirect_uri: "https://app/cb".into(),
2840 apple: Some(provider::AppleConfig {
2841 team_id: "T".into(),
2842 key_id: "K".into(),
2843 private_key_pem: "no".into(),
2844 }),
2845 ..Default::default()
2846 };
2847 let info = cfg
2848 .fetch_userinfo_with_id_token("ignored", Some(&id_token))
2849 .expect("apple id_token decode");
2850 assert_eq!(info.provider_account_id, "001234.abc.def");
2851 assert_eq!(info.email, "user@privaterelay.appleid.com");
2852
2853 let err = cfg.fetch_userinfo_full("token").unwrap_err();
2856 assert!(err.contains("apple login requires"), "got: {err}");
2857 }
2858
2859 #[test]
2863 fn twitter_auth_url_includes_pkce() {
2864 let cfg = OAuthConfig {
2865 provider: "twitter".into(),
2866 client_id: "tw_client".into(),
2867 client_secret: "tw_secret".into(),
2868 redirect_uri: "https://app/cb".into(),
2869 ..Default::default()
2870 };
2871 let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
2872 let v = verifier.expect("twitter must produce verifier");
2873 assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
2874 assert!(url.contains("code_challenge="), "got: {url}");
2875 assert!(url.contains("code_challenge_method=S256"), "got: {url}");
2876
2877 let google = OAuthConfig {
2879 provider: "google".into(),
2880 client_id: "g".into(),
2881 client_secret: "g".into(),
2882 redirect_uri: "https://app/cb".into(),
2883 ..Default::default()
2884 };
2885 let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
2886 assert!(gverifier.is_none(), "google should not add PKCE");
2887 assert!(!gurl.contains("code_challenge"), "got: {gurl}");
2888 }
2889
2890 #[test]
2893 fn tiktok_uses_client_key_and_comma_scopes() {
2894 let cfg = OAuthConfig {
2895 provider: "tiktok".into(),
2896 client_id: "tk_client".into(),
2897 client_secret: "tk_secret".into(),
2898 redirect_uri: "https://app/cb".into(),
2899 scopes_override: Some("user.info.basic video.list".into()),
2900 ..Default::default()
2901 };
2902 let auth = cfg.auth_url();
2903 assert!(auth.contains("client_key=tk_client"), "got: {auth}");
2904 assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
2906 assert!(
2908 !auth.contains("user.info.basic%20video.list"),
2909 "got: {auth}"
2910 );
2911 }
2912
2913 #[test]
2917 fn token_exchange_url_encodes_code() {
2918 let raw = "code+with/special=chars";
2924 let encoded = url_encode(raw);
2925 assert!(!encoded.contains('+'));
2926 assert!(!encoded.contains('/'));
2927 assert!(!encoded.contains('='));
2928 assert!(encoded.contains("%2B"));
2929 assert!(encoded.contains("%2F"));
2930 assert!(encoded.contains("%3D"));
2931 }
2932
2933 #[test]
2937 fn sanitize_token_error_redacts_secrets() {
2938 let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
2939 let scrubbed = sanitize_token_error(raw.into());
2940 assert!(!scrubbed.contains("sk_real_secret_value"));
2941 assert!(!scrubbed.contains("verifierxyz"));
2942 assert!(scrubbed.contains("client_secret=***"));
2943 assert!(scrubbed.contains("code_verifier=***"));
2944 assert!(scrubbed.contains("invalid_grant"));
2946 assert!(scrubbed.contains("hint=check%20your%20code"));
2947 }
2948
2949 #[test]
2952 fn sanitize_token_error_redacts_json_secrets() {
2953 let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
2954 let scrubbed = sanitize_token_error(raw.into());
2955 assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
2956 assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
2957 assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
2958 assert!(
2959 scrubbed.contains(r#""client_secret":"***""#),
2960 "got: {scrubbed}"
2961 );
2962 assert!(
2963 scrubbed.contains(r#""refresh_token":"***""#),
2964 "got: {scrubbed}"
2965 );
2966 assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
2967 assert!(scrubbed.contains("invalid_grant"));
2968 }
2969
2970 #[test]
2975 fn sanitize_token_error_handles_utf8() {
2976 let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
2977 let scrubbed = sanitize_token_error(raw.into());
2978 assert!(
2979 scrubbed.contains("⚠️"),
2980 "non-ASCII chars must survive: {scrubbed}"
2981 );
2982 assert!(!scrubbed.contains("sk_x"));
2983 assert!(scrubbed.contains("client_secret=***"));
2984 }
2985
2986 #[test]
2991 fn oidc_discovery_picks_token_auth_method() {
2992 let json_post = r#"{
2993 "issuer": "https://acme.test/",
2994 "authorization_endpoint": "https://acme.test/auth",
2995 "token_endpoint": "https://acme.test/token",
2996 "token_endpoint_auth_methods_supported": ["client_secret_post"]
2997 }"#;
2998 let spec = provider::OidcDiscoveryDoc::parse(json_post)
2999 .unwrap()
3000 .into_spec();
3001 assert!(matches!(
3002 spec.token_exchange,
3003 provider::TokenExchangeShape::Standard
3004 ));
3005
3006 let json_default = r#"{
3008 "issuer": "https://acme.test/",
3009 "authorization_endpoint": "https://acme.test/auth",
3010 "token_endpoint": "https://acme.test/token"
3011 }"#;
3012 let spec = provider::OidcDiscoveryDoc::parse(json_default)
3013 .unwrap()
3014 .into_spec();
3015 assert!(matches!(
3016 spec.token_exchange,
3017 provider::TokenExchangeShape::BasicAuth
3018 ));
3019 }
3020
3021 #[test]
3024 fn oidc_discovery_rejects_incomplete_doc() {
3025 let json = r#"{
3027 "issuer": "https://acme.test/",
3028 "authorization_endpoint": "https://acme.test/auth"
3029 }"#;
3030 let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
3031 assert!(err.contains("token_endpoint"), "got: {err}");
3032 }
3033
3034 #[test]
3038 fn from_env_picks_up_discord() {
3039 let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
3042 let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
3043 std::env::set_var(key_id, "discord-test-id");
3047 std::env::set_var(key_secret, "discord-test-secret");
3048
3049 let reg = OAuthRegistry::from_env();
3050 let discord = reg.get("discord").expect("discord registered");
3051 assert_eq!(discord.client_id, "discord-test-id");
3052 assert!(discord.auth_url().contains("discord.com"));
3053
3054 std::env::remove_var(key_id);
3055 std::env::remove_var(key_secret);
3056 }
3057
3058 #[test]
3061 fn guest_session() {
3062 let store = SessionStore::new();
3063 let session = store.create_guest();
3064 assert!(session.user_id.starts_with("guest_"));
3065 assert!(!session.token.is_empty());
3066
3067 let ctx = store.resolve(Some(&session.token));
3068 assert!(ctx.is_authenticated());
3069 assert!(ctx.user_id.unwrap().starts_with("guest_"));
3070 }
3071
3072 #[test]
3073 fn upgrade_guest_to_real_user() {
3074 let store = SessionStore::new();
3075 let session = store.create_guest();
3076 assert!(session.user_id.starts_with("guest_"));
3077
3078 let upgraded = store.upgrade(&session.token, "real-user-123".into());
3079 assert!(upgraded);
3080
3081 let ctx = store.resolve(Some(&session.token));
3082 assert_eq!(ctx.user_id, Some("real-user-123".into()));
3083 }
3084
3085 #[test]
3086 fn upgrade_invalid_token_fails() {
3087 let store = SessionStore::new();
3088 let upgraded = store.upgrade("nonexistent-token", "user".into());
3089 assert!(!upgraded);
3090 }
3091
3092 #[test]
3093 fn guest_context() {
3094 let ctx = AuthContext::guest("guest_123".into());
3095 assert!(!ctx.is_authenticated());
3098 assert!(ctx.is_guest);
3099 assert!(!ctx.is_admin);
3100 assert_eq!(ctx.user_id, Some("guest_123".into()));
3101 assert!(!AuthMode::User.check(&ctx));
3102 assert!(AuthMode::Public.check(&ctx));
3103 }
3104
3105 #[test]
3106 fn oauth_token_urls() {
3107 let google = OAuthConfig {
3108 provider: "google".into(),
3109 client_id: "x".into(),
3110 client_secret: "x".into(),
3111 redirect_uri: "x".into(),
3112 ..Default::default()
3113 };
3114 assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
3115 let github = OAuthConfig {
3116 provider: "github".into(),
3117 client_id: "x".into(),
3118 client_secret: "x".into(),
3119 redirect_uri: "x".into(),
3120 ..Default::default()
3121 };
3122 assert_eq!(
3123 github.token_url(),
3124 "https://github.com/login/oauth/access_token"
3125 );
3126 let unknown = OAuthConfig {
3127 provider: "unknown".into(),
3128 client_id: "x".into(),
3129 client_secret: "x".into(),
3130 redirect_uri: "x".into(),
3131 ..Default::default()
3132 };
3133 assert_eq!(unknown.token_url(), "");
3134 assert!(unknown.auth_url().is_empty());
3135 }
3136
3137 #[test]
3138 fn oauth_auth_url_github() {
3139 let config = OAuthConfig {
3140 provider: "github".into(),
3141 client_id: "gh-id".into(),
3142 client_secret: "gh-secret".into(),
3143 redirect_uri: "http://localhost/cb".into(),
3144 ..Default::default()
3145 };
3146 assert!(config.auth_url().contains("github.com"));
3147 assert!(config.auth_url().contains("gh-id"));
3148 }
3149
3150 #[test]
3151 fn oauth_auth_url_with_state() {
3152 let config = OAuthConfig {
3153 provider: "google".into(),
3154 client_id: "test-id".into(),
3155 client_secret: "test-secret".into(),
3156 redirect_uri: "http://localhost/cb".into(),
3157 ..Default::default()
3158 };
3159 let url = config.auth_url_with_state("random_state_123");
3160 assert!(url.contains("&state=random_state_123"));
3161 }
3162
3163 #[test]
3164 fn oauth_state_store_create_and_validate() {
3165 let store = OAuthStateStore::new();
3166 let token = store.create("google", "https://app/cb", "https://app/login");
3167 let rec = store.validate(&token, "google").expect("valid first time");
3168 assert_eq!(rec.callback_url, "https://app/cb");
3169 assert_eq!(rec.error_callback_url, "https://app/login");
3170 assert!(store.validate(&token, "google").is_none());
3172 }
3173
3174 #[test]
3175 fn oauth_state_store_wrong_provider_rejected() {
3176 let store = OAuthStateStore::new();
3177 let token = store.create("google", "https://app/cb", "https://app/cb");
3178 assert!(store.validate(&token, "github").is_none());
3179 }
3180
3181 #[test]
3182 fn oauth_state_store_invalid_state_rejected() {
3183 let store = OAuthStateStore::new();
3184 assert!(store.validate("nonexistent", "google").is_none());
3185 }
3186
3187 #[test]
3188 fn validate_trusted_redirect_basics() {
3189 let trusted = vec!["http://localhost:3000".to_string()];
3190 assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
3191 assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
3192 assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
3193
3194 assert!(matches!(
3196 validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
3197 Err(TrustedOriginError::NotTrusted { .. })
3198 ));
3199 assert!(matches!(
3202 validate_trusted_redirect("javascript:alert(1)", &trusted),
3203 Err(TrustedOriginError::NotHttp)
3204 ));
3205 assert!(matches!(
3206 validate_trusted_redirect("", &trusted),
3207 Err(TrustedOriginError::Empty)
3208 ));
3209 }
3210}