use std::{path::Path, sync::Arc, time::Duration};
use anyhow::{bail, Context, Result};
use arc_swap::ArcSwap;
use jsonpath_rust::JsonPathFinder;
use reqwest::{
cookie::Jar,
header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE},
};
use serde::Deserialize;
use serde_json::Value;
use tokio::{sync::RwLock, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use crate::config::Config;
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InjectAs {
Bearer,
Basic,
Header(String),
Cookie(String),
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthStep {
pub url: String,
#[serde(default = "default_post")]
pub method: String,
pub body: Option<Value>,
#[serde(default)]
pub headers: std::collections::HashMap<String, String>,
pub extract: Option<String>,
pub extract_refresh: Option<String>,
pub extract_expires_in: Option<String>,
pub inject_as: Option<InjectAs>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthFlow {
pub steps: Vec<AuthStep>,
#[serde(default = "default_refresh_secs")]
pub refresh_interval_secs: u64,
}
fn default_post() -> String {
"POST".to_string()
}
fn default_refresh_secs() -> u64 {
840
}
#[derive(Debug, Clone)]
pub struct LiveCredential {
pub value: Arc<ArcSwap<String>>,
pub refresh_value: Option<Arc<RwLock<String>>>,
pub inject_as: InjectAs,
pub refresh_lead_secs: u64,
}
#[derive(Debug)]
pub struct RefreshTaskHandle {
cancel: CancellationToken,
task: JoinHandle<()>,
}
impl RefreshTaskHandle {
pub async fn shutdown(self) {
self.cancel.cancel();
let _ = self.task.await;
}
}
impl LiveCredential {
pub fn current(&self) -> String {
self.value.load().as_ref().clone()
}
pub fn apply_to(&self, map: &mut HeaderMap) {
let val = self.current();
match &self.inject_as {
InjectAs::Bearer => {
if let Ok(v) = HeaderValue::from_str(&format!("Bearer {val}")) {
map.insert(reqwest::header::AUTHORIZATION, v);
}
}
InjectAs::Basic => {
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
let encoded = STANDARD.encode(val.as_bytes());
if let Ok(v) = HeaderValue::from_str(&format!("Basic {encoded}")) {
map.insert(reqwest::header::AUTHORIZATION, v);
}
}
InjectAs::Header(name) => {
if let (Ok(k), Ok(v)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(&val),
) {
map.insert(k, v);
}
}
InjectAs::Cookie(name) => {
let cookie = format!("{name}={val}");
let key = reqwest::header::COOKIE;
let merged = if let Some(existing) = map.get(&key) {
let existing = existing.to_str().unwrap_or("");
format!("{existing}; {cookie}")
} else {
cookie
};
if let Ok(v) = HeaderValue::from_str(&merged) {
map.insert(key, v);
}
}
}
}
}
pub fn load_flow(path: &Path) -> Result<AuthFlow> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Cannot read auth flow file: {}", path.display()))?;
serde_json::from_str(&content).with_context(|| "Auth flow file is not valid JSON")
}
pub async fn execute_flow(flow: &AuthFlow, config: &Config) -> Result<LiveCredential> {
let jar = Arc::new(Jar::default());
let mut builder = reqwest::Client::builder()
.timeout(Duration::from_secs(config.politeness.timeout_secs))
.cookie_provider(Arc::clone(&jar));
if let Some(proxy_url) = &config.proxy {
let proxy = reqwest::Proxy::all(proxy_url).context("Invalid proxy URL in auth flow")?;
builder = builder.proxy(proxy);
}
builder = builder.danger_accept_invalid_certs(config.danger_accept_invalid_certs);
let client = builder.build().context("Failed to build auth client")?;
let mut last_credential: Option<LiveCredential> = None;
for (i, step) in flow.steps.iter().enumerate() {
info!(
"Auth flow step {}/{}: {} {}",
i + 1,
flow.steps.len(),
step.method,
step.url
);
let url = substitute_env_vars(&step.url);
let mut req = client.request(
step.method
.parse()
.context("Invalid HTTP method in auth flow")?,
&url,
);
for (k, v) in &step.headers {
if let (Ok(name), Ok(value)) = (
HeaderName::from_bytes(k.as_bytes()),
HeaderValue::from_str(&substitute_env_vars(v)),
) {
req = req.header(name, value);
}
}
if let Some(ref cred) = last_credential {
let mut map = HeaderMap::new();
cred.apply_to(&mut map);
req = req.headers(map);
}
if let Some(ref body) = step.body {
let substituted = substitute_env_vars_in_value(body);
req = req
.header(CONTENT_TYPE, "application/json")
.json(&substituted);
}
let resp = req.send().await.context("Auth flow request failed")?;
let status = resp.status().as_u16();
if status >= 400 {
bail!("Auth flow step {} returned HTTP {status}", i + 1);
}
let body: Value = resp
.json()
.await
.context("Auth flow response is not JSON. If your endpoint returns form-encoded data, see docs/auth-flow.md#non-json-responses")?;
debug!("Auth flow step {} response: {}", i + 1, body);
if let (Some(extract), Some(inject_as)) = (&step.extract, &step.inject_as) {
let token = extract_jsonpath(&body, extract).with_context(|| {
format!("JSONPath '{extract}' matched nothing in auth response")
})?;
let expires_in = step
.extract_expires_in
.as_ref()
.and_then(|p| extract_jsonpath(&body, p).ok())
.and_then(|v| v.parse::<u64>().ok());
let refresh_value = step
.extract_refresh
.as_ref()
.and_then(|p| extract_jsonpath(&body, p).ok())
.map(|v| Arc::new(RwLock::new(v)));
let refresh_interval = expires_in
.map(|e| e.saturating_sub(60)) .filter(|v| *v > 0)
.unwrap_or(flow.refresh_interval_secs);
info!("Auth flow: credential obtained (refresh in {refresh_interval}s)");
last_credential = Some(LiveCredential {
value: Arc::new(ArcSwap::from_pointee(token)),
refresh_value,
inject_as: inject_as.clone(),
refresh_lead_secs: refresh_interval,
});
}
}
last_credential.context("Auth flow completed but no credential was extracted")
}
pub fn spawn_refresh_task(
flow: AuthFlow,
cred: Arc<LiveCredential>,
config: Config,
) -> RefreshTaskHandle {
let cancel = CancellationToken::new();
let child_cancel = cancel.child_token();
let task = tokio::spawn(async move {
loop {
let sleep_secs = cred.refresh_lead_secs.max(1);
tokio::select! {
_ = child_cancel.cancelled() => {
info!("Auth flow: refresh task cancelled");
break;
}
_ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {}
}
info!("Auth flow: refreshing credential…");
match execute_flow(&flow, &config).await {
Ok(new_cred) => {
let new_val = new_cred.current();
cred.value.store(Arc::new(new_val));
info!("Auth flow: credential refreshed successfully");
}
Err(e) => {
warn!("Auth flow: refresh failed — {e}. Continuing with existing token.");
}
}
}
});
RefreshTaskHandle { cancel, task }
}
fn extract_jsonpath(doc: &Value, path: &str) -> Result<String> {
if path.starts_with('/') {
return doc
.pointer(path)
.and_then(json_scalar_to_string)
.context("JSON Pointer matched nothing");
}
let finder = JsonPathFinder::from_str(&doc.to_string(), path)
.map_err(|e| anyhow::anyhow!("JSONPath error: {e}"))?;
let first = finder.find();
if let Value::Array(arr) = &first {
if let Some(v) = arr.first() {
return json_scalar_to_string(v)
.context("JSONPath result is not a scalar (string/number/bool)");
}
}
bail!("JSONPath '{path}' matched nothing in response")
}
fn json_scalar_to_string(v: &Value) -> Option<String> {
if let Some(s) = v.as_str() {
return Some(s.to_string());
}
if let Some(i) = v.as_i64() {
return Some(i.to_string());
}
if let Some(u) = v.as_u64() {
return Some(u.to_string());
}
if let Some(f) = v.as_f64() {
if f.is_finite() {
if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
return Some((f as i64).to_string());
}
return Some(f.to_string());
}
}
if let Some(b) = v.as_bool() {
return Some(b.to_string());
}
None
}
fn substitute_env_vars(s: &str) -> String {
let re = once_cell::sync::Lazy::force(&ENV_RE);
re.replace_all(s, |caps: ®ex::Captures| {
std::env::var(&caps[1]).unwrap_or_default()
})
.into_owned()
}
fn substitute_env_vars_in_value(v: &Value) -> Value {
match v {
Value::String(s) => Value::String(substitute_env_vars(s)),
Value::Object(map) => Value::Object(
map.iter()
.map(|(k, v)| (k.clone(), substitute_env_vars_in_value(v)))
.collect(),
),
Value::Array(arr) => Value::Array(arr.iter().map(substitute_env_vars_in_value).collect()),
other => other.clone(),
}
}
static ENV_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"\{\{([A-Za-z0-9_]+)\}\}").unwrap());