1use axum::{
8 body::Body,
9 extract::Request,
10 http::{HeaderMap, StatusCode, Uri},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18use crate::{ApiKeyConfig, JwtConfig, ProblemDetails};
19
20const TYPE_AUTH_ERROR: &str = "https://spikard.dev/errors/unauthorized";
22
23const TYPE_CONFIG_ERROR: &str = "https://spikard.dev/errors/configuration-error";
25
26pub const INTERNAL_JWT_CLAIMS_HEADER: &str = "x-spikard-jwt-claims";
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Claims {
32 pub sub: String,
33 pub exp: usize,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub iat: Option<usize>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub nbf: Option<usize>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub aud: Option<Vec<String>>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub iss: Option<String>,
42}
43
44#[cfg(not(tarpaulin_include))]
56pub async fn jwt_auth_middleware(
57 config: JwtConfig,
58 headers: HeaderMap,
59 request: Request<Body>,
60 next: Next,
61) -> Result<Response, Response> {
62 let auth_header = headers
63 .get("authorization")
64 .and_then(|v| v.to_str().ok())
65 .ok_or_else(|| {
66 let problem = ProblemDetails::new(
67 TYPE_AUTH_ERROR,
68 "Missing or invalid Authorization header",
69 StatusCode::UNAUTHORIZED,
70 )
71 .with_detail("Expected 'Authorization: Bearer <token>'");
72 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
73 })?;
74
75 let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
76 let problem = ProblemDetails::new(
77 TYPE_AUTH_ERROR,
78 "Invalid Authorization header format",
79 StatusCode::UNAUTHORIZED,
80 )
81 .with_detail("Authorization header must use Bearer scheme: 'Bearer <token>'");
82 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
83 })?;
84
85 let parts: Vec<&str> = token.split('.').collect();
86 if parts.len() != 3 {
87 let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Malformed JWT token", StatusCode::UNAUTHORIZED)
88 .with_detail(format!(
89 "Malformed JWT token: expected 3 parts separated by dots, found {}",
90 parts.len()
91 ));
92 return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
93 }
94
95 let algorithm = parse_algorithm(&config.algorithm).map_err(|_| {
96 let problem = ProblemDetails::new(
97 TYPE_CONFIG_ERROR,
98 "Invalid JWT configuration",
99 StatusCode::INTERNAL_SERVER_ERROR,
100 )
101 .with_detail(format!("Unsupported algorithm: {}", config.algorithm));
102 (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(problem)).into_response()
103 })?;
104
105 let mut validation = Validation::new(algorithm);
106 if let Some(ref aud) = config.audience {
107 validation.set_audience(aud);
108 }
109 if let Some(ref iss) = config.issuer {
110 validation.set_issuer(std::slice::from_ref(iss));
111 }
112 validation.leeway = config.leeway;
113 validation.validate_nbf = true;
114
115 let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
116 let token_data = decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
117 let detail = match e.kind() {
118 jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token has expired".to_string(),
119 jsonwebtoken::errors::ErrorKind::InvalidToken => "Token is invalid".to_string(),
120 jsonwebtoken::errors::ErrorKind::InvalidSignature | jsonwebtoken::errors::ErrorKind::Base64(_) => {
121 "Token signature is invalid".to_string()
122 }
123 jsonwebtoken::errors::ErrorKind::InvalidAudience => "Token audience is invalid".to_string(),
124 jsonwebtoken::errors::ErrorKind::InvalidIssuer => config.issuer.as_ref().map_or_else(
125 || "Token issuer is invalid".to_string(),
126 |expected_iss| format!("Token issuer is invalid, expected '{expected_iss}'"),
127 ),
128 jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
129 "JWT not valid yet, not before claim is in the future".to_string()
130 }
131 _ => format!("Token validation failed: {e}"),
132 };
133
134 let problem =
135 ProblemDetails::new(TYPE_AUTH_ERROR, "JWT validation failed", StatusCode::UNAUTHORIZED).with_detail(detail);
136 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
137 })?;
138
139 let mut request = request;
140 request.extensions_mut().insert(token_data.claims);
141
142 Ok(next.run(request).await)
143}
144
145fn parse_algorithm(alg: &str) -> Result<Algorithm, String> {
147 match alg {
148 "HS256" => Ok(Algorithm::HS256),
149 "HS384" => Ok(Algorithm::HS384),
150 "HS512" => Ok(Algorithm::HS512),
151 "RS256" => Ok(Algorithm::RS256),
152 "RS384" => Ok(Algorithm::RS384),
153 "RS512" => Ok(Algorithm::RS512),
154 "ES256" => Ok(Algorithm::ES256),
155 "ES384" => Ok(Algorithm::ES384),
156 "PS256" => Ok(Algorithm::PS256),
157 "PS384" => Ok(Algorithm::PS384),
158 "PS512" => Ok(Algorithm::PS512),
159 _ => Err(format!("Unsupported algorithm: {alg}")),
160 }
161}
162
163#[cfg(not(tarpaulin_include))]
175pub async fn api_key_auth_middleware(
176 config: ApiKeyConfig,
177 headers: HeaderMap,
178 request: Request<Body>,
179 next: Next,
180) -> Result<Response, Response> {
181 let valid_keys: HashSet<String> = config.keys.into_iter().collect();
182
183 let uri = request.uri().clone();
184
185 let api_key_from_header = headers.get(&config.header_name).and_then(|v| v.to_str().ok());
186
187 let api_key = api_key_from_header.map_or_else(|| extract_api_key_from_query(&uri), Some);
188
189 let api_key = api_key.ok_or_else(|| {
190 let problem =
191 ProblemDetails::new(TYPE_AUTH_ERROR, "Missing API key", StatusCode::UNAUTHORIZED).with_detail(format!(
192 "Expected '{}' header or 'api_key' query parameter with valid API key",
193 config.header_name
194 ));
195 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
196 })?;
197
198 if !valid_keys.contains(api_key) {
199 let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Invalid API key", StatusCode::UNAUTHORIZED)
200 .with_detail("The provided API key is not valid");
201 return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
202 }
203
204 Ok(next.run(request).await)
205}
206
207fn extract_api_key_from_query(uri: &Uri) -> Option<&str> {
211 let query = uri.query()?;
212
213 for param in query.split('&') {
214 if let Some((key, value)) = param.split_once('=')
215 && (key == "api_key" || key == "apiKey" || key == "key")
216 {
217 return Some(value);
218 }
219 }
220
221 None
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_parse_algorithm() {
230 assert!(matches!(parse_algorithm("HS256"), Ok(Algorithm::HS256)));
231 assert!(matches!(parse_algorithm("HS384"), Ok(Algorithm::HS384)));
232 assert!(matches!(parse_algorithm("HS512"), Ok(Algorithm::HS512)));
233 assert!(matches!(parse_algorithm("RS256"), Ok(Algorithm::RS256)));
234 assert!(matches!(parse_algorithm("RS384"), Ok(Algorithm::RS384)));
235 assert!(matches!(parse_algorithm("RS512"), Ok(Algorithm::RS512)));
236 assert!(matches!(parse_algorithm("ES256"), Ok(Algorithm::ES256)));
237 assert!(matches!(parse_algorithm("ES384"), Ok(Algorithm::ES384)));
238 assert!(matches!(parse_algorithm("PS256"), Ok(Algorithm::PS256)));
239 assert!(matches!(parse_algorithm("PS384"), Ok(Algorithm::PS384)));
240 assert!(matches!(parse_algorithm("PS512"), Ok(Algorithm::PS512)));
241 assert!(parse_algorithm("INVALID").is_err());
242 }
243
244 #[test]
245 fn test_claims_serialization() {
246 let claims = Claims {
247 sub: "user123".to_string(),
248 exp: 1234567890,
249 iat: Some(1234567800),
250 nbf: None,
251 aud: Some(vec!["https://api.example.com".to_string()]),
252 iss: Some("https://auth.example.com".to_string()),
253 };
254
255 let json = serde_json::to_string(&claims).unwrap();
256 assert!(json.contains("user123"));
257 assert!(json.contains("1234567890"));
258 }
259
260 #[test]
261 fn test_extract_api_key_from_query_api_key() {
262 let uri: axum::http::Uri = "/api/endpoint?api_key=secret123".parse().unwrap();
263 let result = extract_api_key_from_query(&uri);
264 assert_eq!(result, Some("secret123"));
265 }
266
267 #[test]
268 fn test_extract_api_key_from_query_api_key_camel_case() {
269 let uri: axum::http::Uri = "/api/endpoint?apiKey=mykey456".parse().unwrap();
270 let result = extract_api_key_from_query(&uri);
271 assert_eq!(result, Some("mykey456"));
272 }
273
274 #[test]
275 fn test_extract_api_key_from_query_key() {
276 let uri: axum::http::Uri = "/api/endpoint?key=testkey789".parse().unwrap();
277 let result = extract_api_key_from_query(&uri);
278 assert_eq!(result, Some("testkey789"));
279 }
280
281 #[test]
282 fn test_extract_api_key_from_query_no_key() {
283 let uri: axum::http::Uri = "/api/endpoint?foo=bar&baz=qux".parse().unwrap();
284 let result = extract_api_key_from_query(&uri);
285 assert_eq!(result, None);
286 }
287
288 #[test]
289 fn test_extract_api_key_from_query_empty_string() {
290 let uri: axum::http::Uri = "/api/endpoint".parse().unwrap();
291 let result = extract_api_key_from_query(&uri);
292 assert_eq!(result, None);
293 }
294
295 #[test]
296 fn test_extract_api_key_from_query_multiple_params() {
297 let uri: axum::http::Uri = "/api/endpoint?foo=bar&api_key=found&baz=qux".parse().unwrap();
298 let result = extract_api_key_from_query(&uri);
299 assert_eq!(result, Some("found"));
300 }
301}