sett 0.3.0

Rust port of sett (data compression, encryption and transfer tool).
Documentation
//! OIDC authentication and token management.

use anyhow::anyhow;
use openidconnect::core::CoreUserInfoClaims;
use openidconnect::OAuth2TokenResponse;
use serde::{Deserialize, Serialize};
use tracing::trace;

#[derive(Clone, Debug, Deserialize, Serialize)]
struct DeviceEndpointProviderMetadata {
    device_authorization_endpoint: openidconnect::DeviceAuthorizationUrl,
}

impl openidconnect::AdditionalProviderMetadata for DeviceEndpointProviderMetadata {}

type DeviceProviderMetadata = openidconnect::ProviderMetadata<
    DeviceEndpointProviderMetadata,
    openidconnect::core::CoreAuthDisplay,
    openidconnect::core::CoreClientAuthMethod,
    openidconnect::core::CoreClaimName,
    openidconnect::core::CoreClaimType,
    openidconnect::core::CoreGrantType,
    openidconnect::core::CoreJweContentEncryptionAlgorithm,
    openidconnect::core::CoreJweKeyManagementAlgorithm,
    openidconnect::core::CoreJwsSigningAlgorithm,
    openidconnect::core::CoreJsonWebKeyType,
    openidconnect::core::CoreJsonWebKeyUse,
    openidconnect::core::CoreJsonWebKey,
    openidconnect::core::CoreResponseMode,
    openidconnect::core::CoreResponseType,
    openidconnect::core::CoreSubjectIdentifierType,
>;

/// OAuth2 access token
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AccessToken(pub(crate) String);

/// OAuth2 refresh token
#[derive(Clone, Serialize, Deserialize)]
pub struct RefreshToken(String);

/// User info
#[derive(Clone, Serialize, Deserialize)]
pub struct UserInfo {
    /// Public username, e.g. "johndoe"
    pub username: Option<String>,
    /// Given name, e.g. "John"
    pub given_name: Option<String>,
    /// Family name, e.g. "Doe"
    pub family_name: Option<String>,
    /// Email address
    pub email: Option<String>,
}

/// OAuth2 tokens.
#[derive(Clone)]
pub struct Token {
    /// Access token.
    pub access_token: AccessToken,
    /// Refresh token.
    pub refresh_token: Option<RefreshToken>,
    /// Access token lifespan.
    lifespan: Option<std::time::Duration>,
    /// Token creation time.
    created_at: std::time::Instant,
}

impl Token {
    /// Remaining time the access token is alive.
    ///
    /// Please note that this time is an approximation, which doesn't take into account network
    /// latency or other factors that add a delay between generating a token (by the identity
    /// provider) and its reception by the application.
    ///
    /// Returns `None` if the token lifespan is unknown (not provided by the identity provider).
    pub fn expires_in(&self) -> Option<std::time::Duration> {
        self.lifespan
            .map(|t| t - (std::time::Instant::now() - self.created_at))
    }
}

/// Openid Connect client.
///
/// This client can be used to authenticate with an OpenID Connect provider using device
/// authorization and refresh tokens.
pub struct Oidc {
    client_id: openidconnect::ClientId,
    issuer_url: openidconnect::IssuerUrl,
}

/// Details for the authentication handler.
#[derive(Clone, Serialize)]
pub struct VerificationDetails {
    /// User Code.
    pub user_code: String,
    /// Verification URI.
    pub verification_uri: String,
    /// Verification URI with User Code as query parameter.
    pub verification_uri_complete: Option<String>,
}

type AuthHandler = dyn FnOnce(&VerificationDetails) -> anyhow::Result<()> + Send;

impl Oidc {
    /// Create a new OpenID Connect client.
    pub fn new(client_id: &str, issuer_url: &str) -> anyhow::Result<Self> {
        Ok(Self {
            client_id: openidconnect::ClientId::new(client_id.into()),
            issuer_url: openidconnect::IssuerUrl::new(issuer_url.into())?,
        })
    }

    async fn client(&self) -> anyhow::Result<openidconnect::core::CoreClient> {
        let provider_metadata = openidconnect::core::CoreProviderMetadata::discover_async(
            self.issuer_url.clone(),
            http_client,
        )
        .await?;
        trace!(?provider_metadata, "provider metadata");
        Ok(openidconnect::core::CoreClient::from_provider_metadata(
            provider_metadata,
            self.client_id.clone(),
            None,
        ))
    }

    /// Get user info from the OpenID Connect provider.
    pub async fn user_info(&self, access_token: &AccessToken) -> anyhow::Result<UserInfo> {
        let client = self.client().await?;
        let userinfo: CoreUserInfoClaims = client
            .user_info(
                openidconnect::AccessToken::new(access_token.0.to_string()),
                None,
            )
            .map_err(|err| anyhow!("No user info endpoint: {:?}", err))?
            .request_async(http_client)
            .await
            .map_err(|err| anyhow!("Failed requesting user info: {:?}", err))?;
        trace!(?userinfo, "user info");
        macro_rules! localized_claim {
            ($input:expr) => {
                $input
                    .and_then(|claim| claim.get(None))
                    .map(|value| value.to_string())
            };
        }
        Ok(UserInfo {
            given_name: localized_claim!(userinfo.given_name()),
            family_name: localized_claim!(userinfo.family_name()),
            username: userinfo
                .preferred_username()
                .map(|username| username.to_string()),
            email: userinfo.email().map(|email| email.to_string()),
        })
    }

    /// Authenticate with the OpenID Connect provider using device authorization.
    ///
    /// The `auth_handler` is a callback for returning the verification URL that the user should
    /// open in their browser to authenticate.
    pub async fn authenticate(&self, auth_handler: Box<AuthHandler>) -> anyhow::Result<Token> {
        let provider_metadata =
            DeviceProviderMetadata::discover_async(self.issuer_url.clone(), http_client).await?;
        trace!(?provider_metadata, "provider metadata");
        let device_authorization_uri = provider_metadata
            .additional_metadata()
            .device_authorization_endpoint
            .clone();
        let client = openidconnect::core::CoreClient::from_provider_metadata(
            provider_metadata,
            self.client_id.clone(),
            None,
        )
        .set_device_authorization_uri(device_authorization_uri)
        .set_auth_type(openidconnect::AuthType::RequestBody);
        let (pkce_challenge, pkce_verifier) = openidconnect::PkceCodeChallenge::new_random_sha256();

        let details: openidconnect::core::CoreDeviceAuthorizationResponse = client
            .exchange_device_code()?
            .add_extra_param("code_challenge", pkce_challenge.as_str())
            .add_extra_param("code_challenge_method", pkce_challenge.method().as_str())
            .request_async(http_client)
            .await?;
        auth_handler(&VerificationDetails {
            user_code: details.user_code().secret().to_string(),
            verification_uri: details.verification_uri().to_string(),
            verification_uri_complete: details
                .verification_uri_complete()
                .map(|uri| uri.secret().to_string()),
        })?;
        let token = client
            .exchange_device_access_token(&details)
            .add_extra_param("code_verifier", pkce_verifier.secret())
            .request_async(http_client, tokio::time::sleep, None)
            .await?;
        Ok(Token {
            access_token: AccessToken(token.access_token().secret().to_string()),
            refresh_token: token
                .refresh_token()
                .map(|t| RefreshToken(t.secret().to_string())),
            lifespan: token.expires_in(),
            created_at: std::time::Instant::now(),
        })
    }

    /// Refresh the access token using a refresh token.
    pub async fn refresh(&self, refresh_token: &RefreshToken) -> anyhow::Result<Token> {
        let client = self.client().await?;
        let token = client
            .exchange_refresh_token(&openidconnect::RefreshToken::new(refresh_token.0.clone()))
            .request_async(http_client)
            .await?;
        Ok(Token {
            access_token: AccessToken(token.access_token().secret().to_string()),
            refresh_token: token
                .refresh_token()
                .map(|t| RefreshToken(t.secret().to_string())),
            lifespan: token.expires_in(),
            created_at: std::time::Instant::now(),
        })
    }
}

/// HTTP client for OpenID Connect requests.
async fn http_client(
    request: openidconnect::HttpRequest,
) -> Result<openidconnect::HttpResponse, reqwest::Error> {
    let client = reqwest::Client::builder()
        .use_rustls_tls()
        // Prevent SSRF (Server-side request forgery) vulnerabilities.
        .redirect(reqwest::redirect::Policy::none())
        .build()?;
    let mut request_builder = client
        .request(
            reqwest::Method::from_bytes(request.method.as_str().as_bytes()).expect("valid method"),
            request.url.as_str(),
        )
        .body(request.body);
    for (name, value) in &request.headers {
        request_builder = request_builder.header(name.as_str(), value.as_bytes());
    }
    let response = client.execute(request_builder.build()?).await?;
    let status_code = openidconnect::http::StatusCode::from_u16(response.status().as_u16())
        .expect("valid status code");
    let headers = response
        .headers()
        .iter()
        .map(|(key, value)| {
            (
                key.as_str()
                    .as_bytes()
                    .try_into()
                    .expect("valid header name"),
                value.as_bytes().try_into().expect("valid header value"),
            )
        })
        .collect();
    Ok(openidconnect::HttpResponse {
        status_code,
        headers,
        body: response.bytes().await?.to_vec(),
    })
}