Skip to main content

entelix_cloud/
refresh.rs

1//! `CachedTokenProvider<T>` — single-flight cached token wrapper used by
2//! every cloud transport that fronts an OAuth-style credential
3//! source.
4//!
5//! Read path is lock-free (`parking_lot::RwLock` read guard cloned
6//! into an `Option<TokenState<T>>`). Single-flight uses an atomic
7//! claim flag plus `tokio::sync::Notify`: at most one task runs the
8//! user-supplied `TokenRefresher::refresh()` future at a time, and
9//! no lock is held across that future (CLAUDE.md "lock ordering").
10//! Refresh fires `refresh_buffer` before the cached token's
11//! wall-clock expiry.
12
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::time::{Duration, Instant};
16
17use async_trait::async_trait;
18use parking_lot::RwLock;
19use tokio::sync::Notify;
20
21use crate::CloudError;
22
23/// Default lead time before a token's expiry triggers a refresh.
24///
25/// Five minutes balances "preempt 401s on long calls" against "don't
26/// thrash short-TTL credentials" — most cloud providers issue
27/// 60-min tokens, so a 5-min buffer leaves ~92% of the lifetime in
28/// the fast path.
29pub const DEFAULT_REFRESH_BUFFER: Duration = Duration::from_mins(5);
30
31/// Source-of-truth that yields a fresh `T` plus its absolute expiry
32/// time when called.
33#[async_trait]
34pub trait TokenRefresher<T>: Send + Sync
35where
36    T: Clone + Send + Sync + 'static,
37{
38    /// Fetch a fresh token. Implementors hit the underlying IDP
39    /// (gcp_auth, azure_identity, …); errors propagate to the
40    /// caller of [`CachedTokenProvider::current`].
41    async fn refresh(&self) -> Result<TokenSnapshot<T>, CloudError>;
42}
43
44/// One refresh result.
45#[derive(Clone, Debug)]
46pub struct TokenSnapshot<T> {
47    /// The token value itself (often a [`secrecy::SecretString`]).
48    pub value: T,
49    /// Wall-clock instant at which the token stops being valid.
50    pub expires_at: Instant,
51}
52
53#[derive(Clone)]
54struct TokenState<T> {
55    value: T,
56    expires_at: Instant,
57}
58
59/// Cache + single-flight wrapper around a [`TokenRefresher`].
60pub struct CachedTokenProvider<T>
61where
62    T: Clone + Send + Sync + 'static,
63{
64    cached: RwLock<Option<TokenState<T>>>,
65    refresh_in_progress: AtomicBool,
66    refresh_done: Notify,
67    refresher: Arc<dyn TokenRefresher<T>>,
68    refresh_buffer: Duration,
69}
70
71impl<T> CachedTokenProvider<T>
72where
73    T: Clone + Send + Sync + 'static,
74{
75    /// Build with the default [`DEFAULT_REFRESH_BUFFER`] lead time.
76    pub fn new(refresher: Arc<dyn TokenRefresher<T>>) -> Self {
77        Self::with_refresh_buffer(refresher, DEFAULT_REFRESH_BUFFER)
78    }
79
80    /// Build with a custom refresh-buffer.
81    pub fn with_refresh_buffer(
82        refresher: Arc<dyn TokenRefresher<T>>,
83        refresh_buffer: Duration,
84    ) -> Self {
85        Self {
86            cached: RwLock::new(None),
87            refresh_in_progress: AtomicBool::new(false),
88            refresh_done: Notify::new(),
89            refresher,
90            refresh_buffer,
91        }
92    }
93
94    /// Return the current valid token, refreshing if the cached
95    /// value is missing or within the refresh-buffer of expiry.
96    ///
97    /// Single-flight: only one caller runs the user-supplied
98    /// [`TokenRefresher::refresh`] at a time. Other callers wait on
99    /// a [`Notify`] for the refresh to complete and then re-check
100    /// the cache. No lock is held across the user-supplied future.
101    pub async fn current(&self) -> Result<T, CloudError> {
102        loop {
103            if let Some(state) = self.read_fresh() {
104                return Ok(state);
105            }
106            // Try to claim refresh ownership atomically.
107            if self
108                .refresh_in_progress
109                .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
110                .is_ok()
111            {
112                let result = self.refresher.refresh().await;
113                if let Ok(snap) = &result {
114                    *self.cached.write() = Some(TokenState {
115                        value: snap.value.clone(),
116                        expires_at: snap.expires_at,
117                    });
118                }
119                self.refresh_in_progress.store(false, Ordering::Release);
120                self.refresh_done.notify_waiters();
121                return result.map(|s| s.value);
122            }
123            // Someone else owns the refresh — register for the
124            // notification, then re-check (the writer may complete
125            // between our claim attempt and our subscription).
126            let waiter = self.refresh_done.notified();
127            tokio::pin!(waiter);
128            waiter.as_mut().enable();
129            if let Some(state) = self.read_fresh() {
130                return Ok(state);
131            }
132            waiter.await;
133            // Loop: re-read cache (now likely fresh) or retry the claim.
134        }
135    }
136
137    /// Drop the cached value. Useful when a 401 surfaces to force a
138    /// reload regardless of the recorded expiry (clock skew defence).
139    pub fn invalidate(&self) {
140        *self.cached.write() = None;
141    }
142
143    fn read_fresh(&self) -> Option<T> {
144        let snapshot = {
145            let guard = self.cached.read();
146            guard.as_ref().cloned()
147        };
148        let state = snapshot?;
149        if Instant::now() + self.refresh_buffer < state.expires_at {
150            Some(state.value)
151        } else {
152            None
153        }
154    }
155}