1use axum::{
6 Json,
7 extract::{Request, State},
8 http::{HeaderMap, StatusCode, header},
9 middleware::Next,
10 response::IntoResponse,
11};
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, errors::ErrorKind};
13use serde::{Deserialize, Serialize};
14use sha2::{Digest, Sha256};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use tracing::warn;
19
20use crate::error::AuthError;
21use crate::models::ApiError;
22
23pub const API_KEY_PREFIX_LIVE: &str = "pg_live_";
25
26pub const API_KEY_PREFIX_TEST: &str = "pg_test_";
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum Scope {
33 Read,
35 Write,
37 Promote,
39 Delete,
41 Admin,
43}
44
45impl std::fmt::Display for Scope {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Scope::Read => write!(f, "read"),
49 Scope::Write => write!(f, "write"),
50 Scope::Promote => write!(f, "promote"),
51 Scope::Delete => write!(f, "delete"),
52 Scope::Admin => write!(f, "admin"),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum Role {
61 Viewer,
63 Contributor,
65 Promoter,
67 Admin,
69}
70
71impl Role {
72 pub fn allowed_scopes(&self) -> Vec<Scope> {
74 match self {
75 Role::Viewer => vec![Scope::Read],
76 Role::Contributor => vec![Scope::Read, Scope::Write],
77 Role::Promoter => vec![Scope::Read, Scope::Write, Scope::Promote],
78 Role::Admin => vec![
79 Scope::Read,
80 Scope::Write,
81 Scope::Promote,
82 Scope::Delete,
83 Scope::Admin,
84 ],
85 }
86 }
87
88 pub fn has_scope(&self, scope: Scope) -> bool {
90 self.allowed_scopes().contains(&scope)
91 }
92
93 pub fn from_scopes(scopes: &[Scope]) -> Self {
95 if scopes.contains(&Scope::Admin) || scopes.contains(&Scope::Delete) {
96 Self::Admin
97 } else if scopes.contains(&Scope::Promote) {
98 Self::Promoter
99 } else if scopes.contains(&Scope::Write) {
100 Self::Contributor
101 } else {
102 Self::Viewer
103 }
104 }
105}
106
107impl std::fmt::Display for Role {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 Role::Viewer => write!(f, "viewer"),
111 Role::Contributor => write!(f, "contributor"),
112 Role::Promoter => write!(f, "promoter"),
113 Role::Admin => write!(f, "admin"),
114 }
115 }
116}
117
118#[derive(Clone)]
120pub struct JwtConfig {
121 secret: Vec<u8>,
122 issuer: Option<String>,
123 audience: Option<String>,
124}
125
126impl JwtConfig {
127 pub fn hs256(secret: impl Into<Vec<u8>>) -> Self {
129 Self {
130 secret: secret.into(),
131 issuer: None,
132 audience: None,
133 }
134 }
135
136 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
138 self.issuer = Some(issuer.into());
139 self
140 }
141
142 pub fn audience(mut self, audience: impl Into<String>) -> Self {
144 self.audience = Some(audience.into());
145 self
146 }
147
148 pub fn secret_bytes(&self) -> &[u8] {
150 &self.secret
151 }
152
153 fn validation(&self) -> Validation {
154 let mut validation = Validation::new(Algorithm::HS256);
155 if let Some(issuer) = &self.issuer {
156 validation.set_issuer(&[issuer.as_str()]);
157 }
158 if let Some(audience) = &self.audience {
159 validation.set_audience(&[audience.as_str()]);
160 }
161 validation
162 }
163}
164
165impl std::fmt::Debug for JwtConfig {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("JwtConfig")
168 .field("secret", &"<redacted>")
169 .field("issuer", &self.issuer)
170 .field("audience", &self.audience)
171 .finish()
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
177pub struct JwtClaims {
178 pub sub: String,
180
181 pub project_id: String,
183
184 pub scopes: Vec<Scope>,
186
187 pub exp: u64,
189
190 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub iat: Option<u64>,
193
194 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub iss: Option<String>,
197
198 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub aud: Option<String>,
201}
202
203#[derive(Clone, Debug)]
205pub struct AuthState {
206 pub key_store: Arc<ApiKeyStore>,
208
209 pub jwt: Option<JwtConfig>,
211}
212
213impl AuthState {
214 pub fn new(key_store: Arc<ApiKeyStore>, jwt: Option<JwtConfig>) -> Self {
216 Self { key_store, jwt }
217 }
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ApiKey {
223 pub id: String,
225
226 pub name: String,
228
229 pub project_id: String,
231
232 pub scopes: Vec<Scope>,
234
235 pub role: Role,
237
238 #[serde(skip_serializing_if = "Option::is_none")]
240 pub expires_at: Option<String>,
241
242 pub created_at: String,
244
245 #[serde(skip_serializing_if = "Option::is_none")]
247 pub last_used_at: Option<String>,
248}
249
250impl ApiKey {
251 pub fn new(id: String, name: String, project_id: String, role: Role) -> Self {
253 Self {
254 id,
255 name,
256 project_id,
257 scopes: role.allowed_scopes(),
258 role,
259 expires_at: None,
260 created_at: chrono::Utc::now().to_rfc3339(),
261 last_used_at: None,
262 }
263 }
264
265 fn from_jwt_claims(claims: &JwtClaims) -> Self {
267 Self {
268 id: format!("jwt:{}", claims.sub),
269 name: format!("JWT {}", claims.sub),
270 project_id: claims.project_id.clone(),
271 scopes: claims.scopes.clone(),
272 role: Role::from_scopes(&claims.scopes),
273 expires_at: format_timestamp(claims.exp),
274 created_at: claims
275 .iat
276 .and_then(format_timestamp)
277 .unwrap_or_else(|| chrono::Utc::now().to_rfc3339()),
278 last_used_at: None,
279 }
280 }
281
282 pub fn is_expired(&self) -> bool {
284 if let Some(exp) = self
285 .expires_at
286 .as_ref()
287 .and_then(|e| chrono::DateTime::parse_from_rfc3339(e).ok())
288 {
289 return exp.timestamp() < chrono::Utc::now().timestamp();
290 }
291 false
292 }
293
294 pub fn has_scope(&self, scope: Scope) -> bool {
296 self.scopes.contains(&scope)
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct AuthContext {
303 pub api_key: ApiKey,
305
306 pub source_ip: Option<String>,
308}
309
310#[derive(Debug, Default)]
312pub struct ApiKeyStore {
313 keys: Arc<RwLock<HashMap<String, ApiKey>>>,
315}
316
317impl ApiKeyStore {
318 pub fn new() -> Self {
320 Self {
321 keys: Arc::new(RwLock::new(HashMap::new())),
322 }
323 }
324
325 pub async fn add_key(&self, key: ApiKey, raw_key: &str) {
327 let hash = hash_api_key(raw_key);
328 let mut keys = self.keys.write().await;
329 keys.insert(hash, key);
330 }
331
332 pub async fn get_key(&self, raw_key: &str) -> Option<ApiKey> {
334 let hash = hash_api_key(raw_key);
335 let keys = self.keys.read().await;
336 keys.get(&hash).cloned()
337 }
338
339 pub async fn remove_key(&self, raw_key: &str) -> bool {
341 let hash = hash_api_key(raw_key);
342 let mut keys = self.keys.write().await;
343 keys.remove(&hash).is_some()
344 }
345
346 pub async fn list_keys(&self) -> Vec<ApiKey> {
348 let keys = self.keys.read().await;
349 keys.values().cloned().collect()
350 }
351}
352
353enum Credentials {
354 ApiKey(String),
355 Jwt(String),
356}
357
358fn hash_api_key(key: &str) -> String {
360 let mut hasher = Sha256::new();
361 hasher.update(key.as_bytes());
362 format!("{:x}", hasher.finalize())
363}
364
365pub fn validate_key_format(key: &str) -> Result<(), AuthError> {
367 if key.starts_with(API_KEY_PREFIX_LIVE) || key.starts_with(API_KEY_PREFIX_TEST) {
368 let remainder = key
369 .strip_prefix(API_KEY_PREFIX_LIVE)
370 .or_else(|| key.strip_prefix(API_KEY_PREFIX_TEST))
371 .unwrap();
372
373 if remainder.len() >= 32 && remainder.chars().all(|c| c.is_alphanumeric()) {
375 return Ok(());
376 }
377 }
378
379 Err(AuthError::InvalidKeyFormat)
380}
381
382fn format_timestamp(timestamp: u64) -> Option<String> {
383 chrono::DateTime::<chrono::Utc>::from_timestamp(timestamp as i64, 0).map(|dt| dt.to_rfc3339())
384}
385
386fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
387 let auth_header = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
388
389 if let Some(key) = auth_header.strip_prefix("Bearer ") {
390 return Some(Credentials::ApiKey(key.to_string()));
391 }
392
393 if let Some(token) = auth_header.strip_prefix("Token ") {
394 return Some(Credentials::Jwt(token.to_string()));
395 }
396
397 None
398}
399
400fn source_ip(headers: &HeaderMap) -> Option<String> {
401 headers
402 .get("X-Forwarded-For")
403 .and_then(|v| v.to_str().ok())
404 .map(ToOwned::to_owned)
405}
406
407fn unauthorized(message: &str) -> (StatusCode, Json<ApiError>) {
408 (
409 StatusCode::UNAUTHORIZED,
410 Json(ApiError::unauthorized(message)),
411 )
412}
413
414async fn authenticate_api_key(
415 key_store: &ApiKeyStore,
416 api_key_str: &str,
417 headers: &HeaderMap,
418) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
419 validate_key_format(api_key_str).map_err(|_| {
420 warn!(
421 key_prefix = &api_key_str[..10.min(api_key_str.len())],
422 "Invalid API key format"
423 );
424 unauthorized("Invalid API key format")
425 })?;
426
427 let api_key = key_store.get_key(api_key_str).await.ok_or_else(|| {
428 warn!(
429 key_prefix = &api_key_str[..10.min(api_key_str.len())],
430 "Invalid API key"
431 );
432 unauthorized("Invalid API key")
433 })?;
434
435 if api_key.is_expired() {
436 warn!(key_id = %api_key.id, "API key expired");
437 return Err(unauthorized("API key has expired"));
438 }
439
440 Ok(AuthContext {
441 api_key,
442 source_ip: source_ip(headers),
443 })
444}
445
446fn validate_jwt(token: &str, config: &JwtConfig) -> Result<JwtClaims, AuthError> {
447 let validation = config.validation();
448
449 decode::<JwtClaims>(
450 token,
451 &DecodingKey::from_secret(config.secret_bytes()),
452 &validation,
453 )
454 .map(|data| data.claims)
455 .map_err(|error| match error.kind() {
456 ErrorKind::ExpiredSignature => AuthError::ExpiredToken,
457 _ => AuthError::InvalidToken(error.to_string()),
458 })
459}
460
461fn authenticate_jwt(
462 config: Option<&JwtConfig>,
463 token: &str,
464 headers: &HeaderMap,
465) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
466 let config = config.ok_or_else(|| {
467 warn!("JWT token received but JWT authentication is not configured");
468 unauthorized("JWT token authentication is not configured")
469 })?;
470
471 let claims = validate_jwt(token, config).map_err(|error| {
472 match &error {
473 AuthError::ExpiredToken => warn!("Expired JWT token"),
474 AuthError::InvalidToken(_) => warn!("Invalid JWT token"),
475 _ => {}
476 }
477 unauthorized(&error.to_string())
478 })?;
479
480 Ok(AuthContext {
481 api_key: ApiKey::from_jwt_claims(&claims),
482 source_ip: source_ip(headers),
483 })
484}
485
486pub async fn auth_middleware(
488 State(auth_state): State<AuthState>,
489 mut request: Request,
490 next: Next,
491) -> Result<impl IntoResponse, (StatusCode, Json<ApiError>)> {
492 if request.uri().path() == "/health" {
494 return Ok(next.run(request).await);
495 }
496
497 let auth_ctx = match extract_credentials(request.headers()) {
498 Some(Credentials::ApiKey(api_key)) => {
499 authenticate_api_key(&auth_state.key_store, &api_key, request.headers()).await?
500 }
501 Some(Credentials::Jwt(token)) => {
502 authenticate_jwt(auth_state.jwt.as_ref(), &token, request.headers())?
503 }
504 None => {
505 warn!("Missing authentication header");
506 return Err(unauthorized("Missing authentication header"));
507 }
508 };
509
510 request.extensions_mut().insert(auth_ctx);
511
512 Ok(next.run(request).await)
513}
514
515pub fn check_scope(
518 auth_ctx: Option<&AuthContext>,
519 scope: Scope,
520) -> Result<(), (StatusCode, Json<ApiError>)> {
521 match auth_ctx {
522 Some(ctx) if ctx.api_key.has_scope(scope) => Ok(()),
523 Some(ctx) => {
524 warn!(
525 key_id = %ctx.api_key.id,
526 required_scope = %scope,
527 actual_role = %ctx.api_key.role,
528 "Insufficient permissions"
529 );
530 Err((
531 StatusCode::FORBIDDEN,
532 Json(ApiError::forbidden(&format!(
533 "Requires '{}' permission",
534 scope
535 ))),
536 ))
537 }
538 None => Err((
539 StatusCode::UNAUTHORIZED,
540 Json(ApiError::unauthorized("Authentication required")),
541 )),
542 }
543}
544
545pub fn generate_api_key(test: bool) -> String {
547 let prefix = if test {
548 API_KEY_PREFIX_TEST
549 } else {
550 API_KEY_PREFIX_LIVE
551 };
552 let random: String = uuid::Uuid::new_v4()
553 .simple()
554 .to_string()
555 .chars()
556 .take(32)
557 .collect();
558 format!("{}{}", prefix, random)
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use axum::{Extension, Router, routing::get};
565 use jsonwebtoken::{Header, encode};
566 use tower::ServiceExt;
567 use uselesskey::{Factory, HmacFactoryExt, HmacSpec, Seed};
568 use uselesskey_jsonwebtoken::JwtKeyExt;
569
570 fn test_jwt_config() -> JwtConfig {
571 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
572 let factory = Factory::deterministic(seed);
573 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
574 JwtConfig::hs256(fixture.secret_bytes())
575 .issuer("perfgate-tests")
576 .audience("perfgate")
577 }
578
579 fn create_test_claims(scopes: Vec<Scope>, exp: u64) -> JwtClaims {
580 JwtClaims {
581 sub: "ci-bot".to_string(),
582 project_id: "project-1".to_string(),
583 scopes,
584 exp,
585 iat: Some(chrono::Utc::now().timestamp() as u64),
586 iss: Some("perfgate-tests".to_string()),
587 aud: Some("perfgate".to_string()),
588 }
589 }
590
591 fn create_test_token(claims: &JwtClaims) -> String {
592 let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
593 let factory = Factory::deterministic(seed);
594 let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
595 encode(&Header::default(), claims, &fixture.encoding_key()).unwrap()
596 }
597
598 fn auth_test_router(auth_state: AuthState) -> Router {
599 Router::new()
600 .route(
601 "/protected",
602 get(|Extension(auth_ctx): Extension<AuthContext>| async move {
603 auth_ctx.api_key.id
604 }),
605 )
606 .layer(axum::middleware::from_fn_with_state(
607 auth_state,
608 auth_middleware,
609 ))
610 }
611
612 #[test]
613 fn test_validate_key_format() {
614 assert!(validate_key_format("pg_live_abcdefghijklmnopqrstuvwxyz123456").is_ok());
615 assert!(validate_key_format("pg_test_abcdefghijklmnopqrstuvwxyz123456").is_ok());
616 assert!(validate_key_format("invalid_abcdefghijklmnopqrstuvwxyz123456").is_err());
617 assert!(validate_key_format("pg_live_short").is_err());
618 assert!(validate_key_format("pg_live_abcdefghijklmnopqrstuvwxyz12345!@").is_err());
619 }
620
621 #[test]
622 fn test_role_scopes() {
623 let viewer = Role::Viewer;
624 assert!(viewer.has_scope(Scope::Read));
625 assert!(!viewer.has_scope(Scope::Write));
626
627 let contributor = Role::Contributor;
628 assert!(contributor.has_scope(Scope::Read));
629 assert!(contributor.has_scope(Scope::Write));
630 assert!(!contributor.has_scope(Scope::Promote));
631
632 let promoter = Role::Promoter;
633 assert!(promoter.has_scope(Scope::Promote));
634 assert!(!promoter.has_scope(Scope::Delete));
635
636 let admin = Role::Admin;
637 assert!(admin.has_scope(Scope::Delete));
638 assert!(admin.has_scope(Scope::Admin));
639 }
640
641 #[test]
642 fn test_role_from_scopes() {
643 assert_eq!(Role::from_scopes(&[Scope::Read]), Role::Viewer);
644 assert_eq!(
645 Role::from_scopes(&[Scope::Read, Scope::Write]),
646 Role::Contributor
647 );
648 assert_eq!(
649 Role::from_scopes(&[Scope::Read, Scope::Write, Scope::Promote]),
650 Role::Promoter
651 );
652 assert_eq!(Role::from_scopes(&[Scope::Delete]), Role::Admin);
653 }
654
655 #[test]
656 fn test_validate_jwt_success() {
657 let config = test_jwt_config();
658 let claims = create_test_claims(
659 vec![Scope::Read, Scope::Write],
660 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
661 );
662 let token = create_test_token(&claims);
663
664 let decoded = validate_jwt(&token, &config).unwrap();
665
666 assert_eq!(decoded.sub, "ci-bot");
667 assert_eq!(decoded.project_id, "project-1");
668 assert_eq!(decoded.scopes, vec![Scope::Read, Scope::Write]);
669 }
670
671 #[test]
672 fn test_validate_jwt_expired() {
673 let config = test_jwt_config();
674 let claims = create_test_claims(
675 vec![Scope::Read],
676 (chrono::Utc::now() - chrono::Duration::minutes(5)).timestamp() as u64,
677 );
678 let token = create_test_token(&claims);
679
680 let err = validate_jwt(&token, &config).unwrap_err();
681 assert!(matches!(err, AuthError::ExpiredToken));
682 }
683
684 #[test]
685 fn test_api_key_expiration() {
686 let mut key = ApiKey::new(
687 "key-1".to_string(),
688 "Test Key".to_string(),
689 "project-1".to_string(),
690 Role::Viewer,
691 );
692
693 assert!(!key.is_expired());
694
695 key.expires_at = Some("2020-01-01T00:00:00Z".to_string());
696 assert!(key.is_expired());
697
698 key.expires_at = Some("2099-01-01T00:00:00Z".to_string());
699 assert!(!key.is_expired());
700 }
701
702 #[tokio::test]
703 async fn test_api_key_store() {
704 let store = ApiKeyStore::new();
705 let raw_key = generate_api_key(false);
706 let key = ApiKey::new(
707 "key-1".to_string(),
708 "Test Key".to_string(),
709 "project-1".to_string(),
710 Role::Contributor,
711 );
712
713 store.add_key(key.clone(), &raw_key).await;
714
715 let retrieved = store.get_key(&raw_key).await;
716 assert!(retrieved.is_some());
717 let retrieved = retrieved.unwrap();
718 assert_eq!(retrieved.id, "key-1");
719 assert_eq!(retrieved.role, Role::Contributor);
720
721 let keys = store.list_keys().await;
722 assert_eq!(keys.len(), 1);
723
724 let removed = store.remove_key(&raw_key).await;
725 assert!(removed);
726
727 let retrieved = store.get_key(&raw_key).await;
728 assert!(retrieved.is_none());
729 }
730
731 #[tokio::test]
732 async fn test_auth_middleware_accepts_api_key() {
733 let store = Arc::new(ApiKeyStore::new());
734 let key = "pg_test_abcdefghijklmnopqrstuvwxyz123456";
735 store
736 .add_key(
737 ApiKey::new(
738 "api-key-1".to_string(),
739 "API Key".to_string(),
740 "project-1".to_string(),
741 Role::Viewer,
742 ),
743 key,
744 )
745 .await;
746
747 let response = auth_test_router(AuthState::new(store, None))
748 .oneshot(
749 Request::builder()
750 .uri("/protected")
751 .header(header::AUTHORIZATION, format!("Bearer {}", key))
752 .body(axum::body::Body::empty())
753 .unwrap(),
754 )
755 .await
756 .unwrap();
757
758 assert_eq!(response.status(), StatusCode::OK);
759 }
760
761 #[tokio::test]
762 async fn test_auth_middleware_accepts_jwt_token() {
763 let claims = create_test_claims(
764 vec![Scope::Read, Scope::Promote],
765 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
766 );
767 let token = create_test_token(&claims);
768
769 let response = auth_test_router(AuthState::new(
770 Arc::new(ApiKeyStore::new()),
771 Some(test_jwt_config()),
772 ))
773 .oneshot(
774 Request::builder()
775 .uri("/protected")
776 .header(header::AUTHORIZATION, format!("Token {}", token))
777 .body(axum::body::Body::empty())
778 .unwrap(),
779 )
780 .await
781 .unwrap();
782
783 assert_eq!(response.status(), StatusCode::OK);
784 }
785
786 #[tokio::test]
787 async fn test_auth_middleware_rejects_jwt_when_unconfigured() {
788 let claims = create_test_claims(
789 vec![Scope::Read],
790 (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
791 );
792 let token = create_test_token(&claims);
793
794 let response = auth_test_router(AuthState::new(Arc::new(ApiKeyStore::new()), None))
795 .oneshot(
796 Request::builder()
797 .uri("/protected")
798 .header(header::AUTHORIZATION, format!("Token {}", token))
799 .body(axum::body::Body::empty())
800 .unwrap(),
801 )
802 .await
803 .unwrap();
804
805 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
806 }
807
808 #[test]
809 fn test_generate_api_key() {
810 let live_key = generate_api_key(false);
811 assert!(live_key.starts_with(API_KEY_PREFIX_LIVE));
812 assert!(live_key.len() >= 40);
813
814 let test_key = generate_api_key(true);
815 assert!(test_key.starts_with(API_KEY_PREFIX_TEST));
816 assert!(test_key.len() >= 40);
817 }
818
819 #[test]
820 fn test_hash_api_key() {
821 let key = "pg_live_test123456789012345678901234567890";
822 let hash1 = hash_api_key(key);
823 let hash2 = hash_api_key(key);
824
825 assert_eq!(hash1, hash2);
826
827 let different_hash = hash_api_key("pg_live_different1234567890123456789012");
828 assert_ne!(hash1, different_hash);
829 }
830}