1use std::{borrow::Cow, cmp::min, sync::Arc, time::Duration};
2
3use base64::Engine;
4use chrono::Utc;
5use openidconnect::{
6 AccessToken, AuthType, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret,
7 CsrfToken, DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType,
8 EndpointMaybeSet, EndpointNotSet, EndpointSet, IntrospectionUrl, Nonce, OAuth2TokenResponse,
9 PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, Scope,
10 StandardErrorResponse, StandardTokenResponse, SubjectIdentifier, TokenResponse,
11 core::{
12 CoreAuthDisplay, CoreAuthPrompt, CoreClientAuthMethod, CoreDeviceAuthorizationResponse,
13 CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
14 CoreJwsSigningAlgorithm, CoreRevocableToken, CoreRevocationErrorResponse,
15 CoreTokenIntrospectionResponse, CoreTokenType,
16 },
17 reqwest,
18};
19use securitydept_oauth_provider::{OAuthProviderRuntime, ProviderMetadataWithExtra};
20use securitydept_utils::observability::{
21 AuthFlowDiagnosis, AuthFlowDiagnosisField, AuthFlowDiagnosisOutcome, AuthFlowOperation,
22 DiagnosedResult,
23};
24use url::Url;
25
26#[cfg(not(feature = "claims-script"))]
27use crate::claims::DefaultClaimsChecker;
28#[cfg(feature = "claims-script")]
29use crate::claims::ScriptClaimsChecker;
30use crate::{
31 ClaimsCheckResult, ExtraOidcClaims, IdTokenClaimsWithExtra, OidcCodeCallbackSearchParams,
32 OidcCodeExchangeResult, OidcCodeFlowAuthorizationRequest, OidcDeviceAuthorizationResult,
33 OidcDeviceTokenPollResult, OidcDeviceTokenResult, OidcRevocableToken, PendingOauthStore,
34 PendingOauthStoreConfig, UserInfoClaimsWithExtra, UserInfoExchangeResult,
35 claims::ClaimsChecker,
36 config::OidcClientConfig,
37 error::{OidcError, OidcResult},
38 models::{IdTokenFieldsWithExtra, OidcCodeCallbackResult, OidcRefreshTokenResult},
39};
40
41pub type TokenResponseWithExtra = StandardTokenResponse<IdTokenFieldsWithExtra, CoreTokenType>;
42
43pub type ClientWithExtra<
44 HasAuthUrl = EndpointNotSet,
45 HasDeviceAuthUrl = EndpointNotSet,
46 HasIntrospectionUrl = EndpointNotSet,
47 HasRevocationUrl = EndpointNotSet,
48 HasTokenUrl = EndpointNotSet,
49 HasUserInfoUrl = EndpointNotSet,
50> = Client<
51 ExtraOidcClaims,
52 CoreAuthDisplay,
53 CoreGenderClaim,
54 CoreJweContentEncryptionAlgorithm,
55 CoreJsonWebKey,
56 CoreAuthPrompt,
57 StandardErrorResponse<CoreErrorResponseType>,
58 TokenResponseWithExtra,
59 CoreTokenIntrospectionResponse,
60 CoreRevocableToken,
61 CoreRevocationErrorResponse,
62 HasAuthUrl,
63 HasDeviceAuthUrl,
64 HasIntrospectionUrl,
65 HasRevocationUrl,
66 HasTokenUrl,
67 HasUserInfoUrl,
68>;
69
70pub type DiscoveredClientWithExtra = ClientWithExtra<
71 EndpointSet,
72 EndpointNotSet,
73 EndpointNotSet,
74 EndpointNotSet,
75 EndpointMaybeSet,
76 EndpointMaybeSet,
77>;
78
79type DeviceAuthorizationClientWithExtra = ClientWithExtra<
80 EndpointSet,
81 EndpointSet,
82 EndpointNotSet,
83 EndpointNotSet,
84 EndpointMaybeSet,
85 EndpointMaybeSet,
86>;
87
88type RevocationClientWithExtra = ClientWithExtra<
89 EndpointSet,
90 EndpointNotSet,
91 EndpointNotSet,
92 EndpointSet,
93 EndpointMaybeSet,
94 EndpointMaybeSet,
95>;
96
97struct OptionalClientEndpoints {
98 _introspection_endpoint: Option<IntrospectionUrl>,
99 revocation_endpoint: Option<RevocationUrl>,
100 device_authorization_endpoint: Option<DeviceAuthorizationUrl>,
101}
102
103struct BuiltClientWithExtra {
104 client: DiscoveredClientWithExtra,
105 optional_endpoints: OptionalClientEndpoints,
106}
107
108pub struct OidcClient<PS>
114where
115 PS: PendingOauthStore,
116{
117 config: OidcClientConfig<PS::Config>,
118 provider: Arc<OAuthProviderRuntime>,
119 base_client: DiscoveredClientWithExtra,
120 #[cfg(feature = "claims-script")]
121 claims_checker: ScriptClaimsChecker,
122 #[cfg(not(feature = "claims-script"))]
123 claims_checker: DefaultClaimsChecker,
124 scopes: Vec<String>,
125 pkce_enabled: bool,
126 pending_oauth_store: PS,
127}
128
129impl<PS> OidcClient<PS>
130where
131 PS: PendingOauthStore,
132{
133 pub async fn from_config(config: OidcClientConfig<PS::Config>) -> OidcResult<Self> {
134 config.validate()?;
135 let provider = Arc::new(OAuthProviderRuntime::from_config(config.provider_config()).await?);
136 Self::from_provider(provider, config).await
137 }
138
139 pub async fn from_provider(
140 provider: Arc<OAuthProviderRuntime>,
141 config: OidcClientConfig<PS::Config>,
142 ) -> OidcResult<Self> {
143 config.validate()?;
144
145 let built_client = build_client(&config, provider.oidc_provider_metadata().await?)
146 .map_err(|e| OidcError::Metadata {
147 message: format!("Failed to build OIDC client from provider metadata: {e}"),
148 })?;
149
150 #[cfg(feature = "claims-script")]
151 let claims_checker =
152 ScriptClaimsChecker::from_file(config.claims_check_script.as_deref()).await?;
153 #[cfg(not(feature = "claims-script"))]
154 let claims_checker = DefaultClaimsChecker;
155
156 Ok(Self {
157 pending_oauth_store: PS::from_config_opt(config.pending_store.as_ref()),
158 config,
159 provider,
160 base_client: built_client.client,
161 claims_checker,
162 scopes: vec![],
163 pkce_enabled: false,
164 }
165 .with_runtime_flags())
166 }
167
168 pub fn provider(&self) -> &Arc<OAuthProviderRuntime> {
169 &self.provider
170 }
171
172 pub async fn handle_code_authorize(
173 &self,
174 external_base_url: &Url,
175 ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
176 self.handle_code_authorize_with_redirect_override(external_base_url, None)
177 .await
178 }
179
180 pub async fn handle_code_authorize_with_redirect_override(
181 &self,
182 external_base_url: &Url,
183 redirect_url_override: Option<&str>,
184 ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
185 self.handle_code_authorize_with_redirect_override_and_extra_data(
186 external_base_url,
187 redirect_url_override,
188 None,
189 )
190 .await
191 }
192
193 pub async fn handle_code_authorize_with_redirect_override_and_extra_data(
194 &self,
195 external_base_url: &Url,
196 redirect_url_override: Option<&str>,
197 extra_data: Option<serde_json::Value>,
198 ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
199 let authorization_request =
200 self.authorize_url_with_redirect_override(external_base_url, redirect_url_override)?;
201 self.pending_oauth_store
202 .insert(
203 authorization_request.csrf_token.secret().to_string(),
204 authorization_request.nonce.secret().to_string(),
205 authorization_request.pkce_verifier_secret.clone(),
206 extra_data,
207 )
208 .await?;
209 Ok(authorization_request)
210 }
211
212 pub async fn handle_device_authorize(&self) -> OidcResult<OidcDeviceAuthorizationResult> {
213 let client = self.fresh_device_authorization_client().await?;
214 let mut request = client.exchange_device_code();
215
216 for scope in &self.scopes {
217 request = request.add_scope(Scope::new(scope.clone()));
218 }
219
220 let details: CoreDeviceAuthorizationResponse = request
221 .request_async(self.provider.http_client())
222 .await
223 .map_err(|e| OidcError::DeviceAuthorization {
224 message: format!("Device authorization request failed: {e}"),
225 })?;
226
227 Ok(OidcDeviceAuthorizationResult {
228 device_code: details.device_code().secret().to_string(),
229 user_code: details.user_code().secret().to_string(),
230 verification_uri: details.verification_uri().to_string(),
231 verification_uri_complete: details
232 .verification_uri_complete()
233 .map(|value| value.secret().to_string()),
234 expires_in: details.expires_in(),
235 interval: Some(details.interval()),
236 })
237 }
238
239 pub async fn handle_device_token_poll(
240 &self,
241 device_authorization: &OidcDeviceAuthorizationResult,
242 current_interval: Option<Duration>,
243 ) -> OidcResult<OidcDeviceTokenPollResult> {
244 let current_interval = current_interval.unwrap_or_else(|| {
245 device_authorization.poll_interval(self.config.device_poll_interval)
246 });
247
248 match self.request_device_token_once(device_authorization).await? {
249 DeviceTokenPollResponse::Complete(token_response) => {
250 let token_result = self.build_device_token_result(*token_response).await?;
251 Ok(OidcDeviceTokenPollResult::Complete {
252 token_result: Box::new(token_result),
253 })
254 }
255 DeviceTokenPollResponse::Pending => Ok(OidcDeviceTokenPollResult::Pending {
256 interval: current_interval,
257 }),
258 DeviceTokenPollResponse::SlowDown => Ok(OidcDeviceTokenPollResult::SlowDown {
259 interval: current_interval.saturating_add(Duration::from_secs(5)),
260 }),
261 DeviceTokenPollResponse::Denied { error_description } => {
262 Ok(OidcDeviceTokenPollResult::Denied { error_description })
263 }
264 DeviceTokenPollResponse::Expired { error_description } => {
265 Ok(OidcDeviceTokenPollResult::Expired { error_description })
266 }
267 }
268 }
269
270 pub async fn handle_device_token_poll_until_complete(
271 &self,
272 device_authorization: &OidcDeviceAuthorizationResult,
273 timeout: Option<Duration>,
274 ) -> OidcResult<OidcDeviceTokenResult> {
275 let started_at = std::time::Instant::now();
276 let mut interval = device_authorization.poll_interval(self.config.device_poll_interval);
277
278 const MIN_POLL_INTERVAL: Duration = Duration::from_secs(1);
281
282 loop {
283 if let Some(timeout) = timeout {
284 let elapsed = started_at.elapsed();
285 if elapsed >= timeout {
286 return Err(OidcError::DeviceTokenPoll {
287 message: format!(
288 "Device token polling timed out after {} seconds",
289 timeout.as_secs()
290 ),
291 });
292 }
293 }
294
295 match self
296 .handle_device_token_poll(device_authorization, Some(interval))
297 .await?
298 {
299 OidcDeviceTokenPollResult::Complete { token_result } => return Ok(*token_result),
300 OidcDeviceTokenPollResult::Pending {
301 interval: next_interval,
302 }
303 | OidcDeviceTokenPollResult::SlowDown {
304 interval: next_interval,
305 } => {
306 interval = next_interval.max(MIN_POLL_INTERVAL);
307 let sleep_duration = if let Some(timeout) = timeout {
308 let remaining = timeout.saturating_sub(started_at.elapsed());
309 min(interval, remaining)
310 } else {
311 interval
312 };
313 tokio::time::sleep(sleep_duration).await;
314 }
315 OidcDeviceTokenPollResult::Denied { error_description } => {
316 return Err(OidcError::DeviceTokenPoll {
317 message: format_device_token_terminal_message(
318 "access_denied",
319 error_description.as_deref(),
320 ),
321 });
322 }
323 OidcDeviceTokenPollResult::Expired { error_description } => {
324 return Err(OidcError::DeviceTokenPoll {
325 message: format_device_token_terminal_message(
326 "expired_token",
327 error_description.as_deref(),
328 ),
329 });
330 }
331 }
332 }
333 }
334
335 pub async fn handle_code_callback(
336 &self,
337 search_params: OidcCodeCallbackSearchParams,
338 external_base_url: &Url,
339 ) -> OidcResult<OidcCodeCallbackResult> {
340 self.handle_code_callback_with_redirect_override_diagnosed(
341 search_params,
342 external_base_url,
343 None,
344 )
345 .await
346 .into_result()
347 }
348
349 pub async fn handle_code_callback_with_redirect_override(
350 &self,
351 search_params: OidcCodeCallbackSearchParams,
352 external_base_url: &Url,
353 redirect_url_override: Option<&str>,
354 ) -> OidcResult<OidcCodeCallbackResult> {
355 self.handle_code_callback_with_redirect_override_diagnosed(
356 search_params,
357 external_base_url,
358 redirect_url_override,
359 )
360 .await
361 .into_result()
362 }
363
364 pub async fn handle_code_callback_with_redirect_override_diagnosed(
365 &self,
366 search_params: OidcCodeCallbackSearchParams,
367 external_base_url: &Url,
368 redirect_url_override: Option<&str>,
369 ) -> DiagnosedResult<OidcCodeCallbackResult, OidcError> {
370 let diagnosis = AuthFlowDiagnosis::started(AuthFlowOperation::OIDC_CALLBACK)
371 .field("redirect_override", redirect_url_override)
372 .field(
373 AuthFlowDiagnosisField::EXTERNAL_BASE_URL,
374 external_base_url.as_str(),
375 )
376 .field("pkce_enabled", self.pkce_enabled)
377 .field(
378 AuthFlowDiagnosisField::HAS_STATE,
379 search_params.state.is_some(),
380 )
381 .field(
382 AuthFlowDiagnosisField::HAS_CODE,
383 !search_params.code.is_empty(),
384 );
385
386 let code = &search_params.code;
387 let state = search_params
388 .state
389 .as_ref()
390 .ok_or_else(|| OidcError::CSRFValidation {
391 message: "Missing state parameter in callback (required for CSRF validation)"
392 .to_string(),
393 });
394
395 let state = match state {
396 Ok(state) => state,
397 Err(error) => {
398 return DiagnosedResult::failure(
399 diagnosis
400 .with_outcome(AuthFlowDiagnosisOutcome::Rejected)
401 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "csrf_validation"),
402 error,
403 );
404 }
405 };
406
407 let pending = match self.pending_oauth_store.take(state).await {
408 Ok(pending) => pending.ok_or_else(|| OidcError::PendingOauth {
409 source: "Invalid or expired state (reuse or unknown); try logging in again"
410 .to_string()
411 .into(),
412 }),
413 Err(error) => {
414 return DiagnosedResult::failure(
415 diagnosis
416 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
417 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "pending_oauth_store"),
418 error,
419 );
420 }
421 };
422
423 let pending = match pending {
424 Ok(pending) => pending,
425 Err(error) => {
426 return DiagnosedResult::failure(
427 diagnosis
428 .with_outcome(AuthFlowDiagnosisOutcome::Rejected)
429 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "pending_oauth_state"),
430 error,
431 );
432 }
433 };
434
435 let nonce = openidconnect::Nonce::new(pending.nonce.clone());
436 let code_verifier = pending.code_verifier;
437
438 let code_exchange = self
439 .exchange_code_with_redirect_override(
440 external_base_url,
441 code,
442 &nonce,
443 code_verifier.as_deref(),
444 redirect_url_override,
445 )
446 .await;
447
448 let code_exchange = match code_exchange {
449 Ok(code_exchange) => code_exchange,
450 Err(error) => {
451 return DiagnosedResult::failure(
452 diagnosis
453 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
454 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "token_exchange"),
455 error,
456 );
457 }
458 };
459
460 let claims_check_result = self
461 .check_claims(
462 &code_exchange.id_token_claims,
463 code_exchange.user_info_claims.as_ref(),
464 )
465 .await;
466
467 let claims_check_result = match claims_check_result {
468 Ok(claims_check_result) => claims_check_result,
469 Err(error) => {
470 return DiagnosedResult::failure(
471 diagnosis
472 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
473 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "claims_check"),
474 error,
475 );
476 }
477 };
478
479 let result = OidcCodeCallbackResult {
480 code: search_params.code,
481 pkce_verifier_secret: code_verifier,
482 state: search_params.state,
483 nonce: pending.nonce,
484 pending_extra_data: pending.extra_data,
485 access_token: code_exchange.access_token,
486 access_token_expiration: code_exchange.access_token_expiration,
487 id_token: code_exchange.id_token,
488 refresh_token: code_exchange.refresh_token,
489 id_token_claims: code_exchange.id_token_claims,
490 user_info_claims: code_exchange.user_info_claims,
491 claims_check_result,
492 };
493
494 DiagnosedResult::success(
495 diagnosis
496 .with_outcome(AuthFlowDiagnosisOutcome::Succeeded)
497 .field(
498 AuthFlowDiagnosisField::SUBJECT,
499 result.id_token_claims.subject().to_string(),
500 )
501 .field("has_refresh_token", result.refresh_token.is_some())
502 .field("has_user_info_claims", result.user_info_claims.is_some()),
503 result,
504 )
505 }
506
507 pub async fn handle_token_refresh(
508 &self,
509 refresh_token: String,
510 id_token: Option<String>,
512 ) -> OidcResult<OidcRefreshTokenResult> {
513 self.handle_token_refresh_diagnosed(refresh_token, id_token)
514 .await
515 .into_result()
516 }
517
518 pub async fn handle_token_refresh_diagnosed(
519 &self,
520 refresh_token: String,
521 id_token: Option<String>,
522 ) -> DiagnosedResult<OidcRefreshTokenResult, OidcError> {
523 let diagnosis = AuthFlowDiagnosis::started(AuthFlowOperation::OIDC_TOKEN_REFRESH)
524 .field("has_previous_id_token", id_token.is_some())
525 .field("pkce_enabled", self.pkce_enabled);
526
527 let client = match self.fresh_client().await {
528 Ok(client) => client,
529 Err(error) => {
530 return DiagnosedResult::failure(
531 diagnosis
532 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
533 .field(
534 AuthFlowDiagnosisField::FAILURE_STAGE,
535 "client_metadata_refresh",
536 ),
537 error,
538 );
539 }
540 };
541 let refresh_token = RefreshToken::new(refresh_token);
542 let now = Utc::now();
543
544 let token_request =
545 client
546 .exchange_refresh_token(&refresh_token)
547 .map_err(|e| OidcError::TokenRefresh {
548 message: format!("Token endpoint not set or config error: {e}"),
549 });
550
551 let token_request = match token_request {
552 Ok(token_request) => token_request,
553 Err(error) => {
554 return DiagnosedResult::failure(
555 diagnosis
556 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
557 .field(
558 AuthFlowDiagnosisField::FAILURE_STAGE,
559 "token_refresh_request_build",
560 ),
561 error,
562 );
563 }
564 };
565
566 let token_response = token_request
567 .request_async(self.provider.http_client())
568 .await
569 .map_err(|e| OidcError::TokenRefresh {
570 message: format!("Refresh token request failed: {e}"),
571 });
572
573 let token_response = match token_response {
574 Ok(token_response) => token_response,
575 Err(error) => {
576 return DiagnosedResult::failure(
577 diagnosis
578 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
579 .field(AuthFlowDiagnosisField::FAILURE_STAGE, "token_refresh"),
580 error,
581 );
582 }
583 };
584
585 let access_token = token_response.access_token().secret().clone();
586 let access_token_expiration = token_response
587 .expires_in()
588 .map(|expires_in| now + expires_in);
589 let refresh_token = token_response
590 .refresh_token()
591 .map(|value| value.secret().clone());
592 let id_token = token_response
593 .id_token()
594 .map(|value| value.to_string())
595 .or(id_token);
596
597 let mut result = OidcRefreshTokenResult {
598 access_token,
599 access_token_expiration,
600 refresh_token,
601 id_token,
602 user_info_claims: None,
603 claims_check_result: None,
604 id_token_claims: None,
605 };
606
607 if let Err(error) = self.check_required_scopes(token_response.scopes()) {
609 return DiagnosedResult::failure(
610 diagnosis
611 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
612 .field("failure_stage", "scope_validation"),
613 error,
614 );
615 }
616
617 if let Some(next_id_token) = token_response.extra_fields().id_token() {
618 let id_token_verifier = client.id_token_verifier();
619 let id_token_claims = next_id_token
620 .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
621 .map_err(|e| OidcError::TokenRefresh {
622 message: format!("Failed to verify refreshed ID token: {e}"),
623 });
624 let id_token_claims = match id_token_claims {
625 Ok(id_token_claims) => id_token_claims,
626 Err(error) => {
627 return DiagnosedResult::failure(
628 diagnosis
629 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
630 .field("failure_stage", "id_token_verification"),
631 error,
632 );
633 }
634 };
635 let user_info_claims = if client.user_info_url().is_some() {
636 match self
637 .request_userinfo(
638 &client,
639 self.provider.http_client(),
640 token_response.access_token().clone(),
641 Some(id_token_claims.subject().clone()),
642 )
643 .await
644 {
645 Ok(user_info_claims) => Some(user_info_claims),
646 Err(error) => {
647 return DiagnosedResult::failure(
648 diagnosis
649 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
650 .field("failure_stage", "userinfo_exchange"),
651 error,
652 );
653 }
654 }
655 } else {
656 None
657 };
658 let claims_check_result = self
659 .check_claims(id_token_claims, user_info_claims.as_ref())
660 .await;
661 let claims_check_result = match claims_check_result {
662 Ok(claims_check_result) => claims_check_result,
663 Err(error) => {
664 return DiagnosedResult::failure(
665 diagnosis
666 .with_outcome(AuthFlowDiagnosisOutcome::Failed)
667 .field("failure_stage", "claims_check"),
668 error,
669 );
670 }
671 };
672 result.id_token = Some(next_id_token.to_string());
673 result.id_token_claims = Some(id_token_claims.clone());
674 result.user_info_claims = user_info_claims;
675 result.claims_check_result = Some(claims_check_result);
676 }
677
678 DiagnosedResult::success(
679 diagnosis
680 .with_outcome(AuthFlowDiagnosisOutcome::Succeeded)
681 .field("has_refresh_token", result.refresh_token.is_some())
682 .field("has_new_id_token", result.id_token.is_some())
683 .field(
684 "has_claims_check_result",
685 result.claims_check_result.is_some(),
686 ),
687 result,
688 )
689 }
690
691 pub async fn handle_token_revoke(&self, token: OidcRevocableToken) -> OidcResult<()> {
692 let client = self.fresh_revocation_client().await?;
693 let token: CoreRevocableToken = match token {
694 OidcRevocableToken::AccessToken(token) => AccessToken::new(token).into(),
695 OidcRevocableToken::RefreshToken(token) => RefreshToken::new(token).into(),
696 };
697
698 client
699 .revoke_token(token)
700 .map_err(|e| OidcError::TokenRevocation {
701 message: format!("Revocation endpoint not set or config error: {e}"),
702 })?
703 .request_async(self.provider.http_client())
704 .await
705 .map_err(|e| OidcError::TokenRevocation {
706 message: format!("Token revocation request failed: {e}"),
707 })
708 }
709
710 pub async fn handle_user_info_exchange(
722 &self,
723 id_token_raw: &str,
724 access_token: &str,
725 ) -> OidcResult<UserInfoExchangeResult> {
726 let client = self.fresh_client().await?;
727 let id_token_verifier = client.id_token_verifier();
728
729 let id_token: openidconnect::IdToken<
731 ExtraOidcClaims,
732 CoreGenderClaim,
733 CoreJweContentEncryptionAlgorithm,
734 CoreJwsSigningAlgorithm,
735 > = serde_json::from_value(serde_json::Value::String(id_token_raw.to_string())).map_err(
736 |e| OidcError::Claims {
737 message: format!("Failed to parse ID token string in user_info exchange: {e}"),
738 },
739 )?;
740
741 let id_token_claims = id_token
743 .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
744 .map_err(|e| OidcError::Claims {
745 message: format!("Failed to verify ID token in user_info exchange: {e}"),
746 })?;
747
748 let access_token_obj = AccessToken::new(access_token.to_string());
749
750 let user_info_claims = if client.user_info_url().is_some() {
751 Some(
752 self.request_userinfo(
753 &client,
754 self.provider.http_client(),
755 access_token_obj,
756 Some(id_token_claims.subject().clone()),
757 )
758 .await?,
759 )
760 } else {
761 None
762 };
763
764 let claims_check_result = self
765 .check_claims(id_token_claims, user_info_claims.as_ref())
766 .await?;
767
768 let issuer = id_token_claims.issuer().url().to_string();
769
770 Ok(UserInfoExchangeResult {
771 subject: id_token_claims.subject().to_string(),
772 display_name: claims_check_result.display_name,
773 picture: claims_check_result.picture,
774 issuer: Some(issuer),
775 claims: Some(claims_check_result.claims),
776 })
777 }
778
779 async fn request_userinfo(
780 &self,
781 client: &DiscoveredClientWithExtra,
782 http_client: &reqwest::Client,
783 access_token: openidconnect::AccessToken,
784 expected_subject: Option<SubjectIdentifier>,
785 ) -> OidcResult<UserInfoClaimsWithExtra> {
786 client
787 .user_info(access_token, expected_subject)
788 .map_err(|e| OidcError::Claims {
789 message: format!("UserInfo request configuration failed: {e}"),
790 })?
791 .request_async(http_client)
792 .await
793 .map_err(|e| OidcError::Claims {
794 message: format!("UserInfo request failed: {e}"),
795 })
796 }
797
798 async fn check_claims(
799 &self,
800 id_token_claims: &IdTokenClaimsWithExtra,
801 user_info_claims: Option<&UserInfoClaimsWithExtra>,
802 ) -> OidcResult<ClaimsCheckResult> {
803 self.claims_checker
804 .check_claims(id_token_claims, user_info_claims)
805 .await
806 }
807
808 fn resolve_redirect_url(
809 &self,
810 external_base_url: &Url,
811 redirect_url_override: Option<&str>,
812 ) -> OidcResult<Url> {
813 external_base_url
814 .join(redirect_url_override.unwrap_or(&self.config.redirect_url))
815 .map_err(|e| OidcError::RedirectUrl { source: e })
816 }
817
818 fn client_with_redirect_override(
819 &self,
820 external_base_url: &Url,
821 redirect_url_override: Option<&str>,
822 ) -> OidcResult<DiscoveredClientWithExtra> {
823 let redirect_url = self.resolve_redirect_url(external_base_url, redirect_url_override)?;
824 Ok(self
825 .base_client
826 .clone()
827 .set_redirect_uri(RedirectUrl::from_url(redirect_url)))
828 }
829
830 async fn fresh_client(&self) -> OidcResult<DiscoveredClientWithExtra> {
831 Ok(self.fresh_client_parts().await?.client)
832 }
833
834 async fn fresh_client_parts(&self) -> OidcResult<BuiltClientWithExtra> {
835 build_client(&self.config, self.provider.oidc_provider_metadata().await?).map_err(|e| {
836 OidcError::Metadata {
837 message: format!("Failed to rebuild OIDC client from provider metadata: {e}"),
838 }
839 })
840 }
841
842 async fn fresh_device_authorization_client(
843 &self,
844 ) -> OidcResult<DeviceAuthorizationClientWithExtra> {
845 let built_client = self.fresh_client_parts().await?;
846 let device_authorization_endpoint = built_client
847 .optional_endpoints
848 .device_authorization_endpoint
849 .ok_or_else(|| OidcError::DeviceAuthorization {
850 message: "Device authorization endpoint not set or config error: device \
851 authorization endpoint URL is not set"
852 .to_string(),
853 })?;
854
855 Ok(built_client
856 .client
857 .set_device_authorization_url(device_authorization_endpoint))
858 }
859
860 async fn fresh_revocation_client(&self) -> OidcResult<RevocationClientWithExtra> {
861 let built_client = self.fresh_client_parts().await?;
862 let revocation_endpoint = built_client
863 .optional_endpoints
864 .revocation_endpoint
865 .ok_or_else(|| OidcError::TokenRevocation {
866 message: "Revocation endpoint not set or config error: revocation endpoint URL is \
867 not set"
868 .to_string(),
869 })?;
870
871 Ok(built_client.client.set_revocation_url(revocation_endpoint))
872 }
873
874 async fn fresh_client_with_redirect_override(
875 &self,
876 external_base_url: &Url,
877 redirect_url_override: Option<&str>,
878 ) -> OidcResult<DiscoveredClientWithExtra> {
879 let redirect_url = self.resolve_redirect_url(external_base_url, redirect_url_override)?;
880 Ok(self
881 .fresh_client()
882 .await?
883 .set_redirect_uri(RedirectUrl::from_url(redirect_url)))
884 }
885
886 pub fn authorize_url(
887 &self,
888 external_base_url: &Url,
889 ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
890 self.authorize_url_with_redirect_override(external_base_url, None)
891 }
892
893 pub fn authorize_url_with_redirect_override(
894 &self,
895 external_base_url: &Url,
896 redirect_url_override: Option<&str>,
897 ) -> OidcResult<OidcCodeFlowAuthorizationRequest> {
898 let client =
899 self.client_with_redirect_override(external_base_url, redirect_url_override)?;
900
901 let mut req = client.authorize_url(
902 AuthenticationFlow::<openidconnect::core::CoreResponseType>::AuthorizationCode,
903 CsrfToken::new_random,
904 Nonce::new_random,
905 );
906
907 let pkce_verifier_secret = if self.pkce_enabled {
908 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
909 req = req.set_pkce_challenge(pkce_challenge);
910 Some(pkce_verifier.into_secret())
911 } else {
912 None
913 };
914
915 for scope in &self.scopes {
916 req = req.add_scope(Scope::new(scope.clone()));
917 }
918
919 let (authorization_url, csrf_token, nonce) = req.url();
920 Ok(OidcCodeFlowAuthorizationRequest {
921 authorization_url,
922 csrf_token,
923 nonce,
924 pkce_verifier_secret,
925 })
926 }
927
928 pub async fn exchange_code(
929 &self,
930 external_base_url: &Url,
931 code: &str,
932 nonce: &Nonce,
933 pkce_verifier_secret: Option<&str>,
934 ) -> OidcResult<OidcCodeExchangeResult> {
935 self.exchange_code_with_redirect_override(
936 external_base_url,
937 code,
938 nonce,
939 pkce_verifier_secret,
940 None,
941 )
942 .await
943 }
944
945 pub async fn exchange_code_with_redirect_override(
946 &self,
947 external_base_url: &Url,
948 code: &str,
949 nonce: &Nonce,
950 pkce_verifier_secret: Option<&str>,
951 redirect_url_override: Option<&str>,
952 ) -> OidcResult<OidcCodeExchangeResult> {
953 let client = self
954 .fresh_client_with_redirect_override(external_base_url, redirect_url_override)
955 .await?;
956
957 let mut token_request = client
958 .exchange_code(AuthorizationCode::new(code.to_string()))
959 .map_err(|e| OidcError::TokenExchange {
960 message: format!("Token endpoint not set or config error: {e}"),
961 })?;
962
963 if let Some(secret) = pkce_verifier_secret {
964 token_request =
965 token_request.set_pkce_verifier(PkceCodeVerifier::new(secret.to_string()));
966 }
967
968 let token_response = token_request
969 .request_async(self.provider.http_client())
970 .await
971 .map_err(|e| OidcError::TokenExchange {
972 message: format!("Token exchange request failed: {e}"),
973 })?;
974
975 let id_token_verifier = client.id_token_verifier();
976 let id_token =
977 token_response
978 .extra_fields()
979 .id_token()
980 .ok_or_else(|| OidcError::TokenExchange {
981 message: "Missing ID token in token response".to_string(),
982 })?;
983
984 let id_token_claims =
985 id_token
986 .claims(&id_token_verifier, nonce)
987 .map_err(|e| OidcError::TokenExchange {
988 message: format!("Failed to verify ID token: {e}"),
989 })?;
990
991 let now = Utc::now();
992 let id_token = id_token.to_string();
993 let access_token = token_response.access_token().secret().clone();
994 let access_token_expiration = token_response
995 .expires_in()
996 .map(|expires_in| now + expires_in);
997 let refresh_token = token_response
998 .refresh_token()
999 .map(|value| value.secret().clone());
1000
1001 let user_info_claims = if client.user_info_url().is_some() {
1002 Some(
1003 self.request_userinfo(
1004 &client,
1005 self.provider.http_client(),
1006 token_response.access_token().clone(),
1007 Some(id_token_claims.subject().clone()),
1008 )
1009 .await?,
1010 )
1011 } else {
1012 None
1013 };
1014
1015 self.check_required_scopes(token_response.scopes())?;
1017
1018 Ok(OidcCodeExchangeResult {
1019 id_token,
1020 id_token_claims: id_token_claims.to_owned(),
1021 refresh_token,
1022 access_token,
1023 access_token_expiration,
1024 user_info_claims,
1025 })
1026 }
1027
1028 fn check_required_scopes(
1036 &self,
1037 response_scopes: Option<&Vec<openidconnect::Scope>>,
1038 ) -> OidcResult<()> {
1039 if self.config.required_scopes.is_empty() {
1040 return Ok(());
1041 }
1042 let granted = match response_scopes {
1043 Some(scopes) => scopes,
1044 None => return Ok(()),
1046 };
1047 let granted_strs: Vec<&str> = granted.iter().map(|s| s.as_str()).collect();
1048 let missing: Vec<String> = self
1049 .config
1050 .required_scopes
1051 .iter()
1052 .filter(|req| !granted_strs.contains(&req.as_str()))
1053 .cloned()
1054 .collect();
1055 if missing.is_empty() {
1056 Ok(())
1057 } else {
1058 Err(OidcError::ScopeValidation { missing })
1059 }
1060 }
1061
1062 fn with_runtime_flags(mut self) -> Self {
1063 self.scopes = self.config.scopes.clone();
1064 self.pkce_enabled = self.config.pkce_enabled;
1065 self
1066 }
1067
1068 async fn request_device_token_once(
1069 &self,
1070 device_authorization: &OidcDeviceAuthorizationResult,
1071 ) -> OidcResult<DeviceTokenPollResponse> {
1072 let client = self.fresh_client().await?;
1073 let token_url = client
1074 .token_uri()
1075 .cloned()
1076 .ok_or_else(|| OidcError::DeviceTokenPoll {
1077 message: "Token endpoint not set for device token polling".to_string(),
1078 })?;
1079
1080 let auth_type = self.resolve_token_endpoint_auth_type().await?;
1081 let mut params = vec![
1082 (
1083 Cow::Borrowed("grant_type"),
1084 Cow::Borrowed("urn:ietf:params:oauth:grant-type:device_code"),
1085 ),
1086 (
1087 Cow::Borrowed("device_code"),
1088 Cow::Owned(device_authorization.device_code.clone()),
1089 ),
1090 ];
1091
1092 if matches!(auth_type, AuthType::RequestBody) {
1093 params.push((
1094 Cow::Borrowed("client_id"),
1095 Cow::Owned(self.config.client_id.clone()),
1096 ));
1097 if let Some(client_secret) = self.config.client_secret.as_ref() {
1098 params.push((
1099 Cow::Borrowed("client_secret"),
1100 Cow::Owned(client_secret.clone()),
1101 ));
1102 }
1103 }
1104
1105 let mut request = self
1106 .provider
1107 .http_client()
1108 .post(token_url.url().clone())
1109 .header(reqwest::header::ACCEPT, "application/json")
1110 .form(¶ms);
1111
1112 if matches!(auth_type, AuthType::BasicAuth) {
1113 let client_secret =
1114 self.config
1115 .client_secret
1116 .as_ref()
1117 .ok_or_else(|| OidcError::DeviceTokenPoll {
1118 message: "client_secret is required for basic token endpoint auth"
1119 .to_string(),
1120 })?;
1121 let credentials = format!(
1122 "{}:{}",
1123 form_urlencode(&self.config.client_id),
1124 form_urlencode(client_secret)
1125 );
1126 let header_value = format!(
1127 "Basic {}",
1128 base64::engine::general_purpose::STANDARD.encode(credentials)
1129 );
1130 request = request.header(reqwest::header::AUTHORIZATION, header_value);
1131 }
1132
1133 let response = request
1134 .send()
1135 .await
1136 .map_err(|e| OidcError::DeviceTokenPoll {
1137 message: format!("Device token poll request failed: {e}"),
1138 })?;
1139 let status = response.status();
1140 let body = response
1141 .bytes()
1142 .await
1143 .map_err(|e| OidcError::DeviceTokenPoll {
1144 message: format!("Failed to read device token poll response: {e}"),
1145 })?;
1146
1147 if status.is_success() {
1148 let token_response =
1149 serde_json::from_slice::<TokenResponseWithExtra>(&body).map_err(|e| {
1150 OidcError::DeviceTokenPoll {
1151 message: format!(
1152 "Failed to parse device token response: {e}; body: {}",
1153 String::from_utf8_lossy(&body)
1154 ),
1155 }
1156 })?;
1157 return Ok(DeviceTokenPollResponse::Complete(Box::new(token_response)));
1158 }
1159
1160 let error_response =
1161 serde_json::from_slice::<DeviceCodeErrorResponse>(&body).map_err(|e| {
1162 OidcError::DeviceTokenPoll {
1163 message: format!(
1164 "Device token poll failed with HTTP {} and an unparseable body: {e}; \
1165 body: {}",
1166 status,
1167 String::from_utf8_lossy(&body)
1168 ),
1169 }
1170 })?;
1171
1172 match error_response.error() {
1173 DeviceCodeErrorResponseType::AuthorizationPending => {
1174 Ok(DeviceTokenPollResponse::Pending)
1175 }
1176 DeviceCodeErrorResponseType::SlowDown => Ok(DeviceTokenPollResponse::SlowDown),
1177 DeviceCodeErrorResponseType::AccessDenied => Ok(DeviceTokenPollResponse::Denied {
1178 error_description: error_response.error_description().cloned(),
1179 }),
1180 DeviceCodeErrorResponseType::ExpiredToken => Ok(DeviceTokenPollResponse::Expired {
1181 error_description: error_response.error_description().cloned(),
1182 }),
1183 other => Err(OidcError::DeviceTokenPoll {
1184 message: format!("Device token poll returned terminal error: {other}"),
1185 }),
1186 }
1187 }
1188
1189 async fn build_device_token_result(
1190 &self,
1191 token_response: TokenResponseWithExtra,
1192 ) -> OidcResult<OidcDeviceTokenResult> {
1193 let client = self.fresh_client().await?;
1194 let id_token_verifier = client.id_token_verifier();
1195 let id_token =
1196 token_response
1197 .extra_fields()
1198 .id_token()
1199 .ok_or_else(|| OidcError::DeviceTokenPoll {
1200 message: "Missing ID token in device token response".to_string(),
1201 })?;
1202 let id_token_claims = id_token
1203 .claims(&id_token_verifier, |_nonce: Option<&Nonce>| Ok(()))
1204 .map_err(|e| OidcError::DeviceTokenPoll {
1205 message: format!("Failed to verify device-flow ID token: {e}"),
1206 })?;
1207
1208 let now = Utc::now();
1209 let access_token = token_response.access_token().secret().clone();
1210 let access_token_expiration = token_response
1211 .expires_in()
1212 .map(|expires_in| now + expires_in);
1213 let refresh_token = token_response
1214 .refresh_token()
1215 .map(|value| value.secret().clone());
1216
1217 let user_info_claims = if client.user_info_url().is_some() {
1218 Some(
1219 self.request_userinfo(
1220 &client,
1221 self.provider.http_client(),
1222 token_response.access_token().clone(),
1223 Some(id_token_claims.subject().clone()),
1224 )
1225 .await?,
1226 )
1227 } else {
1228 None
1229 };
1230 let claims_check_result = self
1231 .check_claims(id_token_claims, user_info_claims.as_ref())
1232 .await?;
1233
1234 Ok(OidcDeviceTokenResult {
1235 access_token,
1236 access_token_expiration,
1237 id_token: id_token.to_string(),
1238 refresh_token,
1239 id_token_claims: id_token_claims.to_owned(),
1240 user_info_claims,
1241 claims_check_result,
1242 })
1243 }
1244
1245 async fn resolve_token_endpoint_auth_type(&self) -> OidcResult<AuthType> {
1246 let metadata = self.provider.oidc_provider_metadata().await?;
1247 let supported = metadata.token_endpoint_auth_methods_supported();
1248
1249 if self.config.client_secret.is_none() {
1250 return Ok(AuthType::RequestBody);
1251 }
1252
1253 let supports_basic = supported
1254 .is_none_or(|methods| methods.contains(&CoreClientAuthMethod::ClientSecretBasic));
1255 if supports_basic {
1256 return Ok(AuthType::BasicAuth);
1257 }
1258
1259 let supports_request_body = supported.is_some_and(|methods| {
1260 methods.contains(&CoreClientAuthMethod::ClientSecretPost)
1261 || methods.contains(&CoreClientAuthMethod::None)
1262 });
1263 if supports_request_body {
1264 return Ok(AuthType::RequestBody);
1265 }
1266
1267 Err(OidcError::DeviceTokenPoll {
1268 message: "The provider only advertises unsupported token endpoint auth methods for \
1269 device polling"
1270 .to_string(),
1271 })
1272 }
1273}
1274
1275enum DeviceTokenPollResponse {
1276 Pending,
1277 SlowDown,
1278 Denied { error_description: Option<String> },
1279 Expired { error_description: Option<String> },
1280 Complete(Box<TokenResponseWithExtra>),
1283}
1284
1285fn form_urlencode(value: &str) -> String {
1286 url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
1287}
1288
1289fn format_device_token_terminal_message(
1290 error_code: &str,
1291 error_description: Option<&str>,
1292) -> String {
1293 match error_description {
1294 Some(error_description) => {
1295 format!("Device token polling stopped with {error_code}: {error_description}")
1296 }
1297 None => format!("Device token polling stopped with {error_code}"),
1298 }
1299}
1300
1301fn build_client(
1302 config: &OidcClientConfig<impl PendingOauthStoreConfig>,
1303 metadata: ProviderMetadataWithExtra,
1304) -> Result<BuiltClientWithExtra, String> {
1305 let client_id = ClientId::new(config.client_id.clone());
1306 let client_secret = config
1307 .client_secret
1308 .as_ref()
1309 .map(|value| ClientSecret::new(value.clone()));
1310
1311 let introspection_endpoint = metadata
1312 .additional_metadata()
1313 .introspection_endpoint
1314 .as_ref()
1315 .map(|value| IntrospectionUrl::new(value.clone()))
1316 .transpose()
1317 .map_err(|e| format!("Invalid introspection_endpoint: {e}"))?;
1318 let revocation_endpoint = metadata
1319 .additional_metadata()
1320 .revocation_endpoint
1321 .as_ref()
1322 .map(|value| RevocationUrl::new(value.clone()))
1323 .transpose()
1324 .map_err(|e| format!("Invalid revocation_endpoint: {e}"))?;
1325 let device_authorization_endpoint = metadata
1326 .additional_metadata()
1327 .device_authorization_endpoint
1328 .as_ref()
1329 .map(|value| DeviceAuthorizationUrl::new(value.clone()))
1330 .transpose()
1331 .map_err(|e| format!("Invalid device_authorization_endpoint: {e}"))?;
1332
1333 Ok(BuiltClientWithExtra {
1334 client: ClientWithExtra::from_provider_metadata(metadata, client_id, client_secret),
1335 optional_endpoints: OptionalClientEndpoints {
1336 _introspection_endpoint: introspection_endpoint,
1337 revocation_endpoint,
1338 device_authorization_endpoint,
1339 },
1340 })
1341}