pub mod types;
use axum::{
body::Body,
extract::{FromRequestParts, Request, State},
http::{self, StatusCode},
middleware::Next,
response::Response,
};
use axum_extra::extract::CookieJar;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use serde::Deserialize;
use service_error::ServiceError;
use crate::types::{Jwks, JwksTrait};
#[derive(Debug, Deserialize, Clone)]
pub struct RealmAccess {
pub roles: Vec<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ResourceRoles {
pub roles: Vec<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct Claims {
pub sub: Option<String>,
pub exp: Option<u64>,
pub iat: Option<u64>,
pub iss: Option<String>,
pub aud: Option<String>,
pub typ: Option<String>,
pub azp: Option<String>,
pub name: Option<String>,
pub preferred_username: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub realm_access: Option<RealmAccess>,
pub resource_access: Option<std::collections::HashMap<String, ResourceRoles>>,
}
#[derive(Debug)]
pub struct ClaimsExtension(pub Claims);
impl<S> FromRequestParts<S> for ClaimsExtension
where
S: Send + Sync,
{
type Rejection = ServiceError;
async fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let ext = &parts.extensions;
let claims = ext.get::<Claims>();
match claims {
Some(claims) => Ok(ClaimsExtension(claims.clone())),
None => Err(ServiceError::Internal(
"Claims not found in request extensions".to_string(),
)),
}
}
}
pub async fn validate_auth_cookie<T>(
State(state): State<std::sync::Arc<T>>,
cookies: CookieJar,
mut req: Request<Body>,
next: Next,
) -> Result<Response, StatusCode>
where
T: JwksTrait,
{
let token = get_auth_token(&cookies)?;
match validate_token(token, &state.jwks()) {
Ok(claims) => {
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
}
Err(_) => Err(StatusCode::UNAUTHORIZED),
}
}
fn get_auth_token(cookies: &CookieJar) -> Result<&str, StatusCode> {
match cookies.get("access_token") {
Some(token) => Ok(token.value()),
None => Err(StatusCode::UNAUTHORIZED),
}
}
fn validate_token(token: &str, jwks: &Jwks) -> Result<Claims, StatusCode> {
let header = decode_header(token).map_err(|_| StatusCode::UNAUTHORIZED)?;
let kid = header.kid.ok_or(StatusCode::UNAUTHORIZED)?;
let jwk = jwks
.keys
.iter()
.find(|k| k.kid == kid)
.ok_or(StatusCode::UNAUTHORIZED)?;
let decoding_key =
DecodingKey::from_rsa_components(&jwk.n, &jwk.e).map_err(|_| StatusCode::UNAUTHORIZED)?;
let validation = Validation::new(Algorithm::RS256);
let token_data = decode::<Claims>(token, &decoding_key, &validation)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
Ok(token_data.claims)
}