mockforge_http/auth/
authenticator.rs

1//! Authentication methods and logic
2//!
3//! This module contains the core authentication logic for different
4//! authentication schemes: JWT, Basic Auth, OAuth2, and API keys.
5
6use base64::Engine;
7use chrono;
8use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
9use serde_json::Value;
10use tracing::debug;
11
12use super::state::AuthState;
13use super::types::{AuthClaims, AuthResult};
14
15/// Authenticate a request using various methods
16pub async fn authenticate_request(
17    state: &AuthState,
18    auth_header: &Option<String>,
19    api_key_header: &Option<String>,
20    api_key_query: &Option<String>,
21) -> AuthResult {
22    // Try JWT/Bearer token first
23    if let Some(header) = auth_header {
24        if header.starts_with("Bearer ") {
25            if let Some(result) = authenticate_jwt(state, header).await {
26                return result;
27            }
28        } else if header.starts_with("Basic ") {
29            if let Some(result) = authenticate_basic(state, header) {
30                return result;
31            }
32        }
33    }
34
35    // Try OAuth2 token introspection
36    if let Some(header) = auth_header {
37        if header.starts_with("Bearer ") {
38            if let Some(result) = authenticate_oauth2(state, header).await {
39                return result;
40            }
41        }
42    }
43
44    // Try API key authentication
45    if let Some(api_key) = api_key_header.as_ref().or(api_key_query.as_ref()) {
46        if let Some(result) = authenticate_api_key(state, api_key) {
47            return result;
48        }
49    }
50
51    // No authentication provided or all methods failed
52    AuthResult::None
53}
54
55/// Authenticate using JWT
56pub async fn authenticate_jwt(state: &AuthState, auth_header: &str) -> Option<AuthResult> {
57    let jwt_config = state.config.jwt.as_ref()?;
58
59    // Extract token from header
60    let token = auth_header.strip_prefix("Bearer ")?;
61
62    // Try to decode header to determine algorithm
63    let header = match decode_header(token) {
64        Ok(h) => h,
65        Err(e) => {
66            debug!("Failed to decode JWT header: {}", e);
67            return Some(AuthResult::Failure("Invalid JWT format".to_string()));
68        }
69    };
70
71    // Check if algorithm is supported
72    let alg_str = match header.alg {
73        Algorithm::HS256 => "HS256",
74        Algorithm::HS384 => "HS384",
75        Algorithm::HS512 => "HS512",
76        Algorithm::RS256 => "RS256",
77        Algorithm::RS384 => "RS384",
78        Algorithm::RS512 => "RS512",
79        Algorithm::ES256 => "ES256",
80        Algorithm::ES384 => "ES384",
81        Algorithm::PS256 => "PS256",
82        Algorithm::PS384 => "PS384",
83        Algorithm::PS512 => "PS512",
84        _ => {
85            debug!("Unsupported JWT algorithm: {:?}", header.alg);
86            return Some(AuthResult::Failure("Unsupported JWT algorithm".to_string()));
87        }
88    };
89
90    if !jwt_config.algorithms.is_empty() && !jwt_config.algorithms.contains(&alg_str.to_string()) {
91        return Some(AuthResult::Failure(format!("Unsupported algorithm: {}", alg_str)));
92    }
93
94    // Create validation
95    let mut validation = Validation::new(header.alg);
96    if let Some(iss) = &jwt_config.issuer {
97        validation.set_issuer(&[iss]);
98    }
99    if let Some(aud) = &jwt_config.audience {
100        validation.set_audience(&[aud]);
101    }
102
103    // Create decoding key based on algorithm
104    let decoding_key = match header.alg {
105        Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
106            let secret = jwt_config
107                .secret
108                .as_ref()
109                .ok_or_else(|| AuthResult::Failure("JWT secret not configured".to_string()))
110                .ok()?;
111            DecodingKey::from_secret(secret.as_bytes())
112        }
113        Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
114            let key = jwt_config
115                .rsa_public_key
116                .as_ref()
117                .ok_or_else(|| AuthResult::Failure("RSA public key not configured".to_string()))
118                .ok()?;
119            DecodingKey::from_rsa_pem(key.as_bytes())
120                .map_err(|e| {
121                    debug!("Failed to parse RSA key: {}", e);
122                    AuthResult::Failure("Invalid RSA key configuration".to_string())
123                })
124                .ok()?
125        }
126        Algorithm::ES256 | Algorithm::ES384 => {
127            let key = jwt_config
128                .ecdsa_public_key
129                .as_ref()
130                .ok_or_else(|| AuthResult::Failure("ECDSA public key not configured".to_string()))
131                .ok()?;
132            DecodingKey::from_ec_pem(key.as_bytes())
133                .map_err(|e| {
134                    debug!("Failed to parse ECDSA key: {}", e);
135                    AuthResult::Failure("Invalid ECDSA key configuration".to_string())
136                })
137                .ok()?
138        }
139        _ => {
140            return Some(AuthResult::Failure("Unsupported algorithm".to_string()));
141        }
142    };
143
144    // Decode and validate token
145    match decode::<Value>(token, &decoding_key, &validation) {
146        Ok(token_data) => {
147            let claims = token_data.claims;
148            let mut auth_claims = AuthClaims::new();
149
150            // Extract standard claims
151            if let Some(sub) = claims.get("sub").and_then(|v| v.as_str()) {
152                auth_claims.sub = Some(sub.to_string());
153            }
154            if let Some(iss) = claims.get("iss").and_then(|v| v.as_str()) {
155                auth_claims.iss = Some(iss.to_string());
156            }
157            if let Some(aud) = claims.get("aud").and_then(|v| v.as_str()) {
158                auth_claims.aud = Some(aud.to_string());
159            }
160            if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
161                auth_claims.exp = Some(exp);
162            }
163            if let Some(iat) = claims.get("iat").and_then(|v| v.as_i64()) {
164                auth_claims.iat = Some(iat);
165            }
166            if let Some(username) = claims
167                .get("username")
168                .or_else(|| claims.get("preferred_username"))
169                .and_then(|v| v.as_str())
170            {
171                auth_claims.username = Some(username.to_string());
172            }
173
174            // Extract roles
175            if let Some(roles) = claims.get("roles").and_then(|v| v.as_array()) {
176                for role in roles {
177                    if let Some(role_str) = role.as_str() {
178                        auth_claims.roles.push(role_str.to_string());
179                    }
180                }
181            }
182
183            // Store custom claims
184            for (key, value) in claims.as_object()? {
185                if ![
186                    "sub",
187                    "iss",
188                    "aud",
189                    "exp",
190                    "iat",
191                    "username",
192                    "preferred_username",
193                    "roles",
194                ]
195                .contains(&key.as_str())
196                {
197                    auth_claims.custom.insert(key.clone(), value.clone());
198                }
199            }
200
201            Some(AuthResult::Success(auth_claims))
202        }
203        Err(e) => {
204            debug!("JWT validation failed: {}", e);
205            Some(AuthResult::Failure(format!("Invalid JWT token: {}", e)))
206        }
207    }
208}
209
210/// Authenticate using Basic Auth
211pub fn authenticate_basic(state: &AuthState, auth_header: &str) -> Option<AuthResult> {
212    let basic_config = state.config.basic_auth.as_ref()?;
213
214    // Extract credentials from header
215    let encoded = auth_header.strip_prefix("Basic ")?;
216    let decoded = match base64::engine::general_purpose::STANDARD.decode(encoded) {
217        Ok(d) => d,
218        Err(_) => return Some(AuthResult::Failure("Invalid base64 in Basic auth".to_string())),
219    };
220    let credentials = match String::from_utf8(decoded) {
221        Ok(c) => c,
222        Err(_) => {
223            return Some(AuthResult::Failure("Invalid UTF-8 in Basic auth credentials".to_string()))
224        }
225    };
226    let parts: Vec<&str> = credentials.splitn(2, ':').collect();
227    if parts.len() != 2 {
228        return Some(AuthResult::Failure("Invalid Basic auth format".to_string()));
229    }
230
231    let username = parts[0];
232    let password = parts[1];
233
234    // Check credentials
235    if let Some(expected_password) = basic_config.credentials.get(username) {
236        if expected_password == password {
237            let mut claims = AuthClaims::new();
238            claims.username = Some(username.to_string());
239            return Some(AuthResult::Success(claims));
240        }
241    }
242
243    Some(AuthResult::Failure("Invalid credentials".to_string()))
244}
245
246/// Authenticate using OAuth2 token introspection
247async fn authenticate_oauth2(state: &AuthState, auth_header: &str) -> Option<AuthResult> {
248    let oauth2_config = state.config.oauth2.as_ref()?;
249
250    // Extract token
251    let token = auth_header.strip_prefix("Bearer ")?;
252
253    // Check cache first
254    {
255        let cache = state.introspection_cache.read().await;
256        if let Some(cached) = cache.get(token) {
257            let now = chrono::Utc::now().timestamp();
258            if cached.expires_at > now {
259                return Some(cached.result.clone());
260            }
261        }
262    }
263
264    // Perform token introspection
265    let client = reqwest::Client::new();
266    let response = match client
267        .post(&oauth2_config.introspection_url)
268        .basic_auth(&oauth2_config.client_id, Some(&oauth2_config.client_secret))
269        .form(&[
270            ("token", token),
271            (
272                "token_type_hint",
273                oauth2_config.token_type_hint.as_deref().unwrap_or("access_token"),
274            ),
275        ])
276        .send()
277        .await
278    {
279        Ok(resp) => resp,
280        Err(e) => {
281            debug!("Network error during OAuth2 introspection: {}", e);
282            return Some(AuthResult::NetworkError(format!(
283                "Failed to connect to introspection endpoint: {}",
284                e
285            )));
286        }
287    };
288
289    if !response.status().is_success() {
290        let status = response.status();
291        debug!("OAuth2 introspection server error: {}", status);
292        return Some(AuthResult::ServerError(format!(
293            "Introspection endpoint returned {}: {}",
294            status,
295            status.canonical_reason().unwrap_or("Unknown error")
296        )));
297    }
298
299    let introspection_result: Value = match response.json().await {
300        Ok(json) => json,
301        Err(e) => {
302            debug!("Failed to parse introspection response: {}", e);
303            return Some(AuthResult::ServerError(format!(
304                "Invalid JSON response from introspection endpoint: {}",
305                e
306            )));
307        }
308    };
309
310    // Check if token is active
311    let active = introspection_result.get("active").and_then(|v| v.as_bool()).unwrap_or(false);
312    if !active {
313        let cached_result = AuthResult::TokenInvalid("Token is not active".to_string());
314        // Cache inactive tokens for a shorter time to avoid repeated checks
315        let expires_at = chrono::Utc::now().timestamp() + 300; // 5 minutes
316        let cached = super::state::CachedIntrospection {
317            result: cached_result.clone(),
318            expires_at,
319        };
320        let mut cache = state.introspection_cache.write().await;
321        cache.insert(token.to_string(), cached);
322        return Some(cached_result);
323    }
324
325    // Check if token is expired
326    if let Some(exp) = introspection_result.get("exp").and_then(|v| v.as_i64()) {
327        let now = chrono::Utc::now().timestamp();
328        if exp <= now {
329            let cached_result = AuthResult::TokenExpired;
330            // Cache expired tokens for a short time
331            let expires_at = chrono::Utc::now().timestamp() + 60; // 1 minute
332            let cached = super::state::CachedIntrospection {
333                result: cached_result.clone(),
334                expires_at,
335            };
336            let mut cache = state.introspection_cache.write().await;
337            cache.insert(token.to_string(), cached);
338            return Some(cached_result);
339        }
340    }
341
342    // Extract claims from introspection response
343    let mut claims = AuthClaims::new();
344    if let Some(sub) = introspection_result.get("sub").and_then(|v| v.as_str()) {
345        claims.sub = Some(sub.to_string());
346    }
347    if let Some(username) = introspection_result.get("username").and_then(|v| v.as_str()) {
348        claims.username = Some(username.to_string());
349    }
350    if let Some(exp) = introspection_result.get("exp").and_then(|v| v.as_i64()) {
351        claims.exp = Some(exp);
352    }
353
354    // Cache successful result - use token expiration or default to 1 hour
355    let expires_at = claims.exp.unwrap_or(chrono::Utc::now().timestamp() + 3600);
356    let cached_result = AuthResult::Success(claims);
357    let cached = super::state::CachedIntrospection {
358        result: cached_result.clone(),
359        expires_at,
360    };
361    let mut cache = state.introspection_cache.write().await;
362    cache.insert(token.to_string(), cached);
363
364    Some(cached_result)
365}
366
367/// Authenticate using API key
368pub fn authenticate_api_key(state: &AuthState, api_key: &str) -> Option<AuthResult> {
369    let api_key_config = state.config.api_key.as_ref()?;
370
371    if api_key_config.keys.contains(&api_key.to_string()) {
372        let mut claims = AuthClaims::new();
373        claims.custom.insert("api_key".to_string(), Value::String(api_key.to_string()));
374        Some(AuthResult::Success(claims))
375    } else {
376        Some(AuthResult::Failure("Invalid API key".to_string()))
377    }
378}