1use async_trait::async_trait;
5use faucet_core::{AuthProvider, Credential, FaucetError};
6use jsonpath_rust::JsonPath;
7use reqwest::Client;
8use serde_json::Value;
9use tokio::sync::Mutex;
10use tokio::time::Instant;
11
12use crate::expiry_instant;
13
14#[derive(Default)]
15struct CachedToken {
16 token: Option<String>,
17 expires_at: Option<Instant>,
18}
19
20pub struct TokenEndpointProvider {
24 http: Client,
25 url: String,
26 method: reqwest::Method,
27 body: Option<Value>,
28 token_path: String,
29 expiry_path: Option<String>,
30 expiry_ratio: f64,
31 state: Mutex<CachedToken>,
32}
33
34impl std::fmt::Debug for TokenEndpointProvider {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("TokenEndpointProvider")
39 .field("url", &self.url)
40 .field("method", &self.method)
41 .field("token_path", &self.token_path)
42 .field("expiry_path", &self.expiry_path)
43 .field("expiry_ratio", &self.expiry_ratio)
44 .finish_non_exhaustive()
45 }
46}
47
48impl TokenEndpointProvider {
49 pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
53 let url = config
54 .get("url")
55 .and_then(Value::as_str)
56 .ok_or_else(|| {
57 FaucetError::Config("token_endpoint auth provider: missing `url`".into())
58 })?
59 .to_string();
60 let method = config
61 .get("method")
62 .and_then(Value::as_str)
63 .unwrap_or("POST")
64 .parse::<reqwest::Method>()
65 .map_err(|e| FaucetError::Config(format!("token_endpoint: invalid method: {e}")))?;
66 let token_path = config
67 .get("token_path")
68 .and_then(Value::as_str)
69 .ok_or_else(|| {
70 FaucetError::Config("token_endpoint auth provider: missing `token_path`".into())
71 })?
72 .to_string();
73 Ok(Self {
74 http: crate::auth_http_client(),
75 url,
76 method,
77 body: config.get("body").cloned().filter(|v| !v.is_null()),
78 token_path,
79 expiry_path: config
80 .get("expiry_path")
81 .and_then(Value::as_str)
82 .map(str::to_string),
83 expiry_ratio: crate::parse_expiry_ratio(config)?,
84 state: Mutex::new(CachedToken::default()),
85 })
86 }
87
88 async fn fetch(&self) -> Result<(String, Option<u64>), FaucetError> {
89 let mut req = self.http.request(self.method.clone(), &self.url);
90 if let Some(body) = &self.body {
91 req = req.json(body);
92 }
93 let resp = req.send().await?;
94 if !resp.status().is_success() {
95 let status = resp.status().as_u16();
96 let body = resp.text().await.unwrap_or_default();
97 return Err(FaucetError::Auth(format!(
98 "token endpoint request failed (HTTP {status}): {body}"
99 )));
100 }
101 let body: Value = resp.json().await?;
102 let token = extract_string(&body, &self.token_path).ok_or_else(|| {
103 FaucetError::Auth(format!(
104 "token_path '{}' did not match a string value in the response",
105 self.token_path
106 ))
107 })?;
108 let expires_in = self
109 .expiry_path
110 .as_deref()
111 .and_then(|p| extract_u64(&body, p));
112 Ok((token, expires_in))
113 }
114}
115
116#[async_trait]
117impl AuthProvider for TokenEndpointProvider {
118 async fn credential(&self) -> Result<Credential, FaucetError> {
119 let mut state = self.state.lock().await;
120 let still_valid = match (&state.token, state.expires_at) {
121 (Some(_), Some(exp)) => Instant::now() < exp,
122 (Some(_), None) => true,
123 _ => false,
124 };
125 if still_valid {
126 return Ok(Credential::Bearer(state.token.clone().unwrap()));
127 }
128 let (token, expires_in) = self.fetch().await?;
129 state.token = Some(token.clone());
130 state.expires_at = expiry_instant(expires_in, self.expiry_ratio);
131 Ok(Credential::Bearer(token))
132 }
133
134 async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
135 let mut state = self.state.lock().await;
136 let current_valid = match (&state.token, state.expires_at) {
144 (Some(t), Some(exp)) if Instant::now() < exp => Some(t.clone()),
145 (Some(t), None) => Some(t.clone()),
146 _ => None,
147 };
148 if let (Some(cur), Credential::Bearer(stale_tok)) = (¤t_valid, stale)
149 && cur != stale_tok
150 {
151 return Ok(Credential::Bearer(cur.clone()));
152 }
153 let (token, expires_in) = self.fetch().await?;
154 state.token = Some(token.clone());
155 state.expires_at = expiry_instant(expires_in, self.expiry_ratio);
156 Ok(Credential::Bearer(token))
157 }
158
159 fn provider_name(&self) -> &'static str {
160 "token_endpoint"
161 }
162}
163
164fn extract_string(body: &Value, path: &str) -> Option<String> {
165 let results = body.query(path).ok()?;
166 match results.first()? {
167 Value::String(s) => Some(s.clone()),
168 Value::Number(n) => Some(n.to_string()),
169 _ => None,
170 }
171}
172
173fn extract_u64(body: &Value, path: &str) -> Option<u64> {
174 let results = body.query(path).ok()?;
175 results.first()?.as_u64()
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use std::sync::Arc;
182 use std::sync::atomic::{AtomicUsize, Ordering};
183 use wiremock::matchers::method;
184 use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
185
186 struct Counting(Arc<AtomicUsize>);
187 impl Respond for Counting {
188 fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
189 let n = self.0.fetch_add(1, Ordering::SeqCst) + 1;
190 ResponseTemplate::new(200).set_body_json(serde_json::json!({
191 "auth": { "access_token": format!("tok{n}") },
192 "ttl": 3600
193 }))
194 }
195 }
196
197 #[tokio::test]
198 async fn extracts_token_via_jsonpath_and_single_flights() {
199 let server = MockServer::start().await;
200 let hits = Arc::new(AtomicUsize::new(0));
201 Mock::given(method("POST"))
202 .respond_with(Counting(hits.clone()))
203 .mount(&server)
204 .await;
205 let p = TokenEndpointProvider::from_config(&serde_json::json!({
206 "url": server.uri(),
207 "token_path": "$.auth.access_token",
208 "expiry_path": "$.ttl",
209 }))
210 .unwrap();
211 let results = futures::future::join_all((0..3).map(|_| p.credential())).await;
212 for r in &results {
213 assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("tok1".into()));
214 }
215 assert_eq!(hits.load(Ordering::SeqCst), 1);
216 }
217
218 #[test]
219 fn provider_debug_does_not_leak_body_secrets() {
220 let p = TokenEndpointProvider::from_config(&serde_json::json!({
223 "url": "https://idp.example/token",
224 "token_path": "$.access_token",
225 "body": { "client_secret": "topsecretbody" },
226 }))
227 .unwrap();
228 let s = format!("{p:?}");
229 assert!(
230 !s.contains("topsecretbody"),
231 "request body secret leaked: {s}"
232 );
233 assert!(
234 s.contains("token_path"),
235 "non-secret fields should remain: {s}"
236 );
237 }
238
239 #[test]
240 fn missing_url_errors() {
241 assert!(
242 TokenEndpointProvider::from_config(&serde_json::json!({"token_path": "$.t"})).is_err()
243 );
244 }
245
246 #[tokio::test]
247 async fn invalidate_forces_a_refresh_of_the_stale_token() {
248 let server = MockServer::start().await;
251 let hits = Arc::new(AtomicUsize::new(0));
252 Mock::given(method("POST"))
253 .respond_with(Counting(hits.clone()))
254 .mount(&server)
255 .await;
256 let p = TokenEndpointProvider::from_config(&serde_json::json!({
257 "url": server.uri(),
258 "token_path": "$.auth.access_token",
259 "expiry_path": "$.ttl",
260 }))
261 .unwrap();
262
263 assert_eq!(
264 p.credential().await.unwrap(),
265 Credential::Bearer("tok1".into())
266 );
267 assert_eq!(hits.load(Ordering::SeqCst), 1);
268
269 assert_eq!(
271 p.invalidate(&Credential::Bearer("tok1".into()))
272 .await
273 .unwrap(),
274 Credential::Bearer("tok2".into())
275 );
276 assert_eq!(hits.load(Ordering::SeqCst), 2);
277
278 assert_eq!(
280 p.credential().await.unwrap(),
281 Credential::Bearer("tok2".into())
282 );
283 assert_eq!(hits.load(Ordering::SeqCst), 2);
284 }
285
286 #[tokio::test]
287 async fn invalidate_short_circuits_when_token_already_rotated() {
288 let server = MockServer::start().await;
291 let hits = Arc::new(AtomicUsize::new(0));
292 Mock::given(method("POST"))
293 .respond_with(Counting(hits.clone()))
294 .mount(&server)
295 .await;
296 let p = TokenEndpointProvider::from_config(&serde_json::json!({
297 "url": server.uri(),
298 "token_path": "$.auth.access_token",
299 "expiry_path": "$.ttl",
300 }))
301 .unwrap();
302
303 assert_eq!(
304 p.credential().await.unwrap(),
305 Credential::Bearer("tok1".into())
306 );
307 assert_eq!(hits.load(Ordering::SeqCst), 1);
308 assert_eq!(
310 p.invalidate(&Credential::Bearer("old-token".into()))
311 .await
312 .unwrap(),
313 Credential::Bearer("tok1".into())
314 );
315 assert_eq!(hits.load(Ordering::SeqCst), 1);
316 }
317
318 #[test]
319 fn rejects_out_of_range_expiry_ratio() {
320 assert!(
322 TokenEndpointProvider::from_config(&serde_json::json!({
323 "url": "http://x", "token_path": "$.t", "expiry_ratio": 0
324 }))
325 .is_err()
326 );
327 assert!(
328 TokenEndpointProvider::from_config(&serde_json::json!({
329 "url": "http://x", "token_path": "$.t", "expiry_ratio": 1.5
330 }))
331 .is_err()
332 );
333 assert!(
335 TokenEndpointProvider::from_config(&serde_json::json!({
336 "url": "http://x", "token_path": "$.t", "expiry_ratio": 0.5
337 }))
338 .is_ok()
339 );
340 }
341}