use std::{collections::HashSet, num::NonZeroU32};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Utc};
use http::header::CONTENT_TYPE;
use language_tags::LanguageTag;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
pkce,
prelude::CodeChallengeMethodExt,
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
Display, Prompt, PushedAuthorizationResponse,
},
scope::Scope,
};
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
};
use serde::Serialize;
use serde_with::skip_serializing_none;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{
client_credentials::ClientCredentials,
scope::{ScopeExt, ScopeToken},
IdToken,
},
utils::{http_all_error_status_codes, http_error_mapper},
};
#[derive(Debug, Clone)]
pub struct AuthorizationRequestData {
pub client_id: String,
pub scope: Scope,
pub redirect_uri: Url,
pub code_challenge_methods_supported: Option<Vec<PkceCodeChallengeMethod>>,
pub display: Option<Display>,
pub prompt: Option<Vec<Prompt>>,
pub max_age: Option<NonZeroU32>,
pub ui_locales: Option<Vec<LanguageTag>>,
pub id_token_hint: Option<String>,
pub login_hint: Option<String>,
pub acr_values: Option<HashSet<String>>,
}
impl AuthorizationRequestData {
#[must_use]
pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self {
Self {
client_id,
scope,
redirect_uri,
code_challenge_methods_supported: None,
display: None,
prompt: None,
max_age: None,
ui_locales: None,
id_token_hint: None,
login_hint: None,
acr_values: None,
}
}
#[must_use]
pub fn with_code_challenge_methods_supported(
mut self,
code_challenge_methods_supported: Vec<PkceCodeChallengeMethod>,
) -> Self {
self.code_challenge_methods_supported = Some(code_challenge_methods_supported);
self
}
#[must_use]
pub fn with_display(mut self, display: Display) -> Self {
self.display = Some(display);
self
}
#[must_use]
pub fn with_prompt(mut self, prompt: Vec<Prompt>) -> Self {
self.prompt = Some(prompt);
self
}
#[must_use]
pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self {
self.max_age = Some(max_age);
self
}
#[must_use]
pub fn with_ui_locales(mut self, ui_locales: Vec<LanguageTag>) -> Self {
self.ui_locales = Some(ui_locales);
self
}
#[must_use]
pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self {
self.id_token_hint = Some(id_token_hint);
self
}
#[must_use]
pub fn with_login_hint(mut self, login_hint: String) -> Self {
self.login_hint = Some(login_hint);
self
}
#[must_use]
pub fn with_acr_values(mut self, acr_values: HashSet<String>) -> Self {
self.acr_values = Some(acr_values);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthorizationValidationData {
pub state: String,
pub nonce: String,
pub redirect_uri: Url,
pub code_challenge_verifier: Option<String>,
}
#[skip_serializing_none]
#[derive(Clone, Serialize)]
struct FullAuthorizationRequest {
#[serde(flatten)]
inner: AuthorizationRequest,
#[serde(flatten)]
pkce: Option<pkce::AuthorizationRequest>,
}
fn build_authorization_request(
authorization_data: AuthorizationRequestData,
rng: &mut impl Rng,
) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
let AuthorizationRequestData {
client_id,
mut scope,
redirect_uri,
code_challenge_methods_supported,
display,
prompt,
max_age,
ui_locales,
id_token_hint,
login_hint,
acr_values,
} = authorization_data;
let state = Alphanumeric.sample_string(rng, 16);
let nonce = Alphanumeric.sample_string(rng, 16);
let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
.iter()
.any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
{
let mut verifier = [0u8; 32];
rng.fill(&mut verifier);
let method = PkceCodeChallengeMethod::S256;
let verifier = Base64UrlUnpadded::encode_string(&verifier);
let code_challenge = method.compute_challenge(&verifier)?.into();
let pkce = pkce::AuthorizationRequest {
code_challenge_method: method,
code_challenge,
};
(Some(pkce), Some(verifier))
} else {
(None, None)
};
scope.insert_token(ScopeToken::Openid);
let auth_request = FullAuthorizationRequest {
inner: AuthorizationRequest {
response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
client_id,
redirect_uri: Some(redirect_uri.clone()),
scope,
state: Some(state.clone()),
response_mode: None,
nonce: Some(nonce.clone()),
display,
prompt,
max_age,
ui_locales,
id_token_hint,
login_hint,
acr_values,
request: None,
request_uri: None,
registration: None,
},
pkce,
};
let auth_data = AuthorizationValidationData {
state,
nonce,
redirect_uri,
code_challenge_verifier,
};
Ok((auth_request, auth_data))
}
#[allow(clippy::too_many_lines)]
pub fn build_authorization_url(
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing..."
);
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let authorization_query = serde_urlencoded::to_string(authorization_request)?;
let mut authorization_url = authorization_endpoint;
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
#[tracing::instrument(skip_all, fields(par_endpoint))]
pub async fn build_par_authorization_url(
http_service: &HttpService,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing with a PAR..."
);
let client_id = client_credentials.client_id().to_owned();
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let par_request = http::Request::post(par_endpoint.as_str())
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body(authorization_request)
.map_err(PushedAuthorizationError::from)?;
let par_request = client_credentials
.apply_to_request(par_request, now, rng)
.map_err(PushedAuthorizationError::from)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<PushedAuthorizationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let par_response = service
.ready_oneshot()
.await
.map_err(PushedAuthorizationError::from)?
.call(par_request)
.await
.map_err(PushedAuthorizationError::from)?
.into_body();
let authorization_query = serde_urlencoded::to_string([
("request_uri", par_response.request_uri.as_str()),
("client_id", &client_id),
])?;
let mut authorization_url = authorization_endpoint;
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_authorization_code(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
code: String,
validation_data: AuthorizationValidationData,
id_token_verification_data: Option<JwtVerificationData<'_>>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
tracing::debug!("Exchanging authorization code for access token...");
let token_response = request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: code.clone(),
redirect_uri: Some(validation_data.redirect_uri),
code_verifier: validation_data.code_challenge_verifier,
}),
now,
rng,
)
.await?;
let id_token = if let Some(verification_data) = id_token_verification_data {
let signing_alg = verification_data.signing_algorithm;
let id_token = token_response
.id_token
.as_deref()
.ok_or(IdTokenError::MissingIdToken)?;
let id_token = verify_id_token(id_token, verification_data, None, now)?;
let mut claims = id_token.payload().clone();
claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(signing_alg, &token_response.access_token),
)
.map_err(IdTokenError::from)?;
claims::C_HASH
.extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
.map_err(IdTokenError::from)?;
claims::NONCE
.extract_required_with_options(&mut claims, validation_data.nonce.as_str())
.map_err(IdTokenError::from)?;
Some(id_token.into_owned())
} else {
None
};
Ok((token_response, id_token))
}