1use faucet_core::FaucetError;
7use jsonpath_rust::JsonPath;
8use reqwest::Client;
9use reqwest::header::HeaderMap;
10use serde_json::Value;
11use std::fmt;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15#[derive(Clone)]
34pub struct ResponseValidator(Arc<dyn Fn(u16) -> bool + Send + Sync>);
35
36impl ResponseValidator {
37 pub fn new(f: impl Fn(u16) -> bool + Send + Sync + 'static) -> Self {
42 Self(Arc::new(f))
43 }
44
45 pub(crate) fn is_success(&self, status: u16) -> bool {
47 (self.0)(status)
48 }
49}
50
51impl fmt::Debug for ResponseValidator {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 write!(f, "ResponseValidator(<fn>)")
54 }
55}
56
57pub const DEFAULT_TOKEN_ENDPOINT_EXPIRY_RATIO: f64 = 0.9;
59
60#[derive(Debug, Clone)]
62struct CachedToken {
63 token: String,
64 expires_at: Option<tokio::time::Instant>,
65}
66
67impl CachedToken {
68 fn is_valid(&self) -> bool {
69 match self.expires_at {
70 Some(exp) => tokio::time::Instant::now() < exp,
71 None => true,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Default)]
78pub struct TokenEndpointCache(Arc<Mutex<Option<CachedToken>>>);
79
80impl TokenEndpointCache {
81 pub fn new() -> Self {
82 Self(Arc::new(Mutex::new(None)))
83 }
84
85 #[allow(clippy::too_many_arguments)]
87 pub async fn get_or_refresh(
88 &self,
89 client: &Client,
90 url: &str,
91 method: &reqwest::Method,
92 headers: &HeaderMap,
93 body: Option<&Value>,
94 token_path: &str,
95 expiry_path: Option<&str>,
96 expiry_ratio: f64,
97 response_validator: Option<&ResponseValidator>,
98 ) -> Result<String, FaucetError> {
99 let mut guard = self.0.lock().await;
100 if let Some(cached) = guard.as_ref() {
101 if cached.is_valid() {
102 return Ok(cached.token.clone());
103 }
104 tracing::debug!("TokenEndpoint token expired; refreshing");
105 }
106
107 let (token, expires_in) = fetch_token(
108 client,
109 url,
110 method,
111 headers,
112 body,
113 token_path,
114 expiry_path,
115 response_validator,
116 )
117 .await?;
118
119 let expires_at = expires_in.map(|secs| {
120 let effective = (secs as f64 * expiry_ratio) as u64;
121 tokio::time::Instant::now() + std::time::Duration::from_secs(effective)
122 });
123
124 *guard = Some(CachedToken {
125 token: token.clone(),
126 expires_at,
127 });
128
129 Ok(token)
130 }
131}
132
133pub async fn fetch_token_from_endpoint(
138 url: &str,
139 method: &reqwest::Method,
140 headers: &HeaderMap,
141 body: Option<&Value>,
142 token_path: &str,
143 response_validator: Option<&ResponseValidator>,
144) -> Result<String, FaucetError> {
145 let client = Client::new();
146 let (token, _) = fetch_token(
147 &client,
148 url,
149 method,
150 headers,
151 body,
152 token_path,
153 None,
154 response_validator,
155 )
156 .await?;
157 Ok(token)
158}
159
160#[allow(clippy::too_many_arguments)]
161async fn fetch_token(
162 client: &Client,
163 url: &str,
164 method: &reqwest::Method,
165 headers: &HeaderMap,
166 body: Option<&Value>,
167 token_path: &str,
168 expiry_path: Option<&str>,
169 response_validator: Option<&ResponseValidator>,
170) -> Result<(String, Option<u64>), FaucetError> {
171 let mut req = client.request(method.clone(), url).headers(headers.clone());
172 if let Some(b) = body {
173 req = req.json(b);
174 }
175
176 let resp = req.send().await?;
177
178 let status = resp.status();
179 let is_success = match response_validator {
180 Some(v) => v.is_success(status.as_u16()),
181 None => status.is_success(),
182 };
183 if !is_success {
184 let status_code = status.as_u16();
185 let body_text = resp.text().await.unwrap_or_default();
186 return Err(FaucetError::Auth(format!(
187 "token endpoint request failed (HTTP {status_code}): {body_text}"
188 )));
189 }
190
191 let resp_body: Value = resp.json().await?;
192
193 let token = extract_string(&resp_body, token_path).ok_or_else(|| {
194 FaucetError::Auth(format!(
195 "token_path '{token_path}' did not match a string value in the response"
196 ))
197 })?;
198
199 let expires_in = expiry_path.and_then(|ep| extract_u64(&resp_body, ep));
200
201 Ok((token, expires_in))
202}
203
204fn extract_string(body: &Value, path: &str) -> Option<String> {
206 let results = body.query(path).ok()?;
207 match results.first()? {
208 Value::String(s) => Some(s.clone()),
209 Value::Number(n) => Some(n.to_string()),
211 _ => None,
212 }
213}
214
215fn extract_u64(body: &Value, path: &str) -> Option<u64> {
217 let results = body.query(path).ok()?;
218 results.first()?.as_u64()
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use serde_json::json;
225
226 #[test]
227 fn extract_string_from_nested_json() {
228 let body = json!({"auth": {"token": "abc123"}});
229 assert_eq!(extract_string(&body, "$.auth.token"), Some("abc123".into()));
230 }
231
232 #[test]
233 fn extract_string_returns_none_for_missing_path() {
234 let body = json!({"auth": {}});
235 assert_eq!(extract_string(&body, "$.auth.token"), None);
236 }
237
238 #[test]
239 fn extract_string_converts_number_to_string() {
240 let body = json!({"token": 12345});
241 assert_eq!(extract_string(&body, "$.token"), Some("12345".into()));
242 }
243
244 #[test]
245 fn extract_u64_from_json() {
246 let body = json!({"expires_in": 3600});
247 assert_eq!(extract_u64(&body, "$.expires_in"), Some(3600));
248 }
249
250 #[test]
251 fn extract_u64_returns_none_for_string() {
252 let body = json!({"expires_in": "not a number"});
253 assert_eq!(extract_u64(&body, "$.expires_in"), None);
254 }
255
256 #[test]
257 fn extract_u64_returns_none_for_missing() {
258 let body = json!({});
259 assert_eq!(extract_u64(&body, "$.expires_in"), None);
260 }
261
262 #[test]
265 fn response_validator_accepts_matching_status() {
266 let v = ResponseValidator::new(|s| s == 200);
267 assert!(v.is_success(200));
268 assert!(!v.is_success(201));
269 }
270
271 #[test]
272 fn response_validator_range_check() {
273 let v = ResponseValidator::new(|s| s < 400);
274 assert!(v.is_success(200));
275 assert!(v.is_success(301));
276 assert!(v.is_success(399));
277 assert!(!v.is_success(400));
278 assert!(!v.is_success(500));
279 }
280
281 #[test]
282 fn response_validator_debug_format() {
283 let v = ResponseValidator::new(|_| true);
284 assert_eq!(format!("{v:?}"), "ResponseValidator(<fn>)");
285 }
286
287 #[test]
288 fn response_validator_clone() {
289 let v = ResponseValidator::new(|s| s == 200);
290 let cloned = v.clone();
291 assert!(cloned.is_success(200));
292 assert!(!cloned.is_success(404));
293 }
294
295 #[test]
298 fn cached_token_without_expiry_is_always_valid() {
299 let token = CachedToken {
300 token: "abc".into(),
301 expires_at: None,
302 };
303 assert!(token.is_valid());
304 }
305
306 #[test]
307 fn cached_token_with_future_expiry_is_valid() {
308 let token = CachedToken {
309 token: "abc".into(),
310 expires_at: Some(tokio::time::Instant::now() + std::time::Duration::from_secs(3600)),
311 };
312 assert!(token.is_valid());
313 }
314
315 #[test]
318 fn extract_string_from_array_path() {
319 let body = json!({"tokens": ["first", "second"]});
320 assert_eq!(extract_string(&body, "$.tokens[0]"), Some("first".into()));
321 }
322
323 #[test]
324 fn extract_string_returns_none_for_object() {
325 let body = json!({"token": {"nested": "value"}});
326 assert_eq!(extract_string(&body, "$.token"), None);
327 }
328
329 #[test]
330 fn extract_string_returns_none_for_null() {
331 let body = json!({"token": null});
332 assert_eq!(extract_string(&body, "$.token"), None);
333 }
334
335 #[test]
336 fn extract_u64_returns_none_for_negative() {
337 let body = json!({"expires_in": -1});
338 assert_eq!(extract_u64(&body, "$.expires_in"), None);
339 }
340
341 #[test]
342 fn extract_u64_returns_none_for_float() {
343 let body = json!({"expires_in": 3600.5});
344 assert_eq!(extract_u64(&body, "$.expires_in"), None);
345 }
346}