Skip to main content

assay_workflow/api/
auth.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
19
20// ── Auth Mode ───────────────────────────────────────────────
21
22/// Auth configuration — determines which mode the engine runs in.
23#[derive(Clone, Debug, Default)]
24pub enum AuthMode {
25    /// No authentication — all requests allowed (dev mode).
26    #[default]
27    NoAuth,
28    /// API key authentication — Bearer token validated against hashed keys in DB.
29    ApiKey,
30    /// JWT/OIDC — validate Bearer JWT signature against JWKS from the issuer.
31    Jwt {
32        issuer: String,
33        audience: Option<String>,
34        jwks_cache: Arc<JwksCache>,
35    },
36}
37
38impl AuthMode {
39    /// Create a JWT auth mode that fetches JWKS from the issuer's OIDC discovery.
40    pub fn jwt(issuer: String, audience: Option<String>) -> Self {
41        Self::Jwt {
42            jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
43            issuer,
44            audience,
45        }
46    }
47}
48
49// ── JWKS Cache ──────────────────────────────────────────────
50
51/// Caches JWKS keys fetched from the OIDC provider.
52/// Keys are refreshed after `JWKS_CACHE_TTL` or on cache miss for a specific `kid`.
53pub struct JwksCache {
54    issuer: String,
55    cache: RwLock<Option<CachedJwks>>,
56    http: reqwest::Client,
57}
58
59struct CachedJwks {
60    jwks: JwkSet,
61    fetched_at: Instant,
62}
63
64impl std::fmt::Debug for JwksCache {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("JwksCache")
67            .field("issuer", &self.issuer)
68            .finish()
69    }
70}
71
72impl JwksCache {
73    pub fn new(issuer: String) -> Self {
74        Self {
75            issuer,
76            cache: RwLock::new(None),
77            http: reqwest::Client::builder()
78                .timeout(Duration::from_secs(10))
79                .build()
80                .expect("building JWKS HTTP client"),
81        }
82    }
83
84    /// For testing: create a cache pre-loaded with keys (no HTTP fetching needed).
85    pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
86        Self {
87            issuer,
88            cache: RwLock::new(Some(CachedJwks {
89                jwks,
90                fetched_at: Instant::now(),
91            })),
92            http: reqwest::Client::new(),
93        }
94    }
95
96    /// Get the JWKS, fetching from the provider if the cache is stale or empty.
97    async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
98        // Check cache
99        {
100            let cache = self.cache.read().await;
101            if let Some(ref cached) = *cache
102                && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
103            {
104                return Ok(cached.jwks.clone());
105            }
106        }
107
108        // Cache miss or stale — fetch fresh JWKS
109        self.refresh().await
110    }
111
112    /// Force-refresh the JWKS (e.g., when a kid is not found in the current set).
113    async fn refresh(&self) -> anyhow::Result<JwkSet> {
114        let jwks_uri = self.discover_jwks_uri().await?;
115        debug!("Fetching JWKS from {jwks_uri}");
116
117        let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
118        info!(
119            "Fetched {} keys from JWKS endpoint",
120            jwks.keys.len()
121        );
122
123        let mut cache = self.cache.write().await;
124        *cache = Some(CachedJwks {
125            jwks: jwks.clone(),
126            fetched_at: Instant::now(),
127        });
128
129        Ok(jwks)
130    }
131
132    /// Discover the JWKS URI from the OIDC discovery endpoint.
133    async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
134        let discovery_url = format!(
135            "{}/.well-known/openid-configuration",
136            self.issuer.trim_end_matches('/')
137        );
138
139        let resp: serde_json::Value = self
140            .http
141            .get(&discovery_url)
142            .send()
143            .await?
144            .json()
145            .await?;
146
147        resp.get("jwks_uri")
148            .and_then(|v| v.as_str())
149            .map(String::from)
150            .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
151    }
152
153    /// Find a decoding key by `kid` (key ID) from the cached JWKS.
154    /// If the kid isn't found, refreshes the cache once and retries.
155    async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
156        let jwks = self.get_jwks().await?;
157
158        // Try to find by kid
159        if let Some(key) = find_key_in_set(&jwks, kid) {
160            return Ok(key);
161        }
162
163        // kid not found — refresh and retry (key rotation)
164        debug!("kid '{kid}' not in JWKS cache, refreshing");
165        let jwks = self.refresh().await?;
166
167        find_key_in_set(&jwks, kid)
168            .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
169    }
170
171    /// Find a decoding key when no kid is provided (use the first matching key).
172    async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
173        let jwks = self.get_jwks().await?;
174
175        for key in &jwks.keys {
176            if let Ok(dk) = DecodingKey::from_jwk(key) {
177                // Use the first key that decodes successfully — if the JWK
178                // specifies an algorithm, the jsonwebtoken Validation will
179                // catch mismatches during decode.
180                let _ = alg; // algorithm check happens at decode time
181                return Ok(dk);
182            }
183        }
184
185        anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
186    }
187}
188
189fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
190    jwks.keys
191        .iter()
192        .find(|k| k.common.key_id.as_deref() == Some(kid))
193        .and_then(|k| DecodingKey::from_jwk(k).ok())
194}
195
196// ── Middleware ───────────────────────────────────────────────
197
198/// Axum middleware that enforces authentication based on the configured mode.
199pub async fn auth_middleware<S: WorkflowStore>(
200    State(state): State<Arc<AppState<S>>>,
201    request: Request,
202    next: Next,
203) -> Response {
204    match &state.auth_mode {
205        AuthMode::NoAuth => next.run(request).await,
206        AuthMode::ApiKey => validate_api_key(state, request, next).await,
207        AuthMode::Jwt {
208            issuer,
209            audience,
210            jwks_cache,
211        } => validate_jwt(issuer, audience.as_deref(), jwks_cache, request, next).await,
212    }
213}
214
215async fn validate_api_key<S: WorkflowStore>(
216    state: Arc<AppState<S>>,
217    request: Request,
218    next: Next,
219) -> Response {
220    let token = match extract_bearer(&request) {
221        Some(t) => t,
222        None => return auth_error("Missing Authorization: Bearer <api-key>"),
223    };
224
225    let hash = hash_api_key(token);
226    match state.engine.store().validate_api_key(&hash).await {
227        Ok(true) => next.run(request).await,
228        Ok(false) => {
229            warn!(
230                "Invalid API key (prefix: {}...)",
231                &token[..8.min(token.len())]
232            );
233            auth_error("Invalid API key")
234        }
235        Err(e) => {
236            warn!("API key validation error: {e}");
237            (
238                StatusCode::INTERNAL_SERVER_ERROR,
239                Json(serde_json::json!({"error": "auth check failed"})),
240            )
241                .into_response()
242        }
243    }
244}
245
246async fn validate_jwt(
247    issuer: &str,
248    audience: Option<&str>,
249    jwks_cache: &JwksCache,
250    request: Request,
251    next: Next,
252) -> Response {
253    let token = match extract_bearer(&request) {
254        Some(t) => t,
255        None => return auth_error("Missing Authorization: Bearer <jwt>"),
256    };
257
258    // Decode header to get algorithm and kid
259    let header = match jsonwebtoken::decode_header(token) {
260        Ok(h) => h,
261        Err(e) => {
262            warn!("Invalid JWT header: {e}");
263            return auth_error("Invalid JWT");
264        }
265    };
266
267    // Find the decoding key from JWKS
268    let decoding_key = match &header.kid {
269        Some(kid) => match jwks_cache.find_key(kid).await {
270            Ok(key) => key,
271            Err(e) => {
272                warn!("JWKS key lookup failed: {e}");
273                return auth_error("JWT validation failed: key not found");
274            }
275        },
276        None => match jwks_cache.find_any_key(header.alg).await {
277            Ok(key) => key,
278            Err(e) => {
279                warn!("JWKS key lookup failed (no kid): {e}");
280                return auth_error("JWT validation failed: no suitable key");
281            }
282        },
283    };
284
285    // Build validation rules
286    let mut validation = Validation::new(header.alg);
287    validation.set_issuer(&[issuer]);
288    if let Some(aud) = audience {
289        validation.set_audience(&[aud]);
290    } else {
291        validation.validate_aud = false;
292    }
293
294    // Validate signature + claims
295    match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
296        Ok(_) => next.run(request).await,
297        Err(e) => {
298            warn!("JWT validation failed: {e}");
299            auth_error(&format!("JWT validation failed: {e}"))
300        }
301    }
302}
303
304fn extract_bearer(request: &Request) -> Option<&str> {
305    request
306        .headers()
307        .get("authorization")
308        .and_then(|v| v.to_str().ok())
309        .and_then(|v| v.strip_prefix("Bearer "))
310}
311
312fn auth_error(msg: &str) -> Response {
313    (
314        StatusCode::UNAUTHORIZED,
315        Json(serde_json::json!({"error": msg})),
316    )
317        .into_response()
318}
319
320// ── API Key Helpers ─────────────────────────────────────────
321
322/// Hash an API key with SHA-256 for storage/lookup.
323pub fn hash_api_key(key: &str) -> String {
324    let mut hasher = Sha256::new();
325    hasher.update(key.as_bytes());
326    data_encoding::HEXLOWER.encode(&hasher.finalize())
327}
328
329/// Generate a new random API key (32 bytes, hex-encoded).
330pub fn generate_api_key() -> String {
331    use rand::Rng;
332    let bytes: [u8; 32] = rand::rng().random();
333    format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
334}
335
336/// Extract the prefix (first 8 chars after "assay_") for display.
337pub fn key_prefix(key: &str) -> String {
338    let stripped = key.strip_prefix("assay_").unwrap_or(key);
339    format!("assay_{}...", &stripped[..8.min(stripped.len())])
340}