sett 0.4.0

Rust port of sett (data compression, encryption and transfer tool).
Documentation
//! OIDC authentication and token management.

pub mod error;

use openidconnect::OAuth2TokenResponse;
use openidconnect::{AsyncHttpClient, core::CoreUserInfoClaims};
use serde::{Deserialize, Serialize};
use tracing::trace;

#[derive(Clone, Debug, Deserialize, Serialize)]
struct DeviceEndpointProviderMetadata {
    device_authorization_endpoint: openidconnect::DeviceAuthorizationUrl,
}

impl openidconnect::AdditionalProviderMetadata for DeviceEndpointProviderMetadata {}

type DeviceProviderMetadata = openidconnect::ProviderMetadata<
    DeviceEndpointProviderMetadata,
    openidconnect::core::CoreAuthDisplay,
    openidconnect::core::CoreClientAuthMethod,
    openidconnect::core::CoreClaimName,
    openidconnect::core::CoreClaimType,
    openidconnect::core::CoreGrantType,
    openidconnect::core::CoreJweContentEncryptionAlgorithm,
    openidconnect::core::CoreJweKeyManagementAlgorithm,
    openidconnect::core::CoreJsonWebKey,
    openidconnect::core::CoreResponseMode,
    openidconnect::core::CoreResponseType,
    openidconnect::core::CoreSubjectIdentifierType,
>;

type OidcClient = openidconnect::core::CoreClient<
    openidconnect::EndpointSet,
    openidconnect::EndpointNotSet,
    openidconnect::EndpointNotSet,
    openidconnect::EndpointNotSet,
    openidconnect::EndpointMaybeSet,
    openidconnect::EndpointMaybeSet,
>;

/// OAuth2 access token
#[derive(Clone, Serialize, Deserialize)]
pub struct AccessToken(pub(crate) String);

impl std::fmt::Debug for AccessToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("AccessToken").field(&"********").finish()
    }
}

impl AsRef<str> for AccessToken {
    fn as_ref(&self) -> &str {
        self.0.as_str()
    }
}

/// OAuth2 refresh token
#[derive(Clone, Serialize, Deserialize)]
pub struct RefreshToken(String);

impl std::fmt::Debug for RefreshToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("RefreshToken").field(&"********").finish()
    }
}

/// User info
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
    /// Public username, e.g. "johndoe"
    pub username: Option<String>,
    /// Given name, e.g. "John"
    pub given_name: Option<String>,
    /// Family name, e.g. "Doe"
    pub family_name: Option<String>,
    /// Email address
    pub email: Option<String>,
}

/// OAuth2 tokens.
#[derive(Debug, Clone)]
pub struct Token {
    /// Access token.
    pub access_token: AccessToken,
    /// Refresh token.
    pub refresh_token: Option<RefreshToken>,
    /// Access token lifespan.
    lifespan: Option<std::time::Duration>,
    /// Token creation time.
    created_at: std::time::Instant,
}

impl Token {
    /// Remaining time the access token is alive.
    ///
    /// Please note that this time is an approximation, which doesn't take into account network
    /// latency or other factors that add a delay between generating a token (by the identity
    /// provider) and its reception by the application.
    ///
    /// Returns `None` if the token lifespan is unknown (not provided by the identity provider).
    pub fn expires_in(&self) -> Option<std::time::Duration> {
        self.lifespan
            .map(|t| t - (std::time::Instant::now() - self.created_at))
    }
}

/// Openid Connect client.
///
/// This client can be used to authenticate with an OpenID Connect provider using device
/// authorization and refresh tokens.
#[derive(Debug, Clone)]
pub struct Oidc {
    http_client: HttpClient,
    oidc_client: OidcClient,
    provider_metadata: DeviceProviderMetadata,
}

/// Issuer identifier.
///
/// URL using the https scheme that contains scheme, host, and optionally,
/// port number and path components and no query or fragment components.
#[derive(Debug, Clone)]
pub struct IssuerUrl {
    inner: openidconnect::IssuerUrl,
}

impl std::str::FromStr for IssuerUrl {
    type Err = error::UrlParseError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(Self {
            inner: openidconnect::IssuerUrl::from_url(
                s.parse()
                    .map_err(|e| error::UrlParseError(format!("{e}")))?,
            ),
        })
    }
}

/// Details for the authentication handler.
#[derive(Debug, Clone, Serialize)]
pub struct VerificationDetails {
    /// User Code.
    pub user_code: String,
    /// Verification URI.
    pub verification_uri: String,
    /// Verification URI with User Code as query parameter.
    pub verification_uri_complete: Option<String>,
}

impl Oidc {
    /// Create a new OpenID Connect client.
    pub async fn new(client_id: &str, issuer_url: IssuerUrl) -> Result<Self, error::ConfigError> {
        let http_client = HttpClient::new()?;
        let provider_metadata =
            DeviceProviderMetadata::discover_async(issuer_url.inner, &http_client).await?;
        trace!(?provider_metadata, "provider metadata");
        let oidc_client = openidconnect::core::CoreClient::from_provider_metadata(
            provider_metadata.clone(),
            openidconnect::ClientId::new(client_id.into()),
            None,
        );
        Ok(Self {
            http_client,
            oidc_client,
            provider_metadata,
        })
    }

    /// Get user info from the OpenID Connect provider.
    pub async fn user_info(
        &self,
        access_token: &AccessToken,
    ) -> Result<UserInfo, error::UserInfoError> {
        let userinfo: CoreUserInfoClaims = self
            .oidc_client
            .user_info(
                openidconnect::AccessToken::new(access_token.0.to_string()),
                None,
            )?
            .request_async(&self.http_client)
            .await?;
        trace!(?userinfo, "user info");

        fn localized_claim<T: core::ops::Deref<Target = String>>(
            name: Option<&openidconnect::LocalizedClaim<T>>,
        ) -> Option<String> {
            name.and_then(|claim| claim.get(None))
                .map(|value| value.deref().clone())
        }
        Ok(UserInfo {
            given_name: localized_claim(userinfo.given_name()),
            family_name: localized_claim(userinfo.family_name()),
            username: userinfo
                .preferred_username()
                .map(|username| username.to_string()),
            email: userinfo.email().map(|email| email.to_string()),
        })
    }

    /// Authenticate with the OpenID Connect provider using device authorization.
    ///
    /// The `auth_handler` is a callback for returning the verification URL that the user should
    /// open in their browser to authenticate.
    pub async fn authenticate<E>(
        &self,
        auth_handler: impl FnOnce(VerificationDetails) -> Result<(), E>,
    ) -> Result<Token, error::AuthError<E>> {
        let client = self
            .oidc_client
            .clone()
            .set_device_authorization_url(
                self.provider_metadata
                    .additional_metadata()
                    .device_authorization_endpoint
                    .clone(),
            )
            .set_auth_type(openidconnect::AuthType::RequestBody);
        let (pkce_challenge, pkce_verifier) = openidconnect::PkceCodeChallenge::new_random_sha256();

        let details: openidconnect::core::CoreDeviceAuthorizationResponse = client
            .exchange_device_code()
            .add_extra_param("code_challenge", pkce_challenge.as_str())
            .add_extra_param("code_challenge_method", pkce_challenge.method().as_str())
            .request_async(&self.http_client)
            .await?;
        auth_handler(VerificationDetails {
            user_code: details.user_code().secret().to_string(),
            verification_uri: details.verification_uri().to_string(),
            verification_uri_complete: details
                .verification_uri_complete()
                .map(|uri| uri.secret().to_string()),
        })
        .map_err(error::AuthError::Handler)?;
        let token = client
            .exchange_device_access_token(&details)?
            .add_extra_param("code_verifier", pkce_verifier.secret())
            .request_async(&self.http_client, tokio::time::sleep, None)
            .await?;
        Ok(Token {
            access_token: AccessToken(token.access_token().secret().to_string()),
            refresh_token: token
                .refresh_token()
                .map(|t| RefreshToken(t.secret().to_string())),
            lifespan: token.expires_in(),
            created_at: std::time::Instant::now(),
        })
    }

    /// Refresh the access token using a refresh token.
    pub async fn refresh(
        &self,
        refresh_token: &RefreshToken,
    ) -> Result<Token, error::RefreshTokenError> {
        let token = self
            .oidc_client
            .exchange_refresh_token(&openidconnect::RefreshToken::new(refresh_token.0.clone()))?
            .request_async(&self.http_client)
            .await?;
        Ok(Token {
            access_token: AccessToken(token.access_token().secret().to_string()),
            refresh_token: token
                .refresh_token()
                .map(|t| RefreshToken(t.secret().to_string())),
            lifespan: token.expires_in(),
            created_at: std::time::Instant::now(),
        })
    }
}

#[derive(Debug, Clone)]
struct HttpClient {
    client: reqwest::Client,
}

impl HttpClient {
    fn new() -> Result<Self, reqwest::Error> {
        let client = reqwest::Client::builder()
            .use_rustls_tls()
            // Prevent SSRF (Server-side request forgery) vulnerabilities.
            .redirect(reqwest::redirect::Policy::none())
            .build()?;
        Ok(Self { client })
    }
}

impl<'c> AsyncHttpClient<'c> for HttpClient {
    type Error = reqwest::Error;
    type Future = std::pin::Pin<
        Box<
            dyn std::future::Future<Output = Result<openidconnect::HttpResponse, Self::Error>>
                + 'c
                + Send,
        >,
    >;

    fn call(&'c self, request: openidconnect::HttpRequest) -> Self::Future {
        Box::pin(async move {
            let original_response = self.client.execute(request.try_into()?).await?;
            let version = original_response.version();
            let status = original_response.status();
            let headers = original_response.headers().clone();
            let mut response =
                openidconnect::HttpResponse::new(original_response.bytes().await?.to_vec());
            *response.version_mut() = version;
            *response.status_mut() = status;
            *response.headers_mut() = headers;
            Ok(response)
        })
    }
}