Skip to main content

perfgate_server/
auth.rs

1//! Authentication and authorization middleware.
2//!
3//! This module provides API key and JWT token validation for the baseline service.
4
5use axum::{
6    Json,
7    extract::{Request, State},
8    http::{HeaderMap, StatusCode, header},
9    middleware::Next,
10    response::IntoResponse,
11};
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, errors::ErrorKind};
13pub use perfgate_auth::{ApiKey, JwtClaims, Role, Scope, validate_key_format};
14use perfgate_error::AuthError;
15use sha2::{Digest, Sha256};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::warn;
20
21use crate::models::ApiError;
22use crate::oidc::OidcProvider;
23
24/// JWT validation settings.
25#[derive(Clone)]
26pub struct JwtConfig {
27    secret: Vec<u8>,
28    issuer: Option<String>,
29    audience: Option<String>,
30}
31
32impl JwtConfig {
33    /// Creates an HS256 JWT configuration from raw secret bytes.
34    pub fn hs256(secret: impl Into<Vec<u8>>) -> Self {
35        Self {
36            secret: secret.into(),
37            issuer: None,
38            audience: None,
39        }
40    }
41
42    /// Sets the expected issuer claim.
43    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
44        self.issuer = Some(issuer.into());
45        self
46    }
47
48    /// Sets the expected audience claim.
49    pub fn audience(mut self, audience: impl Into<String>) -> Self {
50        self.audience = Some(audience.into());
51        self
52    }
53
54    /// Returns the configured secret bytes.
55    pub fn secret_bytes(&self) -> &[u8] {
56        &self.secret
57    }
58
59    fn validation(&self) -> Validation {
60        let mut validation = Validation::new(Algorithm::HS256);
61        if let Some(issuer) = &self.issuer {
62            validation.set_issuer(&[issuer.as_str()]);
63        }
64        if let Some(audience) = &self.audience {
65            validation.set_audience(&[audience.as_str()]);
66        }
67        validation
68    }
69}
70
71impl std::fmt::Debug for JwtConfig {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("JwtConfig")
74            .field("secret", &"<redacted>")
75            .field("issuer", &self.issuer)
76            .field("audience", &self.audience)
77            .finish()
78    }
79}
80
81/// Authentication state shared by middleware.
82#[derive(Clone)]
83pub struct AuthState {
84    /// In-memory API key store.
85    pub key_store: Arc<ApiKeyStore>,
86
87    /// Optional JWT validation settings.
88    pub jwt: Option<JwtConfig>,
89
90    /// Optional OIDC provider.
91    pub oidc: Option<OidcProvider>,
92}
93
94impl AuthState {
95    /// Creates auth state from a key store and optional JWT/OIDC configuration.
96    pub fn new(
97        key_store: Arc<ApiKeyStore>,
98        jwt: Option<JwtConfig>,
99        oidc: Option<OidcProvider>,
100    ) -> Self {
101        Self {
102            key_store,
103            jwt,
104            oidc,
105        }
106    }
107}
108
109/// Authenticated user context extracted from requests.
110#[derive(Debug, Clone)]
111pub struct AuthContext {
112    /// API key information
113    pub api_key: ApiKey,
114
115    /// Source IP address
116    pub source_ip: Option<String>,
117}
118
119/// In-memory API key store for development and testing.
120#[derive(Debug, Default)]
121pub struct ApiKeyStore {
122    /// Keys indexed by key hash
123    keys: Arc<RwLock<HashMap<String, ApiKey>>>,
124}
125
126impl ApiKeyStore {
127    /// Creates a new empty key store.
128    pub fn new() -> Self {
129        Self {
130            keys: Arc::new(RwLock::new(HashMap::new())),
131        }
132    }
133
134    /// Adds an API key to the store.
135    pub async fn add_key(&self, key: ApiKey, raw_key: &str) {
136        let hash = hash_api_key(raw_key);
137        let mut keys = self.keys.write().await;
138        keys.insert(hash, key);
139    }
140
141    /// Looks up an API key by its hash.
142    pub async fn get_key(&self, raw_key: &str) -> Option<ApiKey> {
143        let hash = hash_api_key(raw_key);
144        let keys = self.keys.read().await;
145        keys.get(&hash).cloned()
146    }
147
148    /// Removes an API key from the store.
149    pub async fn remove_key(&self, raw_key: &str) -> bool {
150        let hash = hash_api_key(raw_key);
151        let mut keys = self.keys.write().await;
152        keys.remove(&hash).is_some()
153    }
154
155    /// Lists all API keys (without sensitive data).
156    pub async fn list_keys(&self) -> Vec<ApiKey> {
157        let keys = self.keys.read().await;
158        keys.values().cloned().collect()
159    }
160}
161
162enum Credentials {
163    ApiKey(String),
164    Jwt(String),
165}
166
167/// Hashes an API key for storage.
168fn hash_api_key(key: &str) -> String {
169    let mut hasher = Sha256::new();
170    hasher.update(key.as_bytes());
171    format!("{:x}", hasher.finalize())
172}
173
174fn extract_credentials(headers: &HeaderMap) -> Option<Credentials> {
175    let auth_header = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
176
177    if let Some(key) = auth_header.strip_prefix("Bearer ") {
178        return Some(Credentials::ApiKey(key.to_string()));
179    }
180
181    if let Some(token) = auth_header.strip_prefix("Token ") {
182        return Some(Credentials::Jwt(token.to_string()));
183    }
184
185    None
186}
187
188fn source_ip(headers: &HeaderMap) -> Option<String> {
189    headers
190        .get("X-Forwarded-For")
191        .and_then(|v| v.to_str().ok())
192        .map(ToOwned::to_owned)
193}
194
195fn unauthorized(message: &str) -> (StatusCode, Json<ApiError>) {
196    (
197        StatusCode::UNAUTHORIZED,
198        Json(ApiError::unauthorized(message)),
199    )
200}
201
202async fn authenticate_api_key(
203    key_store: &ApiKeyStore,
204    api_key_str: &str,
205    headers: &HeaderMap,
206) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
207    validate_key_format(api_key_str).map_err(|_| {
208        warn!(
209            key_prefix = &api_key_str[..10.min(api_key_str.len())],
210            "Invalid API key format"
211        );
212        unauthorized("Invalid API key format")
213    })?;
214
215    let api_key = key_store.get_key(api_key_str).await.ok_or_else(|| {
216        warn!(
217            key_prefix = &api_key_str[..10.min(api_key_str.len())],
218            "Invalid API key"
219        );
220        unauthorized("Invalid API key")
221    })?;
222
223    if api_key.is_expired() {
224        warn!(key_id = %api_key.id, "API key expired");
225        return Err(unauthorized("API key has expired"));
226    }
227
228    Ok(AuthContext {
229        api_key,
230        source_ip: source_ip(headers),
231    })
232}
233
234fn validate_jwt(token: &str, config: &JwtConfig) -> Result<JwtClaims, AuthError> {
235    let validation = config.validation();
236
237    decode::<JwtClaims>(
238        token,
239        &DecodingKey::from_secret(config.secret_bytes()),
240        &validation,
241    )
242    .map(|data| data.claims)
243    .map_err(|error| match error.kind() {
244        ErrorKind::ExpiredSignature => AuthError::ExpiredToken,
245        _ => AuthError::InvalidToken(error.to_string()),
246    })
247}
248
249async fn authenticate_jwt(
250    auth_state: &AuthState,
251    token: &str,
252    headers: &HeaderMap,
253) -> Result<AuthContext, (StatusCode, Json<ApiError>)> {
254    // Try static JWT config if available
255    if let Some(config) = &auth_state.jwt {
256        match validate_jwt(token, config) {
257            Ok(claims) => {
258                return Ok(AuthContext {
259                    api_key: api_key_from_jwt_claims(&claims),
260                    source_ip: source_ip(headers),
261                });
262            }
263            Err(e) => {
264                // If we don't have an OIDC provider, fail here.
265                // Otherwise, fall through to OIDC.
266                if auth_state.oidc.is_none() {
267                    match &e {
268                        AuthError::ExpiredToken => warn!("Expired JWT token"),
269                        AuthError::InvalidToken(_) => warn!("Invalid JWT token"),
270                        _ => {}
271                    }
272                    return Err(unauthorized(&e.to_string()));
273                }
274            }
275        }
276    }
277
278    // Try OIDC provider if available
279    if let Some(oidc) = &auth_state.oidc {
280        match oidc.validate_token(token).await {
281            Ok(api_key) => {
282                return Ok(AuthContext {
283                    api_key,
284                    source_ip: source_ip(headers),
285                });
286            }
287            Err(e) => {
288                match &e {
289                    AuthError::ExpiredToken => warn!("Expired OIDC token"),
290                    AuthError::InvalidToken(msg) => warn!("Invalid OIDC token: {}", msg),
291                    _ => {}
292                }
293                return Err(unauthorized(&e.to_string()));
294            }
295        }
296    }
297
298    warn!("JWT token received but no JWT or OIDC authentication is configured");
299    Err(unauthorized("JWT/OIDC authentication is not configured"))
300}
301
302fn api_key_from_jwt_claims(claims: &JwtClaims) -> ApiKey {
303    ApiKey {
304        id: format!("jwt:{}", claims.sub),
305        name: format!("JWT {}", claims.sub),
306        project_id: claims.project_id.clone(),
307        scopes: claims.scopes.clone(),
308        role: Role::from_scopes(&claims.scopes),
309        benchmark_regex: None,
310        expires_at: Some(
311            chrono::DateTime::<chrono::Utc>::from_timestamp(claims.exp as i64, 0)
312                .unwrap_or_else(chrono::Utc::now),
313        ),
314        created_at: claims
315            .iat
316            .and_then(|iat| chrono::DateTime::<chrono::Utc>::from_timestamp(iat as i64, 0))
317            .unwrap_or_else(chrono::Utc::now),
318        last_used_at: None,
319    }
320}
321
322/// Authentication middleware.
323pub async fn auth_middleware(
324    State(auth_state): State<AuthState>,
325    mut request: Request,
326    next: Next,
327) -> Result<impl IntoResponse, (StatusCode, Json<ApiError>)> {
328    // Skip auth for health endpoint
329    if request.uri().path() == "/health" {
330        return Ok(next.run(request).await);
331    }
332
333    let auth_ctx = match extract_credentials(request.headers()) {
334        Some(Credentials::ApiKey(api_key)) => {
335            authenticate_api_key(&auth_state.key_store, &api_key, request.headers()).await?
336        }
337        Some(Credentials::Jwt(token)) => {
338            authenticate_jwt(&auth_state, &token, request.headers()).await?
339        }
340        None => {
341            warn!("Missing authentication header");
342            return Err(unauthorized("Missing authentication header"));
343        }
344    };
345
346    request.extensions_mut().insert(auth_ctx);
347
348    Ok(next.run(request).await)
349}
350
351/// Checks if the current auth context has the required scope, project access, and benchmark access.
352/// Returns an error response if the scope is not present, project mismatch, or benchmark restricted.
353pub fn check_scope(
354    auth_ctx: Option<&AuthContext>,
355    project_id: &str,
356    benchmark: Option<&str>,
357    scope: Scope,
358) -> Result<(), (StatusCode, Json<ApiError>)> {
359    let ctx = match auth_ctx {
360        Some(ctx) => ctx,
361        None => {
362            return Err((
363                StatusCode::UNAUTHORIZED,
364                Json(ApiError::unauthorized("Authentication required")),
365            ));
366        }
367    };
368
369    // 1. Check Scope
370    if !ctx.api_key.has_scope(scope) {
371        warn!(
372            key_id = %ctx.api_key.id,
373            required_scope = %scope,
374            actual_role = %ctx.api_key.role,
375            "Insufficient permissions: scope mismatch"
376        );
377        return Err((
378            StatusCode::FORBIDDEN,
379            Json(ApiError::forbidden(&format!(
380                "Requires '{}' permission",
381                scope
382            ))),
383        ));
384    }
385
386    // 2. Check Project Isolation
387    // Global admins (those with Scope::Admin) can access any project.
388    // Otherwise, the key's project_id must match the requested project_id.
389    if !ctx.api_key.has_scope(Scope::Admin) && ctx.api_key.project_id != project_id {
390        warn!(
391            key_id = %ctx.api_key.id,
392            key_project = %ctx.api_key.project_id,
393            requested_project = %project_id,
394            "Insufficient permissions: project isolation violation"
395        );
396        return Err((
397            StatusCode::FORBIDDEN,
398            Json(ApiError::forbidden(&format!(
399                "Key is restricted to project '{}'",
400                ctx.api_key.project_id
401            ))),
402        ));
403    }
404
405    // 3. Check Benchmark Restriction
406    // If the key has a benchmark_regex, all accessed benchmarks must match it.
407    if let (Some(regex_str), Some(bench)) = (&ctx.api_key.benchmark_regex, benchmark) {
408        let regex = regex::Regex::new(regex_str).map_err(|e| {
409            warn!(key_id = %ctx.api_key.id, regex = %regex_str, error = %e, "Invalid benchmark regex in API key");
410            (
411                StatusCode::INTERNAL_SERVER_ERROR,
412                Json(ApiError::internal_error("Invalid security configuration")),
413            )
414        })?;
415
416        if !regex.is_match(bench) {
417            warn!(
418                key_id = %ctx.api_key.id,
419                benchmark = %bench,
420                regex = %regex_str,
421                "Insufficient permissions: benchmark restriction violation"
422            );
423            return Err((
424                StatusCode::FORBIDDEN,
425                Json(ApiError::forbidden(&format!(
426                    "Key is restricted to benchmarks matching '{}'",
427                    regex_str
428                ))),
429            ));
430        }
431    }
432
433    Ok(())
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use axum::{Extension, Router, routing::get};
440    use jsonwebtoken::{Header, encode};
441    use perfgate_auth::generate_api_key;
442    use tower::ServiceExt;
443    use uselesskey::{Factory, HmacFactoryExt, HmacSpec, Seed};
444    use uselesskey_jsonwebtoken::JwtKeyExt;
445
446    fn test_jwt_config() -> JwtConfig {
447        let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
448        let factory = Factory::deterministic(seed);
449        let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
450        JwtConfig::hs256(fixture.secret_bytes())
451            .issuer("perfgate-tests")
452            .audience("perfgate")
453    }
454
455    fn create_test_claims(scopes: Vec<Scope>, exp: u64) -> JwtClaims {
456        JwtClaims {
457            sub: "ci-bot".to_string(),
458            project_id: "project-1".to_string(),
459            scopes,
460            exp,
461            iat: Some(chrono::Utc::now().timestamp() as u64),
462            iss: Some("perfgate-tests".to_string()),
463            aud: Some("perfgate".to_string()),
464        }
465    }
466
467    fn create_test_token(claims: &JwtClaims) -> String {
468        let seed = Seed::from_env_value("perfgate-server-auth-tests").unwrap();
469        let factory = Factory::deterministic(seed);
470        let fixture = factory.hmac("jwt-auth", HmacSpec::hs256());
471        encode(&Header::default(), claims, &fixture.encoding_key()).unwrap()
472    }
473
474    fn auth_test_router(auth_state: AuthState) -> Router {
475        Router::new()
476            .route(
477                "/protected",
478                get(|Extension(auth_ctx): Extension<AuthContext>| async move {
479                    auth_ctx.api_key.id
480                }),
481            )
482            .layer(axum::middleware::from_fn_with_state(
483                auth_state,
484                auth_middleware,
485            ))
486    }
487
488    #[tokio::test]
489    async fn test_api_key_store() {
490        let store = ApiKeyStore::new();
491        let raw_key = generate_api_key(false);
492        let key = ApiKey::new(
493            "key-1".to_string(),
494            "Test Key".to_string(),
495            "project-1".to_string(),
496            Role::Contributor,
497        );
498
499        store.add_key(key.clone(), &raw_key).await;
500
501        let retrieved = store.get_key(&raw_key).await;
502        assert!(retrieved.is_some());
503        let retrieved = retrieved.unwrap();
504        assert_eq!(retrieved.id, "key-1");
505        assert_eq!(retrieved.role, Role::Contributor);
506
507        let keys = store.list_keys().await;
508        assert_eq!(keys.len(), 1);
509
510        let removed = store.remove_key(&raw_key).await;
511        assert!(removed);
512
513        let retrieved = store.get_key(&raw_key).await;
514        assert!(retrieved.is_none());
515    }
516
517    #[tokio::test]
518    async fn test_auth_middleware_accepts_api_key() {
519        let store = Arc::new(ApiKeyStore::new());
520        let key = "pg_test_abcdefghijklmnopqrstuvwxyz123456";
521        store
522            .add_key(
523                ApiKey::new(
524                    "api-key-1".to_string(),
525                    "API Key".to_string(),
526                    "project-1".to_string(),
527                    Role::Viewer,
528                ),
529                key,
530            )
531            .await;
532
533        let response = auth_test_router(AuthState::new(store, None, None))
534            .oneshot(
535                Request::builder()
536                    .uri("/protected")
537                    .header(header::AUTHORIZATION, format!("Bearer {}", key))
538                    .body(axum::body::Body::empty())
539                    .unwrap(),
540            )
541            .await
542            .unwrap();
543
544        assert_eq!(response.status(), StatusCode::OK);
545    }
546
547    #[tokio::test]
548    async fn test_auth_middleware_accepts_jwt_token() {
549        let claims = create_test_claims(
550            vec![Scope::Read, Scope::Promote],
551            (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
552        );
553        let token = create_test_token(&claims);
554
555        let response = auth_test_router(AuthState::new(
556            Arc::new(ApiKeyStore::new()),
557            Some(test_jwt_config()),
558            None,
559        ))
560        .oneshot(
561            Request::builder()
562                .uri("/protected")
563                .header(header::AUTHORIZATION, format!("Token {}", token))
564                .body(axum::body::Body::empty())
565                .unwrap(),
566        )
567        .await
568        .unwrap();
569
570        assert_eq!(response.status(), StatusCode::OK);
571    }
572
573    #[tokio::test]
574    async fn test_auth_middleware_rejects_jwt_when_unconfigured() {
575        let claims = create_test_claims(
576            vec![Scope::Read],
577            (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp() as u64,
578        );
579        let token = create_test_token(&claims);
580
581        let response = auth_test_router(AuthState::new(Arc::new(ApiKeyStore::new()), None, None))
582            .oneshot(
583                Request::builder()
584                    .uri("/protected")
585                    .header(header::AUTHORIZATION, format!("Token {}", token))
586                    .body(axum::body::Body::empty())
587                    .unwrap(),
588            )
589            .await
590            .unwrap();
591
592        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
593    }
594
595    #[test]
596    fn test_hash_api_key() {
597        let key = "pg_live_test123456789012345678901234567890";
598        let hash1 = hash_api_key(key);
599        let hash2 = hash_api_key(key);
600
601        assert_eq!(hash1, hash2);
602
603        let different_hash = hash_api_key("pg_live_different1234567890123456789012");
604        assert_ne!(hash1, different_hash);
605    }
606
607    #[test]
608    fn test_check_scope_project_isolation() {
609        let key = ApiKey::new(
610            "k1".to_string(),
611            "n1".to_string(),
612            "project-a".to_string(),
613            Role::Contributor,
614        );
615        let ctx = AuthContext {
616            api_key: key,
617            source_ip: None,
618        };
619
620        // Same project, correct scope -> OK
621        assert!(check_scope(Some(&ctx), "project-a", None, Scope::Write).is_ok());
622        assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
623
624        // Same project, wrong scope -> Forbidden
625        let res = check_scope(Some(&ctx), "project-a", None, Scope::Delete);
626        assert!(res.is_err());
627        assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
628
629        // Different project -> Forbidden
630        let res = check_scope(Some(&ctx), "project-b", None, Scope::Read);
631        assert!(res.is_err());
632        assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
633    }
634
635    #[test]
636    fn test_check_scope_global_admin() {
637        let key = ApiKey::new(
638            "k1".to_string(),
639            "admin".to_string(),
640            "any-project".to_string(),
641            Role::Admin,
642        );
643        let ctx = AuthContext {
644            api_key: key,
645            source_ip: None,
646        };
647
648        // Global admin can access ANY project
649        assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
650        assert!(check_scope(Some(&ctx), "project-b", None, Scope::Delete).is_ok());
651        assert!(check_scope(Some(&ctx), "other", None, Scope::Admin).is_ok());
652    }
653
654    #[test]
655    fn test_check_scope_benchmark_restriction() {
656        let mut key = ApiKey::new(
657            "k1".to_string(),
658            "n1".to_string(),
659            "project-a".to_string(),
660            Role::Contributor,
661        );
662        key.benchmark_regex = Some("^web-.*$".to_string());
663
664        let ctx = AuthContext {
665            api_key: key,
666            source_ip: None,
667        };
668
669        // Matches regex -> OK
670        assert!(check_scope(Some(&ctx), "project-a", Some("web-auth"), Scope::Read).is_ok());
671        assert!(check_scope(Some(&ctx), "project-a", Some("web-api"), Scope::Write).is_ok());
672
673        // Does not match regex -> Forbidden
674        let res = check_scope(Some(&ctx), "project-a", Some("worker-job"), Scope::Read);
675        assert!(res.is_err());
676        assert_eq!(res.unwrap_err().0, StatusCode::FORBIDDEN);
677
678        // No benchmark name provided (e.g. list operation) -> OK (scoping only applies to explicit access)
679        assert!(check_scope(Some(&ctx), "project-a", None, Scope::Read).is_ok());
680    }
681}