aws_sso_flow 0.3.1

AWS SSO authentication flow
Documentation
use std::{convert::Infallible, fmt, path::PathBuf};

use aws_config::SdkConfig;
use url::Url;

use crate::{
    cache::{self, Cache},
    sso::{self, GetRoleCredentialsRequest},
    sso_oidc::{self, CreateTokenError, CreateTokenRequest, RegisterClientRequest},
    SessionCredentials, SsoConfig, SsoFlowBuilder, SsoProfileError, CLIENT_NAME,
};

/// A configured AWS SSO authentication flow.
///
/// A default flow can be constructed with [`new`](Self::new).
#[allow(clippy::module_name_repetitions)]
pub struct SsoFlow<V> {
    cache: Cache,
    sso_oidc_client: sso_oidc::Client,
    sso_client: sso::Client,
    config: SsoConfig,
    verification_prompt: V,
}

impl SsoFlow<Infallible> {
    /// Construct a builder for an SSO flow.
    ///
    /// Use this to override the default flow configuration.
    #[must_use]
    pub fn builder() -> SsoFlowBuilder {
        SsoFlowBuilder::default()
    }
}

impl<V> SsoFlow<V>
where
    V: VerificationPrompt,
{
    /// Perform a default AWS SSO authentication flow with the given `verification_prompt`.
    ///
    /// The prompt should direct the user to the given URL, where they will be asked to grant
    /// access. Intermediate tokens are cached, and if the cache is still valid the user may not
    /// need to be prompted. Attempts to obtain credentials will timeout if the user hasn't followed the
    /// verification URL and granted access before the user code expires (default 10 mins).
    ///
    /// SSO configuration is sourced from AWS shared config (located with `AWS_CONFIG_FILE` and
    /// `AWS_PROFILE`) and intermediate tokens are cached in the user's OS cache directory in
    /// `aws_sso_flow/0.1/*`. The cache format is considered part of the crate's API.
    ///
    /// For more flexible configuration see [`SsoFlowBuilder`].
    ///
    /// # Example
    ///
    /// A simple prompt implementation could print the URL and ask the user to visit it:
    ///
    /// ```no_run
    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
    /// use std::convert::Infallible;
    ///
    /// let flow = aws_sso_flow::SsoFlow::new(|url| async move {
    ///     println!("Go to {url} to sign in with SSO");
    ///     Ok::<_, Infallible>(())
    /// }).await?;
    ///
    /// let credentials = flow.authenticate().await?;
    /// # Ok(()) }
    /// ```
    ///
    /// # Errors
    ///
    /// An error is returned if a profile cannot be loaded, or if it is missing SSO configuration.
    pub async fn new(verification_prompt: V) -> Result<Self, SsoProfileError> {
        SsoFlowBuilder::default()
            .verification_prompt(verification_prompt)
            .build()
            .await
    }

    pub(crate) fn build(
        cache_dir: Option<PathBuf>,
        config: SsoConfig,
        verification_prompt: V,
    ) -> Self {
        let sdk_config = SdkConfig::builder().region(config.region.0.clone()).build();

        Self {
            cache: Cache::new(cache_dir, &config),
            sso_oidc_client: sso_oidc::Client::new(&sdk_config),
            sso_client: sso::Client::new(&sdk_config),
            config,
            verification_prompt,
        }
    }

    /// Perform the AWS SSO authentication flow and obtain credentials.
    ///
    /// # Errors
    ///
    /// An errors is returned if the authentication flow cannot complete. See [`SsoFlowError`] for
    /// details of possible errors.
    pub async fn authenticate(&self) -> Result<SessionCredentials, SsoFlowError<V::Error>> {
        let client = self
            .cache
            .get_or_init("client", || {
                self.sso_oidc_client.register_client(RegisterClientRequest {
                    client_name: CLIENT_NAME.to_string(),
                })
            })
            .await
            .map_err(|error| match error {
                cache::Error::Init(error) => SsoFlowError::Api(SsoApiError(error)),
                cache::Error::Cache(error) => SsoFlowError::Cache(SsoCacheError(error)),
            })?;

        let token = self
            .cache
            .get_or_init("token", || {
                self.sso_oidc_client.create_token(
                    CreateTokenRequest {
                        client_id: client.client_id,
                        client_secret: client.client_secret,
                        start_url: self.config.start_url.clone(),
                    },
                    self.verification_prompt.clone(),
                )
            })
            .await
            .map_err(|error| match error {
                cache::Error::Init(CreateTokenError::Api(error)) => {
                    SsoFlowError::Api(SsoApiError(error))
                }
                cache::Error::Init(CreateTokenError::VerificationPrompt(error)) => {
                    SsoFlowError::VerificationPrompt(error)
                }
                cache::Error::Init(CreateTokenError::VerificationPromptTimeout) => {
                    SsoFlowError::VerificationPromptTimeout
                }
                cache::Error::Cache(error) => SsoFlowError::Cache(SsoCacheError(error)),
            })?;

        let credentials = self
            .cache
            .get_or_init("credentials", || {
                self.sso_client
                    .get_role_credentials(GetRoleCredentialsRequest {
                        access_token: token.access_token,
                        account_id: self.config.account_id.clone(),
                        role_name: self.config.role_name.clone(),
                    })
            })
            .await
            .map_err(|error| match error {
                cache::Error::Init(error) => SsoFlowError::Api(SsoApiError(error)),
                cache::Error::Cache(error) => SsoFlowError::Cache(SsoCacheError(error)),
            })?;

        Ok(credentials.into())
    }
}

impl<V> fmt::Debug for SsoFlow<V> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("SsoFlow")
            .field("cache", &self.cache)
            .field("sso_oidc_client", &self.sso_oidc_client)
            .field("sso_client", &self.sso_client)
            .field("config", &self.config)
            .field("verification_prompt", &"_")
            .finish()
    }
}

/// An SSO verification prompt.
///
/// The AWS SSO authentication flow requires users to explicitly grant access by visiting a URL and
/// clicking a button. There are many ways this could be implemented depending on the context, so
/// verification prompts are modelled with this trait.
///
/// The trait is implemented for async functions with a single `Url` argument and returning
/// `Result<(), E>`, so a trivial prompt could look like:
///
/// ```
/// use std::convert::Infallible;
///
/// use aws_sso_flow::VerificationPrompt;
///
/// fn prompt() -> impl VerificationPrompt {
///     |verification_url| async move {
///         println!("Go to {verification_url} to grant access");
///         Ok::<_, Infallible>(())
///     }
/// }
/// ```
///
/// The `Error` associated type can be used if the prompt is fallible. Type information is preserved
/// in the event of any subsequent [`SsoFlowError`].
pub trait VerificationPrompt: Clone + Send + Sync {
    /// The future returned by the prompt.
    type Future: std::future::Future<Output = Result<(), Self::Error>> + Send;

    /// An error that could occur when attempting to prompt.
    type Error: std::error::Error + Send + Sync + 'static;

    /// Prompt the user to grant access via the given `verification_url`.
    fn prompt(self, verification_url: Url) -> Self::Future;
}

impl<F, Fut, E> VerificationPrompt for F
where
    F: FnOnce(Url) -> Fut + Clone + Send + Sync,
    Fut: std::future::Future<Output = Result<(), E>> + Send,
    E: std::error::Error + Send + Sync + 'static,
{
    type Future = Fut;
    type Error = E;

    fn prompt(self, verification_url: Url) -> Fut {
        self(verification_url)
    }
}

/// An error that occurred during the SSO authentication flow.
#[derive(Debug)]
pub enum SsoFlowError<P: std::error::Error + Send + Sync + 'static> {
    /// Indicates that an AWS API call returned an error.
    ///
    /// This could be due to invalid configuration caught by the server, or a network issue. The
    /// error message should be sufficient to aid end-user debugging.
    Api(SsoApiError),

    /// Indicates an issue with the token cache(s).
    ///
    /// This could be due to insufficient permissions, corrupt data, or an usual OS configuration.
    /// The error message should be sufficient to aid end-user debugging.
    Cache(SsoCacheError),

    /// Indicates that an error occurred during the verification prompt.
    ///
    /// See [`VerificationPrompt`] for more information.
    VerificationPrompt(P),

    /// Indicates that the verification prompt timed out.
    VerificationPromptTimeout,
}

impl<P> fmt::Display for SsoFlowError<P>
where
    P: std::error::Error + Send + Sync + 'static,
{
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::Api(error) => write!(f, "SSO authentication failed due to: {}", error),
            Self::Cache(error) => write!(f, "SSO authentication failed due to: {}", error),
            Self::VerificationPrompt(error) => write!(
                f,
                "SSO authentication failed during verification: {}",
                error
            ),
            Self::VerificationPromptTimeout => write!(
                f,
                "SSO authentication failed: timed out waiting for verification"
            ),
        }
    }
}

impl<P: std::error::Error + Send + Sync + 'static> std::error::Error for SsoFlowError<P> {}

/// An API error that occurred during authentication.
///
/// This could be due to invalid configuration caught by the server, or a network issue. The error
/// message should be sufficient to aid end-user debugging.
#[derive(Debug)]
pub struct SsoApiError(String);

impl fmt::Display for SsoApiError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "API error when attempting authentication: {}", self.0)
    }
}

/// An error that occurred interacting with the cache during authentication.
///
/// This could be due to insufficient permissions, corrupt data, or an usual OS configuration. The
/// error message should be sufficient to aid end-user debugging.
#[derive(Debug)]
pub struct SsoCacheError(String);

impl fmt::Display for SsoCacheError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "cache error when attempting authentication: {}", self.0)
    }
}