use std::{ops::Deref, pin::Pin, sync::Arc};
use axum::{body::HttpBody, extract::Request};
use futures_core::future::BoxFuture;
use http::{Response, StatusCode, header::AUTHORIZATION};
use log::{debug, warn};
use openidconnect::{
AccessToken, AsyncHttpClient, ClientId, ClientSecret, HttpClientError, HttpRequest,
HttpResponse, IntrospectionUrl, IssuerUrl, TokenIntrospectionResponse as _, core::CoreClient,
};
use pib_service_inventory::{InventoryProvider, NewUser, UpdateUser};
use serde::{Deserialize, Serialize};
use tower_http::auth::AsyncAuthorizeRequest;
use url::Url;
#[derive(Debug, Clone)]
pub struct OidcAuth {
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
inventory_provider: Arc<dyn InventoryProvider>,
}
impl OidcAuth {
pub fn new(
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
inventory_provider: Arc<dyn InventoryProvider>,
) -> Self {
Self {
issuer_url,
client_id,
client_secret,
inventory_provider,
}
}
}
impl<B: HttpBody + Send + 'static> AsyncAuthorizeRequest<B> for OidcAuth {
type RequestBody = B;
type ResponseBody = axum::body::Body;
type Future =
BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let issuer_url = self.issuer_url.clone();
let client_id = self.client_id.clone();
let client_secret = self.client_secret.clone();
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
let Some(header) = request.headers().get(AUTHORIZATION).cloned() else {
return Box::pin(async move { Err(unauthorized_response) });
};
let inventory_provider = self.inventory_provider.clone();
Box::pin(async move {
let mut inventory = inventory_provider.get_inventory().await.unwrap();
let user = match check_auth(issuer_url, client_id, client_secret, header).await {
Ok(Some(user)) => user,
Ok(None) => return Err(unauthorized_response),
Err(e) => {
return Err(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(format!("Error while checking authentication: {e}").into())
.unwrap());
}
};
let sub = user.sub.to_string();
let display_name = user.display_name.clone();
match inventory.get_user_by_sub(sub.clone()).await {
Ok(user) if user.display_name.is_none() => {
if let Some(display_name) = display_name {
inventory
.update_user(
UpdateUser::new(user.id)
.display_name(display_name)
.sub(user.sub),
)
.await
.unwrap();
}
}
Ok(_user) => {}
Err(e) if e.is_not_found() => {
inventory
.create_user(NewUser { sub, display_name })
.await
.unwrap();
}
Err(e) => {
warn!("Error reading user information from inventory: {e}");
}
}
request.extensions_mut().insert(user);
Ok(request)
})
}
}
async fn check_auth(
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
authorization_header: http::HeaderValue,
) -> anyhow::Result<Option<UserInfo>> {
const HEADER_IDENTIFIER: &str = "bearer ";
let authorization_header = authorization_header.to_str().unwrap();
if !authorization_header
.to_lowercase()
.starts_with(HEADER_IDENTIFIER)
{
panic!("Authorization header must start with \"bearer\" (lower- or upper-case)");
}
let (_, token) = authorization_header.split_at(HEADER_IDENTIFIER.len());
let http_client = ClientWrapper(
reqwest::ClientBuilder::new()
.connection_verbose(true)
.redirect(reqwest::redirect::Policy::none())
.build()?,
);
#[derive(Debug, Clone, Deserialize, Serialize)]
struct AdditionalProviderMetadata {
pub introspection_endpoint: Option<Url>,
}
impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {}
type ProviderMetadata = openidconnect::ProviderMetadata<
AdditionalProviderMetadata,
openidconnect::core::CoreAuthDisplay,
openidconnect::core::CoreClientAuthMethod,
openidconnect::core::CoreClaimName,
openidconnect::core::CoreClaimType,
openidconnect::core::CoreGrantType,
openidconnect::core::CoreJweContentEncryptionAlgorithm,
openidconnect::core::CoreJweKeyManagementAlgorithm,
openidconnect::core::CoreJsonWebKey,
openidconnect::core::CoreResponseMode,
openidconnect::core::CoreResponseType,
openidconnect::core::CoreSubjectIdentifierType,
>;
let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client).await?;
let Some(introspection_url) = provider_metadata
.additional_metadata()
.introspection_endpoint
.clone()
else {
panic!("OIDC service does not support token introspection");
};
let client =
CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
.set_introspection_url(IntrospectionUrl::from_url(introspection_url.clone()));
let token = AccessToken::new(token.to_string());
debug!("Starting introspection of access token at URL {introspection_url}");
let introspection = client
.introspect(&token)
.request_async(&http_client)
.await?;
debug!("Introspection response: {introspection:?}");
if !introspection.active() {
debug!("Access token is not active");
return Ok(None);
}
let Some(iss) = introspection.iss() else {
warn!("Received an access token without an issuer");
return Ok(None);
};
let Some(sub) = introspection.sub() else {
warn!("Received an access token without a subject");
return Ok(None);
};
let display_name = introspection.username().map(|s| s.to_string());
let issuer = Issuer(iss.to_string());
let sub = UserId(sub.to_string());
Ok(Some(UserInfo {
issuer,
sub,
display_name,
}))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserId(String);
impl std::fmt::Display for UserId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Issuer(String);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserInfo {
pub issuer: Issuer,
pub sub: UserId,
pub display_name: Option<String>,
}
#[derive(Debug, Clone)]
struct ClientWrapper(reqwest::Client);
impl<'c> AsyncHttpClient<'c> for ClientWrapper {
type Error = HttpClientError<reqwest::Error>;
type Future =
Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self
.0
.execute(request.try_into().map_err(Box::new)?)
.await
.map_err(Box::new)?;
let mut builder = http::Response::builder()
.status(response.status())
.version(response.version());
for (name, value) in response.headers().iter() {
builder = builder.header(name, value);
}
builder
.body(response.bytes().await.map_err(Box::new)?.to_vec())
.map_err(HttpClientError::Http)
})
}
}
impl From<reqwest::Client> for ClientWrapper {
fn from(value: reqwest::Client) -> Self {
Self(value)
}
}
impl From<&reqwest::Client> for ClientWrapper {
fn from(value: &reqwest::Client) -> Self {
Self(value.clone())
}
}
impl Deref for ClientWrapper {
type Target = reqwest::Client;
fn deref(&self) -> &Self::Target {
&self.0
}
}