Skip to main content

api_scanner/
auth.rs

1// src/auth.rs
2//
3// Authentication flow: load a JSON flow descriptor, execute the login
4// sequence, extract credentials, and optionally refresh them mid-scan.
5
6use std::{path::Path, sync::Arc, time::Duration};
7
8use anyhow::{bail, Context, Result};
9use arc_swap::ArcSwap;
10use jsonpath_rust::JsonPathFinder;
11use reqwest::{
12    cookie::Jar,
13    header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE},
14};
15use serde::Deserialize;
16use serde_json::Value;
17use tokio::{sync::RwLock, task::JoinHandle};
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, info, warn};
20
21use crate::config::Config;
22
23// ── Flow descriptor ───────────────────────────────────────────────────────────
24
25/// How the extracted credential is injected into every request.
26#[derive(Debug, Clone, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum InjectAs {
29    /// Sets `Authorization: Bearer <value>`
30    Bearer,
31    /// Sets `Authorization: Basic <base64(value)>`  (value must be "user:pass")
32    Basic,
33    /// Sets a named header: `header_name: <value>`
34    Header(String),
35    /// Adds the value as a cookie: `cookie_name=<value>`
36    Cookie(String),
37}
38
39/// One HTTP step in the auth flow.
40#[derive(Debug, Clone, Deserialize)]
41pub struct AuthStep {
42    /// Full URL to hit.
43    pub url: String,
44    /// HTTP method (GET, POST, PUT…). Default: POST.
45    #[serde(default = "default_post")]
46    pub method: String,
47    /// Optional JSON request body. Supports `{{ENV_VAR}}` substitution.
48    pub body: Option<Value>,
49    /// Optional extra headers for this step only.
50    #[serde(default)]
51    pub headers: std::collections::HashMap<String, String>,
52    /// JSONPath expression to extract the credential value from the response.
53    /// e.g. `"$.data.access_token"` or `"$.access_token"`
54    pub extract: Option<String>,
55    /// Where to extract the refresh token (optional).
56    pub extract_refresh: Option<String>,
57    /// Where to extract expiry in seconds from now (optional).
58    /// e.g. `"$.expires_in"` — if absent, refresh_interval_secs is used.
59    pub extract_expires_in: Option<String>,
60    /// How to inject the extracted value into all subsequent requests.
61    pub inject_as: Option<InjectAs>,
62}
63
64/// Top-level auth flow descriptor — loaded from `--auth-flow <file>`.
65#[derive(Debug, Clone, Deserialize)]
66pub struct AuthFlow {
67    /// Ordered list of HTTP steps. Usually 1 step (POST /login).
68    pub steps: Vec<AuthStep>,
69    /// How often (seconds) to refresh the token proactively.
70    /// If `extract_expires_in` is set, that value overrides this.
71    #[serde(default = "default_refresh_secs")]
72    pub refresh_interval_secs: u64,
73}
74
75fn default_post() -> String {
76    "POST".to_string()
77}
78fn default_refresh_secs() -> u64 {
79    840
80} // 14 minutes — safe for most 15-min tokens
81
82// ── Loaded / live credential ──────────────────────────────────────────────────
83
84/// The resolved, live credential produced by executing the flow.
85#[derive(Debug, Clone)]
86pub struct LiveCredential {
87    /// The primary credential value (token, cookie value…).
88    pub value: Arc<ArcSwap<String>>,
89    /// Optional refresh token.
90    pub refresh_value: Option<Arc<RwLock<String>>>,
91    /// How to apply it.
92    pub inject_as: InjectAs,
93    /// For token refresh: seconds before expiry to trigger refresh.
94    pub refresh_lead_secs: u64,
95}
96
97/// Handle for a spawned auth refresh background task.
98#[derive(Debug)]
99pub struct RefreshTaskHandle {
100    cancel: CancellationToken,
101    task: JoinHandle<()>,
102}
103
104impl RefreshTaskHandle {
105    /// Signal cancellation and wait for the task to stop.
106    pub async fn shutdown(self) {
107        self.cancel.cancel();
108        let _ = self.task.await;
109    }
110}
111
112impl LiveCredential {
113    /// Read the current token value.
114    pub fn current(&self) -> String {
115        self.value.load().as_ref().clone()
116    }
117
118    /// Apply this credential to a HeaderMap (called per-request in HttpClient).
119    pub fn apply_to(&self, map: &mut HeaderMap) {
120        let val = self.current();
121        match &self.inject_as {
122            InjectAs::Bearer => {
123                if let Ok(v) = HeaderValue::from_str(&format!("Bearer {val}")) {
124                    map.insert(reqwest::header::AUTHORIZATION, v);
125                }
126            }
127            InjectAs::Basic => {
128                use base64::engine::general_purpose::STANDARD;
129                use base64::Engine;
130                let encoded = STANDARD.encode(val.as_bytes());
131                if let Ok(v) = HeaderValue::from_str(&format!("Basic {encoded}")) {
132                    map.insert(reqwest::header::AUTHORIZATION, v);
133                }
134            }
135            InjectAs::Header(name) => {
136                if let (Ok(k), Ok(v)) = (
137                    HeaderName::from_bytes(name.as_bytes()),
138                    HeaderValue::from_str(&val),
139                ) {
140                    map.insert(k, v);
141                }
142            }
143            InjectAs::Cookie(name) => {
144                let cookie = format!("{name}={val}");
145                // Merge with existing Cookie header if present
146                let key = reqwest::header::COOKIE;
147                let merged = if let Some(existing) = map.get(&key) {
148                    let existing = existing.to_str().unwrap_or("");
149                    format!("{existing}; {cookie}")
150                } else {
151                    cookie
152                };
153                if let Ok(v) = HeaderValue::from_str(&merged) {
154                    map.insert(key, v);
155                }
156            }
157        }
158    }
159}
160
161// ── Auth flow loader ──────────────────────────────────────────────────────────
162
163pub fn load_flow(path: &Path) -> Result<AuthFlow> {
164    let content = std::fs::read_to_string(path)
165        .with_context(|| format!("Cannot read auth flow file: {}", path.display()))?;
166    serde_json::from_str(&content).with_context(|| "Auth flow file is not valid JSON")
167}
168
169// ── Flow executor ─────────────────────────────────────────────────────────────
170
171/// Execute all steps in the auth flow using a plain reqwest client
172/// (not the scanner's HttpClient, to avoid circular dependency).
173/// Returns the live credential ready for injection.
174pub async fn execute_flow(flow: &AuthFlow, config: &Config) -> Result<LiveCredential> {
175    let jar = Arc::new(Jar::default());
176    let mut builder = reqwest::Client::builder()
177        .timeout(Duration::from_secs(config.politeness.timeout_secs))
178        .cookie_provider(Arc::clone(&jar));
179
180    // Apply proxy settings if configured
181    if let Some(proxy_url) = &config.proxy {
182        let proxy = reqwest::Proxy::all(proxy_url).context("Invalid proxy URL in auth flow")?;
183        builder = builder.proxy(proxy);
184    }
185
186    // Apply TLS settings
187    builder = builder.danger_accept_invalid_certs(config.danger_accept_invalid_certs);
188
189    let client = builder.build().context("Failed to build auth client")?;
190
191    let mut last_credential: Option<LiveCredential> = None;
192
193    for (i, step) in flow.steps.iter().enumerate() {
194        info!(
195            "Auth flow step {}/{}: {} {}",
196            i + 1,
197            flow.steps.len(),
198            step.method,
199            step.url
200        );
201
202        let url = substitute_env_vars(&step.url);
203
204        let mut req = client.request(
205            step.method
206                .parse()
207                .context("Invalid HTTP method in auth flow")?,
208            &url,
209        );
210
211        // Apply step-level headers
212        for (k, v) in &step.headers {
213            if let (Ok(name), Ok(value)) = (
214                HeaderName::from_bytes(k.as_bytes()),
215                HeaderValue::from_str(&substitute_env_vars(v)),
216            ) {
217                req = req.header(name, value);
218            }
219        }
220
221        // Apply the previous step's credential to subsequent steps
222        if let Some(ref cred) = last_credential {
223            let mut map = HeaderMap::new();
224            cred.apply_to(&mut map);
225            req = req.headers(map);
226        }
227
228        if let Some(ref body) = step.body {
229            let substituted = substitute_env_vars_in_value(body);
230            req = req
231                .header(CONTENT_TYPE, "application/json")
232                .json(&substituted);
233        }
234
235        let resp = req.send().await.context("Auth flow request failed")?;
236        let status = resp.status().as_u16();
237
238        if status >= 400 {
239            bail!("Auth flow step {} returned HTTP {status}", i + 1);
240        }
241
242        // Note: this implementation expects JSON responses from all steps.
243        // application/x-www-form-urlencoded token responses are not supported.
244        let body: Value = resp
245            .json()
246            .await
247            .context("Auth flow response is not JSON. If your endpoint returns form-encoded data, see docs/auth-flow.md#non-json-responses")?;
248        debug!("Auth flow step {} response: {}", i + 1, body);
249
250        if let (Some(extract), Some(inject_as)) = (&step.extract, &step.inject_as) {
251            let token = extract_jsonpath(&body, extract).with_context(|| {
252                format!("JSONPath '{extract}' matched nothing in auth response")
253            })?;
254
255            let expires_in = step
256                .extract_expires_in
257                .as_ref()
258                .and_then(|p| extract_jsonpath(&body, p).ok())
259                .and_then(|v| v.parse::<u64>().ok());
260
261            let refresh_value = step
262                .extract_refresh
263                .as_ref()
264                .and_then(|p| extract_jsonpath(&body, p).ok())
265                .map(|v| Arc::new(RwLock::new(v)));
266
267            let refresh_interval = expires_in
268                .map(|e| e.saturating_sub(60)) // refresh 60s before expiry
269                .filter(|v| *v > 0)
270                .unwrap_or(flow.refresh_interval_secs);
271
272            info!("Auth flow: credential obtained (refresh in {refresh_interval}s)");
273
274            last_credential = Some(LiveCredential {
275                value: Arc::new(ArcSwap::from_pointee(token)),
276                refresh_value,
277                inject_as: inject_as.clone(),
278                refresh_lead_secs: refresh_interval,
279            });
280        }
281    }
282
283    last_credential.context("Auth flow completed but no credential was extracted")
284}
285
286// ── Token refresh background task ─────────────────────────────────────────────
287
288/// Spawn a background task that re-executes the auth flow before the token
289/// expires. Writes the new token into `cred.value` so all in-flight requests
290/// automatically pick it up on the next read.
291pub fn spawn_refresh_task(
292    flow: AuthFlow,
293    cred: Arc<LiveCredential>,
294    config: Config,
295) -> RefreshTaskHandle {
296    let cancel = CancellationToken::new();
297    let child_cancel = cancel.child_token();
298
299    let task = tokio::spawn(async move {
300        loop {
301            let sleep_secs = cred.refresh_lead_secs.max(1);
302            tokio::select! {
303                _ = child_cancel.cancelled() => {
304                    info!("Auth flow: refresh task cancelled");
305                    break;
306                }
307                _ = tokio::time::sleep(Duration::from_secs(sleep_secs)) => {}
308            }
309
310            info!("Auth flow: refreshing credential…");
311            match execute_flow(&flow, &config).await {
312                Ok(new_cred) => {
313                    let new_val = new_cred.current();
314                    cred.value.store(Arc::new(new_val));
315                    info!("Auth flow: credential refreshed successfully");
316                }
317                Err(e) => {
318                    warn!("Auth flow: refresh failed — {e}. Continuing with existing token.");
319                }
320            }
321        }
322    });
323
324    RefreshTaskHandle { cancel, task }
325}
326
327// ── Helpers ───────────────────────────────────────────────────────────────────
328
329/// Extract a value from a JSON document using a JSONPath expression.
330/// Supports both `$.foo.bar` (dot notation) and `/foo/bar` (JSON Pointer).
331fn extract_jsonpath(doc: &Value, path: &str) -> Result<String> {
332    // Fast path: JSON Pointer (RFC 6901)
333    if path.starts_with('/') {
334        return doc
335            .pointer(path)
336            .and_then(json_scalar_to_string)
337            .context("JSON Pointer matched nothing");
338    }
339
340    // JSONPath: $.foo.bar style
341    let finder = JsonPathFinder::from_str(&doc.to_string(), path)
342        .map_err(|e| anyhow::anyhow!("JSONPath error: {e}"))?;
343    let first = finder.find();
344    if let Value::Array(arr) = &first {
345        if let Some(v) = arr.first() {
346            return json_scalar_to_string(v)
347                .context("JSONPath result is not a scalar (string/number/bool)");
348        }
349    }
350    bail!("JSONPath '{path}' matched nothing in response")
351}
352
353fn json_scalar_to_string(v: &Value) -> Option<String> {
354    if let Some(s) = v.as_str() {
355        return Some(s.to_string());
356    }
357    if let Some(i) = v.as_i64() {
358        return Some(i.to_string());
359    }
360    if let Some(u) = v.as_u64() {
361        return Some(u.to_string());
362    }
363    if let Some(f) = v.as_f64() {
364        if f.is_finite() {
365            if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
366                return Some((f as i64).to_string());
367            }
368            return Some(f.to_string());
369        }
370    }
371    if let Some(b) = v.as_bool() {
372        return Some(b.to_string());
373    }
374    None
375}
376
377/// Replace `{{ENV_VAR}}` placeholders with environment variable values.
378fn substitute_env_vars(s: &str) -> String {
379    let re = once_cell::sync::Lazy::force(&ENV_RE);
380    re.replace_all(s, |caps: &regex::Captures| {
381        std::env::var(&caps[1]).unwrap_or_default()
382    })
383    .into_owned()
384}
385
386fn substitute_env_vars_in_value(v: &Value) -> Value {
387    match v {
388        Value::String(s) => Value::String(substitute_env_vars(s)),
389        Value::Object(map) => Value::Object(
390            map.iter()
391                .map(|(k, v)| (k.clone(), substitute_env_vars_in_value(v)))
392                .collect(),
393        ),
394        Value::Array(arr) => Value::Array(arr.iter().map(substitute_env_vars_in_value).collect()),
395        other => other.clone(),
396    }
397}
398
399static ENV_RE: once_cell::sync::Lazy<regex::Regex> =
400    once_cell::sync::Lazy::new(|| regex::Regex::new(r"\{\{([A-Za-z0-9_]+)\}\}").unwrap());