1use axum::{
6 Json,
7 extract::{Request, State},
8 http::{HeaderMap, StatusCode, header},
9 middleware::Next,
10 response::IntoResponse,
11};
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, errors::ErrorKind};
13pub use perfgate_auth::{ApiKey, JwtClaims, Role, Scope, validate_key_format};
14use perfgate_error::AuthError;
15use sha2::{Digest, Sha256};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::warn;
20
21use crate::models::ApiError;
22use crate::oidc::OidcRegistry;
23use crate::storage::KeyStore;
24
25#[derive(Clone)]
27pub struct JwtConfig {
28 secret: Vec<u8>,
29 issuer: Option<String>,
30 audience: Option<String>,
31}
32
33impl JwtConfig {
34 pub fn hs256(secret: impl Into<Vec<u8>>) -> Self {
36 Self {
37 secret: secret.into(),
38 issuer: None,
39 audience: None,
40 }
41 }
42
43 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
45 self.issuer = Some(issuer.into());
46 self
47 }
48
49 pub fn audience(mut self, audience: impl Into<String>) -> Self {
51 self.audience = Some(audience.into());
52 self
53 }
54
55 pub fn secret_bytes(&self) -> &[u8] {
57 &self.secret
58 }
59
60 fn validation(&self) -> Validation {
61 let mut validation = Validation::new(Algorithm::HS256);
62 if let Some(issuer) = &self.issuer {
63 validation.set_issuer(&[issuer.as_str()]);
64 }
65 if let Some(audience) = &self.audience {
66 validation.set_audience(&[audience.as_str()]);
67 }
68 validation
69 }
70}
71
72impl std::fmt::Debug for JwtConfig {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("JwtConfig")
75 .field("secret", &"<redacted>")
76 .field("issuer", &self.issuer)
77 .field("audience", &self.audience)
78 .finish()
79 }
80}
81
82#[derive(Clone)]
84pub struct AuthState {
85 pub key_store: Arc<ApiKeyStore>,
87
88 pub persistent_key_store: Option<Arc<dyn KeyStore>>,
90
91 pub jwt: Option<JwtConfig>,
93
94 pub oidc: OidcRegistry,
96}
97
98impl AuthState {
99 pub fn new(key_store: Arc<ApiKeyStore>, jwt: Option<JwtConfig>, oidc: OidcRegistry) -> Self {
101 Self {
102 key_store,
103 persistent_key_store: None,
104 jwt,
105 oidc,
106 }
107 }
108
109 pub fn with_persistent_key_store(mut self, store: Arc<dyn KeyStore>) -> Self {
111 self.persistent_key_store = Some(store);
112 self
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct AuthContext {
119 pub api_key: ApiKey,
121
122 pub source_ip: Option<String>,
124}
125
126#[derive(Debug, Default)]
128pub struct ApiKeyStore {
129 keys: Arc<RwLock<HashMap<String, ApiKey>>>,
131}
132
133impl ApiKeyStore {
134 pub fn new() -> Self {
136 Self {
137 keys: Arc::new(RwLock::new(HashMap::new())),
138 }
139 }
140
141 pub async fn add_key(&self, key: ApiKey, raw_key: &str) {
143 let hash = hash_api_key(raw_key);
144 let mut keys = self.keys.write().await;
145 keys.insert(hash, key);
146 }
147
148 pub async fn get_key(&self, raw_key: &str) -> Option<ApiKey> {
150 let hash = hash_api_key(raw_key);
151 let keys = self.keys.read().await;
152 keys.get(&hash).cloned()
153 }
154
155 pub async fn remove_key(&self, raw_key: &str) -> bool {
157 let hash = hash_api_key(raw_key);
158 let mut keys = self.keys.write().await;
159 keys.remove(&hash).is_some()
160 }
161
162 pub async fn list_keys(&self) -> Vec<ApiKey> {
164 let keys = self.keys.read().await;
165 keys.values().cloned().collect()
166 }
167}
168
169enum Credentials {
170 ApiKey(String),
171 Jwt(String),
172}
173
174fn hash_api_key(key: &str) -> String {
176 let mut hasher = Sha256::new();
177 hasher.update(key.as_bytes());
178 format!("{:x}", hasher.finalize())
179}
180
181fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
182 let auth_header = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
183
184 if let Some(key) = auth_header.strip_prefix("Bearer ") {
185 return Some(Credentials::ApiKey(key.to_string()));
186 }
187
188 if let Some(token) = auth_header.strip_prefix("Token ") {
189 return Some(Credentials::Jwt(token.to_string()));
190 }
191
192 None
193}
194
195fn source_ip(headers: &HeaderMap) -> Option<String> {
196 headers
197 .get("X-Forwarded-For")
198 .and_then(|v| v.to_str().ok())
199 .map(ToOwned::to_owned)
200}
201
202fn unauthorized(message: &str) -> (StatusCode, Json<ApiError>) {
203 (
204 StatusCode::UNAUTHORIZED,
205 Json(ApiError::unauthorized(message)),
206 )
207}
208
209async fn authenticate_api_key(
210 auth_state: &AuthState,
211 api_key_str: &str,
212 headers: &HeaderMap,
213) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
214 validate_key_format(api_key_str).map_err(|_| {
215 warn!(
216 key_prefix = &api_key_str[..10.min(api_key_str.len())],
217 "Invalid API key format"
218 );
219 unauthorized("Invalid API key format")
220 })?;
221
222 if let Some(api_key) = auth_state.key_store.get_key(api_key_str).await {
224 if api_key.is_expired() {
225 warn!(key_id = %api_key.id, "API key expired");
226 return Err(unauthorized("API key has expired"));
227 }
228 return Ok(AuthContext {
229 api_key,
230 source_ip: source_ip(headers),
231 });
232 }
233
234 if let Some(persistent) = &auth_state.persistent_key_store
236 && let Ok(Some(record)) = persistent.validate_key(api_key_str).await
237 {
238 let mut api_key = ApiKey::new(
239 record.id.clone(),
240 record.description.clone(),
241 record.project.clone(),
242 record.role,
243 );
244 api_key.benchmark_regex = record.pattern.clone();
246 api_key.expires_at = record.expires_at;
247 api_key.created_at = record.created_at;
248
249 return Ok(AuthContext {
250 api_key,
251 source_ip: source_ip(headers),
252 });
253 }
254
255 warn!(
256 key_prefix = &api_key_str[..10.min(api_key_str.len())],
257 "Invalid API key"
258 );
259 Err(unauthorized("Invalid API key"))
260}
261
262fn validate_jwt(token: &str, config: &JwtConfig) -> Result<JwtClaims, AuthError> {
263 let validation = config.validation();
264
265 decode::<JwtClaims>(
266 token,
267 &DecodingKey::from_secret(config.secret_bytes()),
268 &validation,
269 )
270 .map(|data| data.claims)
271 .map_err(|error| match error.kind() {
272 ErrorKind::ExpiredSignature => AuthError::ExpiredToken,
273 _ => AuthError::InvalidToken(error.to_string()),
274 })
275}
276
277async fn authenticate_jwt(
278 auth_state: &AuthState,
279 token: &str,
280 headers: &HeaderMap,
281) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
282 if let Some(config) = &auth_state.jwt {
284 match validate_jwt(token, config) {
285 Ok(claims) => {
286 return Ok(AuthContext {
287 api_key: api_key_from_jwt_claims(&claims),
288 source_ip: source_ip(headers),
289 });
290 }
291 Err(e) => {
292 if !auth_state.oidc.has_providers() {
295 match &e {
296 AuthError::ExpiredToken => warn!("Expired JWT token"),
297 AuthError::InvalidToken(_) => warn!("Invalid JWT token"),
298 _ => {}
299 }
300 return Err(unauthorized(&e.to_string()));
301 }
302 }
303 }
304 }
305
306 if auth_state.oidc.has_providers() {
308 match auth_state.oidc.validate_token(token).await {
309 Ok(api_key) => {
310 return Ok(AuthContext {
311 api_key,
312 source_ip: source_ip(headers),
313 });
314 }
315 Err(e) => {
316 match &e {
317 AuthError::ExpiredToken => warn!("Expired OIDC token"),
318 AuthError::InvalidToken(msg) => warn!("Invalid OIDC token: {}", msg),
319 _ => {}
320 }
321 return Err(unauthorized(&e.to_string()));
322 }
323 }
324 }
325
326 warn!("JWT token received but no JWT or OIDC authentication is configured");
327 Err(unauthorized("JWT/OIDC authentication is not configured"))
328}
329
330fn api_key_from_jwt_claims(claims: &JwtClaims) -> ApiKey {
331 ApiKey {
332 id: format!("jwt:{}", claims.sub),
333 name: format!("JWT {}", claims.sub),
334 project_id: claims.project_id.clone(),
335 scopes: claims.scopes.clone(),
336 role: Role::from_scopes(&claims.scopes),
337 benchmark_regex: None,
338 expires_at: Some(
339 chrono::DateTime::<chrono::Utc>::from_timestamp(claims.exp as i64, 0)
340 .unwrap_or_else(chrono::Utc::now),
341 ),
342 created_at: claims
343 .iat
344 .and_then(|iat| chrono::DateTime::<chrono::Utc>::from_timestamp(iat as i64, 0))
345 .unwrap_or_else(chrono::Utc::now),
346 last_used_at: None,
347 }
348}
349
350pub async fn auth_middleware(
352 State(auth_state): State<AuthState>,
353 mut request: Request,
354 next: Next,
355) -> Result<impl IntoResponse, (StatusCode, Json<ApiError>)> {
356 if request.uri().path() == "/health" {
358 return Ok(next.run(request).await);
359 }
360
361 let auth_ctx = match extract_credentials(request.headers()) {
362 Some(Credentials::ApiKey(api_key)) => {
363 authenticate_api_key(&auth_state, &api_key, request.headers()).await?
364 }
365 Some(Credentials::Jwt(token)) => {
366 authenticate_jwt(&auth_state, &token, request.headers()).await?
367 }
368 None => {
369 warn!("Missing authentication header");
370 return Err(unauthorized("Missing authentication header"));
371 }
372 };
373
374 request.extensions_mut().insert(auth_ctx);
375
376 Ok(next.run(request).await)
377}
378
379pub async fn local_mode_auth_middleware(mut request: Request, next: Next) -> impl IntoResponse {
386 let auth_ctx = AuthContext {
387 api_key: ApiKey::new(
388 "local-mode".to_string(),
389 "Local Mode".to_string(),
390 "local".to_string(),
391 Role::Admin,
392 ),
393 source_ip: source_ip(request.headers()),
394 };
395 request.extensions_mut().insert(auth_ctx);
396 next.run(request).await
397}
398
399pub fn check_scope(
402 auth_ctx: Option<&AuthContext>,
403 project_id: &str,
404 benchmark: Option<&str>,
405 scope: Scope,
406) -> Result<(), (StatusCode, Json<ApiError>)> {
407 let ctx = match auth_ctx {
408 Some(ctx) => ctx,
409 None => {
410 return Err((
411 StatusCode::UNAUTHORIZED,
412 Json(ApiError::unauthorized("Authentication required")),
413 ));
414 }
415 };
416
417 if !ctx.api_key.has_scope(scope) {
419 warn!(
420 key_id = %ctx.api_key.id,
421 required_scope = %scope,
422 actual_role = %ctx.api_key.role,
423 "Insufficient permissions: scope mismatch"
424 );
425 return Err((
426 StatusCode::FORBIDDEN,
427 Json(ApiError::forbidden(&format!(
428 "Requires '{}' permission",
429 scope
430 ))),
431 ));
432 }
433
434 if !ctx.api_key.has_scope(Scope::Admin) && ctx.api_key.project_id != project_id {
438 warn!(
439 key_id = %ctx.api_key.id,
440 key_project = %ctx.api_key.project_id,
441 requested_project = %project_id,
442 "Insufficient permissions: project isolation violation"
443 );
444 return Err((
445 StatusCode::FORBIDDEN,
446 Json(ApiError::forbidden(&format!(
447 "Key is restricted to project '{}'",
448 ctx.api_key.project_id
449 ))),
450 ));
451 }
452
453 if let (Some(regex_str), Some(bench)) = (&ctx.api_key.benchmark_regex, benchmark) {
456 let regex = regex::Regex::new(regex_str).map_err(|e| {
457 warn!(key_id = %ctx.api_key.id, regex = %regex_str, error = %e, "Invalid benchmark regex in API key");
458 (
459 StatusCode::INTERNAL_SERVER_ERROR,
460 Json(ApiError::internal_error("Invalid security configuration")),
461 )
462 })?;
463
464 if !regex.is_match(bench) {
465 warn!(
466 key_id = %ctx.api_key.id,
467 benchmark = %bench,
468 regex = %regex_str,
469 "Insufficient permissions: benchmark restriction violation"
470 );
471 return Err((
472 StatusCode::FORBIDDEN,
473 Json(ApiError::forbidden(&format!(
474 "Key is restricted to benchmarks matching '{}'",
475 regex_str
476 ))),
477 ));
478 }
479 }
480
481 Ok(())
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use axum::{Extension, Router, routing::get};
488 use jsonwebtoken::{Header, encode};
489 use perfgate_auth::generate_api_key;
490 use tower::ServiceExt;
491 use uselesskey::{Factory, HmacFactoryExt, HmacSpec, Seed};
492 use uselesskey_jsonwebtoken::JwtKeyExt;
493
494 fn test_jwt_config() -> JwtConfig {
495 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
496 let factory = Factory::deterministic(seed);
497 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
498 JwtConfig::hs256(fixture.secret_bytes())
499 .issuer("perfgate-tests")
500 .audience("perfgate")
501 }
502
503 fn create_test_claims(scopes: Vec<Scope>, exp: u64) -> JwtClaims {
504 JwtClaims {
505 sub: "ci-bot".to_string(),
506 project_id: "project-1".to_string(),
507 scopes,
508 exp,
509 iat: Some(chrono::Utc::now().timestamp() as u64),
510 iss: Some("perfgate-tests".to_string()),
511 aud: Some("perfgate".to_string()),
512 }
513 }
514
515 fn create_test_token(claims: &JwtClaims) -> String {
516 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
517 let factory = Factory::deterministic(seed);
518 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
519 encode(&Header::default(), claims, &fixture.encoding_key()).unwrap()
520 }
521
522 fn auth_test_router(auth_state: AuthState) -> Router {
523 Router::new()
524 .route(
525 "/protected",
526 get(|Extension(auth_ctx): Extension<AuthContext>| async move {
527 auth_ctx.api_key.id
528 }),
529 )
530 .layer(axum::middleware::from_fn_with_state(
531 auth_state,
532 auth_middleware,
533 ))
534 }
535
536 fn local_auth_test_router() -> Router {
537 Router::new()
538 .route(
539 "/protected",
540 get(|Extension(auth_ctx): Extension<AuthContext>| async move {
541 auth_ctx.api_key.role.to_string()
542 }),
543 )
544 .layer(axum::middleware::from_fn(local_mode_auth_middleware))
545 }
546
547 #[tokio::test]
548 async fn test_api_key_store() {
549 let store = ApiKeyStore::new();
550 let raw_key = generate_api_key(false);
551 let key = ApiKey::new(
552 "key-1".to_string(),
553 "Test Key".to_string(),
554 "project-1".to_string(),
555 Role::Contributor,
556 );
557
558 store.add_key(key.clone(), &raw_key).await;
559
560 let retrieved = store.get_key(&raw_key).await;
561 assert!(retrieved.is_some());
562 let retrieved = retrieved.unwrap();
563 assert_eq!(retrieved.id, "key-1");
564 assert_eq!(retrieved.role, Role::Contributor);
565
566 let keys = store.list_keys().await;
567 assert_eq!(keys.len(), 1);
568
569 let removed = store.remove_key(&raw_key).await;
570 assert!(removed);
571
572 let retrieved = store.get_key(&raw_key).await;
573 assert!(retrieved.is_none());
574 }
575
576 #[tokio::test]
577 async fn test_auth_middleware_accepts_api_key() {
578 let store = Arc::new(ApiKeyStore::new());
579 let key = "pg_test_abcdefghijklmnopqrstuvwxyz123456";
580 store
581 .add_key(
582 ApiKey::new(
583 "api-key-1".to_string(),
584 "API Key".to_string(),
585 "project-1".to_string(),
586 Role::Viewer,
587 ),
588 key,
589 )
590 .await;
591
592 let response = auth_test_router(AuthState::new(store, None, Default::default()))
593 .oneshot(
594 Request::builder()
595 .uri("/protected")
596 .header(header::AUTHORIZATION, format!("Bearer {}", key))
597 .body(axum::body::Body::empty())
598 .unwrap(),
599 )
600 .await
601 .unwrap();
602
603 assert_eq!(response.status(), StatusCode::OK);
604 }
605
606 #[tokio::test]
607 async fn test_auth_middleware_accepts_jwt_token() {
608 let claims = create_test_claims(
609 vec![Scope::Read, Scope::Promote],
610 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
611 );
612 let token = create_test_token(&claims);
613
614 let response = auth_test_router(AuthState::new(
615 Arc::new(ApiKeyStore::new()),
616 Some(test_jwt_config()),
617 Default::default(),
618 ))
619 .oneshot(
620 Request::builder()
621 .uri("/protected")
622 .header(header::AUTHORIZATION, format!("Token {}", token))
623 .body(axum::body::Body::empty())
624 .unwrap(),
625 )
626 .await
627 .unwrap();
628
629 assert_eq!(response.status(), StatusCode::OK);
630 }
631
632 #[tokio::test]
633 async fn test_auth_middleware_rejects_jwt_when_unconfigured() {
634 let claims = create_test_claims(
635 vec![Scope::Read],
636 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
637 );
638 let token = create_test_token(&claims);
639
640 let response = auth_test_router(AuthState::new(
641 Arc::new(ApiKeyStore::new()),
642 None,
643 Default::default(),
644 ))
645 .oneshot(
646 Request::builder()
647 .uri("/protected")
648 .header(header::AUTHORIZATION, format!("Token {}", token))
649 .body(axum::body::Body::empty())
650 .unwrap(),
651 )
652 .await
653 .unwrap();
654
655 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
656 }
657
658 #[tokio::test]
659 async fn test_local_mode_auth_middleware_injects_admin_context() {
660 let response = local_auth_test_router()
661 .oneshot(
662 Request::builder()
663 .uri("/protected")
664 .body(axum::body::Body::empty())
665 .unwrap(),
666 )
667 .await
668 .unwrap();
669
670 assert_eq!(response.status(), StatusCode::OK);
671 }
672
673 #[test]
674 fn test_hash_api_key() {
675 let key = "pg_live_test123456789012345678901234567890";
676 let hash1 = hash_api_key(key);
677 let hash2 = hash_api_key(key);
678
679 assert_eq!(hash1, hash2);
680
681 let different_hash = hash_api_key("pg_live_different1234567890123456789012");
682 assert_ne!(hash1, different_hash);
683 }
684
685 #[test]
686 fn test_check_scope_project_isolation() {
687 let key = ApiKey::new(
688 "k1".to_string(),
689 "n1".to_string(),
690 "project-a".to_string(),
691 Role::Contributor,
692 );
693 let ctx = AuthContext {
694 api_key: key,
695 source_ip: None,
696 };
697
698 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Write).is_ok());
700 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
701
702 let res = check_scope(Some(&ctx), "project-a", None, Scope::Delete);
704 assert!(res.is_err());
705 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
706
707 let res = check_scope(Some(&ctx), "project-b", None, Scope::Read);
709 assert!(res.is_err());
710 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
711 }
712
713 #[test]
714 fn test_check_scope_global_admin() {
715 let key = ApiKey::new(
716 "k1".to_string(),
717 "admin".to_string(),
718 "any-project".to_string(),
719 Role::Admin,
720 );
721 let ctx = AuthContext {
722 api_key: key,
723 source_ip: None,
724 };
725
726 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
728 assert!(check_scope(Some(&ctx), "project-b", None, Scope::Delete).is_ok());
729 assert!(check_scope(Some(&ctx), "other", None, Scope::Admin).is_ok());
730 }
731
732 #[test]
733 fn test_check_scope_benchmark_restriction() {
734 let mut key = ApiKey::new(
735 "k1".to_string(),
736 "n1".to_string(),
737 "project-a".to_string(),
738 Role::Contributor,
739 );
740 key.benchmark_regex = Some("^web-.*$".to_string());
741
742 let ctx = AuthContext {
743 api_key: key,
744 source_ip: None,
745 };
746
747 assert!(check_scope(Some(&ctx), "project-a", Some("web-auth"), Scope::Read).is_ok());
749 assert!(check_scope(Some(&ctx), "project-a", Some("web-api"), Scope::Write).is_ok());
750
751 let res = check_scope(Some(&ctx), "project-a", Some("worker-job"), Scope::Read);
753 assert!(res.is_err());
754 assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
755
756 assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
758 }
759}