1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); #[derive(Clone, Debug, Default)]
31pub struct AuthMode {
32 pub api_key: bool,
35 pub jwt: Option<JwtConfig>,
38}
39
40#[derive(Clone, Debug)]
42pub struct JwtConfig {
43 pub issuer: String,
44 pub audience: Option<String>,
45 pub jwks_cache: Arc<JwksCache>,
46}
47
48impl AuthMode {
49 pub fn no_auth() -> Self {
51 Self::default()
52 }
53
54 pub fn jwt(issuer: String, audience: Option<String>) -> Self {
56 Self {
57 api_key: false,
58 jwt: Some(JwtConfig::new(issuer, audience)),
59 }
60 }
61
62 pub fn api_key() -> Self {
64 Self {
65 api_key: true,
66 jwt: None,
67 }
68 }
69
70 pub fn combined(issuer: String, audience: Option<String>) -> Self {
73 Self {
74 api_key: true,
75 jwt: Some(JwtConfig::new(issuer, audience)),
76 }
77 }
78
79 pub fn is_enabled(&self) -> bool {
81 self.api_key || self.jwt.is_some()
82 }
83
84 pub fn describe(&self) -> String {
86 match (self.jwt.as_ref(), self.api_key) {
87 (None, false) => "no-auth (open access)".to_string(),
88 (None, true) => "api-key".to_string(),
89 (Some(c), false) => format!("jwt (issuer: {})", c.issuer),
90 (Some(c), true) => format!("jwt (issuer: {}) + api-key", c.issuer),
91 }
92 }
93}
94
95impl JwtConfig {
96 pub fn new(issuer: String, audience: Option<String>) -> Self {
98 Self {
99 jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
100 issuer,
101 audience,
102 }
103 }
104}
105
106pub struct JwksCache {
111 issuer: String,
112 cache: RwLock<Option<CachedJwks>>,
113 http: reqwest::Client,
114}
115
116struct CachedJwks {
117 jwks: JwkSet,
118 fetched_at: Instant,
119}
120
121impl std::fmt::Debug for JwksCache {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("JwksCache")
124 .field("issuer", &self.issuer)
125 .finish()
126 }
127}
128
129impl JwksCache {
130 pub fn new(issuer: String) -> Self {
131 Self {
132 issuer,
133 cache: RwLock::new(None),
134 http: reqwest::Client::builder()
135 .timeout(Duration::from_secs(10))
136 .build()
137 .expect("building JWKS HTTP client"),
138 }
139 }
140
141 pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
143 Self {
144 issuer,
145 cache: RwLock::new(Some(CachedJwks {
146 jwks,
147 fetched_at: Instant::now(),
148 })),
149 http: reqwest::Client::new(),
150 }
151 }
152
153 async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
155 {
157 let cache = self.cache.read().await;
158 if let Some(ref cached) = *cache
159 && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
160 {
161 return Ok(cached.jwks.clone());
162 }
163 }
164
165 self.refresh().await
167 }
168
169 async fn refresh(&self) -> anyhow::Result<JwkSet> {
171 let jwks_uri = self.discover_jwks_uri().await?;
172 debug!("Fetching JWKS from {jwks_uri}");
173
174 let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
175 info!(
176 "Fetched {} keys from JWKS endpoint",
177 jwks.keys.len()
178 );
179
180 let mut cache = self.cache.write().await;
181 *cache = Some(CachedJwks {
182 jwks: jwks.clone(),
183 fetched_at: Instant::now(),
184 });
185
186 Ok(jwks)
187 }
188
189 async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
191 let discovery_url = format!(
192 "{}/.well-known/openid-configuration",
193 self.issuer.trim_end_matches('/')
194 );
195
196 let resp: serde_json::Value = self
197 .http
198 .get(&discovery_url)
199 .send()
200 .await?
201 .json()
202 .await?;
203
204 resp.get("jwks_uri")
205 .and_then(|v| v.as_str())
206 .map(String::from)
207 .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
208 }
209
210 async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
213 let jwks = self.get_jwks().await?;
214
215 if let Some(key) = find_key_in_set(&jwks, kid) {
217 return Ok(key);
218 }
219
220 debug!("kid '{kid}' not in JWKS cache, refreshing");
222 let jwks = self.refresh().await?;
223
224 find_key_in_set(&jwks, kid)
225 .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
226 }
227
228 async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
230 let jwks = self.get_jwks().await?;
231
232 for key in &jwks.keys {
233 if let Ok(dk) = DecodingKey::from_jwk(key) {
234 let _ = alg; return Ok(dk);
239 }
240 }
241
242 anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
243 }
244}
245
246fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
247 jwks.keys
248 .iter()
249 .find(|k| k.common.key_id.as_deref() == Some(kid))
250 .and_then(|k| DecodingKey::from_jwk(k).ok())
251}
252
253pub async fn auth_middleware<S: WorkflowStore>(
263 State(state): State<Arc<AppState<S>>>,
264 request: Request,
265 next: Next,
266) -> Response {
267 let auth = &state.auth_mode;
268
269 if !auth.is_enabled() {
270 return next.run(request).await;
271 }
272
273 let token = match extract_bearer(&request) {
274 Some(t) => t,
275 None => return auth_error("Missing Authorization: Bearer <token>"),
276 };
277
278 if jsonwebtoken::decode_header(token).is_ok() {
279 match &auth.jwt {
280 Some(jwt) => {
281 validate_jwt(
282 &jwt.issuer,
283 jwt.audience.as_deref(),
284 &jwt.jwks_cache,
285 request,
286 next,
287 )
288 .await
289 }
290 None => auth_error("JWT authentication is not enabled on this server"),
291 }
292 } else if auth.api_key {
293 validate_api_key(state, request, next).await
294 } else {
295 auth_error("Token is not a valid JWT and API-key authentication is not enabled")
296 }
297}
298
299async fn validate_api_key<S: WorkflowStore>(
300 state: Arc<AppState<S>>,
301 request: Request,
302 next: Next,
303) -> Response {
304 let token = match extract_bearer(&request) {
305 Some(t) => t,
306 None => return auth_error("Missing Authorization: Bearer <api-key>"),
307 };
308
309 let hash = hash_api_key(token);
310 match state.engine.store().validate_api_key(&hash).await {
311 Ok(true) => next.run(request).await,
312 Ok(false) => {
313 warn!(
314 "Invalid API key (prefix: {}...)",
315 &token[..8.min(token.len())]
316 );
317 auth_error("Invalid API key")
318 }
319 Err(e) => {
320 warn!("API key validation error: {e}");
321 (
322 StatusCode::INTERNAL_SERVER_ERROR,
323 Json(serde_json::json!({"error": "auth check failed"})),
324 )
325 .into_response()
326 }
327 }
328}
329
330async fn validate_jwt(
331 issuer: &str,
332 audience: Option<&str>,
333 jwks_cache: &JwksCache,
334 request: Request,
335 next: Next,
336) -> Response {
337 let token = match extract_bearer(&request) {
338 Some(t) => t,
339 None => return auth_error("Missing Authorization: Bearer <jwt>"),
340 };
341
342 let header = match jsonwebtoken::decode_header(token) {
344 Ok(h) => h,
345 Err(e) => {
346 warn!("Invalid JWT header: {e}");
347 return auth_error("Invalid JWT");
348 }
349 };
350
351 let decoding_key = match &header.kid {
353 Some(kid) => match jwks_cache.find_key(kid).await {
354 Ok(key) => key,
355 Err(e) => {
356 warn!("JWKS key lookup failed: {e}");
357 return auth_error("JWT validation failed: key not found");
358 }
359 },
360 None => match jwks_cache.find_any_key(header.alg).await {
361 Ok(key) => key,
362 Err(e) => {
363 warn!("JWKS key lookup failed (no kid): {e}");
364 return auth_error("JWT validation failed: no suitable key");
365 }
366 },
367 };
368
369 let mut validation = Validation::new(header.alg);
371 validation.set_issuer(&[issuer]);
372 if let Some(aud) = audience {
373 validation.set_audience(&[aud]);
374 } else {
375 validation.validate_aud = false;
376 }
377
378 match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
380 Ok(_) => next.run(request).await,
381 Err(e) => {
382 warn!("JWT validation failed: {e}");
383 auth_error(&format!("JWT validation failed: {e}"))
384 }
385 }
386}
387
388fn extract_bearer(request: &Request) -> Option<&str> {
389 request
390 .headers()
391 .get("authorization")
392 .and_then(|v| v.to_str().ok())
393 .and_then(|v| v.strip_prefix("Bearer "))
394}
395
396fn auth_error(msg: &str) -> Response {
397 (
398 StatusCode::UNAUTHORIZED,
399 Json(serde_json::json!({"error": msg})),
400 )
401 .into_response()
402}
403
404pub fn hash_api_key(key: &str) -> String {
408 let mut hasher = Sha256::new();
409 hasher.update(key.as_bytes());
410 data_encoding::HEXLOWER.encode(&hasher.finalize())
411}
412
413pub fn generate_api_key() -> String {
415 use rand::Rng;
416 let bytes: [u8; 32] = rand::rng().random();
417 format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
418}
419
420pub fn key_prefix(key: &str) -> String {
422 let stripped = key.strip_prefix("assay_").unwrap_or(key);
423 format!("assay_{}...", &stripped[..8.min(stripped.len())])
424}