Skip to main content

faucet_auth/
token_endpoint.rs

1//! Generic token-endpoint provider: fetch a token from an arbitrary HTTP
2//! endpoint and extract it from the JSON response via JSONPath.
3
4use async_trait::async_trait;
5use faucet_core::{AuthProvider, Credential, FaucetError};
6use jsonpath_rust::JsonPath;
7use reqwest::Client;
8use serde_json::Value;
9use tokio::sync::Mutex;
10use tokio::time::Instant;
11
12use crate::expiry_instant;
13
14#[derive(Default)]
15struct CachedToken {
16    token: Option<String>,
17    expires_at: Option<Instant>,
18}
19
20/// Fetches a token from an arbitrary endpoint, extracts it via `token_path`
21/// (JSONPath), and caches it with optional expiry tracking. Single-flight
22/// refresh via an internal [`Mutex`].
23pub struct TokenEndpointProvider {
24    http: Client,
25    url: String,
26    method: reqwest::Method,
27    body: Option<Value>,
28    token_path: String,
29    expiry_path: Option<String>,
30    expiry_ratio: f64,
31    state: Mutex<CachedToken>,
32}
33
34// Hand-written so `{:?}` never prints the cached token in `state` or the request
35// `body` (which can carry a `client_secret`). `finish_non_exhaustive` omits both.
36impl std::fmt::Debug for TokenEndpointProvider {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("TokenEndpointProvider")
39            .field("url", &self.url)
40            .field("method", &self.method)
41            .field("token_path", &self.token_path)
42            .field("expiry_path", &self.expiry_path)
43            .field("expiry_ratio", &self.expiry_ratio)
44            .finish_non_exhaustive()
45    }
46}
47
48impl TokenEndpointProvider {
49    /// Build from a config object with `url`, optional `method` (default `POST`),
50    /// optional `body`, `token_path` (JSONPath), optional `expiry_path`, and
51    /// optional `expiry_ratio`.
52    pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
53        let url = config
54            .get("url")
55            .and_then(Value::as_str)
56            .ok_or_else(|| {
57                FaucetError::Config("token_endpoint auth provider: missing `url`".into())
58            })?
59            .to_string();
60        let method = config
61            .get("method")
62            .and_then(Value::as_str)
63            .unwrap_or("POST")
64            .parse::<reqwest::Method>()
65            .map_err(|e| FaucetError::Config(format!("token_endpoint: invalid method: {e}")))?;
66        let token_path = config
67            .get("token_path")
68            .and_then(Value::as_str)
69            .ok_or_else(|| {
70                FaucetError::Config("token_endpoint auth provider: missing `token_path`".into())
71            })?
72            .to_string();
73        Ok(Self {
74            http: crate::auth_http_client(),
75            url,
76            method,
77            body: config.get("body").cloned().filter(|v| !v.is_null()),
78            token_path,
79            expiry_path: config
80                .get("expiry_path")
81                .and_then(Value::as_str)
82                .map(str::to_string),
83            expiry_ratio: crate::parse_expiry_ratio(config)?,
84            state: Mutex::new(CachedToken::default()),
85        })
86    }
87
88    async fn fetch(&self) -> Result<(String, Option<u64>), FaucetError> {
89        let mut req = self.http.request(self.method.clone(), &self.url);
90        if let Some(body) = &self.body {
91            req = req.json(body);
92        }
93        let resp = req.send().await?;
94        if !resp.status().is_success() {
95            let status = resp.status().as_u16();
96            let body = resp.text().await.unwrap_or_default();
97            return Err(FaucetError::Auth(format!(
98                "token endpoint request failed (HTTP {status}): {body}"
99            )));
100        }
101        let body: Value = resp.json().await?;
102        let token = extract_string(&body, &self.token_path).ok_or_else(|| {
103            FaucetError::Auth(format!(
104                "token_path '{}' did not match a string value in the response",
105                self.token_path
106            ))
107        })?;
108        let expires_in = self
109            .expiry_path
110            .as_deref()
111            .and_then(|p| extract_u64(&body, p));
112        Ok((token, expires_in))
113    }
114}
115
116#[async_trait]
117impl AuthProvider for TokenEndpointProvider {
118    async fn credential(&self) -> Result<Credential, FaucetError> {
119        let mut state = self.state.lock().await;
120        let still_valid = match (&state.token, state.expires_at) {
121            (Some(_), Some(exp)) => Instant::now() < exp,
122            (Some(_), None) => true,
123            _ => false,
124        };
125        if still_valid {
126            return Ok(Credential::Bearer(state.token.clone().unwrap()));
127        }
128        let (token, expires_in) = self.fetch().await?;
129        state.token = Some(token.clone());
130        state.expires_at = expiry_instant(expires_in, self.expiry_ratio);
131        Ok(Credential::Bearer(token))
132    }
133
134    async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
135        let mut state = self.state.lock().await;
136        // CAS: if the cache already holds a *different* still-valid token, a
137        // concurrent caller already refreshed after the same 401 — hand that
138        // back instead of fetching again (single-flight). Only refetch when the
139        // cached token is the stale one (or itself expired). Without this
140        // override the default `invalidate` just returns `credential()`, which
141        // serves the still-cached stale token straight back, so a connector that
142        // hit a 401 can never force a refresh (#146 M15).
143        let current_valid = match (&state.token, state.expires_at) {
144            (Some(t), Some(exp)) if Instant::now() < exp => Some(t.clone()),
145            (Some(t), None) => Some(t.clone()),
146            _ => None,
147        };
148        if let (Some(cur), Credential::Bearer(stale_tok)) = (&current_valid, stale)
149            && cur != stale_tok
150        {
151            return Ok(Credential::Bearer(cur.clone()));
152        }
153        let (token, expires_in) = self.fetch().await?;
154        state.token = Some(token.clone());
155        state.expires_at = expiry_instant(expires_in, self.expiry_ratio);
156        Ok(Credential::Bearer(token))
157    }
158
159    fn provider_name(&self) -> &'static str {
160        "token_endpoint"
161    }
162}
163
164fn extract_string(body: &Value, path: &str) -> Option<String> {
165    let results = body.query(path).ok()?;
166    match results.first()? {
167        Value::String(s) => Some(s.clone()),
168        Value::Number(n) => Some(n.to_string()),
169        _ => None,
170    }
171}
172
173fn extract_u64(body: &Value, path: &str) -> Option<u64> {
174    let results = body.query(path).ok()?;
175    results.first()?.as_u64()
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use std::sync::Arc;
182    use std::sync::atomic::{AtomicUsize, Ordering};
183    use wiremock::matchers::method;
184    use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
185
186    struct Counting(Arc<AtomicUsize>);
187    impl Respond for Counting {
188        fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
189            let n = self.0.fetch_add(1, Ordering::SeqCst) + 1;
190            ResponseTemplate::new(200).set_body_json(serde_json::json!({
191                "auth": { "access_token": format!("tok{n}") },
192                "ttl": 3600
193            }))
194        }
195    }
196
197    #[tokio::test]
198    async fn extracts_token_via_jsonpath_and_single_flights() {
199        let server = MockServer::start().await;
200        let hits = Arc::new(AtomicUsize::new(0));
201        Mock::given(method("POST"))
202            .respond_with(Counting(hits.clone()))
203            .mount(&server)
204            .await;
205        let p = TokenEndpointProvider::from_config(&serde_json::json!({
206            "url": server.uri(),
207            "token_path": "$.auth.access_token",
208            "expiry_path": "$.ttl",
209        }))
210        .unwrap();
211        let results = futures::future::join_all((0..3).map(|_| p.credential())).await;
212        for r in &results {
213            assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("tok1".into()));
214        }
215        assert_eq!(hits.load(Ordering::SeqCst), 1);
216    }
217
218    #[test]
219    fn provider_debug_does_not_leak_body_secrets() {
220        // The request `body` may carry a client secret; a `{:?}` of the provider
221        // (held as `Arc<dyn AuthProvider>`) must not print it.
222        let p = TokenEndpointProvider::from_config(&serde_json::json!({
223            "url": "https://idp.example/token",
224            "token_path": "$.access_token",
225            "body": { "client_secret": "topsecretbody" },
226        }))
227        .unwrap();
228        let s = format!("{p:?}");
229        assert!(
230            !s.contains("topsecretbody"),
231            "request body secret leaked: {s}"
232        );
233        assert!(
234            s.contains("token_path"),
235            "non-secret fields should remain: {s}"
236        );
237    }
238
239    #[test]
240    fn missing_url_errors() {
241        assert!(
242            TokenEndpointProvider::from_config(&serde_json::json!({"token_path": "$.t"})).is_err()
243        );
244    }
245
246    #[tokio::test]
247    async fn invalidate_forces_a_refresh_of_the_stale_token() {
248        // M15 (#146): a connector that hit a 401 calls invalidate(stale) and
249        // must get a freshly-fetched token — not the cached stale one back.
250        let server = MockServer::start().await;
251        let hits = Arc::new(AtomicUsize::new(0));
252        Mock::given(method("POST"))
253            .respond_with(Counting(hits.clone()))
254            .mount(&server)
255            .await;
256        let p = TokenEndpointProvider::from_config(&serde_json::json!({
257            "url": server.uri(),
258            "token_path": "$.auth.access_token",
259            "expiry_path": "$.ttl",
260        }))
261        .unwrap();
262
263        assert_eq!(
264            p.credential().await.unwrap(),
265            Credential::Bearer("tok1".into())
266        );
267        assert_eq!(hits.load(Ordering::SeqCst), 1);
268
269        // invalidate(tok1) must refetch → tok2.
270        assert_eq!(
271            p.invalidate(&Credential::Bearer("tok1".into()))
272                .await
273                .unwrap(),
274            Credential::Bearer("tok2".into())
275        );
276        assert_eq!(hits.load(Ordering::SeqCst), 2);
277
278        // The refreshed token is now cached — no extra fetch.
279        assert_eq!(
280            p.credential().await.unwrap(),
281            Credential::Bearer("tok2".into())
282        );
283        assert_eq!(hits.load(Ordering::SeqCst), 2);
284    }
285
286    #[tokio::test]
287    async fn invalidate_short_circuits_when_token_already_rotated() {
288        // CAS: if the cache already holds a token different from the stale one,
289        // a concurrent caller already refreshed — return it without refetching.
290        let server = MockServer::start().await;
291        let hits = Arc::new(AtomicUsize::new(0));
292        Mock::given(method("POST"))
293            .respond_with(Counting(hits.clone()))
294            .mount(&server)
295            .await;
296        let p = TokenEndpointProvider::from_config(&serde_json::json!({
297            "url": server.uri(),
298            "token_path": "$.auth.access_token",
299            "expiry_path": "$.ttl",
300        }))
301        .unwrap();
302
303        assert_eq!(
304            p.credential().await.unwrap(),
305            Credential::Bearer("tok1".into())
306        );
307        assert_eq!(hits.load(Ordering::SeqCst), 1);
308        // Invalidating an already-superseded token returns cached tok1, no fetch.
309        assert_eq!(
310            p.invalidate(&Credential::Bearer("old-token".into()))
311                .await
312                .unwrap(),
313            Credential::Bearer("tok1".into())
314        );
315        assert_eq!(hits.load(Ordering::SeqCst), 1);
316    }
317
318    #[test]
319    fn rejects_out_of_range_expiry_ratio() {
320        // M16 (#146): an out-of-range expiry_ratio breaks caching — reject it.
321        assert!(
322            TokenEndpointProvider::from_config(&serde_json::json!({
323                "url": "http://x", "token_path": "$.t", "expiry_ratio": 0
324            }))
325            .is_err()
326        );
327        assert!(
328            TokenEndpointProvider::from_config(&serde_json::json!({
329                "url": "http://x", "token_path": "$.t", "expiry_ratio": 1.5
330            }))
331            .is_err()
332        );
333        // A valid ratio still constructs.
334        assert!(
335            TokenEndpointProvider::from_config(&serde_json::json!({
336                "url": "http://x", "token_path": "$.t", "expiry_ratio": 0.5
337            }))
338            .is_ok()
339        );
340    }
341}