1use std::sync::RwLock;
23use std::time::{Duration, Instant};
24
25use serde::Deserialize;
26use url::Url;
27
28use crate::auth_client::AuthFuture;
29use crate::error::AuthError;
30use crate::social_providers::{ProviderType, SocialProvider, SocialProviderConfig, SocialUserInfo};
31
32pub const DISCOVERY_TTL: Duration = Duration::from_secs(60 * 60); #[derive(Debug, Clone)]
48pub struct DiscoveryDoc {
49 pub authorize_url: String,
50 pub token_url: String,
51 pub userinfo_url: String,
52}
53
54#[derive(Deserialize)]
58struct DiscoveryDocRaw {
59 authorization_endpoint: String,
60 token_endpoint: String,
61 userinfo_endpoint: String,
62}
63
64impl From<DiscoveryDocRaw> for DiscoveryDoc {
65 fn from(raw: DiscoveryDocRaw) -> Self {
66 Self {
67 authorize_url: raw.authorization_endpoint,
68 token_url: raw.token_endpoint,
69 userinfo_url: raw.userinfo_endpoint,
70 }
71 }
72}
73
74#[derive(Debug)]
81struct DiscoveryCache {
82 doc: DiscoveryDoc,
83 refreshed_at: Instant,
84}
85
86#[derive(Debug)]
90enum CustomOidcEndpoints {
91 Discovery(String),
93 Pinned(DiscoveryDoc),
95}
96
97fn parse_custom_oidc_config(
98 value: Option<&serde_json::Value>,
99) -> Result<CustomOidcEndpoints, AuthError> {
100 let val = value
101 .ok_or_else(|| AuthError::Validation("custom_oidc requires a config object".into()))?;
102
103 let obj = val
104 .as_object()
105 .ok_or_else(|| AuthError::Validation("custom_oidc config must be a JSON object".into()))?;
106
107 let discovery_url = get_opt_str(obj, "discovery_url")?;
108 let authorize_url = get_opt_str(obj, "authorize_url")?;
109 let token_url = get_opt_str(obj, "token_url")?;
110 let userinfo_url = get_opt_str(obj, "userinfo_url")?;
111
112 let has_explicit = authorize_url.is_some() || token_url.is_some() || userinfo_url.is_some();
113
114 if has_explicit {
115 let authorize = authorize_url.ok_or_else(|| {
117 AuthError::Validation("missing one of authorize_url/token_url/userinfo_url".into())
118 })?;
119 let token = token_url.ok_or_else(|| {
120 AuthError::Validation("missing one of authorize_url/token_url/userinfo_url".into())
121 })?;
122 let userinfo = userinfo_url.ok_or_else(|| {
123 AuthError::Validation("missing one of authorize_url/token_url/userinfo_url".into())
124 })?;
125
126 validate_url(authorize, "authorize_url")?;
127 validate_url(token, "token_url")?;
128 validate_url(userinfo, "userinfo_url")?;
129
130 if discovery_url.is_some() {
131 tracing::debug!("custom_oidc: explicit endpoints override discovery_url");
132 }
133
134 return Ok(CustomOidcEndpoints::Pinned(DiscoveryDoc {
135 authorize_url: authorize.to_owned(),
136 token_url: token.to_owned(),
137 userinfo_url: userinfo.to_owned(),
138 }));
139 }
140
141 let discovery = discovery_url.ok_or_else(|| {
143 AuthError::Validation(
144 "custom_oidc config requires either discovery_url or all of \
145 authorize_url/token_url/userinfo_url"
146 .into(),
147 )
148 })?;
149
150 validate_url(discovery, "discovery_url")?;
151
152 Ok(CustomOidcEndpoints::Discovery(discovery.to_owned()))
153}
154
155fn get_opt_str<'a>(
156 obj: &'a serde_json::Map<String, serde_json::Value>,
157 field: &str,
158) -> Result<Option<&'a str>, AuthError> {
159 match obj.get(field) {
160 None => Ok(None),
161 Some(v) => v
162 .as_str()
163 .ok_or_else(|| AuthError::Validation(format!("{field} must be a string")))
164 .map(Some),
165 }
166}
167
168fn validate_url(url: &str, field: &str) -> Result<(), AuthError> {
169 Url::parse(url)
170 .map(|_| ())
171 .map_err(|_| AuthError::Validation(format!("{field} is not a valid URL")))
172}
173
174#[derive(Deserialize)]
182struct UserInfoClaims {
183 sub: String,
184 email: Option<String>,
185 email_verified: Option<bool>,
186 name: Option<String>,
187 preferred_username: Option<String>,
188 nickname: Option<String>,
189 picture: Option<String>,
190 avatar_url: Option<String>,
191}
192
193fn map_user_info(claims: UserInfoClaims) -> Result<SocialUserInfo, AuthError> {
194 let email = claims
195 .email
196 .ok_or_else(|| AuthError::OAuthUserInfoFetch("userinfo missing email".into()))?;
197 let name = claims
198 .name
199 .or(claims.preferred_username)
200 .or(claims.nickname);
201 let avatar_url = claims.picture.or(claims.avatar_url);
202 Ok(SocialUserInfo {
203 provider_user_id: claims.sub,
204 email,
205 email_verified: claims.email_verified.unwrap_or(false),
206 name,
207 avatar_url,
208 })
209}
210
211pub struct CustomOidcSocialProvider {
221 client_id: String,
222 client_secret: String,
223 scopes: Vec<String>,
224 http: reqwest::Client,
225
226 discovery: Option<RwLock<DiscoveryCache>>,
234
235 pinned_endpoints: Option<DiscoveryDoc>,
238
239 discovery_url: Option<String>,
242
243 ttl: Duration,
246}
247
248impl std::fmt::Debug for CustomOidcSocialProvider {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct("CustomOidcSocialProvider")
253 .field("client_id", &self.client_id)
254 .field("client_secret", &"[redacted]")
255 .field("scopes", &self.scopes)
256 .field("discovery_url", &self.discovery_url)
257 .field("ttl", &self.ttl)
258 .finish_non_exhaustive()
259 }
260}
261
262impl CustomOidcSocialProvider {
265 pub async fn new(config: SocialProviderConfig) -> Result<Self, AuthError> {
274 if config.provider_type != ProviderType::CustomOidc {
275 return Err(AuthError::Validation(
276 "provider_type mismatch: expected CustomOidc".into(),
277 ));
278 }
279 if config.scopes.is_empty() {
280 return Err(AuthError::Validation("scopes must not be empty".into()));
281 }
282
283 let endpoints = parse_custom_oidc_config(config.config.as_ref())?;
284
285 let http = reqwest::Client::builder()
286 .user_agent("allowthem-oauth")
287 .timeout(Duration::from_secs(15))
288 .build()
289 .map_err(|e| AuthError::Validation(format!("reqwest client build failed: {e}")))?;
290
291 match endpoints {
292 CustomOidcEndpoints::Discovery(url) => {
293 let doc = fetch_discovery(&http, &url).await?;
294 let cache = RwLock::new(DiscoveryCache {
295 doc,
296 refreshed_at: Instant::now(),
297 });
298 Ok(Self {
299 client_id: config.client_id,
300 client_secret: config.client_secret,
301 scopes: config.scopes,
302 http,
303 discovery: Some(cache),
304 pinned_endpoints: None,
305 discovery_url: Some(url),
306 ttl: DISCOVERY_TTL,
307 })
308 }
309 CustomOidcEndpoints::Pinned(doc) => Ok(Self {
310 client_id: config.client_id,
311 client_secret: config.client_secret,
312 scopes: config.scopes,
313 http,
314 discovery: None,
315 pinned_endpoints: Some(doc),
316 discovery_url: None,
317 ttl: DISCOVERY_TTL,
318 }),
319 }
320 }
321
322 async fn current_endpoints(&self) -> Result<DiscoveryDoc, AuthError> {
329 if let Some(ref doc) = self.pinned_endpoints {
330 return Ok(doc.clone());
331 }
332
333 let lock = self
334 .discovery
335 .as_ref()
336 .expect("discovery cache populated by new()");
337
338 let (doc, refreshed_at) = {
340 let g = lock.read().unwrap();
341 (g.doc.clone(), g.refreshed_at)
342 };
343
344 if refreshed_at.elapsed() > self.ttl {
345 let url = self
346 .discovery_url
347 .as_deref()
348 .expect("discovery_url set whenever discovery cache is set");
349 match fetch_discovery(&self.http, url).await {
350 Ok(new_doc) => {
351 let mut g = lock.write().unwrap();
352 *g = DiscoveryCache {
353 doc: new_doc.clone(),
354 refreshed_at: Instant::now(),
355 };
356 return Ok(new_doc);
357 }
358 Err(e) => {
359 tracing::warn!(
360 "custom_oidc: discovery refresh failed: {e}; serving stale endpoints"
361 );
362 }
363 }
364 }
365
366 Ok(doc)
367 }
368
369 fn cached_endpoints_or_fail(&self) -> DiscoveryDoc {
380 if let Some(ref doc) = self.pinned_endpoints {
381 return doc.clone();
382 }
383 self.discovery
384 .as_ref()
385 .expect("discovery cache populated by new()")
386 .read()
387 .unwrap()
388 .doc
389 .clone()
390 }
391
392 #[cfg(test)]
394 pub(crate) fn set_discovery_ttl(&mut self, ttl: Duration) {
395 self.ttl = ttl;
396 }
397}
398
399async fn fetch_discovery(http: &reqwest::Client, url: &str) -> Result<DiscoveryDoc, AuthError> {
402 let resp = http
403 .get(url)
404 .header("Accept", "application/json")
405 .send()
406 .await
407 .map_err(|e| AuthError::OAuthHttp(format!("discovery fetch failed: {e}")))?;
408
409 let status = resp.status();
410 if !status.is_success() {
411 return Err(AuthError::OAuthHttp(format!("discovery fetch {status}")));
412 }
413
414 let json: serde_json::Value = resp
417 .json()
418 .await
419 .map_err(|e| AuthError::OAuthHttp(format!("discovery fetch JSON parse failed: {e}")))?;
420
421 let obj = json
422 .as_object()
423 .ok_or_else(|| AuthError::OAuthHttp("discovery document is not a JSON object".into()))?;
424
425 let authorize_url = obj
426 .get("authorization_endpoint")
427 .and_then(|v| v.as_str())
428 .ok_or_else(|| {
429 AuthError::Validation("discovery doc is missing authorization_endpoint".into())
430 })?
431 .to_owned();
432
433 let token_url = obj
434 .get("token_endpoint")
435 .and_then(|v| v.as_str())
436 .ok_or_else(|| AuthError::Validation("discovery doc is missing token_endpoint".into()))?
437 .to_owned();
438
439 let userinfo_url = obj
440 .get("userinfo_endpoint")
441 .and_then(|v| v.as_str())
442 .ok_or_else(|| AuthError::Validation("discovery doc is missing userinfo_endpoint".into()))?
443 .to_owned();
444
445 Ok(DiscoveryDoc {
446 authorize_url,
447 token_url,
448 userinfo_url,
449 })
450}
451
452impl SocialProvider for CustomOidcSocialProvider {
455 fn provider_type(&self) -> ProviderType {
456 ProviderType::CustomOidc
457 }
458
459 fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String {
466 let endpoints = self.cached_endpoints_or_fail();
467 let mut url = Url::parse(&endpoints.authorize_url)
468 .expect("authorize_url is a valid URL — validated in new()");
469 url.query_pairs_mut()
470 .append_pair("client_id", &self.client_id)
471 .append_pair("redirect_uri", redirect_uri)
472 .append_pair("response_type", "code")
473 .append_pair("scope", &self.scopes.join(" "))
474 .append_pair("state", state)
475 .append_pair("code_challenge", pkce_challenge)
476 .append_pair("code_challenge_method", "S256");
477 url.into()
478 }
479
480 fn exchange_code<'a>(
481 &'a self,
482 code: &'a str,
483 redirect_uri: &'a str,
484 pkce_verifier: &'a str,
485 ) -> AuthFuture<'a, String> {
486 Box::pin(async move {
487 let endpoints = self.current_endpoints().await?;
488 let resp = self
489 .http
490 .post(&endpoints.token_url)
491 .header("Accept", "application/json")
492 .form(&[
493 ("code", code),
494 ("client_id", self.client_id.as_str()),
495 ("client_secret", self.client_secret.as_str()),
496 ("redirect_uri", redirect_uri),
497 ("grant_type", "authorization_code"),
498 ("code_verifier", pkce_verifier),
499 ])
500 .send()
501 .await
502 .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
503
504 let status = resp.status();
505 if !status.is_success() {
506 let body = resp.text().await.unwrap_or_default();
507 return Err(AuthError::OAuthTokenExchange(format!("{status}: {body}")));
508 }
509
510 let json: serde_json::Value = resp
511 .json()
512 .await
513 .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
514
515 json.get("access_token")
516 .and_then(|v| v.as_str())
517 .map(|s| s.to_owned())
518 .ok_or_else(|| {
519 AuthError::OAuthTokenExchange("token response missing access_token".into())
520 })
521 })
522 }
523
524 fn fetch_user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, SocialUserInfo> {
525 Box::pin(async move {
526 let endpoints = self.current_endpoints().await?;
527 let resp = self
528 .http
529 .get(&endpoints.userinfo_url)
530 .bearer_auth(access_token)
531 .send()
532 .await
533 .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
534
535 let status = resp.status();
536 if !status.is_success() {
537 let body = resp.text().await.unwrap_or_default();
538 return Err(AuthError::OAuthUserInfoFetch(format!("{status}: {body}")));
539 }
540
541 let claims: UserInfoClaims = resp.json().await.map_err(|e| {
542 AuthError::OAuthUserInfoFetch(format!("userinfo parse failed: {e}"))
543 })?;
544
545 map_user_info(claims)
546 })
547 }
548}
549
550#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::types::SocialProviderId;
556
557 fn pinned_config() -> SocialProviderConfig {
560 SocialProviderConfig {
561 id: SocialProviderId::new(),
562 provider_type: ProviderType::CustomOidc,
563 display_name: "Test OIDC".into(),
564 client_id: "test-client-id".into(),
565 client_secret: "test-client-secret".into(),
566 scopes: vec!["openid".into(), "email".into()],
567 enabled: true,
568 priority: 0,
569 config: Some(serde_json::json!({
570 "authorize_url": "https://idp.example.com/authorize",
571 "token_url": "https://idp.example.com/token",
572 "userinfo_url": "https://idp.example.com/userinfo",
573 })),
574 }
575 }
576
577 fn discovery_config(discovery_url: &str) -> SocialProviderConfig {
578 SocialProviderConfig {
579 id: SocialProviderId::new(),
580 provider_type: ProviderType::CustomOidc,
581 display_name: "Test OIDC".into(),
582 client_id: "test-client-id".into(),
583 client_secret: "test-client-secret".into(),
584 scopes: vec!["openid".into(), "email".into()],
585 enabled: true,
586 priority: 0,
587 config: Some(serde_json::json!({ "discovery_url": discovery_url })),
588 }
589 }
590
591 async fn mount_discovery_doc(server: &wiremock::MockServer) {
592 use wiremock::matchers::{method, path};
593 use wiremock::{Mock, ResponseTemplate};
594
595 let base = server.uri();
596 Mock::given(method("GET"))
597 .and(path("/.well-known/openid-configuration"))
598 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
599 "authorization_endpoint": format!("{base}/authorize"),
600 "token_endpoint": format!("{base}/token"),
601 "userinfo_endpoint": format!("{base}/userinfo"),
602 })))
603 .mount(server)
604 .await;
605 }
606
607 #[test]
610 fn discovery_doc_from_raw_renames_endpoints() {
611 let raw = DiscoveryDocRaw {
612 authorization_endpoint: "https://idp.example.com/authorize".to_owned(),
613 token_endpoint: "https://idp.example.com/token".to_owned(),
614 userinfo_endpoint: "https://idp.example.com/userinfo".to_owned(),
615 };
616 let doc = DiscoveryDoc::from(raw);
617 assert_eq!(doc.authorize_url, "https://idp.example.com/authorize");
618 assert_eq!(doc.token_url, "https://idp.example.com/token");
619 assert_eq!(doc.userinfo_url, "https://idp.example.com/userinfo");
620 }
621
622 #[tokio::test]
625 async fn new_rejects_provider_type_mismatch() {
626 let mut cfg = pinned_config();
627 cfg.provider_type = ProviderType::Google;
628 let err = CustomOidcSocialProvider::new(cfg).await.unwrap_err();
629 assert!(matches!(err, AuthError::Validation(_)));
630 }
631
632 #[tokio::test]
633 async fn new_rejects_empty_scopes() {
634 let mut cfg = pinned_config();
635 cfg.scopes = vec![];
636 let err = CustomOidcSocialProvider::new(cfg).await.unwrap_err();
637 assert!(matches!(err, AuthError::Validation(_)));
638 }
639
640 #[test]
643 fn parse_config_rejects_none() {
644 let err = parse_custom_oidc_config(None).unwrap_err();
645 assert!(matches!(err, AuthError::Validation(ref m) if m.contains("requires a config")));
646 }
647
648 #[test]
649 fn parse_config_rejects_empty_object() {
650 let v = serde_json::json!({});
651 let err = parse_custom_oidc_config(Some(&v)).unwrap_err();
652 assert!(matches!(err, AuthError::Validation(_)));
653 }
654
655 #[test]
656 fn parse_config_rejects_partial_explicit_endpoints() {
657 let v = serde_json::json!({
658 "authorize_url": "https://idp.example.com/authorize",
659 "token_url": "https://idp.example.com/token",
660 });
662 let err = parse_custom_oidc_config(Some(&v)).unwrap_err();
663 assert!(matches!(err, AuthError::Validation(ref m) if m.contains("missing one of")));
664 }
665
666 #[test]
667 fn parse_config_rejects_invalid_url() {
668 let v = serde_json::json!({
669 "authorize_url": "not-a-url",
670 "token_url": "https://idp.example.com/token",
671 "userinfo_url": "https://idp.example.com/userinfo",
672 });
673 let err = parse_custom_oidc_config(Some(&v)).unwrap_err();
674 assert!(matches!(err, AuthError::Validation(ref m) if m.contains("not a valid URL")));
675 }
676
677 #[test]
678 fn parse_config_pinned_endpoints_path() {
679 let v = serde_json::json!({
680 "authorize_url": "https://idp.example.com/authorize",
681 "token_url": "https://idp.example.com/token",
682 "userinfo_url": "https://idp.example.com/userinfo",
683 });
684 let result = parse_custom_oidc_config(Some(&v)).unwrap();
685 assert!(matches!(result, CustomOidcEndpoints::Pinned(_)));
686 }
687
688 #[test]
689 fn parse_config_discovery_path() {
690 let v = serde_json::json!({
691 "discovery_url": "https://idp.example.com/.well-known/openid-configuration",
692 });
693 let result = parse_custom_oidc_config(Some(&v)).unwrap();
694 assert!(matches!(result, CustomOidcEndpoints::Discovery(_)));
695 }
696
697 #[test]
698 fn parse_config_explicit_overrides_discovery() {
699 let v = serde_json::json!({
700 "discovery_url": "https://idp.example.com/.well-known/openid-configuration",
701 "authorize_url": "https://idp.example.com/authorize",
702 "token_url": "https://idp.example.com/token",
703 "userinfo_url": "https://idp.example.com/userinfo",
704 });
705 let result = parse_custom_oidc_config(Some(&v)).unwrap();
707 assert!(matches!(result, CustomOidcEndpoints::Pinned(_)));
708 }
709
710 #[test]
711 fn parse_config_partial_explicit_with_discovery_url_still_errors() {
712 let v = serde_json::json!({
715 "discovery_url": "https://idp.example.com/.well-known/openid-configuration",
716 "authorize_url": "https://idp.example.com/authorize",
717 "token_url": "https://idp.example.com/token",
718 });
720 let err = parse_custom_oidc_config(Some(&v)).unwrap_err();
721 assert!(matches!(err, AuthError::Validation(ref m) if m.contains("missing one of")));
722 }
723
724 #[tokio::test]
727 async fn authorize_url_includes_required_params() {
728 let provider = CustomOidcSocialProvider::new(pinned_config())
729 .await
730 .unwrap();
731 let url =
732 provider.authorize_url("https://app.example.com/callback", "mystate", "mychallenge");
733 assert!(url.contains("client_id=test-client-id"), "url: {url}");
734 assert!(url.contains("redirect_uri="), "url: {url}");
735 assert!(url.contains("response_type=code"), "url: {url}");
736 assert!(url.contains("state=mystate"), "url: {url}");
737 assert!(url.contains("code_challenge=mychallenge"), "url: {url}");
738 assert!(url.contains("code_challenge_method=S256"), "url: {url}");
739 }
740
741 #[tokio::test]
742 async fn authorize_url_uses_config_scopes_joined_by_space() {
743 let provider = CustomOidcSocialProvider::new(pinned_config())
744 .await
745 .unwrap();
746 let url = provider.authorize_url("https://app.example.com/callback", "s", "c");
747 assert!(
749 url.contains("scope=openid+email") || url.contains("scope=openid%20email"),
750 "url: {url}"
751 );
752 }
753
754 #[test]
757 fn userinfo_claim_mapping_uses_name_when_present() {
758 let claims = UserInfoClaims {
759 sub: "u1".into(),
760 email: Some("u@example.com".into()),
761 email_verified: Some(true),
762 name: Some("Alice Smith".into()),
763 preferred_username: Some("alice".into()),
764 nickname: Some("al".into()),
765 picture: None,
766 avatar_url: None,
767 };
768 let info = map_user_info(claims).unwrap();
769 assert_eq!(info.name.as_deref(), Some("Alice Smith"));
770 }
771
772 #[test]
773 fn userinfo_claim_mapping_falls_back_to_preferred_username() {
774 let claims = UserInfoClaims {
775 sub: "u1".into(),
776 email: Some("u@example.com".into()),
777 email_verified: None,
778 name: None,
779 preferred_username: Some("alice".into()),
780 nickname: Some("al".into()),
781 picture: None,
782 avatar_url: None,
783 };
784 let info = map_user_info(claims).unwrap();
785 assert_eq!(info.name.as_deref(), Some("alice"));
786 }
787
788 #[test]
789 fn userinfo_claim_mapping_falls_back_to_nickname() {
790 let claims = UserInfoClaims {
791 sub: "u1".into(),
792 email: Some("u@example.com".into()),
793 email_verified: None,
794 name: None,
795 preferred_username: None,
796 nickname: Some("al".into()),
797 picture: None,
798 avatar_url: None,
799 };
800 let info = map_user_info(claims).unwrap();
801 assert_eq!(info.name.as_deref(), Some("al"));
802 }
803
804 #[test]
805 fn userinfo_claim_mapping_avatar_falls_back_to_avatar_url() {
806 let claims = UserInfoClaims {
807 sub: "u1".into(),
808 email: Some("u@example.com".into()),
809 email_verified: None,
810 name: None,
811 preferred_username: None,
812 nickname: None,
813 picture: None,
814 avatar_url: Some("https://cdn.example.com/u1.png".into()),
815 };
816 let info = map_user_info(claims).unwrap();
817 assert_eq!(
818 info.avatar_url.as_deref(),
819 Some("https://cdn.example.com/u1.png")
820 );
821 }
822
823 #[test]
824 fn userinfo_email_verified_defaults_to_false_when_absent() {
825 let claims = UserInfoClaims {
826 sub: "u1".into(),
827 email: Some("u@example.com".into()),
828 email_verified: None,
829 name: None,
830 preferred_username: None,
831 nickname: None,
832 picture: None,
833 avatar_url: None,
834 };
835 let info = map_user_info(claims).unwrap();
836 assert!(!info.email_verified);
837 }
838
839 #[test]
840 fn userinfo_missing_email_returns_error() {
841 let claims = UserInfoClaims {
842 sub: "u1".into(),
843 email: None,
844 email_verified: None,
845 name: None,
846 preferred_username: None,
847 nickname: None,
848 picture: None,
849 avatar_url: None,
850 };
851 let err = map_user_info(claims).unwrap_err();
852 assert!(matches!(err, AuthError::OAuthUserInfoFetch(ref m) if m.contains("missing email")));
853 }
854
855 #[tokio::test]
858 async fn new_fetches_discovery_doc_on_construct() {
859 let server = wiremock::MockServer::start().await;
860 mount_discovery_doc(&server).await;
861 let base = server.uri();
862 let discovery_url = format!("{base}/.well-known/openid-configuration");
863
864 let provider = CustomOidcSocialProvider::new(discovery_config(&discovery_url))
865 .await
866 .unwrap();
867 let doc = provider.current_endpoints().await.unwrap();
868 assert_eq!(doc.authorize_url, format!("{base}/authorize"));
869 assert_eq!(doc.token_url, format!("{base}/token"));
870 assert_eq!(doc.userinfo_url, format!("{base}/userinfo"));
871 }
872
873 #[tokio::test]
874 async fn new_returns_oauth_http_on_discovery_404() {
875 use wiremock::matchers::{method, path};
876 use wiremock::{Mock, ResponseTemplate};
877
878 let server = wiremock::MockServer::start().await;
879 Mock::given(method("GET"))
880 .and(path("/.well-known/openid-configuration"))
881 .respond_with(ResponseTemplate::new(404))
882 .mount(&server)
883 .await;
884
885 let url = format!("{}/.well-known/openid-configuration", server.uri());
886 let err = CustomOidcSocialProvider::new(discovery_config(&url))
887 .await
888 .unwrap_err();
889 assert!(matches!(err, AuthError::OAuthHttp(_)));
890 }
891
892 #[tokio::test]
893 async fn new_returns_oauth_http_on_discovery_malformed_json() {
894 use wiremock::matchers::{method, path};
895 use wiremock::{Mock, ResponseTemplate};
896
897 let server = wiremock::MockServer::start().await;
898 Mock::given(method("GET"))
899 .and(path("/.well-known/openid-configuration"))
900 .respond_with(ResponseTemplate::new(200).set_body_string("not json at all {{"))
901 .mount(&server)
902 .await;
903
904 let url = format!("{}/.well-known/openid-configuration", server.uri());
905 let err = CustomOidcSocialProvider::new(discovery_config(&url))
906 .await
907 .unwrap_err();
908 assert!(matches!(err, AuthError::OAuthHttp(_)));
909 }
910
911 #[tokio::test]
912 async fn new_returns_validation_on_discovery_doc_missing_endpoint() {
913 use wiremock::matchers::{method, path};
914 use wiremock::{Mock, ResponseTemplate};
915
916 let server = wiremock::MockServer::start().await;
917 Mock::given(method("GET"))
919 .and(path("/.well-known/openid-configuration"))
920 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
921 "authorization_endpoint": "https://idp.example.com/authorize",
922 "token_endpoint": "https://idp.example.com/token",
923 })))
924 .mount(&server)
925 .await;
926
927 let url = format!("{}/.well-known/openid-configuration", server.uri());
928 let err = CustomOidcSocialProvider::new(discovery_config(&url))
929 .await
930 .unwrap_err();
931 assert!(
932 matches!(&err, AuthError::Validation(m) if m.contains("userinfo_endpoint")),
933 "got: {err:?}"
934 );
935 }
936
937 #[tokio::test]
938 async fn discovery_cache_does_not_refetch_within_ttl() {
939 use wiremock::matchers::{method, path};
940 use wiremock::{Mock, ResponseTemplate};
941
942 let server = wiremock::MockServer::start().await;
943 let base = server.uri();
944 Mock::given(method("GET"))
945 .and(path("/.well-known/openid-configuration"))
946 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
947 "authorization_endpoint": format!("{base}/authorize"),
948 "token_endpoint": format!("{base}/token"),
949 "userinfo_endpoint": format!("{base}/userinfo"),
950 })))
951 .expect(1) .mount(&server)
953 .await;
954
955 let url = format!("{base}/.well-known/openid-configuration");
956 let provider = CustomOidcSocialProvider::new(discovery_config(&url))
957 .await
958 .unwrap();
959 let _ = provider.current_endpoints().await.unwrap();
961 let _ = provider.current_endpoints().await.unwrap();
962 }
964
965 #[tokio::test]
966 async fn discovery_cache_refreshes_after_ttl() {
967 use wiremock::matchers::{method, path};
968 use wiremock::{Mock, ResponseTemplate};
969
970 let server = wiremock::MockServer::start().await;
971 let base = server.uri();
972 Mock::given(method("GET"))
973 .and(path("/.well-known/openid-configuration"))
974 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
975 "authorization_endpoint": format!("{base}/authorize"),
976 "token_endpoint": format!("{base}/token"),
977 "userinfo_endpoint": format!("{base}/userinfo"),
978 })))
979 .expect(2) .mount(&server)
981 .await;
982
983 let url = format!("{base}/.well-known/openid-configuration");
984 let mut provider = CustomOidcSocialProvider::new(discovery_config(&url))
985 .await
986 .unwrap();
987 provider.set_discovery_ttl(Duration::from_millis(1));
988
989 tokio::time::sleep(Duration::from_millis(50)).await;
991
992 let _ = provider.current_endpoints().await.unwrap();
993 }
995
996 #[tokio::test]
997 async fn discovery_cache_keeps_stale_doc_on_refresh_failure() {
998 use wiremock::matchers::{method, path};
999 use wiremock::{Mock, ResponseTemplate};
1000
1001 let server = wiremock::MockServer::start().await;
1002 let base = server.uri();
1003
1004 Mock::given(method("GET"))
1007 .and(path("/.well-known/openid-configuration"))
1008 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1009 "authorization_endpoint": format!("{base}/authorize"),
1010 "token_endpoint": format!("{base}/token"),
1011 "userinfo_endpoint": format!("{base}/userinfo"),
1012 })))
1013 .up_to_n_times(1)
1014 .mount(&server)
1015 .await;
1016
1017 Mock::given(method("GET"))
1019 .and(path("/.well-known/openid-configuration"))
1020 .respond_with(ResponseTemplate::new(500))
1021 .mount(&server)
1022 .await;
1023
1024 let url = format!("{base}/.well-known/openid-configuration");
1025 let mut provider = CustomOidcSocialProvider::new(discovery_config(&url))
1026 .await
1027 .unwrap();
1028 provider.set_discovery_ttl(Duration::from_millis(1));
1029
1030 tokio::time::sleep(Duration::from_millis(50)).await;
1031
1032 let doc = provider.current_endpoints().await.unwrap();
1034 assert_eq!(doc.token_url, format!("{base}/token"));
1035 }
1036
1037 #[tokio::test]
1038 async fn pinned_endpoints_never_make_http_requests() {
1039 use wiremock::matchers::any;
1040 use wiremock::{Mock, ResponseTemplate};
1041
1042 let server = wiremock::MockServer::start().await;
1043 Mock::given(any())
1045 .respond_with(ResponseTemplate::new(500))
1046 .expect(0)
1047 .mount(&server)
1048 .await;
1049
1050 let provider = CustomOidcSocialProvider::new(pinned_config())
1052 .await
1053 .unwrap();
1054 let _ = provider.current_endpoints().await.unwrap();
1055 let _ = provider.current_endpoints().await.unwrap();
1056 }
1058
1059 #[tokio::test]
1060 async fn exchange_code_posts_correct_form_and_returns_access_token() {
1061 use wiremock::matchers::{method, path};
1062 use wiremock::{Mock, ResponseTemplate};
1063
1064 let server = wiremock::MockServer::start().await;
1065 Mock::given(method("POST"))
1066 .and(path("/token"))
1067 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1068 "access_token": "oidc-access-token",
1069 "token_type": "Bearer",
1070 })))
1071 .mount(&server)
1072 .await;
1073
1074 mount_discovery_doc(&server).await;
1075 let base = server.uri();
1076 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1077 "{base}/.well-known/openid-configuration"
1078 )))
1079 .await
1080 .unwrap();
1081
1082 let token = provider
1083 .exchange_code("mycode", "https://app.example.com/cb", "verifier")
1084 .await
1085 .unwrap();
1086 assert_eq!(token, "oidc-access-token");
1087 }
1088
1089 #[tokio::test]
1090 async fn exchange_code_returns_error_on_4xx() {
1091 use wiremock::matchers::{method, path};
1092 use wiremock::{Mock, ResponseTemplate};
1093
1094 let server = wiremock::MockServer::start().await;
1095 Mock::given(method("POST"))
1096 .and(path("/token"))
1097 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
1098 "error": "invalid_grant"
1099 })))
1100 .mount(&server)
1101 .await;
1102
1103 mount_discovery_doc(&server).await;
1104 let base = server.uri();
1105 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1106 "{base}/.well-known/openid-configuration"
1107 )))
1108 .await
1109 .unwrap();
1110
1111 let err = provider
1112 .exchange_code("badcode", "https://app.example.com/cb", "v")
1113 .await
1114 .unwrap_err();
1115 assert!(matches!(err, AuthError::OAuthTokenExchange(_)));
1116 }
1117
1118 #[tokio::test]
1119 async fn exchange_code_returns_error_on_missing_access_token() {
1120 use wiremock::matchers::{method, path};
1121 use wiremock::{Mock, ResponseTemplate};
1122
1123 let server = wiremock::MockServer::start().await;
1124 Mock::given(method("POST"))
1125 .and(path("/token"))
1126 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1127 "token_type": "Bearer"
1128 })))
1130 .mount(&server)
1131 .await;
1132
1133 mount_discovery_doc(&server).await;
1134 let base = server.uri();
1135 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1136 "{base}/.well-known/openid-configuration"
1137 )))
1138 .await
1139 .unwrap();
1140
1141 let err = provider
1142 .exchange_code("code", "https://app.example.com/cb", "v")
1143 .await
1144 .unwrap_err();
1145 assert!(
1146 matches!(&err, AuthError::OAuthTokenExchange(m) if m.contains("missing access_token")),
1147 "got: {err:?}"
1148 );
1149 }
1150
1151 #[tokio::test]
1152 async fn fetch_user_info_maps_standard_claims() {
1153 use wiremock::matchers::{method, path};
1154 use wiremock::{Mock, ResponseTemplate};
1155
1156 let server = wiremock::MockServer::start().await;
1157 Mock::given(method("GET"))
1158 .and(path("/userinfo"))
1159 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1160 "sub": "user-sub-123",
1161 "email": "alice@example.com",
1162 "email_verified": true,
1163 "name": "Alice Smith",
1164 "picture": "https://cdn.example.com/alice.jpg",
1165 })))
1166 .mount(&server)
1167 .await;
1168
1169 mount_discovery_doc(&server).await;
1170 let base = server.uri();
1171 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1172 "{base}/.well-known/openid-configuration"
1173 )))
1174 .await
1175 .unwrap();
1176
1177 let info = provider.fetch_user_info("my-access-token").await.unwrap();
1178 assert_eq!(info.provider_user_id, "user-sub-123");
1179 assert_eq!(info.email, "alice@example.com");
1180 assert!(info.email_verified);
1181 assert_eq!(info.name.as_deref(), Some("Alice Smith"));
1182 assert_eq!(
1183 info.avatar_url.as_deref(),
1184 Some("https://cdn.example.com/alice.jpg")
1185 );
1186 }
1187
1188 #[tokio::test]
1189 async fn fetch_user_info_maps_non_standard_name_and_avatar() {
1190 use wiremock::matchers::{method, path};
1191 use wiremock::{Mock, ResponseTemplate};
1192
1193 let server = wiremock::MockServer::start().await;
1194 Mock::given(method("GET"))
1195 .and(path("/userinfo"))
1196 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1197 "sub": "u2",
1198 "email": "bob@example.com",
1199 "preferred_username": "bob42",
1201 "avatar_url": "https://cdn.example.com/bob.png",
1203 })))
1204 .mount(&server)
1205 .await;
1206
1207 mount_discovery_doc(&server).await;
1208 let base = server.uri();
1209 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1210 "{base}/.well-known/openid-configuration"
1211 )))
1212 .await
1213 .unwrap();
1214
1215 let info = provider.fetch_user_info("token").await.unwrap();
1216 assert_eq!(info.name.as_deref(), Some("bob42"));
1217 assert_eq!(
1218 info.avatar_url.as_deref(),
1219 Some("https://cdn.example.com/bob.png")
1220 );
1221 }
1222
1223 #[tokio::test]
1224 async fn fetch_user_info_returns_error_on_missing_email() {
1225 use wiremock::matchers::{method, path};
1226 use wiremock::{Mock, ResponseTemplate};
1227
1228 let server = wiremock::MockServer::start().await;
1229 Mock::given(method("GET"))
1230 .and(path("/userinfo"))
1231 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1232 "sub": "u3",
1233 })))
1235 .mount(&server)
1236 .await;
1237
1238 mount_discovery_doc(&server).await;
1239 let base = server.uri();
1240 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1241 "{base}/.well-known/openid-configuration"
1242 )))
1243 .await
1244 .unwrap();
1245
1246 let err = provider.fetch_user_info("token").await.unwrap_err();
1247 assert!(matches!(err, AuthError::OAuthUserInfoFetch(ref m) if m.contains("missing email")));
1248 }
1249
1250 #[tokio::test]
1251 async fn fetch_user_info_returns_error_on_4xx() {
1252 use wiremock::matchers::{method, path};
1253 use wiremock::{Mock, ResponseTemplate};
1254
1255 let server = wiremock::MockServer::start().await;
1256 Mock::given(method("GET"))
1257 .and(path("/userinfo"))
1258 .respond_with(ResponseTemplate::new(401).set_body_string("Unauthorized"))
1259 .mount(&server)
1260 .await;
1261
1262 mount_discovery_doc(&server).await;
1263 let base = server.uri();
1264 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1265 "{base}/.well-known/openid-configuration"
1266 )))
1267 .await
1268 .unwrap();
1269
1270 let err = provider.fetch_user_info("bad-token").await.unwrap_err();
1271 assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
1272 }
1273
1274 #[tokio::test]
1277 async fn exchange_code_posts_oauth_form_fields_and_pkce_verifier() {
1278 use wiremock::matchers::{method, path};
1284 use wiremock::{Mock, ResponseTemplate};
1285
1286 let server = wiremock::MockServer::start().await;
1287 Mock::given(method("POST"))
1288 .and(path("/token"))
1289 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1290 "access_token": "tok",
1291 "token_type": "Bearer",
1292 })))
1293 .mount(&server)
1294 .await;
1295
1296 mount_discovery_doc(&server).await;
1297 let base = server.uri();
1298 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1299 "{base}/.well-known/openid-configuration"
1300 )))
1301 .await
1302 .unwrap();
1303
1304 provider
1305 .exchange_code("the-code", "https://app.example.com/cb", "the-verifier")
1306 .await
1307 .unwrap();
1308
1309 let reqs = server.received_requests().await.unwrap();
1310 let token_req = reqs
1311 .iter()
1312 .find(|r| r.url.path() == "/token")
1313 .expect("token POST must reach the IdP");
1314 let body = std::str::from_utf8(&token_req.body).expect("form body utf-8");
1315 for expected in &[
1316 "code=the-code",
1317 "client_id=test-client-id",
1318 "client_secret=test-client-secret",
1319 "redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb",
1320 "grant_type=authorization_code",
1321 "code_verifier=the-verifier",
1322 ] {
1323 assert!(
1324 body.contains(expected),
1325 "token POST form body missing `{expected}`: {body}"
1326 );
1327 }
1328 }
1329
1330 #[tokio::test]
1331 async fn fetch_user_info_sends_bearer_authorization_header() {
1332 use wiremock::matchers::{method, path};
1337 use wiremock::{Mock, ResponseTemplate};
1338
1339 let server = wiremock::MockServer::start().await;
1340 Mock::given(method("GET"))
1341 .and(path("/userinfo"))
1342 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1343 "sub": "u",
1344 "email": "u@example.com",
1345 })))
1346 .mount(&server)
1347 .await;
1348
1349 mount_discovery_doc(&server).await;
1350 let base = server.uri();
1351 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1352 "{base}/.well-known/openid-configuration"
1353 )))
1354 .await
1355 .unwrap();
1356
1357 provider.fetch_user_info("the-access-token").await.unwrap();
1358
1359 let reqs = server.received_requests().await.unwrap();
1360 let userinfo_req = reqs
1361 .iter()
1362 .find(|r| r.url.path() == "/userinfo")
1363 .expect("userinfo GET must reach the IdP");
1364 let auth = userinfo_req
1365 .headers
1366 .get("authorization")
1367 .expect("Authorization header must be present")
1368 .to_str()
1369 .unwrap();
1370 assert_eq!(auth, "Bearer the-access-token");
1371 }
1372
1373 #[tokio::test]
1374 async fn authorize_url_uses_discovered_authorize_endpoint() {
1375 let server = wiremock::MockServer::start().await;
1381 mount_discovery_doc(&server).await;
1382 let base = server.uri();
1383 let provider = CustomOidcSocialProvider::new(discovery_config(&format!(
1384 "{base}/.well-known/openid-configuration"
1385 )))
1386 .await
1387 .unwrap();
1388
1389 let url =
1390 provider.authorize_url("https://app.example.com/cb", "state-xyz", "challenge-abc");
1391 let expected_prefix = format!("{base}/authorize?");
1392 assert!(
1393 url.starts_with(&expected_prefix),
1394 "authorize_url must use discovered endpoint; got {url}"
1395 );
1396 assert!(url.contains("client_id=test-client-id"));
1397 assert!(url.contains("state=state-xyz"));
1398 assert!(url.contains("code_challenge=challenge-abc"));
1399 assert!(url.contains("code_challenge_method=S256"));
1400 }
1401}