1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use axum::Json;
9use jsonwebtoken::jwk::JwkSet;
10use jsonwebtoken::{DecodingKey, Validation};
11use sha2::{Digest, Sha256};
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15use crate::api::AppState;
16use crate::store::WorkflowStore;
17
18const JWKS_CACHE_TTL: Duration = Duration::from_secs(300); #[derive(Clone, Debug, Default)]
24pub enum AuthMode {
25 #[default]
27 NoAuth,
28 ApiKey,
30 Jwt {
32 issuer: String,
33 audience: Option<String>,
34 jwks_cache: Arc<JwksCache>,
35 },
36}
37
38impl AuthMode {
39 pub fn jwt(issuer: String, audience: Option<String>) -> Self {
41 Self::Jwt {
42 jwks_cache: Arc::new(JwksCache::new(issuer.clone())),
43 issuer,
44 audience,
45 }
46 }
47}
48
49pub struct JwksCache {
54 issuer: String,
55 cache: RwLock<Option<CachedJwks>>,
56 http: reqwest::Client,
57}
58
59struct CachedJwks {
60 jwks: JwkSet,
61 fetched_at: Instant,
62}
63
64impl std::fmt::Debug for JwksCache {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("JwksCache")
67 .field("issuer", &self.issuer)
68 .finish()
69 }
70}
71
72impl JwksCache {
73 pub fn new(issuer: String) -> Self {
74 Self {
75 issuer,
76 cache: RwLock::new(None),
77 http: reqwest::Client::builder()
78 .timeout(Duration::from_secs(10))
79 .build()
80 .expect("building JWKS HTTP client"),
81 }
82 }
83
84 pub fn with_jwks(issuer: String, jwks: JwkSet) -> Self {
86 Self {
87 issuer,
88 cache: RwLock::new(Some(CachedJwks {
89 jwks,
90 fetched_at: Instant::now(),
91 })),
92 http: reqwest::Client::new(),
93 }
94 }
95
96 async fn get_jwks(&self) -> anyhow::Result<JwkSet> {
98 {
100 let cache = self.cache.read().await;
101 if let Some(ref cached) = *cache
102 && cached.fetched_at.elapsed() < JWKS_CACHE_TTL
103 {
104 return Ok(cached.jwks.clone());
105 }
106 }
107
108 self.refresh().await
110 }
111
112 async fn refresh(&self) -> anyhow::Result<JwkSet> {
114 let jwks_uri = self.discover_jwks_uri().await?;
115 debug!("Fetching JWKS from {jwks_uri}");
116
117 let jwks: JwkSet = self.http.get(&jwks_uri).send().await?.json().await?;
118 info!(
119 "Fetched {} keys from JWKS endpoint",
120 jwks.keys.len()
121 );
122
123 let mut cache = self.cache.write().await;
124 *cache = Some(CachedJwks {
125 jwks: jwks.clone(),
126 fetched_at: Instant::now(),
127 });
128
129 Ok(jwks)
130 }
131
132 async fn discover_jwks_uri(&self) -> anyhow::Result<String> {
134 let discovery_url = format!(
135 "{}/.well-known/openid-configuration",
136 self.issuer.trim_end_matches('/')
137 );
138
139 let resp: serde_json::Value = self
140 .http
141 .get(&discovery_url)
142 .send()
143 .await?
144 .json()
145 .await?;
146
147 resp.get("jwks_uri")
148 .and_then(|v| v.as_str())
149 .map(String::from)
150 .ok_or_else(|| anyhow::anyhow!("OIDC discovery response missing jwks_uri"))
151 }
152
153 async fn find_key(&self, kid: &str) -> anyhow::Result<DecodingKey> {
156 let jwks = self.get_jwks().await?;
157
158 if let Some(key) = find_key_in_set(&jwks, kid) {
160 return Ok(key);
161 }
162
163 debug!("kid '{kid}' not in JWKS cache, refreshing");
165 let jwks = self.refresh().await?;
166
167 find_key_in_set(&jwks, kid)
168 .ok_or_else(|| anyhow::anyhow!("No key with kid '{kid}' in JWKS"))
169 }
170
171 async fn find_any_key(&self, alg: jsonwebtoken::Algorithm) -> anyhow::Result<DecodingKey> {
173 let jwks = self.get_jwks().await?;
174
175 for key in &jwks.keys {
176 if let Ok(dk) = DecodingKey::from_jwk(key) {
177 let _ = alg; return Ok(dk);
182 }
183 }
184
185 anyhow::bail!("No suitable key found in JWKS for algorithm {alg:?}")
186 }
187}
188
189fn find_key_in_set(jwks: &JwkSet, kid: &str) -> Option<DecodingKey> {
190 jwks.keys
191 .iter()
192 .find(|k| k.common.key_id.as_deref() == Some(kid))
193 .and_then(|k| DecodingKey::from_jwk(k).ok())
194}
195
196pub async fn auth_middleware<S: WorkflowStore>(
200 State(state): State<Arc<AppState<S>>>,
201 request: Request,
202 next: Next,
203) -> Response {
204 match &state.auth_mode {
205 AuthMode::NoAuth => next.run(request).await,
206 AuthMode::ApiKey => validate_api_key(state, request, next).await,
207 AuthMode::Jwt {
208 issuer,
209 audience,
210 jwks_cache,
211 } => validate_jwt(issuer, audience.as_deref(), jwks_cache, request, next).await,
212 }
213}
214
215async fn validate_api_key<S: WorkflowStore>(
216 state: Arc<AppState<S>>,
217 request: Request,
218 next: Next,
219) -> Response {
220 let token = match extract_bearer(&request) {
221 Some(t) => t,
222 None => return auth_error("Missing Authorization: Bearer <api-key>"),
223 };
224
225 let hash = hash_api_key(token);
226 match state.engine.store().validate_api_key(&hash).await {
227 Ok(true) => next.run(request).await,
228 Ok(false) => {
229 warn!(
230 "Invalid API key (prefix: {}...)",
231 &token[..8.min(token.len())]
232 );
233 auth_error("Invalid API key")
234 }
235 Err(e) => {
236 warn!("API key validation error: {e}");
237 (
238 StatusCode::INTERNAL_SERVER_ERROR,
239 Json(serde_json::json!({"error": "auth check failed"})),
240 )
241 .into_response()
242 }
243 }
244}
245
246async fn validate_jwt(
247 issuer: &str,
248 audience: Option<&str>,
249 jwks_cache: &JwksCache,
250 request: Request,
251 next: Next,
252) -> Response {
253 let token = match extract_bearer(&request) {
254 Some(t) => t,
255 None => return auth_error("Missing Authorization: Bearer <jwt>"),
256 };
257
258 let header = match jsonwebtoken::decode_header(token) {
260 Ok(h) => h,
261 Err(e) => {
262 warn!("Invalid JWT header: {e}");
263 return auth_error("Invalid JWT");
264 }
265 };
266
267 let decoding_key = match &header.kid {
269 Some(kid) => match jwks_cache.find_key(kid).await {
270 Ok(key) => key,
271 Err(e) => {
272 warn!("JWKS key lookup failed: {e}");
273 return auth_error("JWT validation failed: key not found");
274 }
275 },
276 None => match jwks_cache.find_any_key(header.alg).await {
277 Ok(key) => key,
278 Err(e) => {
279 warn!("JWKS key lookup failed (no kid): {e}");
280 return auth_error("JWT validation failed: no suitable key");
281 }
282 },
283 };
284
285 let mut validation = Validation::new(header.alg);
287 validation.set_issuer(&[issuer]);
288 if let Some(aud) = audience {
289 validation.set_audience(&[aud]);
290 } else {
291 validation.validate_aud = false;
292 }
293
294 match jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation) {
296 Ok(_) => next.run(request).await,
297 Err(e) => {
298 warn!("JWT validation failed: {e}");
299 auth_error(&format!("JWT validation failed: {e}"))
300 }
301 }
302}
303
304fn extract_bearer(request: &Request) -> Option<&str> {
305 request
306 .headers()
307 .get("authorization")
308 .and_then(|v| v.to_str().ok())
309 .and_then(|v| v.strip_prefix("Bearer "))
310}
311
312fn auth_error(msg: &str) -> Response {
313 (
314 StatusCode::UNAUTHORIZED,
315 Json(serde_json::json!({"error": msg})),
316 )
317 .into_response()
318}
319
320pub fn hash_api_key(key: &str) -> String {
324 let mut hasher = Sha256::new();
325 hasher.update(key.as_bytes());
326 data_encoding::HEXLOWER.encode(&hasher.finalize())
327}
328
329pub fn generate_api_key() -> String {
331 use rand::Rng;
332 let bytes: [u8; 32] = rand::rng().random();
333 format!("assay_{}", data_encoding::HEXLOWER.encode(&bytes))
334}
335
336pub fn key_prefix(key: &str) -> String {
338 let stripped = key.strip_prefix("assay_").unwrap_or(key);
339 format!("assay_{}...", &stripped[..8.min(stripped.len())])
340}