1pub use oauth2;
4
5use std::borrow::Cow;
7use oauth2::{
9 AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, EndpointNotSet, EndpointSet,
10 HttpClientError, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
11 TokenResponse, TokenUrl,
12 basic::{BasicClient, BasicErrorResponse, BasicRequestTokenError},
13};
14use crate::{
16 _prelude::*,
17 auth::{ScopeSet, TokenFamily, TokenRecord},
18 error::{ConfigError, TransientError, TransportError},
19 http::{ReqwestHttpClient, ResponseMetadata, ResponseMetadataSlot, TokenHttpClient},
20 provider::{
21 ClientAuthMethod, GrantType, ProviderDescriptor, ProviderErrorContext, ProviderErrorKind,
22 ProviderStrategy,
23 },
24};
25
26type ConfiguredBasicClient =
27 BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
28type FacadeTokenResponse = oauth2::basic::BasicTokenResponse;
29type FacadeFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + 'a + Send>>;
30
31pub trait TransportErrorMapper<E>
33where
34 Self: 'static + Send + Sync,
35 E: 'static + Send + Sync + StdError,
36{
37 fn map_transport_error(
39 &self,
40 strategy: &dyn ProviderStrategy,
41 grant: GrantType,
42 metadata: Option<&ResponseMetadata>,
43 error: HttpClientError<E>,
44 ) -> Error;
45}
46
47#[derive(Clone, Debug, Default)]
49pub struct ReqwestTransportErrorMapper;
50impl TransportErrorMapper<ReqwestError> for ReqwestTransportErrorMapper {
51 fn map_transport_error(
52 &self,
53 strategy: &dyn ProviderStrategy,
54 grant: GrantType,
55 meta: Option<&ResponseMetadata>,
56 err: HttpClientError<ReqwestError>,
57 ) -> Error {
58 match err {
59 HttpClientError::Reqwest(inner) => map_reqwest_error(strategy, grant, meta, *inner),
60 HttpClientError::Http(inner) => ConfigError::from(inner).into(),
61 HttpClientError::Io(inner) => TransportError::Io(inner).into(),
62 HttpClientError::Other(message) => map_generic_transport_error(meta, message),
63 _ => map_unknown_transport_error(meta),
64 }
65 }
66}
67
68pub(crate) trait OAuth2Facade {
69 fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
70 &'a self,
71 strategy: &'strategy dyn ProviderStrategy,
72 family: TokenFamily,
73 scopes: &'scopes [&'scopes str],
74 extra_params: &'params [(String, String)],
75 ) -> FacadeFuture<'a, TokenRecord>
76 where
77 'strategy: 'a,
78 'scopes: 'a,
79 'params: 'a;
80
81 fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
82 &'a self,
83 strategy: &'strategy dyn ProviderStrategy,
84 family: TokenFamily,
85 refresh_token: &'refresh str,
86 requested_scope: &'scope ScopeSet,
87 ) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
88 where
89 'strategy: 'a,
90 'refresh: 'a,
91 'scope: 'a;
92
93 fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
94 &'a self,
95 strategy: &'strategy dyn ProviderStrategy,
96 family: TokenFamily,
97 code: &'code str,
98 pkce_verifier: &'pkce str,
99 requested_scope: &'scope ScopeSet,
100 redirect_uri: &'redirect Url,
101 ) -> FacadeFuture<'a, TokenRecord>
102 where
103 'strategy: 'a,
104 'code: 'a,
105 'pkce: 'a,
106 'scope: 'a,
107 'redirect: 'a;
108}
109
110pub(crate) struct BasicFacade<C = ReqwestHttpClient, M = ReqwestTransportErrorMapper>
111where
112 C: ?Sized + TokenHttpClient,
113 M: ?Sized + TransportErrorMapper<C::TransportError>,
114{
115 oauth_client: ConfiguredBasicClient,
116 http_client: Arc<C>,
117 error_mapper: Arc<M>,
118}
119impl<C, M> BasicFacade<C, M>
120where
121 C: ?Sized + TokenHttpClient,
122 M: ?Sized + TransportErrorMapper<C::TransportError>,
123{
124 pub(super) fn new(
125 oauth_client: ConfiguredBasicClient,
126 http_client: impl Into<Arc<C>>,
127 error_mapper: impl Into<Arc<M>>,
128 ) -> Self {
129 Self { oauth_client, http_client: http_client.into(), error_mapper: error_mapper.into() }
130 }
131
132 pub(crate) fn from_descriptor(
133 descriptor: &ProviderDescriptor,
134 client_id: &str,
135 client_secret: Option<&str>,
136 redirect_uri: Option<&Url>,
137 http_client: impl Into<Arc<C>>,
138 error_mapper: impl Into<Arc<M>>,
139 ) -> Result<Self> {
140 let auth_url = AuthUrl::new(descriptor.endpoints.authorization.to_string())
141 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
142 let token_url = TokenUrl::new(descriptor.endpoints.token.to_string())
143 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
144 let secret =
145 if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::NoneWithPkce) {
146 None
147 } else {
148 client_secret.map(|value| ClientSecret::new(value.to_owned()))
149 };
150 let mut oauth_client = BasicClient::new(ClientId::new(client_id.to_owned()))
151 .set_auth_uri(auth_url)
152 .set_token_uri(token_url);
153
154 if let Some(secret) = secret {
155 oauth_client = oauth_client.set_client_secret(secret);
156 }
157 if let Some(redirect) = redirect_uri {
158 let redirect_url = RedirectUrl::new(redirect.to_string())
159 .map_err(|source| ConfigError::InvalidDescriptor { source })?;
160
161 oauth_client = oauth_client.set_redirect_uri(redirect_url);
162 }
163
164 if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::ClientSecretPost) {
165 oauth_client = oauth_client.set_auth_type(AuthType::RequestBody);
166 }
167
168 Ok(Self::new(oauth_client, http_client, error_mapper))
169 }
170}
171impl<C, M> OAuth2Facade for BasicFacade<C, M>
172where
173 C: ?Sized + TokenHttpClient,
174 M: ?Sized + TransportErrorMapper<C::TransportError>,
175{
176 fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
177 &'a self,
178 strategy: &'strategy dyn ProviderStrategy,
179 family: TokenFamily,
180 scopes: &'scopes [&'scopes str],
181 extra_params: &'params [(String, String)],
182 ) -> FacadeFuture<'a, TokenRecord>
183 where
184 'strategy: 'a,
185 'scopes: 'a,
186 'params: 'a,
187 {
188 let meta = ResponseMetadataSlot::default();
189
190 Box::pin(async move {
191 let instrumented = self.http_client.with_metadata(meta.clone());
192 let requested_scope =
193 ScopeSet::new(scopes.iter().copied()).map_err(ConfigError::from)?;
194 let mut request = self.oauth_client.exchange_client_credentials();
195
196 for scope in scopes {
197 request = request.add_scope(Scope::new((*scope).to_owned()));
198 }
199 for (key, value) in extra_params {
200 request = request.add_extra_param(key, value);
201 }
202
203 let response = request.request_async(&instrumented).await.map_err(|err| {
204 map_request_error(
205 strategy,
206 GrantType::ClientCredentials,
207 meta.take(),
208 err,
209 self.error_mapper.as_ref(),
210 )
211 })?;
212
213 map_standard_token_response(family, requested_scope, response)
214 })
215 }
216
217 fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
218 &'a self,
219 strategy: &'strategy dyn ProviderStrategy,
220 family: TokenFamily,
221 refresh_token: &'refresh str,
222 requested_scope: &'scope ScopeSet,
223 ) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
224 where
225 'strategy: 'a,
226 'refresh: 'a,
227 'scope: 'a,
228 {
229 let meta = ResponseMetadataSlot::default();
230
231 Box::pin(async move {
232 let instrumented = self.http_client.with_metadata(meta.clone());
233 let refresh_secret = RefreshToken::new(refresh_token.to_owned());
234 let mut request = self.oauth_client.exchange_refresh_token(&refresh_secret);
235
236 if !requested_scope.is_empty() {
237 for scope in requested_scope.iter() {
238 request = request.add_scope(Scope::new(scope.to_owned()));
239 }
240 }
241
242 let response = request.request_async(&instrumented).await.map_err(|err| {
243 map_request_error(
244 strategy,
245 GrantType::RefreshToken,
246 meta.take(),
247 err,
248 self.error_mapper.as_ref(),
249 )
250 })?;
251
252 map_refresh_token_response(family, requested_scope, response)
253 })
254 }
255
256 fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
257 &'a self,
258 strategy: &'strategy dyn ProviderStrategy,
259 family: TokenFamily,
260 code: &'code str,
261 pkce_verifier: &'pkce str,
262 requested_scope: &'scope ScopeSet,
263 redirect_uri: &'redirect Url,
264 ) -> FacadeFuture<'a, TokenRecord>
265 where
266 'strategy: 'a,
267 'code: 'a,
268 'pkce: 'a,
269 'scope: 'a,
270 'redirect: 'a,
271 {
272 let meta = ResponseMetadataSlot::default();
273
274 Box::pin(async move {
275 let instrumented = self.http_client.with_metadata(meta.clone());
276 let mut request = self
277 .oauth_client
278 .exchange_code(AuthorizationCode::new(code.to_owned()))
279 .set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned()));
280
281 if !requested_scope.is_empty() {
282 request = request.add_extra_param("scope", requested_scope.normalized());
283 }
284
285 let redirect_url = RedirectUrl::new(redirect_uri.to_string())
286 .map_err(|err| ConfigError::InvalidRedirect { source: err })?;
287
288 request = request.set_redirect_uri(Cow::Owned(redirect_url));
289
290 let response = request.request_async(&instrumented).await.map_err(|err| {
291 map_request_error(
292 strategy,
293 GrantType::AuthorizationCode,
294 meta.take(),
295 err,
296 self.error_mapper.as_ref(),
297 )
298 })?;
299 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
300 let expires_in =
301 i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
302
303 if expires_in <= 0 {
304 return Err(ConfigError::NonPositiveExpiresIn.into());
305 }
306
307 if let Some(scopes) = response.scopes() {
308 let returned = ScopeSet::new(scopes.iter().map(|scope| scope.as_ref()))
309 .map_err(ConfigError::from)?;
310 if returned != *requested_scope {
311 return Err(ConfigError::ScopesChanged { grant: "authorization_code" }.into());
312 }
313 }
314
315 let issued_at = OffsetDateTime::now_utc();
316 let mut builder = TokenRecord::builder(family, requested_scope.clone())
317 .access_token(response.access_token().secret().to_owned())
318 .issued_at(issued_at)
319 .expires_in(Duration::seconds(expires_in));
320
321 if let Some(refresh) = response.refresh_token() {
322 builder = builder.refresh_token(refresh.secret().to_owned());
323 }
324
325 builder.build().map_err(|e| ConfigError::from(e).into())
326 })
327 }
328}
329
330fn map_standard_token_response(
331 family: TokenFamily,
332 scope: ScopeSet,
333 response: FacadeTokenResponse,
334) -> Result<TokenRecord> {
335 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
336 let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
337
338 if expires_in <= 0 {
339 return Err(ConfigError::NonPositiveExpiresIn.into());
340 }
341
342 if let Some(scopes) = response.scopes() {
343 let returned =
344 ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
345 if returned != scope {
346 return Err(ConfigError::ScopesChanged { grant: "client_credentials" }.into());
347 }
348 }
349
350 let issued_at = OffsetDateTime::now_utc();
351
352 TokenRecord::builder(family, scope)
353 .access_token(response.access_token().secret().to_owned())
354 .issued_at(issued_at)
355 .expires_in(Duration::seconds(expires_in))
356 .build()
357 .map_err(|err| ConfigError::from(err).into())
358}
359
360fn map_refresh_token_response(
361 family: TokenFamily,
362 requested_scope: &ScopeSet,
363 response: FacadeTokenResponse,
364) -> Result<(TokenRecord, Option<String>)> {
365 let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
366 let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
367
368 if expires_in <= 0 {
369 return Err(ConfigError::NonPositiveExpiresIn.into());
370 }
371
372 if let Some(scopes) = response.scopes() {
373 let returned =
374 ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
375 if returned != *requested_scope {
376 return Err(ConfigError::ScopesChanged { grant: "refresh_token" }.into());
377 }
378 }
379
380 let issued_at = OffsetDateTime::now_utc();
381 let mut builder = TokenRecord::builder(family, requested_scope.clone())
382 .access_token(response.access_token().secret().to_owned())
383 .issued_at(issued_at)
384 .expires_in(Duration::seconds(expires_in));
385 let new_refresh = response.refresh_token().map(|token| token.secret().to_owned());
386
387 if let Some(secret) = &new_refresh {
388 builder = builder.refresh_token(secret.clone());
389 }
390
391 let record = builder.build().map_err(ConfigError::from)?;
392
393 Ok((record, new_refresh))
394}
395
396fn map_request_error<E, M>(
397 strategy: &dyn ProviderStrategy,
398 grant: GrantType,
399 meta: Option<ResponseMetadata>,
400 err: BasicRequestTokenError<HttpClientError<E>>,
401 mapper: &M,
402) -> Error
403where
404 E: 'static + Send + Sync + StdError,
405 M: ?Sized + TransportErrorMapper<E>,
406{
407 let meta_ref = meta.as_ref();
408
409 match err {
410 RequestTokenError::ServerResponse(response) =>
411 map_server_response_error(strategy, grant, response, meta_ref),
412 RequestTokenError::Request(error) =>
413 map_transport_error(strategy, grant, meta_ref, error, mapper),
414 RequestTokenError::Parse(error, _body) =>
415 TransientError::TokenResponseParse { source: error, status: meta_status(meta_ref) }
416 .into(),
417 RequestTokenError::Other(message) => TransientError::TokenEndpoint {
418 message: format!("Token endpoint returned an unexpected response: {message}."),
419 status: meta_status(meta_ref),
420 retry_after: meta_retry_after(meta_ref),
421 }
422 .into(),
423 }
424}
425
426fn map_server_response_error(
427 strategy: &dyn ProviderStrategy,
428 grant: GrantType,
429 response: BasicErrorResponse,
430 meta: Option<&ResponseMetadata>,
431) -> Error {
432 let mut ctx =
433 ProviderErrorContext::new(grant).with_oauth_error(response.error().as_ref().to_string());
434 if let Some(description) = response.error_description() {
435 ctx = ctx.with_error_description(description.clone());
436 }
437
438 if let Some(status) = meta_status(meta) {
439 ctx = ctx.with_http_status(status);
440 }
441
442 let classification = strategy.classify_token_error(&ctx);
443 let message = if let Some(description) = response.error_description() {
444 format!("Token endpoint returned an OAuth error: {description}.")
445 } else {
446 format!("Token endpoint returned an OAuth error: {}.", response.error().as_ref())
447 };
448
449 match classification {
450 ProviderErrorKind::InvalidGrant => Error::InvalidGrant { reason: message },
451 ProviderErrorKind::InvalidClient => Error::InvalidClient { reason: message },
452 ProviderErrorKind::InsufficientScope => Error::InsufficientScope { reason: message },
453 ProviderErrorKind::Transient => TransientError::TokenEndpoint {
454 message,
455 status: meta_status(meta),
456 retry_after: meta_retry_after(meta),
457 }
458 .into(),
459 }
460}
461
462fn map_transport_error<E, M>(
463 strategy: &dyn ProviderStrategy,
464 grant: GrantType,
465 meta: Option<&ResponseMetadata>,
466 err: HttpClientError<E>,
467 mapper: &M,
468) -> Error
469where
470 E: 'static + Send + Sync + StdError,
471 M: ?Sized + TransportErrorMapper<E>,
472{
473 mapper.map_transport_error(strategy, grant, meta, err)
474}
475
476fn map_reqwest_error(
477 strategy: &dyn ProviderStrategy,
478 grant: GrantType,
479 meta: Option<&ResponseMetadata>,
480 err: ReqwestError,
481) -> Error {
482 let _ = (strategy, grant);
484
485 if err.is_builder() {
486 return ConfigError::from(err).into();
487 }
488 if err.is_timeout() {
489 return TransientError::TokenEndpoint {
490 message: "Request timed out while calling the token endpoint.".into(),
491 status: meta_status(meta).or_else(|| reqwest_status(&err)),
492 retry_after: meta_retry_after(meta),
493 }
494 .into();
495 }
496
497 TransportError::from(err).into()
498}
499
500fn map_generic_transport_error(meta: Option<&ResponseMetadata>, message: impl Display) -> Error {
501 TransientError::TokenEndpoint {
502 message: format!("HTTP client error occurred while calling the token endpoint: {message}."),
503 status: meta_status(meta),
504 retry_after: meta_retry_after(meta),
505 }
506 .into()
507}
508
509fn map_unknown_transport_error(meta: Option<&ResponseMetadata>) -> Error {
510 TransientError::TokenEndpoint {
511 message: "HTTP client error occurred while calling the token endpoint.".into(),
512 status: meta_status(meta),
513 retry_after: meta_retry_after(meta),
514 }
515 .into()
516}
517
518fn meta_status(meta: Option<&ResponseMetadata>) -> Option<u16> {
519 meta.and_then(|value| value.status)
520}
521
522fn meta_retry_after(meta: Option<&ResponseMetadata>) -> Option<Duration> {
523 meta.and_then(|value| value.retry_after)
524}
525
526fn reqwest_status(err: &ReqwestError) -> Option<u16> {
527 err.status().map(|code| code.as_u16())
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
534 use crate::auth::ProviderId;
535
536 fn descriptor(method: ClientAuthMethod) -> ProviderDescriptor {
537 let provider_id =
538 ProviderId::new("test-provider").expect("Failed to construct provider identifier.");
539
540 ProviderDescriptor::builder(provider_id)
541 .authorization_endpoint(
542 Url::parse("https://example.com/oauth2/authorize")
543 .expect("Failed to parse authorization endpoint URL."),
544 )
545 .token_endpoint(
546 Url::parse("https://example.com/oauth2/token")
547 .expect("Failed to parse token endpoint URL."),
548 )
549 .support_grant(GrantType::AuthorizationCode)
550 .preferred_client_auth_method(method)
551 .build()
552 .expect("Failed to build provider descriptor.")
553 }
554
555 #[test]
556 fn builds_basic_auth_client() {
557 let descriptor = descriptor(ClientAuthMethod::ClientSecretBasic);
558 let redirect =
559 Url::parse("https://example.com/callback").expect("Failed to parse redirect URI.");
560 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
561 &descriptor,
562 "client-id",
563 Some("secret"),
564 Some(&redirect),
565 Arc::new(ReqwestHttpClient::default()),
566 Arc::new(ReqwestTransportErrorMapper),
567 );
568
569 assert!(result.is_ok());
570 }
571
572 #[test]
573 fn builds_post_auth_client() {
574 let descriptor = descriptor(ClientAuthMethod::ClientSecretPost);
575 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
576 &descriptor,
577 "client-id",
578 Some("secret"),
579 None,
580 Arc::new(ReqwestHttpClient::default()),
581 Arc::new(ReqwestTransportErrorMapper),
582 );
583
584 assert!(result.is_ok());
585 }
586
587 #[test]
588 fn builds_pkce_client_without_secret() {
589 let descriptor = descriptor(ClientAuthMethod::NoneWithPkce);
590 let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
591 &descriptor,
592 "public-client",
593 Some("ignored-secret"),
594 None,
595 Arc::new(ReqwestHttpClient::default()),
596 Arc::new(ReqwestTransportErrorMapper),
597 );
598
599 assert!(result.is_ok());
600 }
601}