1use axum::{
6 Json,
7 extract::{Request, State},
8 http::{HeaderMap, StatusCode, header},
9 middleware::Next,
10 response::IntoResponse,
11};
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, errors::ErrorKind};
13pub use perfgate_auth::{ApiKey, JwtClaims, Role, Scope, validate_key_format};
14use perfgate_error::AuthError;
15use sha2::{Digest, Sha256};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::warn;
20
21use crate::models::ApiError;
22use crate::oidc::OidcProvider;
23
24#[derive(Clone)]
26pub struct JwtConfig {
27 secret: Vec<u8>,
28 issuer: Option<String>,
29 audience: Option<String>,
30}
31
32impl JwtConfig {
33 pub fn hs256(secret: impl Into<Vec<u8>>) -> Self {
35 Self {
36 secret: secret.into(),
37 issuer: None,
38 audience: None,
39 }
40 }
41
42 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
44 self.issuer = Some(issuer.into());
45 self
46 }
47
48 pub fn audience(mut self, audience: impl Into<String>) -> Self {
50 self.audience = Some(audience.into());
51 self
52 }
53
54 pub fn secret_bytes(&self) -> &[u8] {
56 &self.secret
57 }
58
59 fn validation(&self) -> Validation {
60 let mut validation = Validation::new(Algorithm::HS256);
61 if let Some(issuer) = &self.issuer {
62 validation.set_issuer(&[issuer.as_str()]);
63 }
64 if let Some(audience) = &self.audience {
65 validation.set_audience(&[audience.as_str()]);
66 }
67 validation
68 }
69}
70
71impl std::fmt::Debug for JwtConfig {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("JwtConfig")
74 .field("secret", &"<redacted>")
75 .field("issuer", &self.issuer)
76 .field("audience", &self.audience)
77 .finish()
78 }
79}
80
81#[derive(Clone)]
83pub struct AuthState {
84 pub key_store: Arc<ApiKeyStore>,
86
87 pub jwt: Option<JwtConfig>,
89
90 pub oidc: Option<OidcProvider>,
92}
93
94impl AuthState {
95 pub fn new(
97 key_store: Arc<ApiKeyStore>,
98 jwt: Option<JwtConfig>,
99 oidc: Option<OidcProvider>,
100 ) -> Self {
101 Self {
102 key_store,
103 jwt,
104 oidc,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct AuthContext {
112 pub api_key: ApiKey,
114
115 pub source_ip: Option<String>,
117}
118
119#[derive(Debug, Default)]
121pub struct ApiKeyStore {
122 keys: Arc<RwLock<HashMap<String, ApiKey>>>,
124}
125
126impl ApiKeyStore {
127 pub fn new() -> Self {
129 Self {
130 keys: Arc::new(RwLock::new(HashMap::new())),
131 }
132 }
133
134 pub async fn add_key(&self, key: ApiKey, raw_key: &str) {
136 let hash = hash_api_key(raw_key);
137 let mut keys = self.keys.write().await;
138 keys.insert(hash, key);
139 }
140
141 pub async fn get_key(&self, raw_key: &str) -> Option<ApiKey> {
143 let hash = hash_api_key(raw_key);
144 let keys = self.keys.read().await;
145 keys.get(&hash).cloned()
146 }
147
148 pub async fn remove_key(&self, raw_key: &str) -> bool {
150 let hash = hash_api_key(raw_key);
151 let mut keys = self.keys.write().await;
152 keys.remove(&hash).is_some()
153 }
154
155 pub async fn list_keys(&self) -> Vec<ApiKey> {
157 let keys = self.keys.read().await;
158 keys.values().cloned().collect()
159 }
160}
161
162enum Credentials {
163 ApiKey(String),
164 Jwt(String),
165}
166
167fn hash_api_key(key: &str) -> String {
169 let mut hasher = Sha256::new();
170 hasher.update(key.as_bytes());
171 format!("{:x}", hasher.finalize())
172}
173
174fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
175 let auth_header = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
176
177 if let Some(key) = auth_header.strip_prefix("Bearer ") {
178 return Some(Credentials::ApiKey(key.to_string()));
179 }
180
181 if let Some(token) = auth_header.strip_prefix("Token ") {
182 return Some(Credentials::Jwt(token.to_string()));
183 }
184
185 None
186}
187
188fn source_ip(headers: &HeaderMap) -> Option<String> {
189 headers
190 .get("X-Forwarded-For")
191 .and_then(|v| v.to_str().ok())
192 .map(ToOwned::to_owned)
193}
194
195fn unauthorized(message: &str) -> (StatusCode, Json<ApiError>) {
196 (
197 StatusCode::UNAUTHORIZED,
198 Json(ApiError::unauthorized(message)),
199 )
200}
201
202async fn authenticate_api_key(
203 key_store: &ApiKeyStore,
204 api_key_str: &str,
205 headers: &HeaderMap,
206) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
207 validate_key_format(api_key_str).map_err(|_| {
208 warn!(
209 key_prefix = &api_key_str[..10.min(api_key_str.len())],
210 "Invalid API key format"
211 );
212 unauthorized("Invalid API key format")
213 })?;
214
215 let api_key = key_store.get_key(api_key_str).await.ok_or_else(|| {
216 warn!(
217 key_prefix = &api_key_str[..10.min(api_key_str.len())],
218 "Invalid API key"
219 );
220 unauthorized("Invalid API key")
221 })?;
222
223 if api_key.is_expired() {
224 warn!(key_id = %api_key.id, "API key expired");
225 return Err(unauthorized("API key has expired"));
226 }
227
228 Ok(AuthContext {
229 api_key,
230 source_ip: source_ip(headers),
231 })
232}
233
234fn validate_jwt(token: &str, config: &JwtConfig) -> Result<JwtClaims, AuthError> {
235 let validation = config.validation();
236
237 decode::<JwtClaims>(
238 token,
239 &DecodingKey::from_secret(config.secret_bytes()),
240 &validation,
241 )
242 .map(|data| data.claims)
243 .map_err(|error| match error.kind() {
244 ErrorKind::ExpiredSignature => AuthError::ExpiredToken,
245 _ => AuthError::InvalidToken(error.to_string()),
246 })
247}
248
249async fn authenticate_jwt(
250 auth_state: &AuthState,
251 token: &str,
252 headers: &HeaderMap,
253) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
254 if let Some(config) = &auth_state.jwt {
256 match validate_jwt(token, config) {
257 Ok(claims) => {
258 return Ok(AuthContext {
259 api_key: api_key_from_jwt_claims(&claims),
260 source_ip: source_ip(headers),
261 });
262 }
263 Err(e) => {
264 if auth_state.oidc.is_none() {
267 match &e {
268 AuthError::ExpiredToken => warn!("Expired JWT token"),
269 AuthError::InvalidToken(_) => warn!("Invalid JWT token"),
270 _ => {}
271 }
272 return Err(unauthorized(&e.to_string()));
273 }
274 }
275 }
276 }
277
278 if let Some(oidc) = &auth_state.oidc {
280 match oidc.validate_token(token).await {
281 Ok(api_key) => {
282 return Ok(AuthContext {
283 api_key,
284 source_ip: source_ip(headers),
285 });
286 }
287 Err(e) => {
288 match &e {
289 AuthError::ExpiredToken => warn!("Expired OIDC token"),
290 AuthError::InvalidToken(msg) => warn!("Invalid OIDC token: {}", msg),
291 _ => {}
292 }
293 return Err(unauthorized(&e.to_string()));
294 }
295 }
296 }
297
298 warn!("JWT token received but no JWT or OIDC authentication is configured");
299 Err(unauthorized("JWT/OIDC authentication is not configured"))
300}
301
302fn api_key_from_jwt_claims(claims: &JwtClaims) -> ApiKey {
303 ApiKey {
304 id: format!("jwt:{}", claims.sub),
305 name: format!("JWT {}", claims.sub),
306 project_id: claims.project_id.clone(),
307 scopes: claims.scopes.clone(),
308 role: Role::from_scopes(&claims.scopes),
309 benchmark_regex: None,
310 expires_at: Some(
311 chrono::DateTime::<chrono::Utc>::from_timestamp(claims.exp as i64, 0)
312 .unwrap_or_else(chrono::Utc::now),
313 ),
314 created_at: claims
315 .iat
316 .and_then(|iat| chrono::DateTime::<chrono::Utc>::from_timestamp(iat as i64, 0))
317 .unwrap_or_else(chrono::Utc::now),
318 last_used_at: None,
319 }
320}
321
322pub async fn auth_middleware(
324 State(auth_state): State<AuthState>,
325 mut request: Request,
326 next: Next,
327) -> Result<impl IntoResponse, (StatusCode, Json<ApiError>)> {
328 if request.uri().path() == "/health" {
330 return Ok(next.run(request).await);
331 }
332
333 let auth_ctx = match extract_credentials(request.headers()) {
334 Some(Credentials::ApiKey(api_key)) => {
335 authenticate_api_key(&auth_state.key_store, &api_key, request.headers()).await?
336 }
337 Some(Credentials::Jwt(token)) => {
338 authenticate_jwt(&auth_state, &token, request.headers()).await?
339 }
340 None => {
341 warn!("Missing authentication header");
342 return Err(unauthorized("Missing authentication header"));
343 }
344 };
345
346 request.extensions_mut().insert(auth_ctx);
347
348 Ok(next.run(request).await)
349}
350
351pub fn check_scope(
354 auth_ctx: Option<&AuthContext>,
355 project_id: &str,
356 benchmark: Option<&str>,
357 scope: Scope,
358) -> Result<(), (StatusCode, Json<ApiError>)> {
359 let ctx = match auth_ctx {
360 Some(ctx) => ctx,
361 None => {
362 return Err((
363 StatusCode::UNAUTHORIZED,
364 Json(ApiError::unauthorized("Authentication required")),
365 ));
366 }
367 };
368
369 if !ctx.api_key.has_scope(scope) {
371 warn!(
372 key_id = %ctx.api_key.id,
373 required_scope = %scope,
374 actual_role = %ctx.api_key.role,
375 "Insufficient permissions: scope mismatch"
376 );
377 return Err((
378 StatusCode::FORBIDDEN,
379 Json(ApiError::forbidden(&format!(
380 "Requires '{}' permission",
381 scope
382 ))),
383 ));
384 }
385
386 if !ctx.api_key.has_scope(Scope::Admin) && ctx.api_key.project_id != project_id {
390 warn!(
391 key_id = %ctx.api_key.id,
392 key_project = %ctx.api_key.project_id,
393 requested_project = %project_id,
394 "Insufficient permissions: project isolation violation"
395 );
396 return Err((
397 StatusCode::FORBIDDEN,
398 Json(ApiError::forbidden(&format!(
399 "Key is restricted to project '{}'",
400 ctx.api_key.project_id
401 ))),
402 ));
403 }
404
405 if let (Some(regex_str), Some(bench)) = (&ctx.api_key.benchmark_regex, benchmark) {
408 let regex = regex::Regex::new(regex_str).map_err(|e| {
409 warn!(key_id = %ctx.api_key.id, regex = %regex_str, error = %e, "Invalid benchmark regex in API key");
410 (
411 StatusCode::INTERNAL_SERVER_ERROR,
412 Json(ApiError::internal_error("Invalid security configuration")),
413 )
414 })?;
415
416 if !regex.is_match(bench) {
417 warn!(
418 key_id = %ctx.api_key.id,
419 benchmark = %bench,
420 regex = %regex_str,
421 "Insufficient permissions: benchmark restriction violation"
422 );
423 return Err((
424 StatusCode::FORBIDDEN,
425 Json(ApiError::forbidden(&format!(
426 "Key is restricted to benchmarks matching '{}'",
427 regex_str
428 ))),
429 ));
430 }
431 }
432
433 Ok(())
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use axum::{Extension, Router, routing::get};
440 use jsonwebtoken::{Header, encode};
441 use perfgate_auth::generate_api_key;
442 use tower::ServiceExt;
443 use uselesskey::{Factory, HmacFactoryExt, HmacSpec, Seed};
444 use uselesskey_jsonwebtoken::JwtKeyExt;
445
446 fn test_jwt_config() -> JwtConfig {
447 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
448 let factory = Factory::deterministic(seed);
449 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
450 JwtConfig::hs256(fixture.secret_bytes())
451 .issuer("perfgate-tests")
452 .audience("perfgate")
453 }
454
455 fn create_test_claims(scopes: Vec<Scope>, exp: u64) -> JwtClaims {
456 JwtClaims {
457 sub: "ci-bot".to_string(),
458 project_id: "project-1".to_string(),
459 scopes,
460 exp,
461 iat: Some(chrono::Utc::now().timestamp() as u64),
462 iss: Some("perfgate-tests".to_string()),
463 aud: Some("perfgate".to_string()),
464 }
465 }
466
467 fn create_test_token(claims: &JwtClaims) -> String {
468 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
469 let factory = Factory::deterministic(seed);
470 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
471 encode(&Header::default(), claims, &fixture.encoding_key()).unwrap()
472 }
473
474 fn auth_test_router(auth_state: AuthState) -> Router {
475 Router::new()
476 .route(
477 "/protected",
478 get(|Extension(auth_ctx): Extension<AuthContext>| async move {
479 auth_ctx.api_key.id
480 }),
481 )
482 .layer(axum::middleware::from_fn_with_state(
483 auth_state,
484 auth_middleware,
485 ))
486 }
487
488 #[tokio::test]
489 async fn test_api_key_store() {
490 let store = ApiKeyStore::new();
491 let raw_key = generate_api_key(false);
492 let key = ApiKey::new(
493 "key-1".to_string(),
494 "Test Key".to_string(),
495 "project-1".to_string(),
496 Role::Contributor,
497 );
498
499 store.add_key(key.clone(), &raw_key).await;
500
501 let retrieved = store.get_key(&raw_key).await;
502 assert!(retrieved.is_some());
503 let retrieved = retrieved.unwrap();
504 assert_eq!(retrieved.id, "key-1");
505 assert_eq!(retrieved.role, Role::Contributor);
506
507 let keys = store.list_keys().await;
508 assert_eq!(keys.len(), 1);
509
510 let removed = store.remove_key(&raw_key).await;
511 assert!(removed);
512
513 let retrieved = store.get_key(&raw_key).await;
514 assert!(retrieved.is_none());
515 }
516
517 #[tokio::test]
518 async fn test_auth_middleware_accepts_api_key() {
519 let store = Arc::new(ApiKeyStore::new());
520 let key = "pg_test_abcdefghijklmnopqrstuvwxyz123456";
521 store
522 .add_key(
523 ApiKey::new(
524 "api-key-1".to_string(),
525 "API Key".to_string(),
526 "project-1".to_string(),
527 Role::Viewer,
528 ),
529 key,
530 )
531 .await;
532
533 let response = auth_test_router(AuthState::new(store, None, None))
534 .oneshot(
535 Request::builder()
536 .uri("/protected")
537 .header(header::AUTHORIZATION, format!("Bearer {}", key))
538 .body(axum::body::Body::empty())
539 .unwrap(),
540 )
541 .await
542 .unwrap();
543
544 assert_eq!(response.status(), StatusCode::OK);
545 }
546
547 #[tokio::test]
548 async fn test_auth_middleware_accepts_jwt_token() {
549 let claims = create_test_claims(
550 vec![Scope::Read, Scope::Promote],
551 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
552 );
553 let token = create_test_token(&claims);
554
555 let response = auth_test_router(AuthState::new(
556 Arc::new(ApiKeyStore::new()),
557 Some(test_jwt_config()),
558 None,
559 ))
560 .oneshot(
561 Request::builder()
562 .uri("/protected")
563 .header(header::AUTHORIZATION, format!("Token {}", token))
564 .body(axum::body::Body::empty())
565 .unwrap(),
566 )
567 .await
568 .unwrap();
569
570 assert_eq!(response.status(), StatusCode::OK);
571 }
572
573 #[tokio::test]
574 async fn test_auth_middleware_rejects_jwt_when_unconfigured() {
575 let claims = create_test_claims(
576 vec![Scope::Read],
577 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
578 );
579 let token = create_test_token(&claims);
580
581 let response = auth_test_router(AuthState::new(Arc::new(ApiKeyStore::new()), None, None))
582 .oneshot(
583 Request::builder()
584 .uri("/protected")
585 .header(header::AUTHORIZATION, format!("Token {}", token))
586 .body(axum::body::Body::empty())
587 .unwrap(),
588 )
589 .await
590 .unwrap();
591
592 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
593 }
594
595 #[test]
596 fn test_hash_api_key() {
597 let key = "pg_live_test123456789012345678901234567890";
598 let hash1 = hash_api_key(key);
599 let hash2 = hash_api_key(key);
600
601 assert_eq!(hash1, hash2);
602
603 let different_hash = hash_api_key("pg_live_different1234567890123456789012");
604 assert_ne!(hash1, different_hash);
605 }
606
607 #[test]
608 fn test_check_scope_project_isolation() {
609 let key = ApiKey::new(
610 "k1".to_string(),
611 "n1".to_string(),
612 "project-a".to_string(),
613 Role::Contributor,
614 );
615 let ctx = AuthContext {
616 api_key: key,
617 source_ip: None,
618 };
619
620 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Write).is_ok());
622 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
623
624 let res = check_scope(Some(&ctx), "project-a", None, Scope::Delete);
626 assert!(res.is_err());
627 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
628
629 let res = check_scope(Some(&ctx), "project-b", None, Scope::Read);
631 assert!(res.is_err());
632 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
633 }
634
635 #[test]
636 fn test_check_scope_global_admin() {
637 let key = ApiKey::new(
638 "k1".to_string(),
639 "admin".to_string(),
640 "any-project".to_string(),
641 Role::Admin,
642 );
643 let ctx = AuthContext {
644 api_key: key,
645 source_ip: None,
646 };
647
648 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
650 assert!(check_scope(Some(&ctx), "project-b", None, Scope::Delete).is_ok());
651 assert!(check_scope(Some(&ctx), "other", None, Scope::Admin).is_ok());
652 }
653
654 #[test]
655 fn test_check_scope_benchmark_restriction() {
656 let mut key = ApiKey::new(
657 "k1".to_string(),
658 "n1".to_string(),
659 "project-a".to_string(),
660 Role::Contributor,
661 );
662 key.benchmark_regex = Some("^web-.*$".to_string());
663
664 let ctx = AuthContext {
665 api_key: key,
666 source_ip: None,
667 };
668
669 assert!(check_scope(Some(&ctx), "project-a", Some("web-auth"), Scope::Read).is_ok());
671 assert!(check_scope(Some(&ctx), "project-a", Some("web-api"), Scope::Write).is_ok());
672
673 let res = check_scope(Some(&ctx), "project-a", Some("worker-job"), Scope::Read);
675 assert!(res.is_err());
676 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
677
678 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
680 }
681}