1pub use oauth2;
4
5use std::borrow::Cow;
7use oauth2::{
9 AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, EndpointNotSet, EndpointSet,
10 HttpClientError, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
11 TokenResponse, TokenUrl,
12 basic::{BasicClient, BasicErrorResponse, BasicRequestTokenError},
13};
14#[cfg(all(test, feature = "reqwest"))] use crate::http::ReqwestHttpClient;
16use crate::{
17 _prelude::*,
18 auth::{ScopeSet, TokenFamily, TokenRecord},
19 error::{ConfigError, TransientError, TransportError},
20 http::{ResponseMetadata, ResponseMetadataSlot, TokenHttpClient},
21 provider::{
22 ClientAuthMethod, GrantType, ProviderDescriptor, ProviderErrorContext, ProviderErrorKind,
23 ProviderStrategy,
24 },
25};
26
27type ConfiguredBasicClient =
28 BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
29type FacadeTokenResponse = oauth2::basic::BasicTokenResponse;
30type FacadeFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + 'a + Send>>;
31
32pub trait TransportErrorMapper<E>
34where
35 Self: 'static + Send + Sync,
36 E: 'static + Send + Sync + StdError,
37{
38 fn map_transport_error(
40 &self,
41 strategy: &dyn ProviderStrategy,
42 grant: GrantType,
43 metadata: Option<&ResponseMetadata>,
44 error: HttpClientError<E>,
45 ) -> Error;
46}
47
48pub(crate) trait OAuth2Facade {
49 fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
50 &'a self,
51 strategy: &'strategy dyn ProviderStrategy,
52 family: TokenFamily,
53 scopes: &'scopes [&'scopes str],
54 extra_params: &'params [(String, String)],
55 ) -> FacadeFuture<'a, TokenRecord>
56 where
57 'strategy: 'a,
58 'scopes: 'a,
59 'params: 'a;
60
61 fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
62 &'a self,
63 strategy: &'strategy dyn ProviderStrategy,
64 family: TokenFamily,
65 refresh_token: &'refresh str,
66 requested_scope: &'scope ScopeSet,
67 ) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
68 where
69 'strategy: 'a,
70 'refresh: 'a,
71 'scope: 'a;
72
73 fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
74 &'a self,
75 strategy: &'strategy dyn ProviderStrategy,
76 family: TokenFamily,
77 code: &'code str,
78 pkce_verifier: &'pkce str,
79 requested_scope: &'scope ScopeSet,
80 redirect_uri: &'redirect Url,
81 ) -> FacadeFuture<'a, TokenRecord>
82 where
83 'strategy: 'a,
84 'code: 'a,
85 'pkce: 'a,
86 'scope: 'a,
87 'redirect: 'a;
88}
89
90#[cfg(feature = "reqwest")]
91#[derive(Clone, Debug, Default)]
93pub struct ReqwestTransportErrorMapper;
94#[cfg(feature = "reqwest")]
95impl TransportErrorMapper<ReqwestError> for ReqwestTransportErrorMapper {
96 fn map_transport_error(
97 &self,
98 strategy: &dyn ProviderStrategy,
99 grant: GrantType,
100 meta: Option<&ResponseMetadata>,
101 err: HttpClientError<ReqwestError>,
102 ) -> Error {
103 match err {
104 HttpClientError::Reqwest(inner) => map_reqwest_error(strategy, grant, meta, *inner),
105 HttpClientError::Http(inner) => ConfigError::from(inner).into(),
106 HttpClientError::Io(inner) => TransportError::Io(inner).into(),
107 HttpClientError::Other(message) => map_generic_transport_error(meta, message),
108 _ => map_unknown_transport_error(meta),
109 }
110 }
111}
112
113pub(crate) struct BasicFacade<C, M>
114where
115 C: ?Sized + TokenHttpClient,
116 M: ?Sized + TransportErrorMapper<C::TransportError>,
117{
118 oauth_client: ConfiguredBasicClient,
119 http_client: Arc<C>,
120 error_mapper: Arc<M>,
121}
122impl<C, M> BasicFacade<C, M>
123where
124 C: ?Sized + TokenHttpClient,
125 M: ?Sized + TransportErrorMapper<C::TransportError>,
126{
127 pub(super) fn new(
128 oauth_client: ConfiguredBasicClient,
129 http_client: impl Into<Arc<C>>,
130 error_mapper: impl Into<Arc<M>>,
131 ) -> Self {
132 Self { oauth_client, http_client: http_client.into(), error_mapper: error_mapper.into() }
133 }
134
135 pub(crate) fn from_descriptor(
136 descriptor: &ProviderDescriptor,
137 client_id: &str,
138 client_secret: Option<&str>,
139 redirect_uri: Option<&Url>,
140 http_client: impl Into<Arc<C>>,
141 error_mapper: impl Into<Arc<M>>,
142 ) -> Result<Self> {
143 let auth_url = AuthUrl::new(descriptor.endpoints.authorization.to_string())
144 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
145 let token_url = TokenUrl::new(descriptor.endpoints.token.to_string())
146 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
147 let secret =
148 if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::NoneWithPkce) {
149 None
150 } else {
151 client_secret.map(|value| ClientSecret::new(value.to_owned()))
152 };
153 let mut oauth_client = BasicClient::new(ClientId::new(client_id.to_owned()))
154 .set_auth_uri(auth_url)
155 .set_token_uri(token_url);
156
157 if let Some(secret) = secret {
158 oauth_client = oauth_client.set_client_secret(secret);
159 }
160 if let Some(redirect) = redirect_uri {
161 let redirect_url = RedirectUrl::new(redirect.to_string())
162 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
163
164 oauth_client = oauth_client.set_redirect_uri(redirect_url);
165 }
166
167 if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::ClientSecretPost) {
168 oauth_client = oauth_client.set_auth_type(AuthType::RequestBody);
169 }
170
171 Ok(Self::new(oauth_client, http_client, error_mapper))
172 }
173}
174impl<C, M> OAuth2Facade for BasicFacade<C, M>
175where
176 C: ?Sized + TokenHttpClient,
177 M: ?Sized + TransportErrorMapper<C::TransportError>,
178{
179 fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
180 &'a self,
181 strategy: &'strategy dyn ProviderStrategy,
182 family: TokenFamily,
183 scopes: &'scopes [&'scopes str],
184 extra_params: &'params [(String, String)],
185 ) -> FacadeFuture<'a, TokenRecord>
186 where
187 'strategy: 'a,
188 'scopes: 'a,
189 'params: 'a,
190 {
191 let meta = ResponseMetadataSlot::default();
192
193 Box::pin(async move {
194 let instrumented = self.http_client.with_metadata(meta.clone());
195 let requested_scope =
196 ScopeSet::new(scopes.iter().copied()).map_err(ConfigError::from)?;
197 let mut request = self.oauth_client.exchange_client_credentials();
198
199 for scope in scopes {
200 request = request.add_scope(Scope::new((*scope).to_owned()));
201 }
202 for (key, value) in extra_params {
203 request = request.add_extra_param(key, value);
204 }
205
206 let response = request.request_async(&instrumented).await.map_err(|err| {
207 map_request_error(
208 strategy,
209 GrantType::ClientCredentials,
210 meta.take(),
211 err,
212 self.error_mapper.as_ref(),
213 )
214 })?;
215
216 map_standard_token_response(family, requested_scope, response)
217 })
218 }
219
220 fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
221 &'a self,
222 strategy: &'strategy dyn ProviderStrategy,
223 family: TokenFamily,
224 refresh_token: &'refresh str,
225 requested_scope: &'scope ScopeSet,
226 ) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
227 where
228 'strategy: 'a,
229 'refresh: 'a,
230 'scope: 'a,
231 {
232 let meta = ResponseMetadataSlot::default();
233
234 Box::pin(async move {
235 let instrumented = self.http_client.with_metadata(meta.clone());
236 let refresh_secret = RefreshToken::new(refresh_token.to_owned());
237 let mut request = self.oauth_client.exchange_refresh_token(&refresh_secret);
238
239 if !requested_scope.is_empty() {
240 for scope in requested_scope.iter() {
241 request = request.add_scope(Scope::new(scope.to_owned()));
242 }
243 }
244
245 let response = request.request_async(&instrumented).await.map_err(|err| {
246 map_request_error(
247 strategy,
248 GrantType::RefreshToken,
249 meta.take(),
250 err,
251 self.error_mapper.as_ref(),
252 )
253 })?;
254
255 map_refresh_token_response(family, requested_scope, response)
256 })
257 }
258
259 fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
260 &'a self,
261 strategy: &'strategy dyn ProviderStrategy,
262 family: TokenFamily,
263 code: &'code str,
264 pkce_verifier: &'pkce str,
265 requested_scope: &'scope ScopeSet,
266 redirect_uri: &'redirect Url,
267 ) -> FacadeFuture<'a, TokenRecord>
268 where
269 'strategy: 'a,
270 'code: 'a,
271 'pkce: 'a,
272 'scope: 'a,
273 'redirect: 'a,
274 {
275 let meta = ResponseMetadataSlot::default();
276
277 Box::pin(async move {
278 let instrumented = self.http_client.with_metadata(meta.clone());
279 let mut request = self
280 .oauth_client
281 .exchange_code(AuthorizationCode::new(code.to_owned()))
282 .set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned()));
283
284 if !requested_scope.is_empty() {
285 request = request.add_extra_param("scope", requested_scope.normalized());
286 }
287
288 let redirect_url = RedirectUrl::new(redirect_uri.to_string())
289 .map_err(|err| ConfigError::InvalidRedirect { source: err })?;
290
291 request = request.set_redirect_uri(Cow::Owned(redirect_url));
292
293 let response = request.request_async(&instrumented).await.map_err(|err| {
294 map_request_error(
295 strategy,
296 GrantType::AuthorizationCode,
297 meta.take(),
298 err,
299 self.error_mapper.as_ref(),
300 )
301 })?;
302 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
303 let expires_in =
304 i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
305
306 if expires_in <= 0 {
307 return Err(ConfigError::NonPositiveExpiresIn.into());
308 }
309
310 if let Some(scopes) = response.scopes() {
311 let returned = ScopeSet::new(scopes.iter().map(|scope| scope.as_ref()))
312 .map_err(ConfigError::from)?;
313 if returned != *requested_scope {
314 return Err(ConfigError::ScopesChanged { grant: "authorization_code" }.into());
315 }
316 }
317
318 let issued_at = OffsetDateTime::now_utc();
319 let mut builder = TokenRecord::builder(family, requested_scope.clone())
320 .access_token(response.access_token().secret().to_owned())
321 .issued_at(issued_at)
322 .expires_in(Duration::seconds(expires_in));
323
324 if let Some(refresh) = response.refresh_token() {
325 builder = builder.refresh_token(refresh.secret().to_owned());
326 }
327
328 builder.build().map_err(|e| ConfigError::from(e).into())
329 })
330 }
331}
332
333fn map_standard_token_response(
334 family: TokenFamily,
335 scope: ScopeSet,
336 response: FacadeTokenResponse,
337) -> Result<TokenRecord> {
338 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
339 let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
340
341 if expires_in <= 0 {
342 return Err(ConfigError::NonPositiveExpiresIn.into());
343 }
344
345 if let Some(scopes) = response.scopes() {
346 let returned =
347 ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
348 if returned != scope {
349 return Err(ConfigError::ScopesChanged { grant: "client_credentials" }.into());
350 }
351 }
352
353 let issued_at = OffsetDateTime::now_utc();
354
355 TokenRecord::builder(family, scope)
356 .access_token(response.access_token().secret().to_owned())
357 .issued_at(issued_at)
358 .expires_in(Duration::seconds(expires_in))
359 .build()
360 .map_err(|err| ConfigError::from(err).into())
361}
362
363fn map_refresh_token_response(
364 family: TokenFamily,
365 requested_scope: &ScopeSet,
366 response: FacadeTokenResponse,
367) -> Result<(TokenRecord, Option<String>)> {
368 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
369 let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
370
371 if expires_in <= 0 {
372 return Err(ConfigError::NonPositiveExpiresIn.into());
373 }
374
375 if let Some(scopes) = response.scopes() {
376 let returned =
377 ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
378 if returned != *requested_scope {
379 return Err(ConfigError::ScopesChanged { grant: "refresh_token" }.into());
380 }
381 }
382
383 let issued_at = OffsetDateTime::now_utc();
384 let mut builder = TokenRecord::builder(family, requested_scope.clone())
385 .access_token(response.access_token().secret().to_owned())
386 .issued_at(issued_at)
387 .expires_in(Duration::seconds(expires_in));
388 let new_refresh = response.refresh_token().map(|token| token.secret().to_owned());
389
390 if let Some(secret) = &new_refresh {
391 builder = builder.refresh_token(secret.clone());
392 }
393
394 let record = builder.build().map_err(ConfigError::from)?;
395
396 Ok((record, new_refresh))
397}
398
399fn map_request_error<E, M>(
400 strategy: &dyn ProviderStrategy,
401 grant: GrantType,
402 meta: Option<ResponseMetadata>,
403 err: BasicRequestTokenError<HttpClientError<E>>,
404 mapper: &M,
405) -> Error
406where
407 E: 'static + Send + Sync + StdError,
408 M: ?Sized + TransportErrorMapper<E>,
409{
410 let meta_ref = meta.as_ref();
411
412 match err {
413 RequestTokenError::ServerResponse(response) =>
414 map_server_response_error(strategy, grant, response, meta_ref),
415 RequestTokenError::Request(error) =>
416 map_transport_error(strategy, grant, meta_ref, error, mapper),
417 RequestTokenError::Parse(error, _body) =>
418 TransientError::TokenResponseParse { source: error, status: meta_status(meta_ref) }
419 .into(),
420 RequestTokenError::Other(message) => TransientError::TokenEndpoint {
421 message: format!("Token endpoint returned an unexpected response: {message}."),
422 status: meta_status(meta_ref),
423 retry_after: meta_retry_after(meta_ref),
424 }
425 .into(),
426 }
427}
428
429fn map_server_response_error(
430 strategy: &dyn ProviderStrategy,
431 grant: GrantType,
432 response: BasicErrorResponse,
433 meta: Option<&ResponseMetadata>,
434) -> Error {
435 let mut ctx =
436 ProviderErrorContext::new(grant).with_oauth_error(response.error().as_ref().to_string());
437 if let Some(description) = response.error_description() {
438 ctx = ctx.with_error_description(description.clone());
439 }
440
441 if let Some(status) = meta_status(meta) {
442 ctx = ctx.with_http_status(status);
443 }
444
445 let classification = strategy.classify_token_error(&ctx);
446 let message = if let Some(description) = response.error_description() {
447 format!("Token endpoint returned an OAuth error: {description}.")
448 } else {
449 format!("Token endpoint returned an OAuth error: {}.", response.error().as_ref())
450 };
451
452 match classification {
453 ProviderErrorKind::InvalidGrant => Error::InvalidGrant { reason: message },
454 ProviderErrorKind::InvalidClient => Error::InvalidClient { reason: message },
455 ProviderErrorKind::InsufficientScope => Error::InsufficientScope { reason: message },
456 ProviderErrorKind::Transient => TransientError::TokenEndpoint {
457 message,
458 status: meta_status(meta),
459 retry_after: meta_retry_after(meta),
460 }
461 .into(),
462 }
463}
464
465fn map_transport_error<E, M>(
466 strategy: &dyn ProviderStrategy,
467 grant: GrantType,
468 meta: Option<&ResponseMetadata>,
469 err: HttpClientError<E>,
470 mapper: &M,
471) -> Error
472where
473 E: 'static + Send + Sync + StdError,
474 M: ?Sized + TransportErrorMapper<E>,
475{
476 mapper.map_transport_error(strategy, grant, meta, err)
477}
478
479#[cfg(feature = "reqwest")]
480fn map_reqwest_error(
481 strategy: &dyn ProviderStrategy,
482 grant: GrantType,
483 meta: Option<&ResponseMetadata>,
484 err: ReqwestError,
485) -> Error {
486 let _ = (strategy, grant);
488
489 if err.is_builder() {
490 return ConfigError::from(err).into();
491 }
492 if err.is_timeout() {
493 return TransientError::TokenEndpoint {
494 message: "Request timed out while calling the token endpoint.".into(),
495 status: meta_status(meta).or_else(|| reqwest_status(&err)),
496 retry_after: meta_retry_after(meta),
497 }
498 .into();
499 }
500
501 TransportError::from(err).into()
502}
503
504fn map_generic_transport_error(meta: Option<&ResponseMetadata>, message: impl Display) -> Error {
505 TransientError::TokenEndpoint {
506 message: format!("HTTP client error occurred while calling the token endpoint: {message}."),
507 status: meta_status(meta),
508 retry_after: meta_retry_after(meta),
509 }
510 .into()
511}
512
513fn map_unknown_transport_error(meta: Option<&ResponseMetadata>) -> Error {
514 TransientError::TokenEndpoint {
515 message: "HTTP client error occurred while calling the token endpoint.".into(),
516 status: meta_status(meta),
517 retry_after: meta_retry_after(meta),
518 }
519 .into()
520}
521
522fn meta_status(meta: Option<&ResponseMetadata>) -> Option<u16> {
523 meta.and_then(|value| value.status)
524}
525
526fn meta_retry_after(meta: Option<&ResponseMetadata>) -> Option<Duration> {
527 meta.and_then(|value| value.retry_after)
528}
529
530#[cfg(feature = "reqwest")]
531fn reqwest_status(err: &ReqwestError) -> Option<u16> {
532 err.status().map(|code| code.as_u16())
533}
534
535#[cfg(all(test, feature = "reqwest"))]
536mod tests {
537 use super::*;
539 use crate::auth::ProviderId;
540
541 fn descriptor(method: ClientAuthMethod) -> ProviderDescriptor {
542 let provider_id =
543 ProviderId::new("test-provider").expect("Failed to construct provider identifier.");
544
545 ProviderDescriptor::builder(provider_id)
546 .authorization_endpoint(
547 Url::parse("https://example.com/oauth2/authorize")
548 .expect("Failed to parse authorization endpoint URL."),
549 )
550 .token_endpoint(
551 Url::parse("https://example.com/oauth2/token")
552 .expect("Failed to parse token endpoint URL."),
553 )
554 .support_grant(GrantType::AuthorizationCode)
555 .preferred_client_auth_method(method)
556 .build()
557 .expect("Failed to build provider descriptor.")
558 }
559
560 #[test]
561 fn builds_basic_auth_client() {
562 let descriptor = descriptor(ClientAuthMethod::ClientSecretBasic);
563 let redirect =
564 Url::parse("https://example.com/callback").expect("Failed to parse redirect URI.");
565 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
566 &descriptor,
567 "client-id",
568 Some("secret"),
569 Some(&redirect),
570 Arc::new(ReqwestHttpClient::default()),
571 Arc::new(ReqwestTransportErrorMapper),
572 );
573
574 assert!(result.is_ok());
575 }
576
577 #[test]
578 fn builds_post_auth_client() {
579 let descriptor = descriptor(ClientAuthMethod::ClientSecretPost);
580 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
581 &descriptor,
582 "client-id",
583 Some("secret"),
584 None,
585 Arc::new(ReqwestHttpClient::default()),
586 Arc::new(ReqwestTransportErrorMapper),
587 );
588
589 assert!(result.is_ok());
590 }
591
592 #[test]
593 fn builds_pkce_client_without_secret() {
594 let descriptor = descriptor(ClientAuthMethod::NoneWithPkce);
595 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
596 &descriptor,
597 "public-client",
598 Some("ignored-secret"),
599 None,
600 Arc::new(ReqwestHttpClient::default()),
601 Arc::new(ReqwestTransportErrorMapper),
602 );
603
604 assert!(result.is_ok());
605 }
606}