1use axum::{
8 body::Body,
9 extract::Request,
10 http::{HeaderMap, StatusCode, Uri},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18use crate::{ApiKeyConfig, JwtConfig, ProblemDetails};
19
20const TYPE_AUTH_ERROR: &str = "https://spikard.dev/errors/unauthorized";
22
23const TYPE_CONFIG_ERROR: &str = "https://spikard.dev/errors/configuration-error";
25
26#[derive(Debug, Serialize, Deserialize)]
28pub struct Claims {
29 pub sub: String,
30 pub exp: usize,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub iat: Option<usize>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub nbf: Option<usize>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub aud: Option<Vec<String>>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub iss: Option<String>,
39}
40
41#[cfg(not(tarpaulin_include))]
53pub async fn jwt_auth_middleware(
54 config: JwtConfig,
55 headers: HeaderMap,
56 request: Request<Body>,
57 next: Next,
58) -> Result<Response, Response> {
59 let auth_header = headers
60 .get("authorization")
61 .and_then(|v| v.to_str().ok())
62 .ok_or_else(|| {
63 let problem = ProblemDetails::new(
64 TYPE_AUTH_ERROR,
65 "Missing or invalid Authorization header",
66 StatusCode::UNAUTHORIZED,
67 )
68 .with_detail("Expected 'Authorization: Bearer <token>'");
69 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
70 })?;
71
72 let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
73 let problem = ProblemDetails::new(
74 TYPE_AUTH_ERROR,
75 "Invalid Authorization header format",
76 StatusCode::UNAUTHORIZED,
77 )
78 .with_detail("Authorization header must use Bearer scheme: 'Bearer <token>'");
79 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
80 })?;
81
82 let parts: Vec<&str> = token.split('.').collect();
83 if parts.len() != 3 {
84 let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Malformed JWT token", StatusCode::UNAUTHORIZED)
85 .with_detail(format!(
86 "Malformed JWT token: expected 3 parts separated by dots, found {}",
87 parts.len()
88 ));
89 return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
90 }
91
92 let algorithm = parse_algorithm(&config.algorithm).map_err(|_| {
93 let problem = ProblemDetails::new(
94 TYPE_CONFIG_ERROR,
95 "Invalid JWT configuration",
96 StatusCode::INTERNAL_SERVER_ERROR,
97 )
98 .with_detail(format!("Unsupported algorithm: {}", config.algorithm));
99 (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(problem)).into_response()
100 })?;
101
102 let mut validation = Validation::new(algorithm);
103 if let Some(ref aud) = config.audience {
104 validation.set_audience(aud);
105 }
106 if let Some(ref iss) = config.issuer {
107 validation.set_issuer(std::slice::from_ref(iss));
108 }
109 validation.leeway = config.leeway;
110 validation.validate_nbf = true;
111
112 let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
113 let _token_data = decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
114 let detail = match e.kind() {
115 jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token has expired".to_string(),
116 jsonwebtoken::errors::ErrorKind::InvalidToken => "Token is invalid".to_string(),
117 jsonwebtoken::errors::ErrorKind::InvalidSignature | jsonwebtoken::errors::ErrorKind::Base64(_) => {
118 "Token signature is invalid".to_string()
119 }
120 jsonwebtoken::errors::ErrorKind::InvalidAudience => "Token audience is invalid".to_string(),
121 jsonwebtoken::errors::ErrorKind::InvalidIssuer => config.issuer.as_ref().map_or_else(
122 || "Token issuer is invalid".to_string(),
123 |expected_iss| format!("Token issuer is invalid, expected '{expected_iss}'"),
124 ),
125 jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
126 "JWT not valid yet, not before claim is in the future".to_string()
127 }
128 _ => format!("Token validation failed: {e}"),
129 };
130
131 let problem =
132 ProblemDetails::new(TYPE_AUTH_ERROR, "JWT validation failed", StatusCode::UNAUTHORIZED).with_detail(detail);
133 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
134 })?;
135
136 Ok(next.run(request).await)
138}
139
140fn parse_algorithm(alg: &str) -> Result<Algorithm, String> {
142 match alg {
143 "HS256" => Ok(Algorithm::HS256),
144 "HS384" => Ok(Algorithm::HS384),
145 "HS512" => Ok(Algorithm::HS512),
146 "RS256" => Ok(Algorithm::RS256),
147 "RS384" => Ok(Algorithm::RS384),
148 "RS512" => Ok(Algorithm::RS512),
149 "ES256" => Ok(Algorithm::ES256),
150 "ES384" => Ok(Algorithm::ES384),
151 "PS256" => Ok(Algorithm::PS256),
152 "PS384" => Ok(Algorithm::PS384),
153 "PS512" => Ok(Algorithm::PS512),
154 _ => Err(format!("Unsupported algorithm: {alg}")),
155 }
156}
157
158#[cfg(not(tarpaulin_include))]
170pub async fn api_key_auth_middleware(
171 config: ApiKeyConfig,
172 headers: HeaderMap,
173 request: Request<Body>,
174 next: Next,
175) -> Result<Response, Response> {
176 let valid_keys: HashSet<String> = config.keys.into_iter().collect();
177
178 let uri = request.uri().clone();
179
180 let api_key_from_header = headers.get(&config.header_name).and_then(|v| v.to_str().ok());
181
182 let api_key = api_key_from_header.map_or_else(|| extract_api_key_from_query(&uri), Some);
183
184 let api_key = api_key.ok_or_else(|| {
185 let problem =
186 ProblemDetails::new(TYPE_AUTH_ERROR, "Missing API key", StatusCode::UNAUTHORIZED).with_detail(format!(
187 "Expected '{}' header or 'api_key' query parameter with valid API key",
188 config.header_name
189 ));
190 (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
191 })?;
192
193 if !valid_keys.contains(api_key) {
194 let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Invalid API key", StatusCode::UNAUTHORIZED)
195 .with_detail("The provided API key is not valid");
196 return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
197 }
198
199 Ok(next.run(request).await)
200}
201
202fn extract_api_key_from_query(uri: &Uri) -> Option<&str> {
206 let query = uri.query()?;
207
208 for param in query.split('&') {
209 if let Some((key, value)) = param.split_once('=')
210 && (key == "api_key" || key == "apiKey" || key == "key")
211 {
212 return Some(value);
213 }
214 }
215
216 None
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_parse_algorithm() {
225 assert!(matches!(parse_algorithm("HS256"), Ok(Algorithm::HS256)));
226 assert!(matches!(parse_algorithm("HS384"), Ok(Algorithm::HS384)));
227 assert!(matches!(parse_algorithm("HS512"), Ok(Algorithm::HS512)));
228 assert!(matches!(parse_algorithm("RS256"), Ok(Algorithm::RS256)));
229 assert!(matches!(parse_algorithm("RS384"), Ok(Algorithm::RS384)));
230 assert!(matches!(parse_algorithm("RS512"), Ok(Algorithm::RS512)));
231 assert!(matches!(parse_algorithm("ES256"), Ok(Algorithm::ES256)));
232 assert!(matches!(parse_algorithm("ES384"), Ok(Algorithm::ES384)));
233 assert!(matches!(parse_algorithm("PS256"), Ok(Algorithm::PS256)));
234 assert!(matches!(parse_algorithm("PS384"), Ok(Algorithm::PS384)));
235 assert!(matches!(parse_algorithm("PS512"), Ok(Algorithm::PS512)));
236 assert!(parse_algorithm("INVALID").is_err());
237 }
238
239 #[test]
240 fn test_claims_serialization() {
241 let claims = Claims {
242 sub: "user123".to_string(),
243 exp: 1234567890,
244 iat: Some(1234567800),
245 nbf: None,
246 aud: Some(vec!["https://api.example.com".to_string()]),
247 iss: Some("https://auth.example.com".to_string()),
248 };
249
250 let json = serde_json::to_string(&claims).unwrap();
251 assert!(json.contains("user123"));
252 assert!(json.contains("1234567890"));
253 }
254
255 #[test]
256 fn test_extract_api_key_from_query_api_key() {
257 let uri: axum::http::Uri = "/api/endpoint?api_key=secret123".parse().unwrap();
258 let result = extract_api_key_from_query(&uri);
259 assert_eq!(result, Some("secret123"));
260 }
261
262 #[test]
263 fn test_extract_api_key_from_query_api_key_camel_case() {
264 let uri: axum::http::Uri = "/api/endpoint?apiKey=mykey456".parse().unwrap();
265 let result = extract_api_key_from_query(&uri);
266 assert_eq!(result, Some("mykey456"));
267 }
268
269 #[test]
270 fn test_extract_api_key_from_query_key() {
271 let uri: axum::http::Uri = "/api/endpoint?key=testkey789".parse().unwrap();
272 let result = extract_api_key_from_query(&uri);
273 assert_eq!(result, Some("testkey789"));
274 }
275
276 #[test]
277 fn test_extract_api_key_from_query_no_key() {
278 let uri: axum::http::Uri = "/api/endpoint?foo=bar&baz=qux".parse().unwrap();
279 let result = extract_api_key_from_query(&uri);
280 assert_eq!(result, None);
281 }
282
283 #[test]
284 fn test_extract_api_key_from_query_empty_string() {
285 let uri: axum::http::Uri = "/api/endpoint".parse().unwrap();
286 let result = extract_api_key_from_query(&uri);
287 assert_eq!(result, None);
288 }
289
290 #[test]
291 fn test_extract_api_key_from_query_multiple_params() {
292 let uri: axum::http::Uri = "/api/endpoint?foo=bar&api_key=found&baz=qux".parse().unwrap();
293 let result = extract_api_key_from_query(&uri);
294 assert_eq!(result, Some("found"));
295 }
296}