oauth2_broker/
oauth.rs

1//! Internal OAuth client facade abstractions.
2
3pub use oauth2;
4
5// std
6use std::borrow::Cow;
7// crates.io
8use oauth2::{
9	AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, EndpointNotSet, EndpointSet,
10	HttpClientError, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
11	TokenResponse, TokenUrl,
12	basic::{BasicClient, BasicErrorResponse, BasicRequestTokenError},
13};
14// self
15use crate::{
16	_prelude::*,
17	auth::{ScopeSet, TokenFamily, TokenRecord},
18	error::{ConfigError, TransientError, TransportError},
19	http::{ReqwestHttpClient, ResponseMetadata, ResponseMetadataSlot, TokenHttpClient},
20	provider::{
21		ClientAuthMethod, GrantType, ProviderDescriptor, ProviderErrorContext, ProviderErrorKind,
22		ProviderStrategy,
23	},
24};
25
26type ConfiguredBasicClient =
27	BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
28type FacadeTokenResponse = oauth2::basic::BasicTokenResponse;
29type FacadeFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + 'a + Send>>;
30
31/// Maps HTTP transport failures into broker [`Error`] values.
32pub trait TransportErrorMapper<E>
33where
34	Self: 'static + Send + Sync,
35	E: 'static + Send + Sync + StdError,
36{
37	/// Converts an [`HttpClientError`] emitted by the transport into a broker error.
38	fn map_transport_error(
39		&self,
40		strategy: &dyn ProviderStrategy,
41		grant: GrantType,
42		metadata: Option<&ResponseMetadata>,
43		error: HttpClientError<E>,
44	) -> Error;
45}
46
47/// Default mapper for reqwest-backed transports.
48#[derive(Clone, Debug, Default)]
49pub struct ReqwestTransportErrorMapper;
50impl TransportErrorMapper<ReqwestError> for ReqwestTransportErrorMapper {
51	fn map_transport_error(
52		&self,
53		strategy: &dyn ProviderStrategy,
54		grant: GrantType,
55		meta: Option<&ResponseMetadata>,
56		err: HttpClientError<ReqwestError>,
57	) -> Error {
58		match err {
59			HttpClientError::Reqwest(inner) => map_reqwest_error(strategy, grant, meta, *inner),
60			HttpClientError::Http(inner) => ConfigError::from(inner).into(),
61			HttpClientError::Io(inner) => TransportError::Io(inner).into(),
62			HttpClientError::Other(message) => map_generic_transport_error(meta, message),
63			_ => map_unknown_transport_error(meta),
64		}
65	}
66}
67
68pub(crate) trait OAuth2Facade {
69	fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
70		&'a self,
71		strategy: &'strategy dyn ProviderStrategy,
72		family: TokenFamily,
73		scopes: &'scopes [&'scopes str],
74		extra_params: &'params [(String, String)],
75	) -> FacadeFuture<'a, TokenRecord>
76	where
77		'strategy: 'a,
78		'scopes: 'a,
79		'params: 'a;
80
81	fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
82		&'a self,
83		strategy: &'strategy dyn ProviderStrategy,
84		family: TokenFamily,
85		refresh_token: &'refresh str,
86		requested_scope: &'scope ScopeSet,
87	) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
88	where
89		'strategy: 'a,
90		'refresh: 'a,
91		'scope: 'a;
92
93	fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
94		&'a self,
95		strategy: &'strategy dyn ProviderStrategy,
96		family: TokenFamily,
97		code: &'code str,
98		pkce_verifier: &'pkce str,
99		requested_scope: &'scope ScopeSet,
100		redirect_uri: &'redirect Url,
101	) -> FacadeFuture<'a, TokenRecord>
102	where
103		'strategy: 'a,
104		'code: 'a,
105		'pkce: 'a,
106		'scope: 'a,
107		'redirect: 'a;
108}
109
110pub(crate) struct BasicFacade<C = ReqwestHttpClient, M = ReqwestTransportErrorMapper>
111where
112	C: ?Sized + TokenHttpClient,
113	M: ?Sized + TransportErrorMapper<C::TransportError>,
114{
115	oauth_client: ConfiguredBasicClient,
116	http_client: Arc<C>,
117	error_mapper: Arc<M>,
118}
119impl<C, M> BasicFacade<C, M>
120where
121	C: ?Sized + TokenHttpClient,
122	M: ?Sized + TransportErrorMapper<C::TransportError>,
123{
124	pub(super) fn new(
125		oauth_client: ConfiguredBasicClient,
126		http_client: impl Into<Arc<C>>,
127		error_mapper: impl Into<Arc<M>>,
128	) -> Self {
129		Self { oauth_client, http_client: http_client.into(), error_mapper: error_mapper.into() }
130	}
131
132	pub(crate) fn from_descriptor(
133		descriptor: &ProviderDescriptor,
134		client_id: &str,
135		client_secret: Option<&str>,
136		redirect_uri: Option<&Url>,
137		http_client: impl Into<Arc<C>>,
138		error_mapper: impl Into<Arc<M>>,
139	) -> Result<Self> {
140		let auth_url = AuthUrl::new(descriptor.endpoints.authorization.to_string())
141			.map_err(|source| ConfigError::InvalidDescriptor { source })?;
142		let token_url = TokenUrl::new(descriptor.endpoints.token.to_string())
143			.map_err(|source| ConfigError::InvalidDescriptor { source })?;
144		let secret =
145			if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::NoneWithPkce) {
146				None
147			} else {
148				client_secret.map(|value| ClientSecret::new(value.to_owned()))
149			};
150		let mut oauth_client = BasicClient::new(ClientId::new(client_id.to_owned()))
151			.set_auth_uri(auth_url)
152			.set_token_uri(token_url);
153
154		if let Some(secret) = secret {
155			oauth_client = oauth_client.set_client_secret(secret);
156		}
157		if let Some(redirect) = redirect_uri {
158			let redirect_url = RedirectUrl::new(redirect.to_string())
159				.map_err(|source| ConfigError::InvalidDescriptor { source })?;
160
161			oauth_client = oauth_client.set_redirect_uri(redirect_url);
162		}
163
164		if matches!(descriptor.preferred_client_auth_method, ClientAuthMethod::ClientSecretPost) {
165			oauth_client = oauth_client.set_auth_type(AuthType::RequestBody);
166		}
167
168		Ok(Self::new(oauth_client, http_client, error_mapper))
169	}
170}
171impl<C, M> OAuth2Facade for BasicFacade<C, M>
172where
173	C: ?Sized + TokenHttpClient,
174	M: ?Sized + TransportErrorMapper<C::TransportError>,
175{
176	fn exchange_client_credentials<'a, 'strategy, 'scopes, 'params>(
177		&'a self,
178		strategy: &'strategy dyn ProviderStrategy,
179		family: TokenFamily,
180		scopes: &'scopes [&'scopes str],
181		extra_params: &'params [(String, String)],
182	) -> FacadeFuture<'a, TokenRecord>
183	where
184		'strategy: 'a,
185		'scopes: 'a,
186		'params: 'a,
187	{
188		let meta = ResponseMetadataSlot::default();
189
190		Box::pin(async move {
191			let instrumented = self.http_client.with_metadata(meta.clone());
192			let requested_scope =
193				ScopeSet::new(scopes.iter().copied()).map_err(ConfigError::from)?;
194			let mut request = self.oauth_client.exchange_client_credentials();
195
196			for scope in scopes {
197				request = request.add_scope(Scope::new((*scope).to_owned()));
198			}
199			for (key, value) in extra_params {
200				request = request.add_extra_param(key, value);
201			}
202
203			let response = request.request_async(&instrumented).await.map_err(|err| {
204				map_request_error(
205					strategy,
206					GrantType::ClientCredentials,
207					meta.take(),
208					err,
209					self.error_mapper.as_ref(),
210				)
211			})?;
212
213			map_standard_token_response(family, requested_scope, response)
214		})
215	}
216
217	fn refresh_token<'a, 'strategy, 'refresh, 'scope>(
218		&'a self,
219		strategy: &'strategy dyn ProviderStrategy,
220		family: TokenFamily,
221		refresh_token: &'refresh str,
222		requested_scope: &'scope ScopeSet,
223	) -> FacadeFuture<'a, (TokenRecord, Option<String>)>
224	where
225		'strategy: 'a,
226		'refresh: 'a,
227		'scope: 'a,
228	{
229		let meta = ResponseMetadataSlot::default();
230
231		Box::pin(async move {
232			let instrumented = self.http_client.with_metadata(meta.clone());
233			let refresh_secret = RefreshToken::new(refresh_token.to_owned());
234			let mut request = self.oauth_client.exchange_refresh_token(&refresh_secret);
235
236			if !requested_scope.is_empty() {
237				for scope in requested_scope.iter() {
238					request = request.add_scope(Scope::new(scope.to_owned()));
239				}
240			}
241
242			let response = request.request_async(&instrumented).await.map_err(|err| {
243				map_request_error(
244					strategy,
245					GrantType::RefreshToken,
246					meta.take(),
247					err,
248					self.error_mapper.as_ref(),
249				)
250			})?;
251
252			map_refresh_token_response(family, requested_scope, response)
253		})
254	}
255
256	fn exchange_authorization_code<'a, 'strategy, 'code, 'pkce, 'scope, 'redirect>(
257		&'a self,
258		strategy: &'strategy dyn ProviderStrategy,
259		family: TokenFamily,
260		code: &'code str,
261		pkce_verifier: &'pkce str,
262		requested_scope: &'scope ScopeSet,
263		redirect_uri: &'redirect Url,
264	) -> FacadeFuture<'a, TokenRecord>
265	where
266		'strategy: 'a,
267		'code: 'a,
268		'pkce: 'a,
269		'scope: 'a,
270		'redirect: 'a,
271	{
272		let meta = ResponseMetadataSlot::default();
273
274		Box::pin(async move {
275			let instrumented = self.http_client.with_metadata(meta.clone());
276			let mut request = self
277				.oauth_client
278				.exchange_code(AuthorizationCode::new(code.to_owned()))
279				.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned()));
280
281			if !requested_scope.is_empty() {
282				request = request.add_extra_param("scope", requested_scope.normalized());
283			}
284
285			let redirect_url = RedirectUrl::new(redirect_uri.to_string())
286				.map_err(|err| ConfigError::InvalidRedirect { source: err })?;
287
288			request = request.set_redirect_uri(Cow::Owned(redirect_url));
289
290			let response = request.request_async(&instrumented).await.map_err(|err| {
291				map_request_error(
292					strategy,
293					GrantType::AuthorizationCode,
294					meta.take(),
295					err,
296					self.error_mapper.as_ref(),
297				)
298			})?;
299			let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
300			let expires_in =
301				i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
302
303			if expires_in <= 0 {
304				return Err(ConfigError::NonPositiveExpiresIn.into());
305			}
306
307			if let Some(scopes) = response.scopes() {
308				let returned = ScopeSet::new(scopes.iter().map(|scope| scope.as_ref()))
309					.map_err(ConfigError::from)?;
310				if returned != *requested_scope {
311					return Err(ConfigError::ScopesChanged { grant: "authorization_code" }.into());
312				}
313			}
314
315			let issued_at = OffsetDateTime::now_utc();
316			let mut builder = TokenRecord::builder(family, requested_scope.clone())
317				.access_token(response.access_token().secret().to_owned())
318				.issued_at(issued_at)
319				.expires_in(Duration::seconds(expires_in));
320
321			if let Some(refresh) = response.refresh_token() {
322				builder = builder.refresh_token(refresh.secret().to_owned());
323			}
324
325			builder.build().map_err(|e| ConfigError::from(e).into())
326		})
327	}
328}
329
330fn map_standard_token_response(
331	family: TokenFamily,
332	scope: ScopeSet,
333	response: FacadeTokenResponse,
334) -> Result<TokenRecord> {
335	let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
336	let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
337
338	if expires_in <= 0 {
339		return Err(ConfigError::NonPositiveExpiresIn.into());
340	}
341
342	if let Some(scopes) = response.scopes() {
343		let returned =
344			ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
345		if returned != scope {
346			return Err(ConfigError::ScopesChanged { grant: "client_credentials" }.into());
347		}
348	}
349
350	let issued_at = OffsetDateTime::now_utc();
351
352	TokenRecord::builder(family, scope)
353		.access_token(response.access_token().secret().to_owned())
354		.issued_at(issued_at)
355		.expires_in(Duration::seconds(expires_in))
356		.build()
357		.map_err(|err| ConfigError::from(err).into())
358}
359
360fn map_refresh_token_response(
361	family: TokenFamily,
362	requested_scope: &ScopeSet,
363	response: FacadeTokenResponse,
364) -> Result<(TokenRecord, Option<String>)> {
365	let expires_in = response.expires_in().ok_or(ConfigError::MissingExpiresIn)?.as_secs();
366	let expires_in = i64::try_from(expires_in).map_err(|_| ConfigError::ExpiresInOutOfRange)?;
367
368	if expires_in <= 0 {
369		return Err(ConfigError::NonPositiveExpiresIn.into());
370	}
371
372	if let Some(scopes) = response.scopes() {
373		let returned =
374			ScopeSet::new(scopes.iter().map(|scope| scope.as_ref())).map_err(ConfigError::from)?;
375		if returned != *requested_scope {
376			return Err(ConfigError::ScopesChanged { grant: "refresh_token" }.into());
377		}
378	}
379
380	let issued_at = OffsetDateTime::now_utc();
381	let mut builder = TokenRecord::builder(family, requested_scope.clone())
382		.access_token(response.access_token().secret().to_owned())
383		.issued_at(issued_at)
384		.expires_in(Duration::seconds(expires_in));
385	let new_refresh = response.refresh_token().map(|token| token.secret().to_owned());
386
387	if let Some(secret) = &new_refresh {
388		builder = builder.refresh_token(secret.clone());
389	}
390
391	let record = builder.build().map_err(ConfigError::from)?;
392
393	Ok((record, new_refresh))
394}
395
396fn map_request_error<E, M>(
397	strategy: &dyn ProviderStrategy,
398	grant: GrantType,
399	meta: Option<ResponseMetadata>,
400	err: BasicRequestTokenError<HttpClientError<E>>,
401	mapper: &M,
402) -> Error
403where
404	E: 'static + Send + Sync + StdError,
405	M: ?Sized + TransportErrorMapper<E>,
406{
407	let meta_ref = meta.as_ref();
408
409	match err {
410		RequestTokenError::ServerResponse(response) =>
411			map_server_response_error(strategy, grant, response, meta_ref),
412		RequestTokenError::Request(error) =>
413			map_transport_error(strategy, grant, meta_ref, error, mapper),
414		RequestTokenError::Parse(error, _body) =>
415			TransientError::TokenResponseParse { source: error, status: meta_status(meta_ref) }
416				.into(),
417		RequestTokenError::Other(message) => TransientError::TokenEndpoint {
418			message: format!("Token endpoint returned an unexpected response: {message}."),
419			status: meta_status(meta_ref),
420			retry_after: meta_retry_after(meta_ref),
421		}
422		.into(),
423	}
424}
425
426fn map_server_response_error(
427	strategy: &dyn ProviderStrategy,
428	grant: GrantType,
429	response: BasicErrorResponse,
430	meta: Option<&ResponseMetadata>,
431) -> Error {
432	let mut ctx =
433		ProviderErrorContext::new(grant).with_oauth_error(response.error().as_ref().to_string());
434	if let Some(description) = response.error_description() {
435		ctx = ctx.with_error_description(description.clone());
436	}
437
438	if let Some(status) = meta_status(meta) {
439		ctx = ctx.with_http_status(status);
440	}
441
442	let classification = strategy.classify_token_error(&ctx);
443	let message = if let Some(description) = response.error_description() {
444		format!("Token endpoint returned an OAuth error: {description}.")
445	} else {
446		format!("Token endpoint returned an OAuth error: {}.", response.error().as_ref())
447	};
448
449	match classification {
450		ProviderErrorKind::InvalidGrant => Error::InvalidGrant { reason: message },
451		ProviderErrorKind::InvalidClient => Error::InvalidClient { reason: message },
452		ProviderErrorKind::InsufficientScope => Error::InsufficientScope { reason: message },
453		ProviderErrorKind::Transient => TransientError::TokenEndpoint {
454			message,
455			status: meta_status(meta),
456			retry_after: meta_retry_after(meta),
457		}
458		.into(),
459	}
460}
461
462fn map_transport_error<E, M>(
463	strategy: &dyn ProviderStrategy,
464	grant: GrantType,
465	meta: Option<&ResponseMetadata>,
466	err: HttpClientError<E>,
467	mapper: &M,
468) -> Error
469where
470	E: 'static + Send + Sync + StdError,
471	M: ?Sized + TransportErrorMapper<E>,
472{
473	mapper.map_transport_error(strategy, grant, meta, err)
474}
475
476fn map_reqwest_error(
477	strategy: &dyn ProviderStrategy,
478	grant: GrantType,
479	meta: Option<&ResponseMetadata>,
480	err: ReqwestError,
481) -> Error {
482	// Strategy reserved for future use.
483	let _ = (strategy, grant);
484
485	if err.is_builder() {
486		return ConfigError::from(err).into();
487	}
488	if err.is_timeout() {
489		return TransientError::TokenEndpoint {
490			message: "Request timed out while calling the token endpoint.".into(),
491			status: meta_status(meta).or_else(|| reqwest_status(&err)),
492			retry_after: meta_retry_after(meta),
493		}
494		.into();
495	}
496
497	TransportError::from(err).into()
498}
499
500fn map_generic_transport_error(meta: Option<&ResponseMetadata>, message: impl Display) -> Error {
501	TransientError::TokenEndpoint {
502		message: format!("HTTP client error occurred while calling the token endpoint: {message}."),
503		status: meta_status(meta),
504		retry_after: meta_retry_after(meta),
505	}
506	.into()
507}
508
509fn map_unknown_transport_error(meta: Option<&ResponseMetadata>) -> Error {
510	TransientError::TokenEndpoint {
511		message: "HTTP client error occurred while calling the token endpoint.".into(),
512		status: meta_status(meta),
513		retry_after: meta_retry_after(meta),
514	}
515	.into()
516}
517
518fn meta_status(meta: Option<&ResponseMetadata>) -> Option<u16> {
519	meta.and_then(|value| value.status)
520}
521
522fn meta_retry_after(meta: Option<&ResponseMetadata>) -> Option<Duration> {
523	meta.and_then(|value| value.retry_after)
524}
525
526fn reqwest_status(err: &ReqwestError) -> Option<u16> {
527	err.status().map(|code| code.as_u16())
528}
529
530#[cfg(test)]
531mod tests {
532	// self
533	use super::*;
534	use crate::auth::ProviderId;
535
536	fn descriptor(method: ClientAuthMethod) -> ProviderDescriptor {
537		let provider_id =
538			ProviderId::new("test-provider").expect("Failed to construct provider identifier.");
539
540		ProviderDescriptor::builder(provider_id)
541			.authorization_endpoint(
542				Url::parse("https://example.com/oauth2/authorize")
543					.expect("Failed to parse authorization endpoint URL."),
544			)
545			.token_endpoint(
546				Url::parse("https://example.com/oauth2/token")
547					.expect("Failed to parse token endpoint URL."),
548			)
549			.support_grant(GrantType::AuthorizationCode)
550			.preferred_client_auth_method(method)
551			.build()
552			.expect("Failed to build provider descriptor.")
553	}
554
555	#[test]
556	fn builds_basic_auth_client() {
557		let descriptor = descriptor(ClientAuthMethod::ClientSecretBasic);
558		let redirect =
559			Url::parse("https://example.com/callback").expect("Failed to parse redirect URI.");
560		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
561			&descriptor,
562			"client-id",
563			Some("secret"),
564			Some(&redirect),
565			Arc::new(ReqwestHttpClient::default()),
566			Arc::new(ReqwestTransportErrorMapper),
567		);
568
569		assert!(result.is_ok());
570	}
571
572	#[test]
573	fn builds_post_auth_client() {
574		let descriptor = descriptor(ClientAuthMethod::ClientSecretPost);
575		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
576			&descriptor,
577			"client-id",
578			Some("secret"),
579			None,
580			Arc::new(ReqwestHttpClient::default()),
581			Arc::new(ReqwestTransportErrorMapper),
582		);
583
584		assert!(result.is_ok());
585	}
586
587	#[test]
588	fn builds_pkce_client_without_secret() {
589		let descriptor = descriptor(ClientAuthMethod::NoneWithPkce);
590		let result = <BasicFacade<ReqwestHttpClient, ReqwestTransportErrorMapper>>::from_descriptor(
591			&descriptor,
592			"public-client",
593			Some("ignored-secret"),
594			None,
595			Arc::new(ReqwestHttpClient::default()),
596			Arc::new(ReqwestTransportErrorMapper),
597		);
598
599		assert!(result.is_ok());
600	}
601}