1use crate::claims::AuthContext;
2use crate::error::VerifyError;
3use crate::keys::VerifyingKey;
4use crate::token::TokenVerifier;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::RwLock;
9
10#[derive(Clone)]
12pub struct RotationKey {
13 pub key: VerifyingKey,
15 pub key_id: String,
17 pub added_at: Instant,
19 pub expires_at: Option<Instant>,
21 pub is_primary: bool,
23}
24
25impl RotationKey {
26 pub fn primary(key: VerifyingKey, key_id: impl Into<String>) -> Self {
28 Self {
29 key,
30 key_id: key_id.into(),
31 added_at: Instant::now(),
32 expires_at: None,
33 is_primary: true,
34 }
35 }
36
37 pub fn secondary(key: VerifyingKey, key_id: impl Into<String>, grace_period: Duration) -> Self {
39 Self {
40 key,
41 key_id: key_id.into(),
42 added_at: Instant::now(),
43 expires_at: Some(Instant::now() + grace_period),
44 is_primary: false,
45 }
46 }
47
48 pub fn is_expired(&self) -> bool {
50 self.expires_at
51 .map(|exp| Instant::now() > exp)
52 .unwrap_or(false)
53 }
54}
55
56pub struct MultiKeyVerifier {
85 keys: Arc<RwLock<HashMap<String, RotationKey>>>,
86 issuer: String,
87 audience: String,
88 require_origin: bool,
89 cleanup_interval: Duration,
90 last_cleanup: Arc<RwLock<Instant>>,
91}
92
93impl MultiKeyVerifier {
94 pub fn new(
96 keys: Vec<RotationKey>,
97 issuer: impl Into<String>,
98 audience: impl Into<String>,
99 ) -> Self {
100 let key_map: HashMap<String, RotationKey> =
101 keys.into_iter().map(|k| (k.key_id.clone(), k)).collect();
102
103 Self {
104 keys: Arc::new(RwLock::new(key_map)),
105 issuer: issuer.into(),
106 audience: audience.into(),
107 require_origin: false,
108 cleanup_interval: Duration::from_secs(3600), last_cleanup: Arc::new(RwLock::new(Instant::now())),
110 }
111 }
112
113 pub fn from_single_key(
115 key: VerifyingKey,
116 key_id: impl Into<String>,
117 issuer: impl Into<String>,
118 audience: impl Into<String>,
119 ) -> Self {
120 Self::new(vec![RotationKey::primary(key, key_id)], issuer, audience)
121 }
122
123 pub fn with_origin_validation(mut self) -> Self {
125 self.require_origin = true;
126 self
127 }
128
129 pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
131 self.cleanup_interval = interval;
132 self
133 }
134
135 pub async fn add_key(&self, key: RotationKey) {
137 let mut keys = self.keys.write().await;
138
139 if key.is_primary {
141 for (_, existing) in keys.iter_mut() {
142 if existing.is_primary {
143 existing.is_primary = false;
144 existing.expires_at = Some(Instant::now() + Duration::from_secs(86400));
146 }
148 }
149 }
150
151 keys.insert(key.key_id.clone(), key);
152 }
153
154 pub async fn remove_key(&self, key_id: &str) {
156 let mut keys = self.keys.write().await;
157 keys.remove(key_id);
158 }
159
160 pub async fn key_ids(&self) -> Vec<String> {
162 let keys = self.keys.read().await;
163 keys.keys().cloned().collect()
164 }
165
166 pub async fn primary_key_id(&self) -> Option<String> {
168 let keys = self.keys.read().await;
169 keys.values()
170 .find(|k| k.is_primary)
171 .map(|k| k.key_id.clone())
172 }
173
174 async fn cleanup_expired_keys(&self) {
176 let should_cleanup = {
177 let last = self.last_cleanup.read().await;
178 last.elapsed() >= self.cleanup_interval
179 };
180
181 if !should_cleanup {
182 return;
183 }
184
185 let mut keys = self.keys.write().await;
186 let expired: Vec<String> = keys
187 .iter()
188 .filter(|(_, k)| k.is_expired())
189 .map(|(id, _)| id.clone())
190 .collect();
191
192 for key_id in expired {
193 keys.remove(&key_id);
194 }
195
196 let mut last = self.last_cleanup.write().await;
198 *last = Instant::now();
199 }
200
201 pub async fn verify(
203 &self,
204 token: &str,
205 expected_origin: Option<&str>,
206 expected_client_ip: Option<&str>,
207 ) -> Result<AuthContext, VerifyError> {
208 self.cleanup_expired_keys().await;
210
211 let keys = self.keys.read().await;
212
213 if keys.is_empty() {
214 return Err(VerifyError::KeyNotFound("no keys configured".to_string()));
215 }
216
217 let mut last_error = None;
218
219 let mut key_order: Vec<&RotationKey> = keys.values().collect();
221 key_order.sort_by_key(|k| !k.is_primary); for key_entry in key_order {
224 if key_entry.is_expired() {
225 continue;
226 }
227
228 let verifier = if self.require_origin {
229 TokenVerifier::new(
230 key_entry.key.clone(),
231 self.issuer.clone(),
232 self.audience.clone(),
233 )
234 .with_origin_validation()
235 } else {
236 TokenVerifier::new(
237 key_entry.key.clone(),
238 self.issuer.clone(),
239 self.audience.clone(),
240 )
241 };
242
243 match verifier.verify(token, expected_origin, expected_client_ip) {
244 Ok(ctx) => {
245 return Ok(ctx);
246 }
247 Err(VerifyError::InvalidSignature) => {
248 last_error = Some(VerifyError::InvalidSignature);
250 continue;
251 }
252 Err(e) => {
253 return Err(e);
255 }
256 }
257 }
258
259 Err(last_error.unwrap_or(VerifyError::InvalidSignature))
261 }
262
263 pub async fn verify_fast(
265 &self,
266 token: &str,
267 expected_origin: Option<&str>,
268 expected_client_ip: Option<&str>,
269 ) -> Result<AuthContext, VerifyError> {
270 let keys = self.keys.read().await;
271
272 if keys.is_empty() {
273 return Err(VerifyError::KeyNotFound("no keys configured".to_string()));
274 }
275
276 let mut last_error = None;
277
278 let mut key_order: Vec<&RotationKey> = keys.values().collect();
280 key_order.sort_by_key(|k| !k.is_primary);
281
282 for key_entry in key_order {
283 if key_entry.is_expired() {
284 continue;
285 }
286
287 let verifier = if self.require_origin {
288 TokenVerifier::new(
289 key_entry.key.clone(),
290 self.issuer.clone(),
291 self.audience.clone(),
292 )
293 .with_origin_validation()
294 } else {
295 TokenVerifier::new(
296 key_entry.key.clone(),
297 self.issuer.clone(),
298 self.audience.clone(),
299 )
300 };
301
302 match verifier.verify(token, expected_origin, expected_client_ip) {
303 Ok(ctx) => return Ok(ctx),
304 Err(VerifyError::InvalidSignature) => {
305 last_error = Some(VerifyError::InvalidSignature);
306 continue;
307 }
308 Err(e) => return Err(e),
309 }
310 }
311
312 Err(last_error.unwrap_or(VerifyError::InvalidSignature))
313 }
314}
315
316pub struct MultiKeyVerifierBuilder {
318 keys: Vec<RotationKey>,
319 issuer: String,
320 audience: String,
321 require_origin: bool,
322 cleanup_interval: Duration,
323}
324
325impl MultiKeyVerifierBuilder {
326 pub fn new(issuer: impl Into<String>, audience: impl Into<String>) -> Self {
328 Self {
329 keys: Vec::new(),
330 issuer: issuer.into(),
331 audience: audience.into(),
332 require_origin: false,
333 cleanup_interval: Duration::from_secs(3600),
334 }
335 }
336
337 pub fn with_primary_key(mut self, key: VerifyingKey, key_id: impl Into<String>) -> Self {
339 self.keys.push(RotationKey::primary(key, key_id));
340 self
341 }
342
343 pub fn with_secondary_key(
345 mut self,
346 key: VerifyingKey,
347 key_id: impl Into<String>,
348 grace_period: Duration,
349 ) -> Self {
350 self.keys
351 .push(RotationKey::secondary(key, key_id, grace_period));
352 self
353 }
354
355 pub fn with_origin_validation(mut self) -> Self {
357 self.require_origin = true;
358 self
359 }
360
361 pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
363 self.cleanup_interval = interval;
364 self
365 }
366
367 pub fn build(self) -> MultiKeyVerifier {
369 let mut verifier = MultiKeyVerifier::new(self.keys, self.issuer, self.audience);
370 if self.require_origin {
371 verifier = verifier.with_origin_validation();
372 }
373 verifier.with_cleanup_interval(self.cleanup_interval)
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::claims::{KeyClass, SessionClaims};
381 use crate::keys::SigningKey;
382 use crate::token::TokenSigner;
383
384 #[tokio::test]
385 async fn test_multi_key_verifier_single_key() {
386 let signing_key = SigningKey::generate();
387 let verifying_key = signing_key.verifying_key();
388
389 let signer = TokenSigner::new(signing_key, "test-issuer");
390 let verifier = MultiKeyVerifier::from_single_key(
391 verifying_key,
392 "key-1",
393 "test-issuer",
394 "test-audience",
395 );
396
397 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
398 .with_scope("read")
399 .with_metering_key("meter-123")
400 .with_key_class(KeyClass::Publishable)
401 .build();
402
403 let token = signer.sign(claims).unwrap();
404 let context = verifier.verify(&token, None, None).await.unwrap();
405
406 assert_eq!(context.subject, "test-subject");
407 assert_eq!(verifier.primary_key_id().await, Some("key-1".to_string()));
408 }
409
410 #[tokio::test]
411 async fn test_key_rotation() {
412 let old_signing_key = SigningKey::generate();
414 let old_verifying_key = old_signing_key.verifying_key();
415 let old_signer = TokenSigner::new(old_signing_key, "test-issuer");
416
417 let new_signing_key = SigningKey::generate();
419 let new_verifying_key = new_signing_key.verifying_key();
420 let new_signer = TokenSigner::new(new_signing_key, "test-issuer");
421
422 let old_key = RotationKey::primary(old_verifying_key.clone(), "key-old");
424 let verifier = MultiKeyVerifier::new(vec![old_key], "test-issuer", "test-audience");
425
426 let old_claims = SessionClaims::builder("test-issuer", "subject-1", "test-audience")
428 .with_scope("read")
429 .with_metering_key("meter-1")
430 .with_key_class(KeyClass::Publishable)
431 .build();
432 let old_token = old_signer.sign(old_claims).unwrap();
433
434 let ctx = verifier.verify(&old_token, None, None).await.unwrap();
436 assert_eq!(ctx.subject, "subject-1");
437
438 let new_key = RotationKey::primary(new_verifying_key, "key-new");
440 verifier.add_key(new_key).await;
441
442 let ctx = verifier.verify(&old_token, None, None).await.unwrap();
444 assert_eq!(ctx.subject, "subject-1");
445
446 let new_claims = SessionClaims::builder("test-issuer", "subject-2", "test-audience")
448 .with_scope("read")
449 .with_metering_key("meter-2")
450 .with_key_class(KeyClass::Publishable)
451 .build();
452 let new_token = new_signer.sign(new_claims).unwrap();
453
454 let ctx = verifier.verify(&new_token, None, None).await.unwrap();
455 assert_eq!(ctx.subject, "subject-2");
456
457 assert_eq!(verifier.primary_key_id().await, Some("key-new".to_string()));
459
460 let key_ids = verifier.key_ids().await;
462 assert!(key_ids.contains(&"key-old".to_string()));
463 assert!(key_ids.contains(&"key-new".to_string()));
464 }
465
466 #[tokio::test]
467 async fn test_verifier_builder() {
468 let signing_key = SigningKey::generate();
469 let verifying_key = signing_key.verifying_key();
470
471 let verifier = MultiKeyVerifierBuilder::new("test-issuer", "test-audience")
472 .with_primary_key(verifying_key, "key-1")
473 .with_origin_validation()
474 .build();
475
476 let signer = TokenSigner::new(signing_key, "test-issuer");
477 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
478 .with_scope("read")
479 .with_origin("https://trusted.example.com")
480 .with_key_class(KeyClass::Secret)
481 .build();
482
483 let token = signer.sign(claims).unwrap();
484 let ctx = verifier
485 .verify(&token, Some("https://trusted.example.com"), None)
486 .await
487 .unwrap();
488 assert_eq!(ctx.subject, "test-subject");
489 }
490
491 #[tokio::test]
492 async fn test_invalid_signature_with_multiple_keys() {
493 let key1_signing = SigningKey::generate();
495 let key1_verifying = key1_signing.verifying_key();
496
497 let key2_signing = SigningKey::generate();
498 let _key2_verifying = key2_signing.verifying_key();
499
500 let signer = TokenSigner::new(key1_signing, "test-issuer");
501
502 let verifier = MultiKeyVerifier::from_single_key(
504 key2_signing.verifying_key(),
505 "key-2",
506 "test-issuer",
507 "test-audience",
508 );
509
510 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
511 .with_scope("read")
512 .with_key_class(KeyClass::Publishable)
513 .build();
514
515 let token = signer.sign(claims).unwrap();
516
517 let result = verifier.verify(&token, None, None).await;
519 assert!(matches!(result, Err(VerifyError::InvalidSignature)));
520 }
521
522 #[tokio::test]
523 async fn test_jwks_key_rotation_grace_period() {
524 use crate::token::{Jwk, Jwks};
525 use base64::Engine;
526
527 let old_signing_key = SigningKey::generate();
529 let old_verifying_key = old_signing_key.verifying_key();
530 let old_kid = old_verifying_key.key_id();
531 let old_signer = TokenSigner::new(old_signing_key, "test-issuer");
532
533 let new_signing_key = SigningKey::generate();
535 let new_verifying_key = new_signing_key.verifying_key();
536 let new_kid = new_verifying_key.key_id();
537 let new_signer = TokenSigner::new(new_signing_key, "test-issuer");
538
539 let old_key_b64 =
541 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(old_verifying_key.to_bytes());
542 let new_key_b64 =
543 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(new_verifying_key.to_bytes());
544
545 let jwks = Jwks {
546 keys: vec![
547 Jwk {
548 kty: "OKP".to_string(),
549 use_: Some("sig".to_string()),
550 kid: old_kid,
551 x: old_key_b64,
552 },
553 Jwk {
554 kty: "OKP".to_string(),
555 use_: Some("sig".to_string()),
556 kid: new_kid,
557 x: new_key_b64,
558 },
559 ],
560 };
561
562 let verifier =
564 crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience");
565
566 let old_claims = SessionClaims::builder("test-issuer", "subject-old", "test-audience")
568 .with_scope("read")
569 .with_key_class(KeyClass::Secret)
570 .build();
571 let old_token = old_signer.sign(old_claims).unwrap();
572
573 let ctx = verifier.verify(&old_token, None, None).await.unwrap();
575 assert_eq!(ctx.subject, "subject-old");
576
577 let new_claims = SessionClaims::builder("test-issuer", "subject-new", "test-audience")
579 .with_scope("read")
580 .with_key_class(KeyClass::Secret)
581 .build();
582 let new_token = new_signer.sign(new_claims).unwrap();
583
584 let ctx = verifier.verify(&new_token, None, None).await.unwrap();
586 assert_eq!(ctx.subject, "subject-new");
587 }
588
589 #[tokio::test]
590 async fn test_jwks_key_not_found() {
591 use crate::token::{Jwk, Jwks};
592 use base64::Engine;
593
594 let signing_key = SigningKey::generate();
596 let _verifying_key = signing_key.verifying_key();
597 let signer = TokenSigner::new(signing_key, "test-issuer");
598
599 let different_key = SigningKey::generate();
601 let different_verifying_key = different_key.verifying_key();
602 let different_key_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
603 .encode(different_verifying_key.to_bytes());
604
605 let jwks = Jwks {
606 keys: vec![Jwk {
607 kty: "OKP".to_string(),
608 use_: Some("sig".to_string()),
609 kid: "different-key".to_string(),
610 x: different_key_b64,
611 }],
612 };
613
614 let verifier =
615 crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience");
616
617 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
618 .with_scope("read")
619 .with_key_class(KeyClass::Secret)
620 .build();
621 let token = signer.sign(claims).unwrap();
622
623 let result = verifier.verify(&token, None, None).await;
625 assert!(matches!(result, Err(VerifyError::KeyNotFound(_))));
626 }
627
628 #[tokio::test]
629 async fn test_jwks_with_origin_validation() {
630 use crate::token::{Jwk, Jwks};
631 use base64::Engine;
632
633 let signing_key = SigningKey::generate();
634 let verifying_key = signing_key.verifying_key();
635 let kid = verifying_key.key_id();
636 let signer = TokenSigner::new(signing_key, "test-issuer");
637
638 let key_b64 =
639 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(verifying_key.to_bytes());
640
641 let jwks = Jwks {
642 keys: vec![Jwk {
643 kty: "OKP".to_string(),
644 use_: Some("sig".to_string()),
645 kid,
646 x: key_b64,
647 }],
648 };
649
650 let verifier =
652 crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience")
653 .with_origin_validation();
654
655 let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
657 .with_scope("read")
658 .with_key_class(KeyClass::Secret)
659 .with_origin("https://trusted.example.com")
660 .build();
661 let token = signer.sign(claims).unwrap();
662
663 let ctx = verifier
665 .verify(&token, Some("https://trusted.example.com"), None)
666 .await
667 .unwrap();
668 assert_eq!(ctx.subject, "test-subject");
669
670 let result = verifier
672 .verify(&token, Some("https://evil.example.com"), None)
673 .await;
674 assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
675 }
676}