1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); #[derive(Clone, Debug, Default)]
31pub struct AuthMode {
32 pub api_key: bool,
35 pub jwt: Option<JwtConfig>,
38}
39
40#[derive(Clone, Debug)]
42pub struct JwtConfig {
43 pub issuer: String,
44 pub audience: Option<String>,
45 pub jwks_cache: Arc<JwksCache>,
46}
47
48impl AuthMode {
49 pub fn no_auth() -> Self {
51 Self::default()
52 }
53
54 pub fn jwt(issuer: String, audience: Option<String>) -> Self {
56 Self {
57 api_key: false,
58 jwt: Some(JwtConfig::new(issuer, audience)),
59 }
60 }
61
62 pub fn api_key() -> Self {
64 Self {
65 api_key: true,
66 jwt: None,
67 }
68 }
69
70 pub fn combined(issuer: String, audience: Option<String>) -> Self {
73 Self {
74 api_key: true,
75 jwt: Some(JwtConfig::new(issuer, audience)),
76 }
77 }
78
79 pub fn is_enabled(&self) -> bool {
81 self.api_key || self.jwt.is_some()
82 }
83
84 pub fn describe(&self) -> String {
86 match (self.jwt.as_ref(), self.api_key) {
87 (None, false) => "no-auth (open access)".to_string(),
88 (None, true) => "api-key".to_string(),
89 (Some(c), false) => format!("jwt (issuer: {})", c.issuer),
90 (Some(c), true) => format!("jwt (issuer: {}) + api-key", c.issuer),
91 }
92 }
93}
94
95impl JwtConfig {
96 pub fn new(issuer: String, audience: Option<String>) -> Self {
98 Self {
99 jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
100 issuer,
101 audience,
102 }
103 }
104}
105
106pub struct JwksCache {
111 issuer: String,
112 cache: RwLock<Option<CachedJwks>>,
113 http: reqwest::Client,
114}
115
116struct CachedJwks {
117 jwks: JwkSet,
118 fetched_at: Instant,
119}
120
121impl std::fmt::Debug for JwksCache {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("JwksCache")
124 .field("issuer", &self.issuer)
125 .finish()
126 }
127}
128
129impl JwksCache {
130 pub fn new(issuer: String) -> Self {
131 Self {
132 issuer,
133 cache: RwLock::new(None),
134 http: reqwest::Client::builder()
135 .timeout(Duration::from_secs(10))
136 .build()
137 .expect("building JWKS HTTP client"),
138 }
139 }
140
141 pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
143 Self {
144 issuer,
145 cache: RwLock::new(Some(CachedJwks {
146 jwks,
147 fetched_at: Instant::now(),
148 })),
149 http: reqwest::Client::new(),
150 }
151 }
152
153 async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
155 {
157 let cache = self.cache.read().await;
158 if let Some(ref cached) = *cache
159 && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
160 {
161 return Ok(cached.jwks.clone());
162 }
163 }
164
165 self.refresh().await
167 }
168
169 async fn refresh(&self) -> anyhow::Result<JwkSet> {
171 let jwks_uri = self.discover_jwks_uri().await?;
172 debug!("Fetching JWKS from {jwks_uri}");
173
174 let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
175 info!(
176 "Fetched {} keys from JWKS endpoint",
177 jwks.keys.len()
178 );
179
180 let mut cache = self.cache.write().await;
181 *cache = Some(CachedJwks {
182 jwks: jwks.clone(),
183 fetched_at: Instant::now(),
184 });
185
186 Ok(jwks)
187 }
188
189 async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
191 let discovery_url = format!(
192 "{}/.well-known/openid-configuration",
193 self.issuer.trim_end_matches('/')
194 );
195
196 let resp: serde_json::Value = self
197 .http
198 .get(&discovery_url)
199 .send()
200 .await?
201 .json()
202 .await?;
203
204 resp.get("jwks_uri")
205 .and_then(|v| v.as_str())
206 .map(String::from)
207 .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
208 }
209
210 async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
213 let jwks = self.get_jwks().await?;
214
215 if let Some(key) = find_key_in_set(&jwks, kid) {
217 return Ok(key);
218 }
219
220 debug!("kid '{kid}' not in JWKS cache, refreshing");
222 let jwks = self.refresh().await?;
223
224 find_key_in_set(&jwks, kid)
225 .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
226 }
227
228 async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
230 let jwks = self.get_jwks().await?;
231
232 for key in &jwks.keys {
233 if let Ok(dk) = DecodingKey::from_jwk(key) {
234 let _ = alg; return Ok(dk);
239 }
240 }
241
242 anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
243 }
244}
245
246fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
247 jwks.keys
248 .iter()
249 .find(|k| k.common.key_id.as_deref() == Some(kid))
250 .and_then(|k| DecodingKey::from_jwk(k).ok())
251}
252
253pub async fn auth_middleware<S: WorkflowStore>(
269 State(state): State<Arc<AppState<S>>>,
270 request: Request,
271 next: Next,
272) -> Response {
273 let auth = &state.auth_mode;
274
275 if !auth.is_enabled() {
276 return next.run(request).await;
277 }
278
279 if is_bootstrap_request(&request) {
280 match state.engine.store().api_keys_empty().await {
281 Ok(true) => {
282 info!(
283 "Allowing unauthenticated POST /api/v1/api-keys — api_keys table is empty (bootstrap window)"
284 );
285 return next.run(request).await;
286 }
287 Ok(false) => {
288 }
290 Err(e) => {
291 warn!("api_keys_empty check failed: {e}");
292 return (
293 StatusCode::INTERNAL_SERVER_ERROR,
294 Json(serde_json::json!({"error": "auth bootstrap check failed"})),
295 )
296 .into_response();
297 }
298 }
299 }
300
301 let token = match extract_bearer(&request) {
302 Some(t) => t,
303 None => return auth_error("Missing Authorization: Bearer <token>"),
304 };
305
306 if jsonwebtoken::decode_header(token).is_ok() {
307 match &auth.jwt {
308 Some(jwt) => {
309 validate_jwt(
310 &jwt.issuer,
311 jwt.audience.as_deref(),
312 &jwt.jwks_cache,
313 request,
314 next,
315 )
316 .await
317 }
318 None => auth_error("JWT authentication is not enabled on this server"),
319 }
320 } else if auth.api_key {
321 validate_api_key(state, request, next).await
322 } else {
323 auth_error("Token is not a valid JWT and API-key authentication is not enabled")
324 }
325}
326
327async fn validate_api_key<S: WorkflowStore>(
328 state: Arc<AppState<S>>,
329 request: Request,
330 next: Next,
331) -> Response {
332 let token = match extract_bearer(&request) {
333 Some(t) => t,
334 None => return auth_error("Missing Authorization: Bearer <api-key>"),
335 };
336
337 let hash = hash_api_key(token);
338 match state.engine.store().validate_api_key(&hash).await {
339 Ok(true) => next.run(request).await,
340 Ok(false) => {
341 warn!(
342 "Invalid API key (prefix: {}...)",
343 &token[..8.min(token.len())]
344 );
345 auth_error("Invalid API key")
346 }
347 Err(e) => {
348 warn!("API key validation error: {e}");
349 (
350 StatusCode::INTERNAL_SERVER_ERROR,
351 Json(serde_json::json!({"error": "auth check failed"})),
352 )
353 .into_response()
354 }
355 }
356}
357
358async fn validate_jwt(
359 issuer: &str,
360 audience: Option<&str>,
361 jwks_cache: &JwksCache,
362 request: Request,
363 next: Next,
364) -> Response {
365 let token = match extract_bearer(&request) {
366 Some(t) => t,
367 None => return auth_error("Missing Authorization: Bearer <jwt>"),
368 };
369
370 let header = match jsonwebtoken::decode_header(token) {
372 Ok(h) => h,
373 Err(e) => {
374 warn!("Invalid JWT header: {e}");
375 return auth_error("Invalid JWT");
376 }
377 };
378
379 let decoding_key = match &header.kid {
381 Some(kid) => match jwks_cache.find_key(kid).await {
382 Ok(key) => key,
383 Err(e) => {
384 warn!("JWKS key lookup failed: {e}");
385 return auth_error("JWT validation failed: key not found");
386 }
387 },
388 None => match jwks_cache.find_any_key(header.alg).await {
389 Ok(key) => key,
390 Err(e) => {
391 warn!("JWKS key lookup failed (no kid): {e}");
392 return auth_error("JWT validation failed: no suitable key");
393 }
394 },
395 };
396
397 let mut validation = Validation::new(header.alg);
399 validation.set_issuer(&[issuer]);
400 if let Some(aud) = audience {
401 validation.set_audience(&[aud]);
402 } else {
403 validation.validate_aud = false;
404 }
405
406 match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
408 Ok(_) => next.run(request).await,
409 Err(e) => {
410 warn!("JWT validation failed: {e}");
411 auth_error(&format!("JWT validation failed: {e}"))
412 }
413 }
414}
415
416fn extract_bearer(request: &Request) -> Option<&str> {
417 request
418 .headers()
419 .get("authorization")
420 .and_then(|v| v.to_str().ok())
421 .and_then(|v| v.strip_prefix("Bearer "))
422}
423
424fn is_bootstrap_request(request: &Request) -> bool {
428 request.method() == axum::http::Method::POST
429 && request.uri().path() == "/api/v1/api-keys"
430}
431
432fn auth_error(msg: &str) -> Response {
433 (
434 StatusCode::UNAUTHORIZED,
435 Json(serde_json::json!({"error": msg})),
436 )
437 .into_response()
438}
439
440pub fn hash_api_key(key: &str) -> String {
444 let mut hasher = Sha256::new();
445 hasher.update(key.as_bytes());
446 data_encoding::HEXLOWER.encode(&hasher.finalize())
447}
448
449pub fn generate_api_key() -> String {
451 use rand::Rng;
452 let bytes: [u8; 32] = rand::rng().random();
453 format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
454}
455
456pub fn key_prefix(key: &str) -> String {
458 let stripped = key.strip_prefix("assay_").unwrap_or(key);
459 format!("assay_{}...", &stripped[..8.min(stripped.len())])
460}