Skip to main content

assay_workflow/api/
auth.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
19
20// ── Auth Mode ───────────────────────────────────────────────
21
22/// Auth configuration for the engine's HTTP API.
23///
24/// Both authentication methods (JWT and API key) can be enabled at the same time.
25/// When both are enabled, the middleware dispatches on token shape — tokens that
26/// parse as a JWS header are validated as JWTs, everything else is validated as
27/// an API key. This lets the same server accept long-lived machine API keys
28/// alongside short-lived OIDC-issued user tokens without the caller picking a
29/// mode up front.
30#[derive(Clone, Debug, Default)]
31pub struct AuthMode {
32    /// API-key authentication enabled. When true, Bearer tokens that are not
33    /// JWT-shaped are validated against the `api_keys` table.
34    pub api_key: bool,
35    /// JWT authentication enabled. When set, Bearer tokens that parse as a
36    /// JWS header are validated against the issuer's JWKS.
37    pub jwt: Option<JwtConfig>,
38}
39
40/// JWT validation configuration.
41#[derive(Clone, Debug)]
42pub struct JwtConfig {
43    pub issuer: String,
44    pub audience: Option<String>,
45    pub jwks_cache: Arc<JwksCache>,
46}
47
48impl AuthMode {
49    /// Open access — no authentication. All requests allowed.
50    pub fn no_auth() -> Self {
51        Self::default()
52    }
53
54    /// JWT/OIDC only. Tokens are validated against the issuer's JWKS.
55    pub fn jwt(issuer: String, audience: Option<String>) -> Self {
56        Self {
57            api_key: false,
58            jwt: Some(JwtConfig::new(issuer, audience)),
59        }
60    }
61
62    /// API key only. Bearer tokens are hashed and looked up in the store.
63    pub fn api_key() -> Self {
64        Self {
65            api_key: true,
66            jwt: None,
67        }
68    }
69
70    /// Both JWT and API key. Tokens that parse as JWTs take the JWT path;
71    /// everything else takes the API-key path.
72    pub fn combined(issuer: String, audience: Option<String>) -> Self {
73        Self {
74            api_key: true,
75            jwt: Some(JwtConfig::new(issuer, audience)),
76        }
77    }
78
79    /// True if any authentication method is enabled.
80    pub fn is_enabled(&self) -> bool {
81        self.api_key || self.jwt.is_some()
82    }
83
84    /// Human-readable summary for startup logging.
85    pub fn describe(&self) -> String {
86        match (self.jwt.as_ref(), self.api_key) {
87            (None, false) => "no-auth (open access)".to_string(),
88            (None, true) => "api-key".to_string(),
89            (Some(c), false) => format!("jwt (issuer: {})", c.issuer),
90            (Some(c), true) => format!("jwt (issuer: {}) + api-key", c.issuer),
91        }
92    }
93}
94
95impl JwtConfig {
96    /// Build a JwtConfig with a fresh JWKS cache pointed at `issuer`'s OIDC discovery endpoint.
97    pub fn new(issuer: String, audience: Option<String>) -> Self {
98        Self {
99            jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
100            issuer,
101            audience,
102        }
103    }
104}
105
106// ── JWKS Cache ──────────────────────────────────────────────
107
108/// Caches JWKS keys fetched from the OIDC provider.
109/// Keys are refreshed after `JWKS_CACHE_TTL` or on cache miss for a specific `kid`.
110pub struct JwksCache {
111    issuer: String,
112    cache: RwLock<Option<CachedJwks>>,
113    http: reqwest::Client,
114}
115
116struct CachedJwks {
117    jwks: JwkSet,
118    fetched_at: Instant,
119}
120
121impl std::fmt::Debug for JwksCache {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("JwksCache")
124            .field("issuer", &self.issuer)
125            .finish()
126    }
127}
128
129impl JwksCache {
130    pub fn new(issuer: String) -> Self {
131        Self {
132            issuer,
133            cache: RwLock::new(None),
134            http: reqwest::Client::builder()
135                .timeout(Duration::from_secs(10))
136                .build()
137                .expect("building JWKS HTTP client"),
138        }
139    }
140
141    /// For testing: create a cache pre-loaded with keys (no HTTP fetching needed).
142    pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
143        Self {
144            issuer,
145            cache: RwLock::new(Some(CachedJwks {
146                jwks,
147                fetched_at: Instant::now(),
148            })),
149            http: reqwest::Client::new(),
150        }
151    }
152
153    /// Get the JWKS, fetching from the provider if the cache is stale or empty.
154    async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
155        // Check cache
156        {
157            let cache = self.cache.read().await;
158            if let Some(ref cached) = *cache
159                && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
160            {
161                return Ok(cached.jwks.clone());
162            }
163        }
164
165        // Cache miss or stale — fetch fresh JWKS
166        self.refresh().await
167    }
168
169    /// Force-refresh the JWKS (e.g., when a kid is not found in the current set).
170    async fn refresh(&self) -> anyhow::Result<JwkSet> {
171        let jwks_uri = self.discover_jwks_uri().await?;
172        debug!("Fetching JWKS from {jwks_uri}");
173
174        let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
175        info!(
176            "Fetched {} keys from JWKS endpoint",
177            jwks.keys.len()
178        );
179
180        let mut cache = self.cache.write().await;
181        *cache = Some(CachedJwks {
182            jwks: jwks.clone(),
183            fetched_at: Instant::now(),
184        });
185
186        Ok(jwks)
187    }
188
189    /// Discover the JWKS URI from the OIDC discovery endpoint.
190    async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
191        let discovery_url = format!(
192            "{}/.well-known/openid-configuration",
193            self.issuer.trim_end_matches('/')
194        );
195
196        let resp: serde_json::Value = self
197            .http
198            .get(&discovery_url)
199            .send()
200            .await?
201            .json()
202            .await?;
203
204        resp.get("jwks_uri")
205            .and_then(|v| v.as_str())
206            .map(String::from)
207            .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
208    }
209
210    /// Find a decoding key by `kid` (key ID) from the cached JWKS.
211    /// If the kid isn't found, refreshes the cache once and retries.
212    async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
213        let jwks = self.get_jwks().await?;
214
215        // Try to find by kid
216        if let Some(key) = find_key_in_set(&jwks, kid) {
217            return Ok(key);
218        }
219
220        // kid not found — refresh and retry (key rotation)
221        debug!("kid '{kid}' not in JWKS cache, refreshing");
222        let jwks = self.refresh().await?;
223
224        find_key_in_set(&jwks, kid)
225            .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
226    }
227
228    /// Find a decoding key when no kid is provided (use the first matching key).
229    async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
230        let jwks = self.get_jwks().await?;
231
232        for key in &jwks.keys {
233            if let Ok(dk) = DecodingKey::from_jwk(key) {
234                // Use the first key that decodes successfully — if the JWK
235                // specifies an algorithm, the jsonwebtoken Validation will
236                // catch mismatches during decode.
237                let _ = alg; // algorithm check happens at decode time
238                return Ok(dk);
239            }
240        }
241
242        anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
243    }
244}
245
246fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
247    jwks.keys
248        .iter()
249        .find(|k| k.common.key_id.as_deref() == Some(kid))
250        .and_then(|k| DecodingKey::from_jwk(k).ok())
251}
252
253// ── Middleware ───────────────────────────────────────────────
254
255/// Axum middleware that enforces authentication based on the configured mode.
256///
257/// When both JWT and API-key auth are enabled, dispatch is based on token shape:
258/// if the Bearer token parses as a JWS header it takes the JWT path, otherwise
259/// the API-key path. A semantically-invalid JWT (expired, forged signature, wrong
260/// audience) is rejected and is *not* retried as an API key — a token that looks
261/// like a JWT is treated as a JWT.
262pub async fn auth_middleware<S: WorkflowStore>(
263    State(state): State<Arc<AppState<S>>>,
264    request: Request,
265    next: Next,
266) -> Response {
267    let auth = &state.auth_mode;
268
269    if !auth.is_enabled() {
270        return next.run(request).await;
271    }
272
273    let token = match extract_bearer(&request) {
274        Some(t) => t,
275        None => return auth_error("Missing Authorization: Bearer <token>"),
276    };
277
278    if jsonwebtoken::decode_header(token).is_ok() {
279        match &auth.jwt {
280            Some(jwt) => {
281                validate_jwt(
282                    &jwt.issuer,
283                    jwt.audience.as_deref(),
284                    &jwt.jwks_cache,
285                    request,
286                    next,
287                )
288                .await
289            }
290            None => auth_error("JWT authentication is not enabled on this server"),
291        }
292    } else if auth.api_key {
293        validate_api_key(state, request, next).await
294    } else {
295        auth_error("Token is not a valid JWT and API-key authentication is not enabled")
296    }
297}
298
299async fn validate_api_key<S: WorkflowStore>(
300    state: Arc<AppState<S>>,
301    request: Request,
302    next: Next,
303) -> Response {
304    let token = match extract_bearer(&request) {
305        Some(t) => t,
306        None => return auth_error("Missing Authorization: Bearer <api-key>"),
307    };
308
309    let hash = hash_api_key(token);
310    match state.engine.store().validate_api_key(&hash).await {
311        Ok(true) => next.run(request).await,
312        Ok(false) => {
313            warn!(
314                "Invalid API key (prefix: {}...)",
315                &token[..8.min(token.len())]
316            );
317            auth_error("Invalid API key")
318        }
319        Err(e) => {
320            warn!("API key validation error: {e}");
321            (
322                StatusCode::INTERNAL_SERVER_ERROR,
323                Json(serde_json::json!({"error": "auth check failed"})),
324            )
325                .into_response()
326        }
327    }
328}
329
330async fn validate_jwt(
331    issuer: &str,
332    audience: Option<&str>,
333    jwks_cache: &JwksCache,
334    request: Request,
335    next: Next,
336) -> Response {
337    let token = match extract_bearer(&request) {
338        Some(t) => t,
339        None => return auth_error("Missing Authorization: Bearer <jwt>"),
340    };
341
342    // Decode header to get algorithm and kid
343    let header = match jsonwebtoken::decode_header(token) {
344        Ok(h) => h,
345        Err(e) => {
346            warn!("Invalid JWT header: {e}");
347            return auth_error("Invalid JWT");
348        }
349    };
350
351    // Find the decoding key from JWKS
352    let decoding_key = match &header.kid {
353        Some(kid) => match jwks_cache.find_key(kid).await {
354            Ok(key) => key,
355            Err(e) => {
356                warn!("JWKS key lookup failed: {e}");
357                return auth_error("JWT validation failed: key not found");
358            }
359        },
360        None => match jwks_cache.find_any_key(header.alg).await {
361            Ok(key) => key,
362            Err(e) => {
363                warn!("JWKS key lookup failed (no kid): {e}");
364                return auth_error("JWT validation failed: no suitable key");
365            }
366        },
367    };
368
369    // Build validation rules
370    let mut validation = Validation::new(header.alg);
371    validation.set_issuer(&[issuer]);
372    if let Some(aud) = audience {
373        validation.set_audience(&[aud]);
374    } else {
375        validation.validate_aud = false;
376    }
377
378    // Validate signature + claims
379    match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
380        Ok(_) => next.run(request).await,
381        Err(e) => {
382            warn!("JWT validation failed: {e}");
383            auth_error(&format!("JWT validation failed: {e}"))
384        }
385    }
386}
387
388fn extract_bearer(request: &Request) -> Option<&str> {
389    request
390        .headers()
391        .get("authorization")
392        .and_then(|v| v.to_str().ok())
393        .and_then(|v| v.strip_prefix("Bearer "))
394}
395
396fn auth_error(msg: &str) -> Response {
397    (
398        StatusCode::UNAUTHORIZED,
399        Json(serde_json::json!({"error": msg})),
400    )
401        .into_response()
402}
403
404// ── API Key Helpers ─────────────────────────────────────────
405
406/// Hash an API key with SHA-256 for storage/lookup.
407pub fn hash_api_key(key: &str) -> String {
408    let mut hasher = Sha256::new();
409    hasher.update(key.as_bytes());
410    data_encoding::HEXLOWER.encode(&hasher.finalize())
411}
412
413/// Generate a new random API key (32 bytes, hex-encoded).
414pub fn generate_api_key() -> String {
415    use rand::Rng;
416    let bytes: [u8; 32] = rand::rng().random();
417    format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
418}
419
420/// Extract the prefix (first 8 chars after "assay_") for display.
421pub fn key_prefix(key: &str) -> String {
422    let stripped = key.strip_prefix("assay_").unwrap_or(key);
423    format!("assay_{}...", &stripped[..8.min(stripped.len())])
424}