1#![forbid(unsafe_code)]
2
3use axum::extract::{FromRequestParts, State};
11use axum::http::{Request, StatusCode, header::AUTHORIZATION, request::Parts};
12use axum::middleware::Next;
13use axum::response::{IntoResponse, Response};
14use axum::routing::get;
15use axum::{Json, Router};
16use jsonwebtoken::{
17 Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, decode_header, encode,
18};
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21use std::sync::Arc;
22use uselesskey_core::{Factory, Seed};
23use uselesskey_rsa::{RsaFactoryExt, RsaKeyPair, RsaSpec};
24
25const DEFAULT_JWKS_PATH: &str = "/.well-known/jwks.json";
26const DEFAULT_OIDC_PATH: &str = "/.well-known/openid-configuration";
27
28#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
30pub struct AuthExpectations {
31 pub issuer: String,
33 pub audience: String,
35 pub kid: String,
37}
38
39impl AuthExpectations {
40 pub fn new(
42 issuer: impl Into<String>,
43 audience: impl Into<String>,
44 kid: impl Into<String>,
45 ) -> Self {
46 Self {
47 issuer: issuer.into(),
48 audience: audience.into(),
49 kid: kid.into(),
50 }
51 }
52
53 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
55 self.issuer = issuer.into();
56 self
57 }
58
59 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
61 self.audience = audience.into();
62 self
63 }
64
65 pub fn with_kid(mut self, kid: impl Into<String>) -> Self {
67 self.kid = kid.into();
68 self
69 }
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
74pub enum RotationPhase {
75 Primary,
77 Next,
79}
80
81impl RotationPhase {
82 fn suffix(self) -> &'static str {
83 match self {
84 Self::Primary => "primary",
85 Self::Next => "next",
86 }
87 }
88}
89
90#[derive(Clone)]
92pub struct DeterministicJwksPhase {
93 keypair: RsaKeyPair,
94 expectations: AuthExpectations,
95}
96
97impl DeterministicJwksPhase {
98 pub fn new(
100 seed: Seed,
101 label: impl AsRef<str>,
102 phase: RotationPhase,
103 issuer: impl Into<String>,
104 audience: impl Into<String>,
105 ) -> Self {
106 let fx = Factory::deterministic(seed);
107 let keypair = fx.rsa(
108 format!("{}:{}", label.as_ref(), phase.suffix()),
109 RsaSpec::rs256(),
110 );
111 let kid = keypair.kid();
112 Self {
113 keypair,
114 expectations: AuthExpectations::new(issuer, audience, kid),
115 }
116 }
117
118 pub fn jwks_json(&self) -> Value {
120 self.keypair.public_jwks_json()
121 }
122
123 pub fn expectations(&self) -> &AuthExpectations {
125 &self.expectations
126 }
127
128 pub fn issue_token(&self, mut claims: Value, ttl_seconds: u64) -> String {
130 let now = current_unix_seconds();
131 if claims.get("iss").is_none() {
132 claims["iss"] = Value::String(self.expectations.issuer.clone());
133 }
134 if claims.get("aud").is_none() {
135 claims["aud"] = Value::String(self.expectations.audience.clone());
136 }
137 if claims.get("iat").is_none() {
138 claims["iat"] = Value::Number((now as u64).into());
139 }
140 if claims.get("exp").is_none() {
141 claims["exp"] = Value::Number((now as u64 + ttl_seconds).into());
142 }
143
144 let mut header = Header::new(Algorithm::RS256);
145 header.kid = Some(self.expectations.kid.clone());
146
147 encode(
148 &header,
149 &claims,
150 &EncodingKey::from_rsa_pem(self.keypair.private_key_pkcs8_pem().as_bytes())
151 .expect("deterministic fixture key should produce valid RSA encoding key"),
152 )
153 .expect("deterministic fixture key should produce valid JWT")
154 }
155
156 fn decoding_key(&self) -> DecodingKey {
157 DecodingKey::from_rsa_pem(self.keypair.public_key_spki_pem().as_bytes())
158 .expect("deterministic fixture key should produce valid RSA decoding key")
159 }
160}
161
162#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
164pub struct TestAuthContext {
165 pub sub: String,
166 pub iss: String,
167 pub aud: String,
168 pub kid: String,
169 pub exp: u64,
170}
171
172impl<S> FromRequestParts<S> for TestAuthContext
173where
174 S: Send + Sync,
175{
176 type Rejection = (StatusCode, &'static str);
177
178 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179 parts
180 .extensions
181 .get::<Self>()
182 .cloned()
183 .ok_or((StatusCode::UNAUTHORIZED, "missing auth context"))
184 }
185}
186
187#[derive(Clone)]
189pub struct MockJwtVerifierState {
190 signer: DeterministicJwksPhase,
191}
192
193impl MockJwtVerifierState {
194 pub fn new(signer: DeterministicJwksPhase) -> Self {
196 Self { signer }
197 }
198
199 pub fn jwks_json(&self) -> Value {
201 self.signer.jwks_json()
202 }
203
204 pub fn oidc_json(&self, base_url: impl AsRef<str>) -> Value {
206 let base = base_url.as_ref().trim_end_matches('/');
207 json!({
208 "issuer": self.signer.expectations().issuer,
209 "jwks_uri": format!("{base}{DEFAULT_JWKS_PATH}"),
210 "id_token_signing_alg_values_supported": ["RS256"],
211 "token_endpoint_auth_methods_supported": ["none"],
212 "response_types_supported": ["token"],
213 "subject_types_supported": ["public"],
214 })
215 }
216
217 pub fn issue_token(&self, claims: Value, ttl_seconds: u64) -> String {
219 self.signer.issue_token(claims, ttl_seconds)
220 }
221
222 pub fn expectations(&self) -> AuthExpectations {
224 self.signer.expectations().clone()
225 }
226}
227
228pub fn jwks_router(state: MockJwtVerifierState) -> Router {
230 Router::new()
231 .route(DEFAULT_JWKS_PATH, get(jwks_handler))
232 .with_state(state)
233}
234
235pub fn oidc_router(state: MockJwtVerifierState, base_url: impl Into<String>) -> Router {
237 let state = OidcState {
238 verifier: state,
239 base_url: base_url.into(),
240 };
241 Router::new()
242 .route(DEFAULT_OIDC_PATH, get(oidc_handler))
243 .with_state(state)
244}
245
246pub fn mock_jwt_verifier_layer(router: Router, state: MockJwtVerifierState) -> Router {
248 let state = Arc::new(state);
249 router.layer(axum::middleware::from_fn(move |request, next| {
250 let state = Arc::clone(&state);
251 async move { verify_bearer_token(state.as_ref().clone(), request, next).await }
252 }))
253}
254
255pub fn inject_auth_context_layer(router: Router, context: TestAuthContext) -> Router {
257 let context = Arc::new(context);
258 router.layer(axum::middleware::from_fn(move |request, next| {
259 let context = Arc::clone(&context);
260 async move { inject_auth_context(context.as_ref().clone(), request, next).await }
261 }))
262}
263
264#[derive(Clone)]
265struct OidcState {
266 verifier: MockJwtVerifierState,
267 base_url: String,
268}
269
270async fn jwks_handler(State(state): State<MockJwtVerifierState>) -> Json<Value> {
271 Json(state.jwks_json())
272}
273
274async fn oidc_handler(State(state): State<OidcState>) -> Json<Value> {
275 Json(state.verifier.oidc_json(&state.base_url))
276}
277
278async fn inject_auth_context(
279 context: TestAuthContext,
280 mut request: Request<axum::body::Body>,
281 next: Next,
282) -> Response {
283 request.extensions_mut().insert(context);
284 next.run(request).await
285}
286
287async fn verify_bearer_token(
288 state: MockJwtVerifierState,
289 mut request: Request<axum::body::Body>,
290 next: Next,
291) -> Response {
292 let bearer = match extract_bearer(request.headers()) {
293 Ok(token) => token,
294 Err((code, msg)) => return (code, msg).into_response(),
295 };
296
297 let header = match decode_header(bearer) {
298 Ok(header) => header,
299 Err(_) => return (StatusCode::UNAUTHORIZED, "invalid jwt header").into_response(),
300 };
301
302 let expected = state.expectations();
303 if header.kid.as_deref() != Some(expected.kid.as_str()) {
304 return (StatusCode::UNAUTHORIZED, "unexpected kid").into_response();
305 }
306
307 let mut validation = Validation::new(Algorithm::RS256);
308 validation.set_issuer(std::slice::from_ref(&expected.issuer));
309 validation.set_audience(std::slice::from_ref(&expected.audience));
310 validation.leeway = 0;
311
312 let token = match decode::<Value>(bearer, &state.signer.decoding_key(), &validation) {
313 Ok(token) => token,
314 Err(_) => return (StatusCode::UNAUTHORIZED, "token verification failed").into_response(),
315 };
316
317 let sub = token
318 .claims
319 .get("sub")
320 .and_then(Value::as_str)
321 .unwrap_or("unknown")
322 .to_owned();
323 let iss = token
324 .claims
325 .get("iss")
326 .and_then(Value::as_str)
327 .unwrap_or_default()
328 .to_owned();
329 let aud = token
330 .claims
331 .get("aud")
332 .and_then(Value::as_str)
333 .unwrap_or_default()
334 .to_owned();
335 let exp = token
336 .claims
337 .get("exp")
338 .and_then(Value::as_u64)
339 .unwrap_or_default();
340
341 request.extensions_mut().insert(TestAuthContext {
342 sub,
343 iss,
344 aud,
345 kid: expected.kid,
346 exp,
347 });
348
349 next.run(request).await
350}
351
352fn extract_bearer(headers: &axum::http::HeaderMap) -> Result<&str, (StatusCode, &'static str)> {
353 let header = headers
354 .get(AUTHORIZATION)
355 .and_then(|value| value.to_str().ok())
356 .ok_or((StatusCode::UNAUTHORIZED, "missing authorization header"))?;
357 let token = header
358 .strip_prefix("Bearer ")
359 .ok_or((StatusCode::UNAUTHORIZED, "invalid authorization scheme"))?;
360 if token.is_empty() {
361 return Err((StatusCode::UNAUTHORIZED, "empty bearer token"));
362 }
363 Ok(token)
364}
365
366fn current_unix_seconds() -> usize {
367 std::time::SystemTime::now()
368 .duration_since(std::time::UNIX_EPOCH)
369 .expect("current time should be >= unix epoch")
370 .as_secs() as usize
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use axum::body::Body;
377 use axum::http::Request;
378 use axum::response::IntoResponse;
379 use axum::routing::get;
380 use tower::ServiceExt;
381
382 fn phase(phase: RotationPhase) -> DeterministicJwksPhase {
383 let seed = Seed::from_env_value("uselesskey-axum-tests").expect("seed parse");
384 DeterministicJwksPhase::new(
385 seed,
386 "auth-suite",
387 phase,
388 "https://issuer.example.test",
389 "api://example-aud",
390 )
391 }
392
393 #[tokio::test]
394 async fn jwks_and_oidc_routes_respond() {
395 let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
396 let app = jwks_router(state.clone()).merge(oidc_router(state, "http://localhost:3000"));
397
398 let jwks_res = app
399 .clone()
400 .oneshot(
401 Request::builder()
402 .uri(DEFAULT_JWKS_PATH)
403 .body(Body::empty())
404 .unwrap(),
405 )
406 .await
407 .unwrap();
408 assert_eq!(jwks_res.status(), StatusCode::OK);
409
410 let oidc_res = app
411 .oneshot(
412 Request::builder()
413 .uri(DEFAULT_OIDC_PATH)
414 .body(Body::empty())
415 .unwrap(),
416 )
417 .await
418 .unwrap();
419 assert_eq!(oidc_res.status(), StatusCode::OK);
420 }
421
422 #[tokio::test]
423 async fn rotation_phase_produces_distinct_kids() {
424 let primary = phase(RotationPhase::Primary);
425 let next = phase(RotationPhase::Next);
426 assert_ne!(primary.expectations().kid, next.expectations().kid);
427 }
428
429 #[tokio::test]
430 async fn verifier_rejects_wrong_audience() {
431 let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
432 let token = state.issue_token(json!({"sub":"alice", "aud":"api://wrong-aud"}), 300);
433
434 let app = mock_jwt_verifier_layer(
435 Router::new().route(
436 "/me",
437 get(|auth: TestAuthContext| async move {
438 Json(json!({"sub": auth.sub})).into_response()
439 }),
440 ),
441 state,
442 );
443
444 let response = app
445 .oneshot(
446 Request::builder()
447 .uri("/me")
448 .header(AUTHORIZATION, format!("Bearer {token}"))
449 .body(Body::empty())
450 .unwrap(),
451 )
452 .await
453 .unwrap();
454
455 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
456 }
457
458 #[tokio::test]
459 async fn verifier_rejects_expired_token() {
460 let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
461 let now = current_unix_seconds() as u64;
462 let token = state.issue_token(
463 json!({"sub":"alice", "exp": now.saturating_sub(5), "iat": now.saturating_sub(10)}),
464 300,
465 );
466
467 let app = mock_jwt_verifier_layer(
468 Router::new().route("/me", get(|| async { StatusCode::OK })),
469 state,
470 );
471
472 let response = app
473 .oneshot(
474 Request::builder()
475 .uri("/me")
476 .header(AUTHORIZATION, format!("Bearer {token}"))
477 .body(Body::empty())
478 .unwrap(),
479 )
480 .await
481 .unwrap();
482
483 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
484 }
485
486 #[tokio::test]
487 async fn deterministic_auth_context_injection_works() {
488 let app = inject_auth_context_layer(
489 Router::new().route(
490 "/me",
491 get(|auth: TestAuthContext| async move {
492 Json(json!({"sub": auth.sub, "kid": auth.kid})).into_response()
493 }),
494 ),
495 TestAuthContext {
496 sub: "test-user".into(),
497 iss: "iss".into(),
498 aud: "aud".into(),
499 kid: "kid-1".into(),
500 exp: 42,
501 },
502 );
503
504 let response = app
505 .oneshot(Request::builder().uri("/me").body(Body::empty()).unwrap())
506 .await
507 .unwrap();
508
509 assert_eq!(response.status(), StatusCode::OK);
510 }
511}