use axum::{body::HttpBody, extract::Request};
use futures_core::future::BoxFuture;
use http::{Response, StatusCode, header::AUTHORIZATION};
use openidconnect::{
AccessToken, ClientId, ClientSecret, IntrospectionUrl, IssuerUrl,
TokenIntrospectionResponse as _, core::CoreClient,
};
use serde::{Deserialize, Serialize};
use tower_http::auth::AsyncAuthorizeRequest;
use url::Url;
#[derive(Debug, Clone)]
pub struct OidcAuth {
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
}
impl OidcAuth {
pub fn new(issuer_url: IssuerUrl, client_id: ClientId, client_secret: ClientSecret) -> Self {
Self {
issuer_url,
client_id,
client_secret,
}
}
}
impl<B: HttpBody + Send + 'static> AsyncAuthorizeRequest<B> for OidcAuth {
type RequestBody = B;
type ResponseBody = axum::body::Body;
type Future =
BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
let issuer_url = self.issuer_url.clone();
let client_id = self.client_id.clone();
let client_secret = self.client_secret.clone();
Box::pin(async move {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
let Some(header) = request.headers().get(AUTHORIZATION).cloned() else {
return Err(unauthorized_response);
};
let Some(user) = check_auth(issuer_url, client_id, client_secret, header).await else {
return Err(unauthorized_response);
};
println!("USER AUTHENTICATED: {user:?}");
Ok(request)
})
}
}
async fn check_auth(
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
authorization_header: http::HeaderValue,
) -> Option<UserId> {
println!("AUTH HEADER: {authorization_header:?}");
const HEADER_IDENTIFIER: &str = "bearer ";
let authorization_header = authorization_header.to_str().unwrap();
if !authorization_header
.to_lowercase()
.starts_with(HEADER_IDENTIFIER)
{
panic!("Authorization header must start with \"bearer\" (lower- or upper-case)");
}
let (_, token) = authorization_header.split_at(HEADER_IDENTIFIER.len());
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");
#[derive(Debug, Clone, Deserialize, Serialize)]
struct AdditionalProviderMetadata {
pub introspection_endpoint: Option<Url>,
}
impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {}
type ProviderMetadata = openidconnect::ProviderMetadata<
AdditionalProviderMetadata,
openidconnect::core::CoreAuthDisplay,
openidconnect::core::CoreClientAuthMethod,
openidconnect::core::CoreClaimName,
openidconnect::core::CoreClaimType,
openidconnect::core::CoreGrantType,
openidconnect::core::CoreJweContentEncryptionAlgorithm,
openidconnect::core::CoreJweKeyManagementAlgorithm,
openidconnect::core::CoreJsonWebKey,
openidconnect::core::CoreResponseMode,
openidconnect::core::CoreResponseType,
openidconnect::core::CoreSubjectIdentifierType,
>;
let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client)
.await
.unwrap();
let Some(introspection_url) = provider_metadata
.additional_metadata()
.introspection_endpoint
.clone()
else {
panic!("OIDC service does not support token introspection");
};
let client =
CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
.set_introspection_url(IntrospectionUrl::from_url(introspection_url));
let token = AccessToken::new(token.to_string());
let introspection = client
.introspect(&token)
.request_async(&http_client)
.await
.unwrap();
println!("INTROSPECTION: {introspection:?}");
if !introspection.active() {
return None;
}
Some(UserId("Moritz".to_string()))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserId(String);