1use std::{path::Path, sync::Arc, time::Duration};
7
8use anyhow::{bail, Context, Result};
9use arc_swap::ArcSwap;
10use jsonpath_rust::JsonPathFinder;
11use reqwest::{
12 cookie::Jar,
13 header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE},
14};
15use serde::Deserialize;
16use serde_json::Value;
17use tokio::{sync::RwLock, task::JoinHandle};
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, info, warn};
20
21use crate::config::Config;
22
23#[derive(Debug, Clone, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum InjectAs {
29 Bearer,
31 Basic,
33 Header(String),
35 Cookie(String),
37}
38
39#[derive(Debug, Clone, Deserialize)]
41pub struct AuthStep {
42 pub url: String,
44 #[serde(default = "default_post")]
46 pub method: String,
47 pub body: Option<Value>,
49 #[serde(default)]
51 pub headers: std::collections::HashMap<String, String>,
52 pub extract: Option<String>,
55 pub extract_refresh: Option<String>,
57 pub extract_expires_in: Option<String>,
60 pub inject_as: Option<InjectAs>,
62}
63
64#[derive(Debug, Clone, Deserialize)]
66pub struct AuthFlow {
67 pub steps: Vec<AuthStep>,
69 #[serde(default = "default_refresh_secs")]
72 pub refresh_interval_secs: u64,
73}
74
75fn default_post() -> String {
76 "POST".to_string()
77}
78fn default_refresh_secs() -> u64 {
79 840
80} #[derive(Debug, Clone)]
86pub struct LiveCredential {
87 pub value: Arc<ArcSwap<String>>,
89 pub refresh_value: Option<Arc<RwLock<String>>>,
91 pub inject_as: InjectAs,
93 pub refresh_lead_secs: u64,
95}
96
97#[derive(Debug)]
99pub struct RefreshTaskHandle {
100 cancel: CancellationToken,
101 task: JoinHandle<()>,
102}
103
104impl RefreshTaskHandle {
105 pub async fn shutdown(self) {
107 self.cancel.cancel();
108 let _ = self.task.await;
109 }
110}
111
112impl LiveCredential {
113 pub fn current(&self) -> String {
115 self.value.load().as_ref().clone()
116 }
117
118 pub fn apply_to(&self, map: &mut HeaderMap) {
120 let val = self.current();
121 match &self.inject_as {
122 InjectAs::Bearer => {
123 if let Ok(v) = HeaderValue::from_str(&format!("Bearer {val}")) {
124 map.insert(reqwest::header::AUTHORIZATION, v);
125 }
126 }
127 InjectAs::Basic => {
128 use base64::engine::general_purpose::STANDARD;
129 use base64::Engine;
130 let encoded = STANDARD.encode(val.as_bytes());
131 if let Ok(v) = HeaderValue::from_str(&format!("Basic {encoded}")) {
132 map.insert(reqwest::header::AUTHORIZATION, v);
133 }
134 }
135 InjectAs::Header(name) => {
136 if let (Ok(k), Ok(v)) = (
137 HeaderName::from_bytes(name.as_bytes()),
138 HeaderValue::from_str(&val),
139 ) {
140 map.insert(k, v);
141 }
142 }
143 InjectAs::Cookie(name) => {
144 let cookie = format!("{name}={val}");
145 let key = reqwest::header::COOKIE;
147 let merged = if let Some(existing) = map.get(&key) {
148 let existing = existing.to_str().unwrap_or("");
149 format!("{existing}; {cookie}")
150 } else {
151 cookie
152 };
153 if let Ok(v) = HeaderValue::from_str(&merged) {
154 map.insert(key, v);
155 }
156 }
157 }
158 }
159}
160
161pub fn load_flow(path: &Path) -> Result<AuthFlow> {
164 let content = std::fs::read_to_string(path)
165 .with_context(|| format!("Cannot read auth flow file: {}", path.display()))?;
166 serde_json::from_str(&content).with_context(|| "Auth flow file is not valid JSON")
167}
168
169pub async fn execute_flow(flow: &AuthFlow, config: &Config) -> Result<LiveCredential> {
175 let jar = Arc::new(Jar::default());
176 let mut builder = reqwest::Client::builder()
177 .timeout(Duration::from_secs(config.politeness.timeout_secs))
178 .cookie_provider(Arc::clone(&jar));
179
180 if let Some(proxy_url) = &config.proxy {
182 let proxy = reqwest::Proxy::all(proxy_url).context("Invalid proxy URL in auth flow")?;
183 builder = builder.proxy(proxy);
184 }
185
186 builder = builder.danger_accept_invalid_certs(config.danger_accept_invalid_certs);
188
189 let client = builder.build().context("Failed to build auth client")?;
190
191 let mut last_credential: Option<LiveCredential> = None;
192
193 for (i, step) in flow.steps.iter().enumerate() {
194 info!(
195 "Auth flow step {}/{}: {} {}",
196 i + 1,
197 flow.steps.len(),
198 step.method,
199 step.url
200 );
201
202 let url = substitute_env_vars(&step.url);
203
204 let mut req = client.request(
205 step.method
206 .parse()
207 .context("Invalid HTTP method in auth flow")?,
208 &url,
209 );
210
211 for (k, v) in &step.headers {
213 if let (Ok(name), Ok(value)) = (
214 HeaderName::from_bytes(k.as_bytes()),
215 HeaderValue::from_str(&substitute_env_vars(v)),
216 ) {
217 req = req.header(name, value);
218 }
219 }
220
221 if let Some(ref cred) = last_credential {
223 let mut map = HeaderMap::new();
224 cred.apply_to(&mut map);
225 req = req.headers(map);
226 }
227
228 if let Some(ref body) = step.body {
229 let substituted = substitute_env_vars_in_value(body);
230 req = req
231 .header(CONTENT_TYPE, "application/json")
232 .json(&substituted);
233 }
234
235 let resp = req.send().await.context("Auth flow request failed")?;
236 let status = resp.status().as_u16();
237
238 if status >= 400 {
239 bail!("Auth flow step {} returned HTTP {status}", i + 1);
240 }
241
242 let body: Value = resp
245 .json()
246 .await
247 .context("Auth flow response is not JSON. If your endpoint returns form-encoded data, see docs/auth-flow.md#non-json-responses")?;
248 debug!("Auth flow step {} response: {}", i + 1, body);
249
250 if let (Some(extract), Some(inject_as)) = (&step.extract, &step.inject_as) {
251 let token = extract_jsonpath(&body, extract).with_context(|| {
252 format!("JSONPath '{extract}' matched nothing in auth response")
253 })?;
254
255 let expires_in = step
256 .extract_expires_in
257 .as_ref()
258 .and_then(|p| extract_jsonpath(&body, p).ok())
259 .and_then(|v| v.parse::<u64>().ok());
260
261 let refresh_value = step
262 .extract_refresh
263 .as_ref()
264 .and_then(|p| extract_jsonpath(&body, p).ok())
265 .map(|v| Arc::new(RwLock::new(v)));
266
267 let refresh_interval = expires_in
268 .map(|e| e.saturating_sub(60)) .filter(|v| *v > 0)
270 .unwrap_or(flow.refresh_interval_secs);
271
272 info!("Auth flow: credential obtained (refresh in {refresh_interval}s)");
273
274 last_credential = Some(LiveCredential {
275 value: Arc::new(ArcSwap::from_pointee(token)),
276 refresh_value,
277 inject_as: inject_as.clone(),
278 refresh_lead_secs: refresh_interval,
279 });
280 }
281 }
282
283 last_credential.context("Auth flow completed but no credential was extracted")
284}
285
286pub fn spawn_refresh_task(
292 flow: AuthFlow,
293 cred: Arc<LiveCredential>,
294 config: Config,
295) -> RefreshTaskHandle {
296 let cancel = CancellationToken::new();
297 let child_cancel = cancel.child_token();
298
299 let task = tokio::spawn(async move {
300 loop {
301 let sleep_secs = cred.refresh_lead_secs.max(1);
302 tokio::select! {
303 _ = child_cancel.cancelled() => {
304 info!("Auth flow: refresh task cancelled");
305 break;
306 }
307 _ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {}
308 }
309
310 info!("Auth flow: refreshing credential…");
311 match execute_flow(&flow, &config).await {
312 Ok(new_cred) => {
313 let new_val = new_cred.current();
314 cred.value.store(Arc::new(new_val));
315 info!("Auth flow: credential refreshed successfully");
316 }
317 Err(e) => {
318 warn!("Auth flow: refresh failed — {e}. Continuing with existing token.");
319 }
320 }
321 }
322 });
323
324 RefreshTaskHandle { cancel, task }
325}
326
327fn extract_jsonpath(doc: &Value, path: &str) -> Result<String> {
332 if path.starts_with('/') {
334 return doc
335 .pointer(path)
336 .and_then(json_scalar_to_string)
337 .context("JSON Pointer matched nothing");
338 }
339
340 let finder = JsonPathFinder::from_str(&doc.to_string(), path)
342 .map_err(|e| anyhow::anyhow!("JSONPath error: {e}"))?;
343 let first = finder.find();
344 if let Value::Array(arr) = &first {
345 if let Some(v) = arr.first() {
346 return json_scalar_to_string(v)
347 .context("JSONPath result is not a scalar (string/number/bool)");
348 }
349 }
350 bail!("JSONPath '{path}' matched nothing in response")
351}
352
353fn json_scalar_to_string(v: &Value) -> Option<String> {
354 if let Some(s) = v.as_str() {
355 return Some(s.to_string());
356 }
357 if let Some(i) = v.as_i64() {
358 return Some(i.to_string());
359 }
360 if let Some(u) = v.as_u64() {
361 return Some(u.to_string());
362 }
363 if let Some(f) = v.as_f64() {
364 if f.is_finite() {
365 if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
366 return Some((f as i64).to_string());
367 }
368 return Some(f.to_string());
369 }
370 }
371 if let Some(b) = v.as_bool() {
372 return Some(b.to_string());
373 }
374 None
375}
376
377fn substitute_env_vars(s: &str) -> String {
379 let re = once_cell::sync::Lazy::force(&ENV_RE);
380 re.replace_all(s, |caps: ®ex::Captures| {
381 std::env::var(&caps[1]).unwrap_or_default()
382 })
383 .into_owned()
384}
385
386fn substitute_env_vars_in_value(v: &Value) -> Value {
387 match v {
388 Value::String(s) => Value::String(substitute_env_vars(s)),
389 Value::Object(map) => Value::Object(
390 map.iter()
391 .map(|(k, v)| (k.clone(), substitute_env_vars_in_value(v)))
392 .collect(),
393 ),
394 Value::Array(arr) => Value::Array(arr.iter().map(substitute_env_vars_in_value).collect()),
395 other => other.clone(),
396 }
397}
398
399static ENV_RE: once_cell::sync::Lazy<regex::Regex> =
400 once_cell::sync::Lazy::new(|| regex::Regex::new(r"\{\{([A-Za-z0-9_]+)\}\}").unwrap());