cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use std::{path::Path, time::Duration};

use async_mutex::Mutex as AsyncMutex;
use async_trait::async_trait;
use miette::Diagnostic;
use serde_json::json;
use thiserror::Error;
use tracing::debug;
use url::Url;

use super::service_token::ServiceToken;
use crate::credentials::{
    token_store::TokenStore, AutoRefreshable, ClearTokenError, Credentials, GetTokenError,
    TokenExpiry,
};
use crate::reqwest_client::create_client;
use crate::user_agent::get_user_agent;

pub struct ServiceAccessKeyCredentials {
    access_key: String,
    audience: Option<String>,
    cts_base_url: Url,
    token_store: AsyncMutex<TokenStore<ServiceToken>>,
    client: reqwest_middleware::ClientWithMiddleware,
}

#[derive(Diagnostic, Error, Debug)]
pub enum AcquireTokenError {
    #[error("Failed to acquire token: {0}")]
    RequestFailed(Box<dyn std::error::Error + Sync + Send>),

    #[error("Failed to parse json response: {0}")]
    BadResponse(Box<dyn std::error::Error + Sync + Send>),
}

impl ServiceAccessKeyCredentials {
    pub fn new(
        token_path: &Path,
        access_key: &str,
        cts_base_url: &Url,
        audience: Option<&str>,
    ) -> Self {
        Self {
            access_key: access_key.to_string(),
            audience: audience.map(|s| s.to_string()),
            cts_base_url: cts_base_url.to_owned(),
            token_store: AsyncMutex::new(TokenStore::new(token_path)),
            client: create_client(),
        }
    }

    async fn authorise(&self) -> Result<ServiceToken, AcquireTokenError> {
        debug!(target: "service_access_key_credentials", "Authorising Access Token with CTS");

        let url = self.cts_base_url.join("/api/authorise").unwrap();

        let token: ServiceToken = self
            .client
            .post(url)
            .json(&json!({ "accessKey": self.access_key, "audience": self.audience }))
            .header("user-agent", get_user_agent())
            .send()
            .await
            .map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
            .error_for_status()
            .map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
            .json()
            .await
            .map_err(|e| AcquireTokenError::BadResponse(Box::new(e)))?;

        debug!(target: "service_access_key_credentials",
            "Access Token Acquired - expiry(epoch seconds): {}",
            &token.expiry
        );
        Ok(token)
    }
}

#[async_trait]
impl Credentials for ServiceAccessKeyCredentials {
    type Token = ServiceToken;

    async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
        debug!(target: "service_access_key_credentials", "getting token (waiting for lock)");

        // Fast path: lock briefly to check cache, then drop.
        // When tokens are cached and valid, concurrent callers return immediately.
        let cached_token = {
            let mut token_store = self.token_store.lock().await;
            debug!(target: "service_access_key_credentials", "getting token (got lock)");
            token_store.get()
        };

        if let Some(cached_token) = &cached_token {
            if !cached_token.is_expired() {
                debug!(target: "service_access_key_credentials", "using cached token");
                return Ok(cached_token.clone());
            }

            debug!(target: "service_access_key_credentials", "cached token is expired");
        }

        // Slow path: token is expired or missing. Re-acquire lock and double-check
        // before authorising. This serializes the HTTP call so that:
        //  - only one caller hits CTS (waiters pick up the refreshed token)
        //  - a concurrent clear_token() cannot be silently overwritten
        let mut token_store = self.token_store.lock().await;

        // Double-check: another caller may have refreshed while we waited for the lock
        if let Some(cached_token) = token_store.get() {
            if !cached_token.is_expired() {
                debug!(target: "service_access_key_credentials", "token already refreshed by another caller");
                return Ok(cached_token);
            }
        }

        debug!(target: "service_access_key_credentials", "fetching new token from CTS");

        // We hold the lock — sole caller to authorise
        let new_token = self
            .authorise()
            .await
            .map_err(|e| GetTokenError::AcquireNewTokenFailed(Box::new(e)))?;

        token_store
            .set(&new_token)
            .map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;

        Ok(new_token)
    }

    async fn clear_token(&self) -> Result<(), ClearTokenError> {
        debug!(target: "service_access_key_credentials", "clearing token");

        let mut token_store = self.token_store.lock().await;
        token_store
            .clear()
            .map_err(|e| ClearTokenError(Box::new(e)))
    }
}

#[async_trait]
impl AutoRefreshable for ServiceAccessKeyCredentials {
    async fn refresh(&self) -> Duration {
        // Fast path: lock briefly to check if refresh is needed
        let token = {
            let mut token_store = self.token_store.lock().await;
            token_store.get()
        };

        if let Some(cached_token) = &token {
            if !cached_token.should_refresh() {
                debug!(target: "service_access_key_credentials", "Access token is still new");
                return cached_token.refresh_interval();
            }
        }

        // Slow path: re-lock and hold through I/O to serialize with
        // get_token() and clear_token(), preventing a concurrent clear
        // from being silently overwritten by a stale refresh.
        let mut token_store = self.token_store.lock().await;

        // Double-check: get_token() or a previous refresh() may have already refreshed
        if let Some(cached_token) = token_store.get() {
            if !cached_token.should_refresh() {
                debug!(target: "service_access_key_credentials", "Access token already refreshed by another caller");
                return cached_token.refresh_interval();
            }
        }

        debug!(target: "service_access_key_credentials", "Access token is missing or close to expiry, refreshing");

        match self.authorise().await {
            Ok(new_token) => {
                if let Err(err) = token_store.set(&new_token) {
                    tracing::warn!(
                        target: "service_access_key_credentials",
                        error = %err,
                        "Failed to persist refreshed token"
                    );
                } else {
                    debug!(target: "service_access_key_credentials", "Access token refreshed and saved to disk");
                    return new_token.refresh_interval();
                }
            }
            Err(err) => {
                tracing::warn!(
                    target: "service_access_key_credentials",
                    error = %err,
                    "Failed to refresh access key token"
                );
            }
        }

        Self::Token::min_refresh_interval()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::credentials::test_utils::CountingState;
    use std::sync::Arc;

    async fn slow_authorise(
        axum::extract::State(state): axum::extract::State<CountingState>,
    ) -> axum::Json<serde_json::Value> {
        state.enter();
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        state.exit();
        axum::Json(serde_json::json!({
            "accessToken": "test-token",
            "expiry": 9999999999u64
        }))
    }

    /// get_token() serializes the authorise HTTP call via double-checked
    /// locking: only one caller hits CTS, waiters pick up the refreshed token.
    /// This also prevents a concurrent clear_token() from being overwritten.
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_service_access_key_get_token_serializes_authorise() {
        let state = CountingState::new();
        let stats = state.clone();

        let app = axum::Router::new()
            .route("/api/authorise", axum::routing::post(slow_authorise))
            .with_state(state);

        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        tokio::spawn(async move {
            axum::serve(listener, app).await.unwrap();
        });

        let tmp = tempfile::TempDir::new().unwrap();
        let token_path = tmp.path().join("token.json");
        let base_url = Url::parse(&format!("http://{addr}")).unwrap();

        let creds = Arc::new(ServiceAccessKeyCredentials::new(
            &token_path,
            "test-access-key",
            &base_url,
            None,
        ));

        let mut handles = vec![];
        for _ in 0..5 {
            let creds = creds.clone();
            handles.push(tokio::spawn(
                async move { creds.get_token().await.unwrap() },
            ));
        }

        for h in handles {
            h.await.unwrap();
        }

        let peak = stats.peak();
        let total = stats.total();
        assert_eq!(
            peak, 1,
            "Expected serialized authorise but peak concurrency was {peak}. \
             Concurrent authorise calls waste resources and race with clear_token().",
        );
        assert_eq!(
            total, 1,
            "Expected exactly 1 authorise request but got {total}. \
             Double-check pattern should let waiters use the refreshed token.",
        );
    }
}