oauth2_broker/
oauth.rs

1//! Internal OAuth client facade abstractions.
2
3pub use oauth2;
4
5// std
6use std::borrow::Cow;
7// crates.io
8use oauth2::{
9	AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, EndpointNotSet, EndpointSet,
10	HttpClientError, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
11	TokenResponse, TokenUrl,
12	basic::{BasicClient, BasicErrorResponse, BasicRequestTokenError},
13};
14// self
15#[cfg(all(test, feature = "reqwest"))] use crate::http::ReqwestHttpClient;
16use crate::{
17	_prelude::*,
18	auth::{ScopeSet, TokenFamily, TokenRecord},
19	error::{ConfigError, TransientError, TransportError},
20	http::{ResponseMetadata, ResponseMetadataSlot, TokenHttpClient},
21	provider::{
22		ClientAuthMethod, GrantType, ProviderDescriptor, ProviderErrorContext, ProviderErrorKind,
23		ProviderStrategy,
24	},
25};
26
27type ConfiguredBasicClient =
28	BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
29type FacadeTokenResponse = oauth2::basic::BasicTokenResponse;
30type FacadeFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + 'a + Send>>;
31
32/// Maps HTTP transport failures into broker [`Error`] values.
33pub trait TransportErrorMapper<E>
34where
35	Self: 'static + Send + Sync,
36	E: 'static + Send + Sync + StdError,
37{
38	/// Converts an [`HttpClientError`] emitted by the transport into a broker error.
39	fn map_transport_error(
40		&self,
41		strategy: &dyn ProviderStrategy,
42		grant: GrantType,
43		metadata: Option<&ResponseMetadata>,
44		error: HttpClientError<E>,
45	) -> Error;
46}
47
48pub(crate) trait OAuth2Facade {
49	fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
50		&'a self,
51		strategy: &'strategy dyn ProviderStrategy,
52		family: TokenFamily,
53		scopes: &'scopes [&'scopes str],
54		extra_params: &'params [(String, String)],
55	) -> FacadeFuture<'a, TokenRecord>
56	where
57		'strategy: 'a,
58		'scopes: 'a,
59		'params: 'a;
60
61	fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
62		&'a self,
63		strategy: &'strategy dyn ProviderStrategy,
64		family: TokenFamily,
65		refresh_token: &'refresh str,
66		requested_scope: &'scope ScopeSet,
67	) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
68	where
69		'strategy: 'a,
70		'refresh: 'a,
71		'scope: 'a;
72
73	fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
74		&'a self,
75		strategy: &'strategy dyn ProviderStrategy,
76		family: TokenFamily,
77		code: &'code str,
78		pkce_verifier: &'pkce str,
79		requested_scope: &'scope ScopeSet,
80		redirect_uri: &'redirect Url,
81	) -> FacadeFuture<'a, TokenRecord>
82	where
83		'strategy: 'a,
84		'code: 'a,
85		'pkce: 'a,
86		'scope: 'a,
87		'redirect: 'a;
88}
89
90#[cfg(feature = "reqwest")]
91/// Default mapper for reqwest-backed transports.
92#[derive(Clone, Debug, Default)]
93pub struct ReqwestTransportErrorMapper;
94#[cfg(feature = "reqwest")]
95impl TransportErrorMapper<ReqwestError> for ReqwestTransportErrorMapper {
96	fn map_transport_error(
97		&self,
98		strategy: &dyn ProviderStrategy,
99		grant: GrantType,
100		meta: Option<&ResponseMetadata>,
101		err: HttpClientError<ReqwestError>,
102	) -> Error {
103		match err {
104			HttpClientError::Reqwest(inner) => map_reqwest_error(strategy, grant, meta, *inner),
105			HttpClientError::Http(inner) => ConfigError::from(inner).into(),
106			HttpClientError::Io(inner) => TransportError::Io(inner).into(),
107			HttpClientError::Other(message) => map_generic_transport_error(meta, message),
108			_ => map_unknown_transport_error(meta),
109		}
110	}
111}
112
113pub(crate) struct BasicFacade<C, M>
114where
115	C: ?Sized + TokenHttpClient,
116	M: ?Sized + TransportErrorMapper<C::TransportError>,
117{
118	oauth_client: ConfiguredBasicClient,
119	http_client: Arc<C>,
120	error_mapper: Arc<M>,
121}
122impl<C, M> BasicFacade<C, M>
123where
124	C: ?Sized + TokenHttpClient,
125	M: ?Sized + TransportErrorMapper<C::TransportError>,
126{
127	pub(super) fn new(
128		oauth_client: ConfiguredBasicClient,
129		http_client: impl Into<Arc<C>>,
130		error_mapper: impl Into<Arc<M>>,
131	) -> Self {
132		Self { oauth_client, http_client: http_client.into(), error_mapper: error_mapper.into() }
133	}
134
135	pub(crate) fn from_descriptor(
136		descriptor: &ProviderDescriptor,
137		client_id: &str,
138		client_secret: Option<&str>,
139		redirect_uri: Option<&Url>,
140		http_client: impl Into<Arc<C>>,
141		error_mapper: impl Into<Arc<M>>,
142	) -> Result<Self> {
143		let auth_url = AuthUrl::new(descriptor.endpoints.authorization.to_string())
144			.map_err(|source| ConfigError::InvalidDescriptor { source })?;
145		let token_url = TokenUrl::new(descriptor.endpoints.token.to_string())
146			.map_err(|source| ConfigError::InvalidDescriptor { source })?;
147		let secret =
148			if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::NoneWithPkce) {
149				None
150			} else {
151				client_secret.map(|value| ClientSecret::new(value.to_owned()))
152			};
153		let mut oauth_client = BasicClient::new(ClientId::new(client_id.to_owned()))
154			.set_auth_uri(auth_url)
155			.set_token_uri(token_url);
156
157		if let Some(secret) = secret {
158			oauth_client = oauth_client.set_client_secret(secret);
159		}
160		if let Some(redirect) = redirect_uri {
161			let redirect_url = RedirectUrl::new(redirect.to_string())
162				.map_err(|source| ConfigError::InvalidDescriptor { source })?;
163
164			oauth_client = oauth_client.set_redirect_uri(redirect_url);
165		}
166
167		if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::ClientSecretPost) {
168			oauth_client = oauth_client.set_auth_type(AuthType::RequestBody);
169		}
170
171		Ok(Self::new(oauth_client, http_client, error_mapper))
172	}
173}
174impl<C, M> OAuth2Facade for BasicFacade<C, M>
175where
176	C: ?Sized + TokenHttpClient,
177	M: ?Sized + TransportErrorMapper<C::TransportError>,
178{
179	fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
180		&'a self,
181		strategy: &'strategy dyn ProviderStrategy,
182		family: TokenFamily,
183		scopes: &'scopes [&'scopes str],
184		extra_params: &'params [(String, String)],
185	) -> FacadeFuture<'a, TokenRecord>
186	where
187		'strategy: 'a,
188		'scopes: 'a,
189		'params: 'a,
190	{
191		let meta = ResponseMetadataSlot::default();
192
193		Box::pin(async move {
194			let instrumented = self.http_client.with_metadata(meta.clone());
195			let requested_scope =
196				ScopeSet::new(scopes.iter().copied()).map_err(ConfigError::from)?;
197			let mut request = self.oauth_client.exchange_client_credentials();
198
199			for scope in scopes {
200				request = request.add_scope(Scope::new((*scope).to_owned()));
201			}
202			for (key, value) in extra_params {
203				request = request.add_extra_param(key, value);
204			}
205
206			let response = request.request_async(&instrumented).await.map_err(|err| {
207				map_request_error(
208					strategy,
209					GrantType::ClientCredentials,
210					meta.take(),
211					err,
212					self.error_mapper.as_ref(),
213				)
214			})?;
215
216			map_standard_token_response(family, requested_scope, response)
217		})
218	}
219
220	fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
221		&'a self,
222		strategy: &'strategy dyn ProviderStrategy,
223		family: TokenFamily,
224		refresh_token: &'refresh str,
225		requested_scope: &'scope ScopeSet,
226	) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
227	where
228		'strategy: 'a,
229		'refresh: 'a,
230		'scope: 'a,
231	{
232		let meta = ResponseMetadataSlot::default();
233
234		Box::pin(async move {
235			let instrumented = self.http_client.with_metadata(meta.clone());
236			let refresh_secret = RefreshToken::new(refresh_token.to_owned());
237			let mut request = self.oauth_client.exchange_refresh_token(&refresh_secret);
238
239			if !requested_scope.is_empty() {
240				for scope in requested_scope.iter() {
241					request = request.add_scope(Scope::new(scope.to_owned()));
242				}
243			}
244
245			let response = request.request_async(&instrumented).await.map_err(|err| {
246				map_request_error(
247					strategy,
248					GrantType::RefreshToken,
249					meta.take(),
250					err,
251					self.error_mapper.as_ref(),
252				)
253			})?;
254
255			map_refresh_token_response(family, requested_scope, response)
256		})
257	}
258
259	fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
260		&'a self,
261		strategy: &'strategy dyn ProviderStrategy,
262		family: TokenFamily,
263		code: &'code str,
264		pkce_verifier: &'pkce str,
265		requested_scope: &'scope ScopeSet,
266		redirect_uri: &'redirect Url,
267	) -> FacadeFuture<'a, TokenRecord>
268	where
269		'strategy: 'a,
270		'code: 'a,
271		'pkce: 'a,
272		'scope: 'a,
273		'redirect: 'a,
274	{
275		let meta = ResponseMetadataSlot::default();
276
277		Box::pin(async move {
278			let instrumented = self.http_client.with_metadata(meta.clone());
279			let mut request = self
280				.oauth_client
281				.exchange_code(AuthorizationCode::new(code.to_owned()))
282				.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned()));
283
284			if !requested_scope.is_empty() {
285				request = request.add_extra_param("scope", requested_scope.normalized());
286			}
287
288			let redirect_url = RedirectUrl::new(redirect_uri.to_string())
289				.map_err(|err| ConfigError::InvalidRedirect { source: err })?;
290
291			request = request.set_redirect_uri(Cow::Owned(redirect_url));
292
293			let response = request.request_async(&instrumented).await.map_err(|err| {
294				map_request_error(
295					strategy,
296					GrantType::AuthorizationCode,
297					meta.take(),
298					err,
299					self.error_mapper.as_ref(),
300				)
301			})?;
302			let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
303			let expires_in =
304				i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
305
306			if expires_in <= 0 {
307				return Err(ConfigError::NonPositiveExpiresIn.into());
308			}
309
310			if let Some(scopes) = response.scopes() {
311				let returned = ScopeSet::new(scopes.iter().map(|scope| scope.as_ref()))
312					.map_err(ConfigError::from)?;
313				if returned != *requested_scope {
314					return Err(ConfigError::ScopesChanged { grant: "authorization_code" }.into());
315				}
316			}
317
318			let issued_at = OffsetDateTime::now_utc();
319			let mut builder = TokenRecord::builder(family, requested_scope.clone())
320				.access_token(response.access_token().secret().to_owned())
321				.issued_at(issued_at)
322				.expires_in(Duration::seconds(expires_in));
323
324			if let Some(refresh) = response.refresh_token() {
325				builder = builder.refresh_token(refresh.secret().to_owned());
326			}
327
328			builder.build().map_err(|e| ConfigError::from(e).into())
329		})
330	}
331}
332
333fn map_standard_token_response(
334	family: TokenFamily,
335	scope: ScopeSet,
336	response: FacadeTokenResponse,
337) -> Result<TokenRecord> {
338	let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
339	let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
340
341	if expires_in <= 0 {
342		return Err(ConfigError::NonPositiveExpiresIn.into());
343	}
344
345	if let Some(scopes) = response.scopes() {
346		let returned =
347			ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
348		if returned != scope {
349			return Err(ConfigError::ScopesChanged { grant: "client_credentials" }.into());
350		}
351	}
352
353	let issued_at = OffsetDateTime::now_utc();
354
355	TokenRecord::builder(family, scope)
356		.access_token(response.access_token().secret().to_owned())
357		.issued_at(issued_at)
358		.expires_in(Duration::seconds(expires_in))
359		.build()
360		.map_err(|err| ConfigError::from(err).into())
361}
362
363fn map_refresh_token_response(
364	family: TokenFamily,
365	requested_scope: &ScopeSet,
366	response: FacadeTokenResponse,
367) -> Result<(TokenRecord, Option<String>)> {
368	let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
369	let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
370
371	if expires_in <= 0 {
372		return Err(ConfigError::NonPositiveExpiresIn.into());
373	}
374
375	if let Some(scopes) = response.scopes() {
376		let returned =
377			ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
378		if returned != *requested_scope {
379			return Err(ConfigError::ScopesChanged { grant: "refresh_token" }.into());
380		}
381	}
382
383	let issued_at = OffsetDateTime::now_utc();
384	let mut builder = TokenRecord::builder(family, requested_scope.clone())
385		.access_token(response.access_token().secret().to_owned())
386		.issued_at(issued_at)
387		.expires_in(Duration::seconds(expires_in));
388	let new_refresh = response.refresh_token().map(|token| token.secret().to_owned());
389
390	if let Some(secret) = &new_refresh {
391		builder = builder.refresh_token(secret.clone());
392	}
393
394	let record = builder.build().map_err(ConfigError::from)?;
395
396	Ok((record, new_refresh))
397}
398
399fn map_request_error<E, M>(
400	strategy: &dyn ProviderStrategy,
401	grant: GrantType,
402	meta: Option<ResponseMetadata>,
403	err: BasicRequestTokenError<HttpClientError<E>>,
404	mapper: &M,
405) -> Error
406where
407	E: 'static + Send + Sync + StdError,
408	M: ?Sized + TransportErrorMapper<E>,
409{
410	let meta_ref = meta.as_ref();
411
412	match err {
413		RequestTokenError::ServerResponse(response) =>
414			map_server_response_error(strategy, grant, response, meta_ref),
415		RequestTokenError::Request(error) =>
416			map_transport_error(strategy, grant, meta_ref, error, mapper),
417		RequestTokenError::Parse(error, _body) =>
418			TransientError::TokenResponseParse { source: error, status: meta_status(meta_ref) }
419				.into(),
420		RequestTokenError::Other(message) => TransientError::TokenEndpoint {
421			message: format!("Token endpoint returned an unexpected response: {message}."),
422			status: meta_status(meta_ref),
423			retry_after: meta_retry_after(meta_ref),
424		}
425		.into(),
426	}
427}
428
429fn map_server_response_error(
430	strategy: &dyn ProviderStrategy,
431	grant: GrantType,
432	response: BasicErrorResponse,
433	meta: Option<&ResponseMetadata>,
434) -> Error {
435	let mut ctx =
436		ProviderErrorContext::new(grant).with_oauth_error(response.error().as_ref().to_string());
437	if let Some(description) = response.error_description() {
438		ctx = ctx.with_error_description(description.clone());
439	}
440
441	if let Some(status) = meta_status(meta) {
442		ctx = ctx.with_http_status(status);
443	}
444
445	let classification = strategy.classify_token_error(&ctx);
446	let message = if let Some(description) = response.error_description() {
447		format!("Token endpoint returned an OAuth error: {description}.")
448	} else {
449		format!("Token endpoint returned an OAuth error: {}.", response.error().as_ref())
450	};
451
452	match classification {
453		ProviderErrorKind::InvalidGrant => Error::InvalidGrant { reason: message },
454		ProviderErrorKind::InvalidClient => Error::InvalidClient { reason: message },
455		ProviderErrorKind::InsufficientScope => Error::InsufficientScope { reason: message },
456		ProviderErrorKind::Transient => TransientError::TokenEndpoint {
457			message,
458			status: meta_status(meta),
459			retry_after: meta_retry_after(meta),
460		}
461		.into(),
462	}
463}
464
465fn map_transport_error<E, M>(
466	strategy: &dyn ProviderStrategy,
467	grant: GrantType,
468	meta: Option<&ResponseMetadata>,
469	err: HttpClientError<E>,
470	mapper: &M,
471) -> Error
472where
473	E: 'static + Send + Sync + StdError,
474	M: ?Sized + TransportErrorMapper<E>,
475{
476	mapper.map_transport_error(strategy, grant, meta, err)
477}
478
479#[cfg(feature = "reqwest")]
480fn map_reqwest_error(
481	strategy: &dyn ProviderStrategy,
482	grant: GrantType,
483	meta: Option<&ResponseMetadata>,
484	err: ReqwestError,
485) -> Error {
486	// Strategy reserved for future use.
487	let _ = (strategy, grant);
488
489	if err.is_builder() {
490		return ConfigError::from(err).into();
491	}
492	if err.is_timeout() {
493		return TransientError::TokenEndpoint {
494			message: "Request timed out while calling the token endpoint.".into(),
495			status: meta_status(meta).or_else(|| reqwest_status(&err)),
496			retry_after: meta_retry_after(meta),
497		}
498		.into();
499	}
500
501	TransportError::from(err).into()
502}
503
504fn map_generic_transport_error(meta: Option<&ResponseMetadata>, message: impl Display) -> Error {
505	TransientError::TokenEndpoint {
506		message: format!("HTTP client error occurred while calling the token endpoint: {message}."),
507		status: meta_status(meta),
508		retry_after: meta_retry_after(meta),
509	}
510	.into()
511}
512
513fn map_unknown_transport_error(meta: Option<&ResponseMetadata>) -> Error {
514	TransientError::TokenEndpoint {
515		message: "HTTP client error occurred while calling the token endpoint.".into(),
516		status: meta_status(meta),
517		retry_after: meta_retry_after(meta),
518	}
519	.into()
520}
521
522fn meta_status(meta: Option<&ResponseMetadata>) -> Option<u16> {
523	meta.and_then(|value| value.status)
524}
525
526fn meta_retry_after(meta: Option<&ResponseMetadata>) -> Option<Duration> {
527	meta.and_then(|value| value.retry_after)
528}
529
530#[cfg(feature = "reqwest")]
531fn reqwest_status(err: &ReqwestError) -> Option<u16> {
532	err.status().map(|code| code.as_u16())
533}
534
535#[cfg(all(test, feature = "reqwest"))]
536mod tests {
537	// self
538	use super::*;
539	use crate::auth::ProviderId;
540
541	fn descriptor(method: ClientAuthMethod) -> ProviderDescriptor {
542		let provider_id =
543			ProviderId::new("test-provider").expect("Failed to construct provider identifier.");
544
545		ProviderDescriptor::builder(provider_id)
546			.authorization_endpoint(
547				Url::parse("https://example.com/oauth2/authorize")
548					.expect("Failed to parse authorization endpoint URL."),
549			)
550			.token_endpoint(
551				Url::parse("https://example.com/oauth2/token")
552					.expect("Failed to parse token endpoint URL."),
553			)
554			.support_grant(GrantType::AuthorizationCode)
555			.preferred_client_auth_method(method)
556			.build()
557			.expect("Failed to build provider descriptor.")
558	}
559
560	#[test]
561	fn builds_basic_auth_client() {
562		let descriptor = descriptor(ClientAuthMethod::ClientSecretBasic);
563		let redirect =
564			Url::parse("https://example.com/callback").expect("Failed to parse redirect URI.");
565		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
566			&descriptor,
567			"client-id",
568			Some("secret"),
569			Some(&redirect),
570			Arc::new(ReqwestHttpClient::default()),
571			Arc::new(ReqwestTransportErrorMapper),
572		);
573
574		assert!(result.is_ok());
575	}
576
577	#[test]
578	fn builds_post_auth_client() {
579		let descriptor = descriptor(ClientAuthMethod::ClientSecretPost);
580		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
581			&descriptor,
582			"client-id",
583			Some("secret"),
584			None,
585			Arc::new(ReqwestHttpClient::default()),
586			Arc::new(ReqwestTransportErrorMapper),
587		);
588
589		assert!(result.is_ok());
590	}
591
592	#[test]
593	fn builds_pkce_client_without_secret() {
594		let descriptor = descriptor(ClientAuthMethod::NoneWithPkce);
595		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
596			&descriptor,
597			"public-client",
598			Some("ignored-secret"),
599			None,
600			Arc::new(ReqwestHttpClient::default()),
601			Arc::new(ReqwestTransportErrorMapper),
602		);
603
604		assert!(result.is_ok());
605	}
606}