use std::collections::HashMap;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use mas_http::JsonResponseLayer;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, TimeOptions},
jwk::PublicJsonWebKeySet,
jwt::Jwt,
};
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::{IdTokenError, JwksError, JwtVerificationError},
http_service::HttpService,
types::IdToken,
};
#[tracing::instrument(skip_all, fields(jwks_uri))]
pub async fn fetch_jwks(
http_service: &HttpService,
jwks_uri: &Url,
) -> Result<PublicJsonWebKeySet, JwksError> {
tracing::debug!("Fetching JWKS...");
let jwks_request = http::Request::get(jwks_uri.as_str()).body(Bytes::new())?;
let service = JsonResponseLayer::<PublicJsonWebKeySet>::default().layer(http_service.clone());
let response = service.ready_oneshot().await?.call(jwks_request).await?;
Ok(response.into_body())
}
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
pub issuer: &'a str,
pub jwks: &'a PublicJsonWebKeySet,
pub client_id: &'a String,
pub signing_algorithm: &'a JsonWebSignatureAlg,
}
pub fn verify_signed_jwt<'a>(
jwt: &'a str,
verification_data: JwtVerificationData<'_>,
) -> Result<Jwt<'a, HashMap<String, Value>>, JwtVerificationError> {
tracing::debug!("Validating JWT...");
let JwtVerificationData {
issuer,
jwks,
client_id,
signing_algorithm,
} = verification_data;
let jwt: Jwt<HashMap<String, Value>> = jwt.try_into()?;
jwt.verify_with_jwks(jwks)?;
let (header, mut claims) = jwt.clone().into_parts();
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
claims::AUD.extract_required_with_options(&mut claims, client_id)?;
if header.alg() != signing_algorithm {
return Err(JwtVerificationError::WrongSignatureAlg);
}
Ok(jwt)
}
pub fn verify_id_token<'a>(
id_token: &'a str,
verification_data: JwtVerificationData<'_>,
auth_id_token: Option<&IdToken<'_>>,
now: DateTime<Utc>,
) -> Result<IdToken<'a>, IdTokenError> {
let id_token = verify_signed_jwt(id_token, verification_data)?;
let mut claims = id_token.payload().clone();
let time_options = TimeOptions::new(now);
claims::EXP.extract_required_with_options(&mut claims, &time_options)?;
claims::IAT.extract_required_with_options(&mut claims, time_options)?;
let sub = claims::SUB.extract_required(&mut claims)?;
if let Some(auth_id_token) = auth_id_token {
let mut auth_claims = auth_id_token.payload().clone();
let auth_sub = claims::SUB.extract_required(&mut auth_claims)?;
if sub != auth_sub {
return Err(IdTokenError::WrongSubjectIdentifier);
}
if let Some(auth_time) = claims::AUTH_TIME.extract_optional(&mut claims)? {
let prev_auth_time = claims::AUTH_TIME.extract_required(&mut auth_claims)?;
if prev_auth_time != auth_time {
return Err(IdTokenError::WrongAuthTime);
}
}
}
Ok(id_token)
}