cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
//! Module for credential providers for various internal services.

pub mod service_credentials;
mod static_credentials;
pub mod token_store;
pub mod user_credentials;

// Re-export the public API
pub use service_credentials::{ServiceAccessKeyCredentials, ServiceCredentials, ServiceToken};
pub use static_credentials::StaticCredentials;

#[cfg(feature = "tokio")]
pub mod auto_refresh;

use std::{
    sync::Arc,
    time::{Duration, SystemTime, UNIX_EPOCH},
};

use async_trait::async_trait;
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use thiserror::Error;

fn now_secs() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("Expected system time to be greater than UNIX_EPOCH")
        .as_secs()
}

pub trait TokenExpiry<'a>: Clone + Serialize + Deserialize<'a> {
    const EXPIRY_LEEWAY_SECONDS: u64 = 60;
    const REFRESH_LEEWAY_SECONDS: u64 = 180;
    const MIN_REFRESH_INTERVAL_SECONDS: u64 = 10;

    fn expires_at_secs(&self) -> u64;

    fn is_expired(&self) -> bool {
        (now_secs() + Self::EXPIRY_LEEWAY_SECONDS) > self.expires_at_secs()
    }

    fn should_refresh(&self) -> bool {
        (now_secs() + Self::REFRESH_LEEWAY_SECONDS) > self.expires_at_secs()
    }

    fn refresh_interval(&self) -> Duration {
        let threshold = now_secs() + Self::REFRESH_LEEWAY_SECONDS;
        let expires_at = self.expires_at_secs();

        if expires_at > threshold {
            Duration::from_secs(expires_at - threshold)
        } else {
            Duration::from_secs(Self::MIN_REFRESH_INTERVAL_SECONDS)
        }
    }

    fn min_refresh_interval() -> Duration {
        Duration::from_secs(Self::MIN_REFRESH_INTERVAL_SECONDS)
    }
}

#[derive(Diagnostic, Error, Debug)]
pub enum GetTokenError {
    #[error("RefreshTokenFailed: {0}")]
    #[diagnostic(transparent)]
    RefreshTokenFailed(Box<dyn Diagnostic + Send + Sync>),

    #[error("AcquireNewTokenFailed: {0}")]
    #[diagnostic(transparent)]
    AcquireNewTokenFailed(Box<dyn Diagnostic + Send + Sync>),

    #[error("PersistTokenError: {0}")]
    #[diagnostic(transparent)]
    PersistTokenError(Box<dyn Diagnostic + Send + Sync>),

    #[error("Token missing or expired")]
    #[diagnostic(help("Token is missing or expired"))]
    MissingOrExpired,
}

#[derive(Error, Debug)]
#[error("RefreshTokenFailed: {0}")]
pub struct ClearTokenError(pub Box<dyn Diagnostic + Send + Sync>);

// TODO: We can remove async_trait now (check MSRV)
#[async_trait]
pub trait Credentials: Send + Sync + 'static {
    // TODO: Token should probably be bound to some trait
    type Token;

    async fn get_token(&self) -> Result<Self::Token, GetTokenError>;

    async fn clear_token(&self) -> Result<(), ClearTokenError>;

    /// Check if the token is valid.
    /// The default implementation checks if the token can be acquired via [`Credentials::get_token`].
    async fn valid(&self) -> bool {
        tracing::debug!(target: "credentials", "Attempting to acquire token");
        self.get_token().await.is_ok()
    }
}

#[async_trait]
pub trait AutoRefreshable: Credentials {
    /// Refresh the token, caches the result, and returns the duration until when the token should refresh again
    async fn refresh(&self) -> Duration;
}

/// Implement `Credentials` for `Arc<Credentials>`
#[async_trait]
impl<C: Credentials> Credentials for Arc<C> {
    type Token = C::Token;

    async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
        self.as_ref().get_token().await
    }

    async fn clear_token(&self) -> Result<(), ClearTokenError> {
        self.as_ref().clear_token().await
    }
}

#[cfg(test)]
pub(crate) mod test_utils {
    use std::sync::{
        atomic::{AtomicUsize, Ordering},
        Arc,
    };

    /// Tracks in-flight and peak concurrency for test assertions.
    /// Use with axum handlers to verify lock contention behavior.
    #[derive(Clone)]
    pub(crate) struct CountingState {
        pub total: Arc<AtomicUsize>,
        pub current: Arc<AtomicUsize>,
        pub peak: Arc<AtomicUsize>,
    }

    impl CountingState {
        pub fn new() -> Self {
            Self {
                total: Arc::new(AtomicUsize::new(0)),
                current: Arc::new(AtomicUsize::new(0)),
                peak: Arc::new(AtomicUsize::new(0)),
            }
        }

        /// Call at handler entry. Must be paired with a call to `exit` on handler exit.
        pub fn enter(&self) {
            self.total.fetch_add(1, Ordering::SeqCst);
            let prev = self.current.fetch_add(1, Ordering::SeqCst);
            self.peak.fetch_max(prev + 1, Ordering::SeqCst);
        }

        /// Call at handler exit.
        pub fn exit(&self) {
            self.current.fetch_sub(1, Ordering::SeqCst);
        }

        pub fn peak(&self) -> usize {
            self.peak.load(Ordering::SeqCst)
        }

        pub fn total(&self) -> usize {
            self.total.load(Ordering::SeqCst)
        }
    }
}