pib-service-api 0.11.0

pib-service edit API
Documentation
// SPDX-FileCopyrightText: Politik im Blick developers
// SPDX-FileCopyrightText: Wolfgang Silbermayr <wolfgang@silbermayr.at>
//
// SPDX-License-Identifier: AGPL-3.0-or-later OR EUPL-1.2

use std::{ops::Deref, pin::Pin, sync::Arc};

use axum::{body::HttpBody, extract::Request};
use futures_core::future::BoxFuture;
use http::{Response, StatusCode, header::AUTHORIZATION};
use log::{debug, warn};
use openidconnect::{
    AccessToken, AsyncHttpClient, ClientId, ClientSecret, HttpClientError, HttpRequest,
    HttpResponse, IntrospectionUrl, IssuerUrl, TokenIntrospectionResponse as _, core::CoreClient,
};
use pib_service_inventory::{InventoryProvider, NewUser, UpdateUser};
use serde::{Deserialize, Serialize};
use tower_http::auth::AsyncAuthorizeRequest;
use url::Url;

#[derive(Debug, Clone)]
pub struct OidcAuth {
    issuer_url: IssuerUrl,
    client_id: ClientId,
    client_secret: ClientSecret,
    inventory_provider: Arc<dyn InventoryProvider>,
}

impl OidcAuth {
    pub fn new(
        issuer_url: IssuerUrl,
        client_id: ClientId,
        client_secret: ClientSecret,
        inventory_provider: Arc<dyn InventoryProvider>,
    ) -> Self {
        Self {
            issuer_url,
            client_id,
            client_secret,
            inventory_provider,
        }
    }
}

impl<B: HttpBody + Send + 'static> AsyncAuthorizeRequest<B> for OidcAuth {
    type RequestBody = B;
    type ResponseBody = axum::body::Body;
    type Future =
        BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;

    fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
        let issuer_url = self.issuer_url.clone();
        let client_id = self.client_id.clone();
        let client_secret = self.client_secret.clone();

        let unauthorized_response = Response::builder()
            .status(StatusCode::UNAUTHORIZED)
            .body(axum::body::Body::empty())
            .unwrap();

        let Some(header) = request.headers().get(AUTHORIZATION).cloned() else {
            return Box::pin(async move { Err(unauthorized_response) });
        };
        let inventory_provider = self.inventory_provider.clone();

        Box::pin(async move {
            let mut inventory = inventory_provider.get_inventory().await.unwrap();

            let user = match check_auth(issuer_url, client_id, client_secret, header).await {
                Ok(Some(user)) => user,
                Ok(None) => return Err(unauthorized_response),
                Err(e) => {
                    return Err(Response::builder()
                        .status(StatusCode::INTERNAL_SERVER_ERROR)
                        .body(format!("Error while checking authentication: {e}").into())
                        .unwrap());
                }
            };

            let sub = user.sub.to_string();
            let display_name = user.display_name.clone();

            match inventory.get_user_by_sub(sub.clone()).await {
                Ok(user) if user.display_name.is_none() => {
                    if let Some(display_name) = display_name {
                        inventory
                            .update_user(
                                UpdateUser::new(user.id)
                                    .display_name(display_name)
                                    .sub(user.sub),
                            )
                            .await
                            .unwrap();
                    }
                }
                Ok(_user) => {}
                Err(e) if e.is_not_found() => {
                    inventory
                        .create_user(NewUser { sub, display_name })
                        .await
                        .unwrap();
                }
                Err(e) => {
                    warn!("Error reading user information from inventory: {e}");
                }
            }

            request.extensions_mut().insert(user);

            Ok(request)
        })
    }
}

async fn check_auth(
    issuer_url: IssuerUrl,
    client_id: ClientId,
    client_secret: ClientSecret,
    authorization_header: http::HeaderValue,
) -> anyhow::Result<Option<UserInfo>> {
    const HEADER_IDENTIFIER: &str = "bearer ";

    let authorization_header = authorization_header.to_str().unwrap();
    if !authorization_header
        .to_lowercase()
        .starts_with(HEADER_IDENTIFIER)
    {
        panic!("Authorization header must start with \"bearer\" (lower- or upper-case)");
    }
    let (_, token) = authorization_header.split_at(HEADER_IDENTIFIER.len());

    let http_client = ClientWrapper(
        reqwest::ClientBuilder::new()
            .connection_verbose(true)
            // Following redirects opens the client up to SSRF vulnerabilities.
            .redirect(reqwest::redirect::Policy::none())
            .build()?,
    );

    #[derive(Debug, Clone, Deserialize, Serialize)]
    struct AdditionalProviderMetadata {
        pub introspection_endpoint: Option<Url>,
    }

    impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {}

    type ProviderMetadata = openidconnect::ProviderMetadata<
        AdditionalProviderMetadata,
        openidconnect::core::CoreAuthDisplay,
        openidconnect::core::CoreClientAuthMethod,
        openidconnect::core::CoreClaimName,
        openidconnect::core::CoreClaimType,
        openidconnect::core::CoreGrantType,
        openidconnect::core::CoreJweContentEncryptionAlgorithm,
        openidconnect::core::CoreJweKeyManagementAlgorithm,
        openidconnect::core::CoreJsonWebKey,
        openidconnect::core::CoreResponseMode,
        openidconnect::core::CoreResponseType,
        openidconnect::core::CoreSubjectIdentifierType,
    >;

    let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client).await?;

    let Some(introspection_url) = provider_metadata
        .additional_metadata()
        .introspection_endpoint
        .clone()
    else {
        panic!("OIDC service does not support token introspection");
    };
    let client =
        CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
            .set_introspection_url(IntrospectionUrl::from_url(introspection_url.clone()));

    let token = AccessToken::new(token.to_string());

    debug!("Starting introspection of access token at URL {introspection_url}");

    let introspection = client
        .introspect(&token)
        .request_async(&http_client)
        .await?;

    debug!("Introspection response: {introspection:?}");

    if !introspection.active() {
        debug!("Access token is not active");
        return Ok(None);
    }

    let Some(iss) = introspection.iss() else {
        warn!("Received an access token without an issuer");
        return Ok(None);
    };

    let Some(sub) = introspection.sub() else {
        warn!("Received an access token without a subject");
        return Ok(None);
    };

    let display_name = introspection.username().map(|s| s.to_string());

    let issuer = Issuer(iss.to_string());
    let sub = UserId(sub.to_string());

    Ok(Some(UserInfo {
        issuer,
        sub,
        display_name,
    }))
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserId(String);

impl std::fmt::Display for UserId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Issuer(String);

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserInfo {
    pub issuer: Issuer,
    pub sub: UserId,
    pub display_name: Option<String>,
}

#[derive(Debug, Clone)]
struct ClientWrapper(reqwest::Client);

impl<'c> AsyncHttpClient<'c> for ClientWrapper {
    type Error = HttpClientError<reqwest::Error>;

    type Future =
        Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;

    fn call(&'c self, request: HttpRequest) -> Self::Future {
        Box::pin(async move {
            let response = self
                .0
                .execute(request.try_into().map_err(Box::new)?)
                .await
                .map_err(Box::new)?;

            let mut builder = http::Response::builder()
                .status(response.status())
                .version(response.version());

            for (name, value) in response.headers().iter() {
                builder = builder.header(name, value);
            }

            builder
                .body(response.bytes().await.map_err(Box::new)?.to_vec())
                .map_err(HttpClientError::Http)
        })
    }
}

impl From<reqwest::Client> for ClientWrapper {
    fn from(value: reqwest::Client) -> Self {
        Self(value)
    }
}

impl From<&reqwest::Client> for ClientWrapper {
    fn from(value: &reqwest::Client) -> Self {
        Self(value.clone())
    }
}

impl Deref for ClientWrapper {
    type Target = reqwest::Client;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}