Skip to main content

faucet_source_rest/auth/
token_endpoint.rs

1//! Generic token-endpoint authentication with caching.
2//!
3//! Fetches a token from an arbitrary HTTP endpoint, extracts it from the
4//! response via JSONPath, and caches it with optional expiry tracking.
5
6use faucet_core::FaucetError;
7use jsonpath_rust::JsonPath;
8use reqwest::Client;
9use reqwest::header::HeaderMap;
10use serde_json::Value;
11use std::fmt;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15/// Optional callback to decide whether the token endpoint response is
16/// successful.
17///
18/// Receives the HTTP status code and returns `true` if the response should
19/// be treated as successful.  When not provided, the default check is
20/// `status.is_success()` (i.e. 2xx).
21///
22/// # Example
23///
24/// ```
25/// use faucet_source_rest::ResponseValidator;
26///
27/// // Accept 200 and 201 only:
28/// let validator = ResponseValidator::new(|status| status == 200 || status == 201);
29///
30/// // Accept anything below 400:
31/// let validator = ResponseValidator::new(|status| status < 400);
32/// ```
33#[derive(Clone)]
34pub struct ResponseValidator(Arc<dyn Fn(u16) -> bool + Send + Sync>);
35
36impl ResponseValidator {
37    /// Create a new response validator from a closure.
38    ///
39    /// The closure receives the HTTP status code as a `u16` and must
40    /// return `true` if the response should be considered successful.
41    pub fn new(f: impl Fn(u16) -> bool + Send + Sync + 'static) -> Self {
42        Self(Arc::new(f))
43    }
44
45    /// Evaluate the validator against a status code.
46    pub(crate) fn is_success(&self, status: u16) -> bool {
47        (self.0)(status)
48    }
49}
50
51impl fmt::Debug for ResponseValidator {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "ResponseValidator(<fn>)")
54    }
55}
56
57/// Default fraction of `expires_in` after which the token is refreshed.
58pub const DEFAULT_TOKEN_ENDPOINT_EXPIRY_RATIO: f64 = 0.9;
59
60/// Cached token with expiration tracking.
61#[derive(Debug, Clone)]
62struct CachedToken {
63    token: String,
64    expires_at: Option<tokio::time::Instant>,
65}
66
67impl CachedToken {
68    fn is_valid(&self) -> bool {
69        match self.expires_at {
70            Some(exp) => tokio::time::Instant::now() < exp,
71            None => true,
72        }
73    }
74}
75
76/// Thread-safe token cache for `Auth::TokenEndpoint`.
77#[derive(Debug, Clone, Default)]
78pub struct TokenEndpointCache(Arc<Mutex<Option<CachedToken>>>);
79
80impl TokenEndpointCache {
81    pub fn new() -> Self {
82        Self(Arc::new(Mutex::new(None)))
83    }
84
85    /// Return a valid cached token or fetch a new one from the endpoint.
86    #[allow(clippy::too_many_arguments)]
87    pub async fn get_or_refresh(
88        &self,
89        client: &Client,
90        url: &str,
91        method: &reqwest::Method,
92        headers: &HeaderMap,
93        body: Option<&Value>,
94        token_path: &str,
95        expiry_path: Option<&str>,
96        expiry_ratio: f64,
97        response_validator: Option<&ResponseValidator>,
98    ) -> Result<String, FaucetError> {
99        let mut guard = self.0.lock().await;
100        if let Some(cached) = guard.as_ref() {
101            if cached.is_valid() {
102                return Ok(cached.token.clone());
103            }
104            tracing::debug!("TokenEndpoint token expired; refreshing");
105        }
106
107        let (token, expires_in) = fetch_token(
108            client,
109            url,
110            method,
111            headers,
112            body,
113            token_path,
114            expiry_path,
115            response_validator,
116        )
117        .await?;
118
119        let expires_at = expires_in.map(|secs| {
120            let effective = (secs as f64 * expiry_ratio) as u64;
121            tokio::time::Instant::now() + std::time::Duration::from_secs(effective)
122        });
123
124        *guard = Some(CachedToken {
125            token: token.clone(),
126            expires_at,
127        });
128
129        Ok(token)
130    }
131}
132
133/// Fetch a token from the given endpoint and extract it using JSONPath.
134///
135/// This is the public one-shot variant for callers who want to fetch a token
136/// without caching (e.g. for use with `Auth::Bearer`).
137pub async fn fetch_token_from_endpoint(
138    url: &str,
139    method: &reqwest::Method,
140    headers: &HeaderMap,
141    body: Option<&Value>,
142    token_path: &str,
143    response_validator: Option<&ResponseValidator>,
144) -> Result<String, FaucetError> {
145    let client = Client::new();
146    let (token, _) = fetch_token(
147        &client,
148        url,
149        method,
150        headers,
151        body,
152        token_path,
153        None,
154        response_validator,
155    )
156    .await?;
157    Ok(token)
158}
159
160#[allow(clippy::too_many_arguments)]
161async fn fetch_token(
162    client: &Client,
163    url: &str,
164    method: &reqwest::Method,
165    headers: &HeaderMap,
166    body: Option<&Value>,
167    token_path: &str,
168    expiry_path: Option<&str>,
169    response_validator: Option<&ResponseValidator>,
170) -> Result<(String, Option<u64>), FaucetError> {
171    let mut req = client.request(method.clone(), url).headers(headers.clone());
172    if let Some(b) = body {
173        req = req.json(b);
174    }
175
176    let resp = req.send().await?;
177
178    let status = resp.status();
179    let is_success = match response_validator {
180        Some(v) => v.is_success(status.as_u16()),
181        None => status.is_success(),
182    };
183    if !is_success {
184        let status_code = status.as_u16();
185        let body_text = resp.text().await.unwrap_or_default();
186        return Err(FaucetError::Auth(format!(
187            "token endpoint request failed (HTTP {status_code}): {body_text}"
188        )));
189    }
190
191    let resp_body: Value = resp.json().await?;
192
193    let token = extract_string(&resp_body, token_path).ok_or_else(|| {
194        FaucetError::Auth(format!(
195            "token_path '{token_path}' did not match a string value in the response"
196        ))
197    })?;
198
199    let expires_in = expiry_path.and_then(|ep| extract_u64(&resp_body, ep));
200
201    Ok((token, expires_in))
202}
203
204/// Extract a single string value from a JSON body using a JSONPath expression.
205fn extract_string(body: &Value, path: &str) -> Option<String> {
206    let results = body.query(path).ok()?;
207    match results.first()? {
208        Value::String(s) => Some(s.clone()),
209        // Accept numbers/bools as tokens by converting to string.
210        Value::Number(n) => Some(n.to_string()),
211        _ => None,
212    }
213}
214
215/// Extract a single u64 value from a JSON body using a JSONPath expression.
216fn extract_u64(body: &Value, path: &str) -> Option<u64> {
217    let results = body.query(path).ok()?;
218    results.first()?.as_u64()
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use serde_json::json;
225
226    #[test]
227    fn extract_string_from_nested_json() {
228        let body = json!({"auth": {"token": "abc123"}});
229        assert_eq!(extract_string(&body, "$.auth.token"), Some("abc123".into()));
230    }
231
232    #[test]
233    fn extract_string_returns_none_for_missing_path() {
234        let body = json!({"auth": {}});
235        assert_eq!(extract_string(&body, "$.auth.token"), None);
236    }
237
238    #[test]
239    fn extract_string_converts_number_to_string() {
240        let body = json!({"token": 12345});
241        assert_eq!(extract_string(&body, "$.token"), Some("12345".into()));
242    }
243
244    #[test]
245    fn extract_u64_from_json() {
246        let body = json!({"expires_in": 3600});
247        assert_eq!(extract_u64(&body, "$.expires_in"), Some(3600));
248    }
249
250    #[test]
251    fn extract_u64_returns_none_for_string() {
252        let body = json!({"expires_in": "not a number"});
253        assert_eq!(extract_u64(&body, "$.expires_in"), None);
254    }
255
256    #[test]
257    fn extract_u64_returns_none_for_missing() {
258        let body = json!({});
259        assert_eq!(extract_u64(&body, "$.expires_in"), None);
260    }
261
262    // ── ResponseValidator tests ──────────────────────────────────────────────
263
264    #[test]
265    fn response_validator_accepts_matching_status() {
266        let v = ResponseValidator::new(|s| s == 200);
267        assert!(v.is_success(200));
268        assert!(!v.is_success(201));
269    }
270
271    #[test]
272    fn response_validator_range_check() {
273        let v = ResponseValidator::new(|s| s < 400);
274        assert!(v.is_success(200));
275        assert!(v.is_success(301));
276        assert!(v.is_success(399));
277        assert!(!v.is_success(400));
278        assert!(!v.is_success(500));
279    }
280
281    #[test]
282    fn response_validator_debug_format() {
283        let v = ResponseValidator::new(|_| true);
284        assert_eq!(format!("{v:?}"), "ResponseValidator(<fn>)");
285    }
286
287    #[test]
288    fn response_validator_clone() {
289        let v = ResponseValidator::new(|s| s == 200);
290        let cloned = v.clone();
291        assert!(cloned.is_success(200));
292        assert!(!cloned.is_success(404));
293    }
294
295    // ── CachedToken tests ────────────────────────────────────────────────────
296
297    #[test]
298    fn cached_token_without_expiry_is_always_valid() {
299        let token = CachedToken {
300            token: "abc".into(),
301            expires_at: None,
302        };
303        assert!(token.is_valid());
304    }
305
306    #[test]
307    fn cached_token_with_future_expiry_is_valid() {
308        let token = CachedToken {
309            token: "abc".into(),
310            expires_at: Some(tokio::time::Instant::now() + std::time::Duration::from_secs(3600)),
311        };
312        assert!(token.is_valid());
313    }
314
315    // ── extract edge cases ───────────────────────────────────────────────────
316
317    #[test]
318    fn extract_string_from_array_path() {
319        let body = json!({"tokens": ["first", "second"]});
320        assert_eq!(extract_string(&body, "$.tokens[0]"), Some("first".into()));
321    }
322
323    #[test]
324    fn extract_string_returns_none_for_object() {
325        let body = json!({"token": {"nested": "value"}});
326        assert_eq!(extract_string(&body, "$.token"), None);
327    }
328
329    #[test]
330    fn extract_string_returns_none_for_null() {
331        let body = json!({"token": null});
332        assert_eq!(extract_string(&body, "$.token"), None);
333    }
334
335    #[test]
336    fn extract_u64_returns_none_for_negative() {
337        let body = json!({"expires_in": -1});
338        assert_eq!(extract_u64(&body, "$.expires_in"), None);
339    }
340
341    #[test]
342    fn extract_u64_returns_none_for_float() {
343        let body = json!({"expires_in": 3600.5});
344        assert_eq!(extract_u64(&body, "$.expires_in"), None);
345    }
346}