Skip to main content

spikard_http/
auth.rs

1//! Authentication middleware for JWT and API keys.
2//!
3//! This module provides tower middleware for authenticating requests using:
4//! - JWT tokens (via the Authorization header)
5//! - API keys (via custom headers)
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{HeaderMap, StatusCode, Uri},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18use crate::{ApiKeyConfig, JwtConfig, ProblemDetails};
19
20/// Standard type URI for authentication errors (401)
21const TYPE_AUTH_ERROR: &str = "https://spikard.dev/errors/unauthorized";
22
23/// Standard type URI for configuration errors (500)
24const TYPE_CONFIG_ERROR: &str = "https://spikard.dev/errors/configuration-error";
25
26/// JWT claims structure - can be extended based on needs
27#[derive(Debug, Serialize, Deserialize)]
28pub struct Claims {
29    pub sub: String,
30    pub exp: usize,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub iat: Option<usize>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub nbf: Option<usize>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub aud: Option<Vec<String>>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub iss: Option<String>,
39}
40
41/// JWT authentication middleware
42///
43/// Validates JWT tokens from the Authorization header (Bearer scheme).
44/// On success, the validated claims are available to downstream handlers.
45/// On failure, returns 401 Unauthorized with RFC 9457 Problem Details.
46///
47/// Coverage: Tested via integration tests (`auth_integration.rs`)
48///
49/// # Errors
50/// Returns an error response when the Authorization header is missing, malformed,
51/// the token is invalid, or configuration is incorrect.
52#[cfg(not(tarpaulin_include))]
53pub async fn jwt_auth_middleware(
54    config: JwtConfig,
55    headers: HeaderMap,
56    request: Request<Body>,
57    next: Next,
58) -> Result<Response, Response> {
59    let auth_header = headers
60        .get("authorization")
61        .and_then(|v| v.to_str().ok())
62        .ok_or_else(|| {
63            let problem = ProblemDetails::new(
64                TYPE_AUTH_ERROR,
65                "Missing or invalid Authorization header",
66                StatusCode::UNAUTHORIZED,
67            )
68            .with_detail("Expected 'Authorization: Bearer <token>'");
69            (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
70        })?;
71
72    let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
73        let problem = ProblemDetails::new(
74            TYPE_AUTH_ERROR,
75            "Invalid Authorization header format",
76            StatusCode::UNAUTHORIZED,
77        )
78        .with_detail("Authorization header must use Bearer scheme: 'Bearer <token>'");
79        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
80    })?;
81
82    let parts: Vec<&str> = token.split('.').collect();
83    if parts.len() != 3 {
84        let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Malformed JWT token", StatusCode::UNAUTHORIZED)
85            .with_detail(format!(
86                "Malformed JWT token: expected 3 parts separated by dots, found {}",
87                parts.len()
88            ));
89        return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
90    }
91
92    let algorithm = parse_algorithm(&config.algorithm).map_err(|_| {
93        let problem = ProblemDetails::new(
94            TYPE_CONFIG_ERROR,
95            "Invalid JWT configuration",
96            StatusCode::INTERNAL_SERVER_ERROR,
97        )
98        .with_detail(format!("Unsupported algorithm: {}", config.algorithm));
99        (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(problem)).into_response()
100    })?;
101
102    let mut validation = Validation::new(algorithm);
103    if let Some(ref aud) = config.audience {
104        validation.set_audience(aud);
105    }
106    if let Some(ref iss) = config.issuer {
107        validation.set_issuer(std::slice::from_ref(iss));
108    }
109    validation.leeway = config.leeway;
110    validation.validate_nbf = true;
111
112    let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
113    let _token_data = decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
114        let detail = match e.kind() {
115            jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token has expired".to_string(),
116            jsonwebtoken::errors::ErrorKind::InvalidToken => "Token is invalid".to_string(),
117            jsonwebtoken::errors::ErrorKind::InvalidSignature | jsonwebtoken::errors::ErrorKind::Base64(_) => {
118                "Token signature is invalid".to_string()
119            }
120            jsonwebtoken::errors::ErrorKind::InvalidAudience => "Token audience is invalid".to_string(),
121            jsonwebtoken::errors::ErrorKind::InvalidIssuer => config.issuer.as_ref().map_or_else(
122                || "Token issuer is invalid".to_string(),
123                |expected_iss| format!("Token issuer is invalid, expected '{expected_iss}'"),
124            ),
125            jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
126                "JWT not valid yet, not before claim is in the future".to_string()
127            }
128            _ => format!("Token validation failed: {e}"),
129        };
130
131        let problem =
132            ProblemDetails::new(TYPE_AUTH_ERROR, "JWT validation failed", StatusCode::UNAUTHORIZED).with_detail(detail);
133        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
134    })?;
135
136    // TODO: Attach claims to request extensions for handlers to access
137    Ok(next.run(request).await)
138}
139
140/// Parse JWT algorithm string to jsonwebtoken Algorithm enum
141fn parse_algorithm(alg: &str) -> Result<Algorithm, String> {
142    match alg {
143        "HS256" => Ok(Algorithm::HS256),
144        "HS384" => Ok(Algorithm::HS384),
145        "HS512" => Ok(Algorithm::HS512),
146        "RS256" => Ok(Algorithm::RS256),
147        "RS384" => Ok(Algorithm::RS384),
148        "RS512" => Ok(Algorithm::RS512),
149        "ES256" => Ok(Algorithm::ES256),
150        "ES384" => Ok(Algorithm::ES384),
151        "PS256" => Ok(Algorithm::PS256),
152        "PS384" => Ok(Algorithm::PS384),
153        "PS512" => Ok(Algorithm::PS512),
154        _ => Err(format!("Unsupported algorithm: {alg}")),
155    }
156}
157
158/// API Key authentication middleware
159///
160/// Validates API keys from a custom header (default: X-API-Key) or query parameter.
161/// Checks header first, then query parameter as fallback.
162/// On success, the request proceeds to the next handler.
163/// On failure, returns 401 Unauthorized with RFC 9457 Problem Details.
164///
165/// Coverage: Tested via integration tests (`auth_integration.rs`)
166///
167/// # Errors
168/// Returns an error response when the API key is missing or invalid.
169#[cfg(not(tarpaulin_include))]
170pub async fn api_key_auth_middleware(
171    config: ApiKeyConfig,
172    headers: HeaderMap,
173    request: Request<Body>,
174    next: Next,
175) -> Result<Response, Response> {
176    let valid_keys: HashSet<String> = config.keys.into_iter().collect();
177
178    let uri = request.uri().clone();
179
180    let api_key_from_header = headers.get(&config.header_name).and_then(|v| v.to_str().ok());
181
182    let api_key = api_key_from_header.map_or_else(|| extract_api_key_from_query(&uri), Some);
183
184    let api_key = api_key.ok_or_else(|| {
185        let problem =
186            ProblemDetails::new(TYPE_AUTH_ERROR, "Missing API key", StatusCode::UNAUTHORIZED).with_detail(format!(
187                "Expected '{}' header or 'api_key' query parameter with valid API key",
188                config.header_name
189            ));
190        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
191    })?;
192
193    if !valid_keys.contains(api_key) {
194        let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Invalid API key", StatusCode::UNAUTHORIZED)
195            .with_detail("The provided API key is not valid");
196        return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
197    }
198
199    Ok(next.run(request).await)
200}
201
202/// Extract API key from query parameters
203///
204/// Checks for common API key parameter names: api_key, apiKey, key
205fn extract_api_key_from_query(uri: &Uri) -> Option<&str> {
206    let query = uri.query()?;
207
208    for param in query.split('&') {
209        if let Some((key, value)) = param.split_once('=')
210            && (key == "api_key" || key == "apiKey" || key == "key")
211        {
212            return Some(value);
213        }
214    }
215
216    None
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_parse_algorithm() {
225        assert!(matches!(parse_algorithm("HS256"), Ok(Algorithm::HS256)));
226        assert!(matches!(parse_algorithm("HS384"), Ok(Algorithm::HS384)));
227        assert!(matches!(parse_algorithm("HS512"), Ok(Algorithm::HS512)));
228        assert!(matches!(parse_algorithm("RS256"), Ok(Algorithm::RS256)));
229        assert!(matches!(parse_algorithm("RS384"), Ok(Algorithm::RS384)));
230        assert!(matches!(parse_algorithm("RS512"), Ok(Algorithm::RS512)));
231        assert!(matches!(parse_algorithm("ES256"), Ok(Algorithm::ES256)));
232        assert!(matches!(parse_algorithm("ES384"), Ok(Algorithm::ES384)));
233        assert!(matches!(parse_algorithm("PS256"), Ok(Algorithm::PS256)));
234        assert!(matches!(parse_algorithm("PS384"), Ok(Algorithm::PS384)));
235        assert!(matches!(parse_algorithm("PS512"), Ok(Algorithm::PS512)));
236        assert!(parse_algorithm("INVALID").is_err());
237    }
238
239    #[test]
240    fn test_claims_serialization() {
241        let claims = Claims {
242            sub: "user123".to_string(),
243            exp: 1234567890,
244            iat: Some(1234567800),
245            nbf: None,
246            aud: Some(vec!["https://api.example.com".to_string()]),
247            iss: Some("https://auth.example.com".to_string()),
248        };
249
250        let json = serde_json::to_string(&claims).unwrap();
251        assert!(json.contains("user123"));
252        assert!(json.contains("1234567890"));
253    }
254
255    #[test]
256    fn test_extract_api_key_from_query_api_key() {
257        let uri: axum::http::Uri = "/api/endpoint?api_key=secret123".parse().unwrap();
258        let result = extract_api_key_from_query(&uri);
259        assert_eq!(result, Some("secret123"));
260    }
261
262    #[test]
263    fn test_extract_api_key_from_query_api_key_camel_case() {
264        let uri: axum::http::Uri = "/api/endpoint?apiKey=mykey456".parse().unwrap();
265        let result = extract_api_key_from_query(&uri);
266        assert_eq!(result, Some("mykey456"));
267    }
268
269    #[test]
270    fn test_extract_api_key_from_query_key() {
271        let uri: axum::http::Uri = "/api/endpoint?key=testkey789".parse().unwrap();
272        let result = extract_api_key_from_query(&uri);
273        assert_eq!(result, Some("testkey789"));
274    }
275
276    #[test]
277    fn test_extract_api_key_from_query_no_key() {
278        let uri: axum::http::Uri = "/api/endpoint?foo=bar&baz=qux".parse().unwrap();
279        let result = extract_api_key_from_query(&uri);
280        assert_eq!(result, None);
281    }
282
283    #[test]
284    fn test_extract_api_key_from_query_empty_string() {
285        let uri: axum::http::Uri = "/api/endpoint".parse().unwrap();
286        let result = extract_api_key_from_query(&uri);
287        assert_eq!(result, None);
288    }
289
290    #[test]
291    fn test_extract_api_key_from_query_multiple_params() {
292        let uri: axum::http::Uri = "/api/endpoint?foo=bar&api_key=found&baz=qux".parse().unwrap();
293        let result = extract_api_key_from_query(&uri);
294        assert_eq!(result, Some("found"));
295    }
296}