entelix_cloud/refresh.rs
1//! `CachedTokenProvider<T>` — single-flight cached token wrapper used by
2//! every cloud transport that fronts an OAuth-style credential
3//! source.
4//!
5//! Read path is lock-free (`parking_lot::RwLock` read guard cloned
6//! into an `Option<TokenState<T>>`). Single-flight uses an atomic
7//! claim flag plus `tokio::sync::Notify`: at most one task runs the
8//! user-supplied `TokenRefresher::refresh()` future at a time, and
9//! no lock is held across that future (CLAUDE.md "lock ordering").
10//! Refresh fires `refresh_buffer` before the cached token's
11//! wall-clock expiry.
12
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::time::{Duration, Instant};
16
17use async_trait::async_trait;
18use parking_lot::RwLock;
19use tokio::sync::Notify;
20
21use crate::CloudError;
22
23/// Default lead time before a token's expiry triggers a refresh.
24///
25/// Five minutes balances "preempt 401s on long calls" against "don't
26/// thrash short-TTL credentials" — most cloud providers issue
27/// 60-min tokens, so a 5-min buffer leaves ~92% of the lifetime in
28/// the fast path.
29pub const DEFAULT_REFRESH_BUFFER: Duration = Duration::from_mins(5);
30
31/// Source-of-truth that yields a fresh `T` plus its absolute expiry
32/// time when called.
33#[async_trait]
34pub trait TokenRefresher<T>: Send + Sync
35where
36 T: Clone + Send + Sync + 'static,
37{
38 /// Fetch a fresh token. Implementors hit the underlying IDP
39 /// (gcp_auth, azure_identity, …); errors propagate to the
40 /// caller of [`CachedTokenProvider::current`].
41 async fn refresh(&self) -> Result<TokenSnapshot<T>, CloudError>;
42}
43
44/// One refresh result.
45#[derive(Clone, Debug)]
46pub struct TokenSnapshot<T> {
47 /// The token value itself (often a [`secrecy::SecretString`]).
48 pub value: T,
49 /// Wall-clock instant at which the token stops being valid.
50 pub expires_at: Instant,
51}
52
53#[derive(Clone)]
54struct TokenState<T> {
55 value: T,
56 expires_at: Instant,
57}
58
59/// Cache + single-flight wrapper around a [`TokenRefresher`].
60pub struct CachedTokenProvider<T>
61where
62 T: Clone + Send + Sync + 'static,
63{
64 cached: RwLock<Option<TokenState<T>>>,
65 refresh_in_progress: AtomicBool,
66 refresh_done: Notify,
67 refresher: Arc<dyn TokenRefresher<T>>,
68 refresh_buffer: Duration,
69}
70
71impl<T> CachedTokenProvider<T>
72where
73 T: Clone + Send + Sync + 'static,
74{
75 /// Build with the default [`DEFAULT_REFRESH_BUFFER`] lead time.
76 pub fn new(refresher: Arc<dyn TokenRefresher<T>>) -> Self {
77 Self::with_refresh_buffer(refresher, DEFAULT_REFRESH_BUFFER)
78 }
79
80 /// Build with a custom refresh-buffer.
81 pub fn with_refresh_buffer(
82 refresher: Arc<dyn TokenRefresher<T>>,
83 refresh_buffer: Duration,
84 ) -> Self {
85 Self {
86 cached: RwLock::new(None),
87 refresh_in_progress: AtomicBool::new(false),
88 refresh_done: Notify::new(),
89 refresher,
90 refresh_buffer,
91 }
92 }
93
94 /// Return the current valid token, refreshing if the cached
95 /// value is missing or within the refresh-buffer of expiry.
96 ///
97 /// Single-flight: only one caller runs the user-supplied
98 /// [`TokenRefresher::refresh`] at a time. Other callers wait on
99 /// a [`Notify`] for the refresh to complete and then re-check
100 /// the cache. No lock is held across the user-supplied future.
101 pub async fn current(&self) -> Result<T, CloudError> {
102 loop {
103 if let Some(state) = self.read_fresh() {
104 return Ok(state);
105 }
106 // Try to claim refresh ownership atomically.
107 if self
108 .refresh_in_progress
109 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
110 .is_ok()
111 {
112 let result = self.refresher.refresh().await;
113 if let Ok(snap) = &result {
114 *self.cached.write() = Some(TokenState {
115 value: snap.value.clone(),
116 expires_at: snap.expires_at,
117 });
118 }
119 self.refresh_in_progress.store(false, Ordering::Release);
120 self.refresh_done.notify_waiters();
121 return result.map(|s| s.value);
122 }
123 // Someone else owns the refresh — register for the
124 // notification, then re-check (the writer may complete
125 // between our claim attempt and our subscription).
126 let waiter = self.refresh_done.notified();
127 tokio::pin!(waiter);
128 waiter.as_mut().enable();
129 if let Some(state) = self.read_fresh() {
130 return Ok(state);
131 }
132 waiter.await;
133 // Loop: re-read cache (now likely fresh) or retry the claim.
134 }
135 }
136
137 /// Drop the cached value. Useful when a 401 surfaces to force a
138 /// reload regardless of the recorded expiry (clock skew defence).
139 pub fn invalidate(&self) {
140 *self.cached.write() = None;
141 }
142
143 fn read_fresh(&self) -> Option<T> {
144 let snapshot = {
145 let guard = self.cached.read();
146 guard.as_ref().cloned()
147 };
148 let state = snapshot?;
149 if Instant::now() + self.refresh_buffer < state.expires_at {
150 Some(state.value)
151 } else {
152 None
153 }
154 }
155}