cipherstash-client 0.34.1-alpha.3

The official CipherStash SDK
Documentation
use serde::Deserialize;
use serde_json::json;
use std::time::Duration;
use tracing::{debug, trace};
use url::Url;

use crate::credentials::user_credentials::{
    prompt_user, AccessTokenResponse, NewTokenError, PollingInfo, DEFAULT_REQUESTED_SCOPES,
};

use super::{RefreshTokenError, UserToken};

#[derive(Debug)]
pub struct Auth0UserCredentials {
    idp_base_url: Url,
    idp_audience: String,
    idp_client_id: String,
}

impl Auth0UserCredentials {
    pub fn new(idp_base_url: &Url, idp_audience: &str, idp_client_id: &str) -> Self {
        Self {
            idp_base_url: idp_base_url.to_owned(),
            idp_audience: idp_audience.to_string(),
            idp_client_id: idp_client_id.to_string(),
        }
    }

    /// Poll the IDP until it returns an access token or fails
    async fn poll_access_token(
        &self,
        client: &reqwest::Client,
        polling_info: &PollingInfo,
    ) -> Result<UserToken, NewTokenError> {
        debug!(target: "console_credentials", "Logging in - polling for access token");
        let mut interval = Duration::from_secs(5);
        let url = self.idp_base_url.join("oauth/token").expect("Invalid url");

        loop {
            let response = client
                .post(url.to_owned())
                .json(&json!({
                    "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
                    "device_code": polling_info.device_code,
                    "client_id": self.idp_client_id
                }))
                .send()
                .await
                .map_err(NewTokenError::PollTokenRequestFailed)?;

            match response.status().as_u16() {
                200 => {
                    let token: AccessTokenResponse = response.json().await.map_err(|e| {
                        trace!("Failed to parse token access response: {e:?}");
                        NewTokenError::PollTokenBadResponse(e)
                    })?;

                    debug!(target: "console_credentials",
                        "Access Token Acquired - expires_in(s): {}",
                        &token.expires_in
                    );
                    return Ok(token.into());
                }
                403 => {
                    #[derive(Deserialize, Debug)]
                    #[serde(rename_all = "snake_case")]
                    enum AuthError {
                        AuthorizationPending,
                        SlowDown,
                        InvalidGrant,
                        AccessDenied,
                        ExpiredToken,
                    }

                    #[derive(Deserialize)]
                    struct PendingResponse {
                        error: AuthError,
                        error_description: Option<String>,
                    }

                    let PendingResponse {
                        error,
                        error_description,
                    } = response
                        .json()
                        .await
                        .map_err(NewTokenError::PollTokenBadPendingResponse)?;

                    match error {
                        AuthError::AuthorizationPending => {}
                        AuthError::SlowDown => {
                            interval += Duration::from_secs(5);
                        }
                        _ => {
                            let reason = error_description.unwrap_or(format!("{error:?}"));
                            return Err(NewTokenError::PollTokenAuthFailed(reason));
                        }
                    }

                    crate::sleep::sleep(interval).await
                }
                code => {
                    return Err(NewTokenError::PollTokenUnexpected(format!(
                        "Unexpected response code: {code}"
                    )));
                }
            }
        }
    }

    pub async fn acquire_new_token(&self) -> Result<UserToken, NewTokenError> {
        debug!(target: "console_credentials", "Logging in...");
        let client = reqwest::Client::new();
        let url = self
            .idp_base_url
            .join("oauth/device/code")
            .expect("Invalid url");

        let info_response = client
            .post(url)
            .json(&json!({
                "audience": &self.idp_audience,
                "client_id": &self.idp_client_id,
                "scope": DEFAULT_REQUESTED_SCOPES
            }))
            .send()
            .await
            .map_err(NewTokenError::DeviceCodeRequestFailed)?
            .error_for_status()
            .map_err(NewTokenError::DeviceCodeRequestFailed)?;

        let polling_info: PollingInfo = info_response
            .json()
            .await
            .map_err(NewTokenError::DeviceCodeBadResponse)?;

        prompt_user(&polling_info);

        self.poll_access_token(&client, &polling_info).await
    }

    pub async fn refresh_access_token(
        &self,
        cached_token: &UserToken,
    ) -> Result<Option<UserToken>, RefreshTokenError> {
        debug!(target: "console_credentials", "Refreshing Access Token...");
        let client = reqwest::Client::new();
        let url = self.idp_base_url.join("oauth/token").expect("Invalid url");

        let response = client
            .post(url)
            .json(&json!({
                "grant_type": "refresh_token",
                "refresh_token": cached_token.refresh_token,
                "client_id": self.idp_client_id,
                "scope": DEFAULT_REQUESTED_SCOPES
            }))
            .send()
            .await
            .map_err(RefreshTokenError::RequestFailed)?;

        if let Ok(r) = response.error_for_status() {
            let response: AccessTokenResponse =
                r.json().await.map_err(RefreshTokenError::BadResponse)?;
            debug!(target: "console_credentials",
                "Access Token Acquired - expires_in(s): {}",
                &response.expires_in
            );

            Ok(Some(response.into()))
        } else {
            Ok(None)
        }
    }
}