cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use super::{AutoRefreshable, ClearTokenError, Credentials, GetTokenError};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::task::JoinHandle;

pub struct AutoRefresh<C: AutoRefreshable> {
    refreshable: Arc<C>,
    job: JoinHandle<()>,
}

impl<C: AutoRefreshable> AutoRefresh<C> {
    pub fn new(credentials: C) -> Self {
        let refreshable = Arc::new(credentials);

        let refreshable_clone = refreshable.clone();
        let job = tokio::spawn(async move {
            let started_at = tokio::time::Instant::now();
            let mut attempt: u64 = 0;

            loop {
                attempt += 1;
                let elapsed = started_at.elapsed();

                let refresh_interval = {
                    tracing::debug!(
                        target: "auto_refresh",
                        attempt,
                        elapsed_secs = elapsed.as_secs(),
                        "Refreshing credentials"
                    );
                    refreshable_clone.refresh().await
                };

                tracing::debug!(
                    target: "auto_refresh",
                    attempt,
                    elapsed_secs = started_at.elapsed().as_secs(),
                    next_refresh_secs = refresh_interval.as_secs_f64(),
                    "Refresh complete, sleeping until next refresh"
                );

                tokio::time::sleep(refresh_interval).await;
            }
        });

        Self { refreshable, job }
    }
}

#[async_trait]
impl<C: AutoRefreshable> Credentials for AutoRefresh<C> {
    type Token = C::Token;

    async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
        self.refreshable.get_token().await
    }

    async fn clear_token(&self) -> Result<(), ClearTokenError> {
        self.refreshable.clear_token().await
    }
}

impl<Credentials: AutoRefreshable> Drop for AutoRefresh<Credentials> {
    fn drop(&mut self) {
        tracing::debug!(target: "auto_refresh", "Aborting refresh job");
        self.job.abort();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::time::Duration;

    /// Mock credential that tracks how many get_token() calls are in-flight
    /// simultaneously. Has NO inner lock — concurrent calls CAN overlap.
    /// Any serialization observed is caused by AutoRefresh's outer Mutex.
    struct ConcurrencyTracker {
        current: Arc<AtomicUsize>,
        peak: Arc<AtomicUsize>,
    }

    impl ConcurrencyTracker {
        fn new() -> (Self, Arc<AtomicUsize>) {
            let peak = Arc::new(AtomicUsize::new(0));
            (
                Self {
                    current: Arc::new(AtomicUsize::new(0)),
                    peak: peak.clone(),
                },
                peak,
            )
        }
    }

    #[async_trait]
    impl Credentials for ConcurrencyTracker {
        type Token = String;

        async fn get_token(&self) -> Result<String, GetTokenError> {
            let prev = self.current.fetch_add(1, Ordering::SeqCst);
            self.peak.fetch_max(prev + 1, Ordering::SeqCst);

            // Hold long enough for other spawned tasks to enter
            tokio::time::sleep(Duration::from_millis(50)).await;

            self.current.fetch_sub(1, Ordering::SeqCst);
            Ok("token".to_string())
        }

        async fn clear_token(&self) -> Result<(), ClearTokenError> {
            Ok(())
        }
    }

    #[async_trait]
    impl AutoRefreshable for ConcurrencyTracker {
        async fn refresh(&self) -> Duration {
            Duration::from_secs(300)
        }
    }

    /// BEFORE fix (outer Mutex): peak == 1 (all calls serialized)
    /// AFTER fix (outer Mutex removed): peak > 1 (calls overlap)
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_auto_refresh_does_not_serialize_concurrent_calls() {
        let (tracker, peak) = ConcurrencyTracker::new();
        let creds = Arc::new(AutoRefresh::new(tracker));

        let mut handles = vec![];
        for _ in 0..10 {
            let creds = creds.clone();
            handles.push(tokio::spawn(
                async move { creds.get_token().await.unwrap() },
            ));
        }

        for h in handles {
            h.await.unwrap();
        }

        let peak = peak.load(Ordering::SeqCst);
        assert!(
            peak > 1,
            "Expected concurrent get_token() calls but peak concurrency was {peak}. \
             AutoRefresh is serializing calls through its outer Mutex.",
        );
    }
}