Skip to main content

assay_workflow/api/
auth.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
19
20// ── Auth Mode ───────────────────────────────────────────────
21
22/// Auth configuration for the engine's HTTP API.
23///
24/// Both authentication methods (JWT and API key) can be enabled at the same time.
25/// When both are enabled, the middleware dispatches on token shape — tokens that
26/// parse as a JWS header are validated as JWTs, everything else is validated as
27/// an API key. This lets the same server accept long-lived machine API keys
28/// alongside short-lived OIDC-issued user tokens without the caller picking a
29/// mode up front.
30#[derive(Clone, Debug, Default)]
31pub struct AuthMode {
32    /// API-key authentication enabled. When true, Bearer tokens that are not
33    /// JWT-shaped are validated against the `api_keys` table.
34    pub api_key: bool,
35    /// JWT authentication enabled. When set, Bearer tokens that parse as a
36    /// JWS header are validated against the issuer's JWKS.
37    pub jwt: Option<JwtConfig>,
38}
39
40/// JWT validation configuration.
41#[derive(Clone, Debug)]
42pub struct JwtConfig {
43    pub issuer: String,
44    pub audience: Option<String>,
45    pub jwks_cache: Arc<JwksCache>,
46}
47
48impl AuthMode {
49    /// Open access — no authentication. All requests allowed.
50    pub fn no_auth() -> Self {
51        Self::default()
52    }
53
54    /// JWT/OIDC only. Tokens are validated against the issuer's JWKS.
55    pub fn jwt(issuer: String, audience: Option<String>) -> Self {
56        Self {
57            api_key: false,
58            jwt: Some(JwtConfig::new(issuer, audience)),
59        }
60    }
61
62    /// API key only. Bearer tokens are hashed and looked up in the store.
63    pub fn api_key() -> Self {
64        Self {
65            api_key: true,
66            jwt: None,
67        }
68    }
69
70    /// Both JWT and API key. Tokens that parse as JWTs take the JWT path;
71    /// everything else takes the API-key path.
72    pub fn combined(issuer: String, audience: Option<String>) -> Self {
73        Self {
74            api_key: true,
75            jwt: Some(JwtConfig::new(issuer, audience)),
76        }
77    }
78
79    /// True if any authentication method is enabled.
80    pub fn is_enabled(&self) -> bool {
81        self.api_key || self.jwt.is_some()
82    }
83
84    /// Human-readable summary for startup logging.
85    pub fn describe(&self) -> String {
86        match (self.jwt.as_ref(), self.api_key) {
87            (None, false) => "no-auth (open access)".to_string(),
88            (None, true) => "api-key".to_string(),
89            (Some(c), false) => format!("jwt (issuer: {})", c.issuer),
90            (Some(c), true) => format!("jwt (issuer: {}) + api-key", c.issuer),
91        }
92    }
93}
94
95impl JwtConfig {
96    /// Build a JwtConfig with a fresh JWKS cache pointed at `issuer`'s OIDC discovery endpoint.
97    pub fn new(issuer: String, audience: Option<String>) -> Self {
98        Self {
99            jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
100            issuer,
101            audience,
102        }
103    }
104}
105
106// ── JWKS Cache ──────────────────────────────────────────────
107
108/// Caches JWKS keys fetched from the OIDC provider.
109/// Keys are refreshed after `JWKS_CACHE_TTL` or on cache miss for a specific `kid`.
110pub struct JwksCache {
111    issuer: String,
112    cache: RwLock<Option<CachedJwks>>,
113    http: reqwest::Client,
114}
115
116struct CachedJwks {
117    jwks: JwkSet,
118    fetched_at: Instant,
119}
120
121impl std::fmt::Debug for JwksCache {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("JwksCache")
124            .field("issuer", &self.issuer)
125            .finish()
126    }
127}
128
129impl JwksCache {
130    pub fn new(issuer: String) -> Self {
131        Self {
132            issuer,
133            cache: RwLock::new(None),
134            http: reqwest::Client::builder()
135                .timeout(Duration::from_secs(10))
136                .build()
137                .expect("building JWKS HTTP client"),
138        }
139    }
140
141    /// For testing: create a cache pre-loaded with keys (no HTTP fetching needed).
142    pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
143        Self {
144            issuer,
145            cache: RwLock::new(Some(CachedJwks {
146                jwks,
147                fetched_at: Instant::now(),
148            })),
149            http: reqwest::Client::new(),
150        }
151    }
152
153    /// Get the JWKS, fetching from the provider if the cache is stale or empty.
154    async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
155        // Check cache
156        {
157            let cache = self.cache.read().await;
158            if let Some(ref cached) = *cache
159                && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
160            {
161                return Ok(cached.jwks.clone());
162            }
163        }
164
165        // Cache miss or stale — fetch fresh JWKS
166        self.refresh().await
167    }
168
169    /// Force-refresh the JWKS (e.g., when a kid is not found in the current set).
170    async fn refresh(&self) -> anyhow::Result<JwkSet> {
171        let jwks_uri = self.discover_jwks_uri().await?;
172        debug!("Fetching JWKS from {jwks_uri}");
173
174        let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
175        info!(
176            "Fetched {} keys from JWKS endpoint",
177            jwks.keys.len()
178        );
179
180        let mut cache = self.cache.write().await;
181        *cache = Some(CachedJwks {
182            jwks: jwks.clone(),
183            fetched_at: Instant::now(),
184        });
185
186        Ok(jwks)
187    }
188
189    /// Discover the JWKS URI from the OIDC discovery endpoint.
190    async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
191        let discovery_url = format!(
192            "{}/.well-known/openid-configuration",
193            self.issuer.trim_end_matches('/')
194        );
195
196        let resp: serde_json::Value = self
197            .http
198            .get(&discovery_url)
199            .send()
200            .await?
201            .json()
202            .await?;
203
204        resp.get("jwks_uri")
205            .and_then(|v| v.as_str())
206            .map(String::from)
207            .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
208    }
209
210    /// Find a decoding key by `kid` (key ID) from the cached JWKS.
211    /// If the kid isn't found, refreshes the cache once and retries.
212    async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
213        let jwks = self.get_jwks().await?;
214
215        // Try to find by kid
216        if let Some(key) = find_key_in_set(&jwks, kid) {
217            return Ok(key);
218        }
219
220        // kid not found — refresh and retry (key rotation)
221        debug!("kid '{kid}' not in JWKS cache, refreshing");
222        let jwks = self.refresh().await?;
223
224        find_key_in_set(&jwks, kid)
225            .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
226    }
227
228    /// Find a decoding key when no kid is provided (use the first matching key).
229    async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
230        let jwks = self.get_jwks().await?;
231
232        for key in &jwks.keys {
233            if let Ok(dk) = DecodingKey::from_jwk(key) {
234                // Use the first key that decodes successfully — if the JWK
235                // specifies an algorithm, the jsonwebtoken Validation will
236                // catch mismatches during decode.
237                let _ = alg; // algorithm check happens at decode time
238                return Ok(dk);
239            }
240        }
241
242        anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
243    }
244}
245
246fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
247    jwks.keys
248        .iter()
249        .find(|k| k.common.key_id.as_deref() == Some(kid))
250        .and_then(|k| DecodingKey::from_jwk(k).ok())
251}
252
253// ── Middleware ───────────────────────────────────────────────
254
255/// Axum middleware that enforces authentication based on the configured mode.
256///
257/// When both JWT and API-key auth are enabled, dispatch is based on token shape:
258/// if the Bearer token parses as a JWS header it takes the JWT path, otherwise
259/// the API-key path. A semantically-invalid JWT (expired, forged signature, wrong
260/// audience) is rejected and is *not* retried as an API key — a token that looks
261/// like a JWT is treated as a JWT.
262///
263/// **Bootstrap window:** `POST /api/v1/api-keys` is accepted without a Bearer
264/// token iff the `api_keys` table is empty. This is the only way a freshly
265/// deployed server running in API-key or combined mode can receive its first
266/// credential without operator shell access. The window closes the moment any
267/// key exists.
268pub async fn auth_middleware<S: WorkflowStore>(
269    State(state): State<Arc<AppState<S>>>,
270    request: Request,
271    next: Next,
272) -> Response {
273    let auth = &state.auth_mode;
274
275    if !auth.is_enabled() {
276        return next.run(request).await;
277    }
278
279    if is_bootstrap_request(&request) {
280        match state.engine.store().api_keys_empty().await {
281            Ok(true) => {
282                info!(
283                    "Allowing unauthenticated POST /api/v1/api-keys — api_keys table is empty (bootstrap window)"
284                );
285                return next.run(request).await;
286            }
287            Ok(false) => {
288                // Fall through to normal auth — bootstrap window is closed.
289            }
290            Err(e) => {
291                warn!("api_keys_empty check failed: {e}");
292                return (
293                    StatusCode::INTERNAL_SERVER_ERROR,
294                    Json(serde_json::json!({"error": "auth bootstrap check failed"})),
295                )
296                    .into_response();
297            }
298        }
299    }
300
301    let token = match extract_bearer(&request) {
302        Some(t) => t,
303        None => return auth_error("Missing Authorization: Bearer <token>"),
304    };
305
306    if jsonwebtoken::decode_header(token).is_ok() {
307        match &auth.jwt {
308            Some(jwt) => {
309                validate_jwt(
310                    &jwt.issuer,
311                    jwt.audience.as_deref(),
312                    &jwt.jwks_cache,
313                    request,
314                    next,
315                )
316                .await
317            }
318            None => auth_error("JWT authentication is not enabled on this server"),
319        }
320    } else if auth.api_key {
321        validate_api_key(state, request, next).await
322    } else {
323        auth_error("Token is not a valid JWT and API-key authentication is not enabled")
324    }
325}
326
327async fn validate_api_key<S: WorkflowStore>(
328    state: Arc<AppState<S>>,
329    request: Request,
330    next: Next,
331) -> Response {
332    let token = match extract_bearer(&request) {
333        Some(t) => t,
334        None => return auth_error("Missing Authorization: Bearer <api-key>"),
335    };
336
337    let hash = hash_api_key(token);
338    match state.engine.store().validate_api_key(&hash).await {
339        Ok(true) => next.run(request).await,
340        Ok(false) => {
341            warn!(
342                "Invalid API key (prefix: {}...)",
343                &token[..8.min(token.len())]
344            );
345            auth_error("Invalid API key")
346        }
347        Err(e) => {
348            warn!("API key validation error: {e}");
349            (
350                StatusCode::INTERNAL_SERVER_ERROR,
351                Json(serde_json::json!({"error": "auth check failed"})),
352            )
353                .into_response()
354        }
355    }
356}
357
358async fn validate_jwt(
359    issuer: &str,
360    audience: Option<&str>,
361    jwks_cache: &JwksCache,
362    request: Request,
363    next: Next,
364) -> Response {
365    let token = match extract_bearer(&request) {
366        Some(t) => t,
367        None => return auth_error("Missing Authorization: Bearer <jwt>"),
368    };
369
370    // Decode header to get algorithm and kid
371    let header = match jsonwebtoken::decode_header(token) {
372        Ok(h) => h,
373        Err(e) => {
374            warn!("Invalid JWT header: {e}");
375            return auth_error("Invalid JWT");
376        }
377    };
378
379    // Find the decoding key from JWKS
380    let decoding_key = match &header.kid {
381        Some(kid) => match jwks_cache.find_key(kid).await {
382            Ok(key) => key,
383            Err(e) => {
384                warn!("JWKS key lookup failed: {e}");
385                return auth_error("JWT validation failed: key not found");
386            }
387        },
388        None => match jwks_cache.find_any_key(header.alg).await {
389            Ok(key) => key,
390            Err(e) => {
391                warn!("JWKS key lookup failed (no kid): {e}");
392                return auth_error("JWT validation failed: no suitable key");
393            }
394        },
395    };
396
397    // Build validation rules
398    let mut validation = Validation::new(header.alg);
399    validation.set_issuer(&[issuer]);
400    if let Some(aud) = audience {
401        validation.set_audience(&[aud]);
402    } else {
403        validation.validate_aud = false;
404    }
405
406    // Validate signature + claims
407    match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
408        Ok(_) => next.run(request).await,
409        Err(e) => {
410            warn!("JWT validation failed: {e}");
411            auth_error(&format!("JWT validation failed: {e}"))
412        }
413    }
414}
415
416fn extract_bearer(request: &Request) -> Option<&str> {
417    request
418        .headers()
419        .get("authorization")
420        .and_then(|v| v.to_str().ok())
421        .and_then(|v| v.strip_prefix("Bearer "))
422}
423
424/// True iff the request is the bootstrap-window endpoint
425/// (`POST /api/v1/api-keys`). The caller is still responsible for checking
426/// that the `api_keys` table is empty before actually allowing unauth access.
427fn is_bootstrap_request(request: &Request) -> bool {
428    request.method() == axum::http::Method::POST
429        && request.uri().path() == "/api/v1/api-keys"
430}
431
432fn auth_error(msg: &str) -> Response {
433    (
434        StatusCode::UNAUTHORIZED,
435        Json(serde_json::json!({"error": msg})),
436    )
437        .into_response()
438}
439
440// ── API Key Helpers ─────────────────────────────────────────
441
442/// Hash an API key with SHA-256 for storage/lookup.
443pub fn hash_api_key(key: &str) -> String {
444    let mut hasher = Sha256::new();
445    hasher.update(key.as_bytes());
446    data_encoding::HEXLOWER.encode(&hasher.finalize())
447}
448
449/// Generate a new random API key (32 bytes, hex-encoded).
450pub fn generate_api_key() -> String {
451    use rand::Rng;
452    let bytes: [u8; 32] = rand::rng().random();
453    format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
454}
455
456/// Extract the prefix (first 8 chars after "assay_") for display.
457pub fn key_prefix(key: &str) -> String {
458    let stripped = key.strip_prefix("assay_").unwrap_or(key);
459    format!("assay_{}...", &stripped[..8.min(stripped.len())])
460}