Skip to main content

hyperstack_auth/
multi_key.rs

1use crate::claims::AuthContext;
2use crate::error::VerifyError;
3use crate::keys::VerifyingKey;
4use crate::token::TokenVerifier;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::RwLock;
9
10/// A key with its metadata for rotation
11#[derive(Clone)]
12pub struct RotationKey {
13    /// The verifying key
14    pub key: VerifyingKey,
15    /// Key ID for JWKS compatibility
16    pub key_id: String,
17    /// When this key was added
18    pub added_at: Instant,
19    /// Optional: when this key should be removed (for grace period rotation)
20    pub expires_at: Option<Instant>,
21    /// Whether this is the primary (current) key
22    pub is_primary: bool,
23}
24
25impl RotationKey {
26    /// Create a new primary key
27    pub fn primary(key: VerifyingKey, key_id: impl Into<String>) -> Self {
28        Self {
29            key,
30            key_id: key_id.into(),
31            added_at: Instant::now(),
32            expires_at: None,
33            is_primary: true,
34        }
35    }
36
37    /// Create a secondary (rotating out) key with expiration
38    pub fn secondary(key: VerifyingKey, key_id: impl Into<String>, grace_period: Duration) -> Self {
39        Self {
40            key,
41            key_id: key_id.into(),
42            added_at: Instant::now(),
43            expires_at: Some(Instant::now() + grace_period),
44            is_primary: false,
45        }
46    }
47
48    /// Check if this key has expired
49    pub fn is_expired(&self) -> bool {
50        self.expires_at
51            .map(|exp| Instant::now() > exp)
52            .unwrap_or(false)
53    }
54}
55
56/// Multi-key verifier supporting graceful key rotation
57///
58/// This verifier maintains multiple keys and attempts verification with each
59/// until one succeeds. This allows zero-downtime key rotation:
60///
61/// 1. Generate new key pair
62/// 2. Add new key as primary, mark old key as secondary with grace period
63/// 3. Update JWKS to include both keys
64/// 4. After grace period, remove old key
65///
66/// # Example
67/// ```rust
68/// use hyperstack_auth::{MultiKeyVerifier, RotationKey, SigningKey};
69/// use std::time::Duration;
70///
71/// // Generate key pairs
72/// let old_signing_key = SigningKey::generate();
73/// let old_verifying_key = old_signing_key.verifying_key();
74/// let new_signing_key = SigningKey::generate();
75/// let new_verifying_key = new_signing_key.verifying_key();
76///
77/// // Create rotation keys
78/// let old_key = RotationKey::secondary(old_verifying_key, "key-1", Duration::from_secs(86400));
79/// let new_key = RotationKey::primary(new_verifying_key, "key-2");
80///
81/// let verifier = MultiKeyVerifier::new(vec![old_key, new_key], "issuer", "audience")
82///     .with_cleanup_interval(Duration::from_secs(3600));
83/// ```
84pub struct MultiKeyVerifier {
85    keys: Arc<RwLock<HashMap<String, RotationKey>>>,
86    issuer: String,
87    audience: String,
88    require_origin: bool,
89    cleanup_interval: Duration,
90    last_cleanup: Arc<RwLock<Instant>>,
91}
92
93impl MultiKeyVerifier {
94    /// Create a new multi-key verifier
95    pub fn new(
96        keys: Vec<RotationKey>,
97        issuer: impl Into<String>,
98        audience: impl Into<String>,
99    ) -> Self {
100        let key_map: HashMap<String, RotationKey> =
101            keys.into_iter().map(|k| (k.key_id.clone(), k)).collect();
102
103        Self {
104            keys: Arc::new(RwLock::new(key_map)),
105            issuer: issuer.into(),
106            audience: audience.into(),
107            require_origin: false,
108            cleanup_interval: Duration::from_secs(3600), // 1 hour default
109            last_cleanup: Arc::new(RwLock::new(Instant::now())),
110        }
111    }
112
113    /// Create from a single key (for backward compatibility)
114    pub fn from_single_key(
115        key: VerifyingKey,
116        key_id: impl Into<String>,
117        issuer: impl Into<String>,
118        audience: impl Into<String>,
119    ) -> Self {
120        Self::new(vec![RotationKey::primary(key, key_id)], issuer, audience)
121    }
122
123    /// Require origin validation
124    pub fn with_origin_validation(mut self) -> Self {
125        self.require_origin = true;
126        self
127    }
128
129    /// Set cleanup interval for expired keys
130    pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
131        self.cleanup_interval = interval;
132        self
133    }
134
135    /// Add a new key to the verifier
136    pub async fn add_key(&self, key: RotationKey) {
137        let mut keys = self.keys.write().await;
138
139        // If adding a primary key, demote existing primary to secondary
140        if key.is_primary {
141            for (_, existing) in keys.iter_mut() {
142                if existing.is_primary {
143                    existing.is_primary = false;
144                    // Set grace period for old primary
145                    existing.expires_at = Some(Instant::now() + Duration::from_secs(86400));
146                    // 24 hours
147                }
148            }
149        }
150
151        keys.insert(key.key_id.clone(), key);
152    }
153
154    /// Remove a key by ID
155    pub async fn remove_key(&self, key_id: &str) {
156        let mut keys = self.keys.write().await;
157        keys.remove(key_id);
158    }
159
160    /// Get all key IDs
161    pub async fn key_ids(&self) -> Vec<String> {
162        let keys = self.keys.read().await;
163        keys.keys().cloned().collect()
164    }
165
166    /// Get primary key ID
167    pub async fn primary_key_id(&self) -> Option<String> {
168        let keys = self.keys.read().await;
169        keys.values()
170            .find(|k| k.is_primary)
171            .map(|k| k.key_id.clone())
172    }
173
174    /// Clean up expired keys
175    async fn cleanup_expired_keys(&self) {
176        let should_cleanup = {
177            let last = self.last_cleanup.read().await;
178            last.elapsed() >= self.cleanup_interval
179        };
180
181        if !should_cleanup {
182            return;
183        }
184
185        let mut keys = self.keys.write().await;
186        let expired: Vec<String> = keys
187            .iter()
188            .filter(|(_, k)| k.is_expired())
189            .map(|(id, _)| id.clone())
190            .collect();
191
192        for key_id in expired {
193            keys.remove(&key_id);
194        }
195
196        // Update last cleanup time
197        let mut last = self.last_cleanup.write().await;
198        *last = Instant::now();
199    }
200
201    /// Verify a token against all keys
202    pub async fn verify(
203        &self,
204        token: &str,
205        expected_origin: Option<&str>,
206        expected_client_ip: Option<&str>,
207    ) -> Result<AuthContext, VerifyError> {
208        // Clean up expired keys periodically
209        self.cleanup_expired_keys().await;
210
211        let keys = self.keys.read().await;
212
213        if keys.is_empty() {
214            return Err(VerifyError::KeyNotFound("no keys configured".to_string()));
215        }
216
217        let mut last_error = None;
218
219        // Try primary key first, then secondary keys
220        let mut key_order: Vec<&RotationKey> = keys.values().collect();
221        key_order.sort_by_key(|k| !k.is_primary); // Primary first
222
223        for key_entry in key_order {
224            if key_entry.is_expired() {
225                continue;
226            }
227
228            let verifier = if self.require_origin {
229                TokenVerifier::new(
230                    key_entry.key.clone(),
231                    self.issuer.clone(),
232                    self.audience.clone(),
233                )
234                .with_origin_validation()
235            } else {
236                TokenVerifier::new(
237                    key_entry.key.clone(),
238                    self.issuer.clone(),
239                    self.audience.clone(),
240                )
241            };
242
243            match verifier.verify(token, expected_origin, expected_client_ip) {
244                Ok(ctx) => {
245                    return Ok(ctx);
246                }
247                Err(VerifyError::InvalidSignature) => {
248                    // Wrong key, try next
249                    last_error = Some(VerifyError::InvalidSignature);
250                    continue;
251                }
252                Err(e) => {
253                    // Other errors (expired, invalid format, etc.) - don't try other keys
254                    return Err(e);
255                }
256            }
257        }
258
259        // All keys failed
260        Err(last_error.unwrap_or(VerifyError::InvalidSignature))
261    }
262
263    /// Verify without cleaning up (for high-throughput scenarios)
264    pub async fn verify_fast(
265        &self,
266        token: &str,
267        expected_origin: Option<&str>,
268        expected_client_ip: Option<&str>,
269    ) -> Result<AuthContext, VerifyError> {
270        let keys = self.keys.read().await;
271
272        if keys.is_empty() {
273            return Err(VerifyError::KeyNotFound("no keys configured".to_string()));
274        }
275
276        let mut last_error = None;
277
278        // Try primary key first, then secondary keys
279        let mut key_order: Vec<&RotationKey> = keys.values().collect();
280        key_order.sort_by_key(|k| !k.is_primary);
281
282        for key_entry in key_order {
283            if key_entry.is_expired() {
284                continue;
285            }
286
287            let verifier = if self.require_origin {
288                TokenVerifier::new(
289                    key_entry.key.clone(),
290                    self.issuer.clone(),
291                    self.audience.clone(),
292                )
293                .with_origin_validation()
294            } else {
295                TokenVerifier::new(
296                    key_entry.key.clone(),
297                    self.issuer.clone(),
298                    self.audience.clone(),
299                )
300            };
301
302            match verifier.verify(token, expected_origin, expected_client_ip) {
303                Ok(ctx) => return Ok(ctx),
304                Err(VerifyError::InvalidSignature) => {
305                    last_error = Some(VerifyError::InvalidSignature);
306                    continue;
307                }
308                Err(e) => return Err(e),
309            }
310        }
311
312        Err(last_error.unwrap_or(VerifyError::InvalidSignature))
313    }
314}
315
316/// Builder for constructing a MultiKeyVerifier with rotation support
317pub struct MultiKeyVerifierBuilder {
318    keys: Vec<RotationKey>,
319    issuer: String,
320    audience: String,
321    require_origin: bool,
322    cleanup_interval: Duration,
323}
324
325impl MultiKeyVerifierBuilder {
326    /// Create a new builder
327    pub fn new(issuer: impl Into<String>, audience: impl Into<String>) -> Self {
328        Self {
329            keys: Vec::new(),
330            issuer: issuer.into(),
331            audience: audience.into(),
332            require_origin: false,
333            cleanup_interval: Duration::from_secs(3600),
334        }
335    }
336
337    /// Add a primary key
338    pub fn with_primary_key(mut self, key: VerifyingKey, key_id: impl Into<String>) -> Self {
339        self.keys.push(RotationKey::primary(key, key_id));
340        self
341    }
342
343    /// Add a secondary key with grace period
344    pub fn with_secondary_key(
345        mut self,
346        key: VerifyingKey,
347        key_id: impl Into<String>,
348        grace_period: Duration,
349    ) -> Self {
350        self.keys
351            .push(RotationKey::secondary(key, key_id, grace_period));
352        self
353    }
354
355    /// Require origin validation
356    pub fn with_origin_validation(mut self) -> Self {
357        self.require_origin = true;
358        self
359    }
360
361    /// Set cleanup interval
362    pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
363        self.cleanup_interval = interval;
364        self
365    }
366
367    /// Build the verifier
368    pub fn build(self) -> MultiKeyVerifier {
369        let mut verifier = MultiKeyVerifier::new(self.keys, self.issuer, self.audience);
370        if self.require_origin {
371            verifier = verifier.with_origin_validation();
372        }
373        verifier.with_cleanup_interval(self.cleanup_interval)
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::claims::{KeyClass, SessionClaims};
381    use crate::keys::SigningKey;
382    use crate::token::TokenSigner;
383
384    #[tokio::test]
385    async fn test_multi_key_verifier_single_key() {
386        let signing_key = SigningKey::generate();
387        let verifying_key = signing_key.verifying_key();
388
389        let signer = TokenSigner::new(signing_key, "test-issuer");
390        let verifier = MultiKeyVerifier::from_single_key(
391            verifying_key,
392            "key-1",
393            "test-issuer",
394            "test-audience",
395        );
396
397        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
398            .with_scope("read")
399            .with_metering_key("meter-123")
400            .with_key_class(KeyClass::Publishable)
401            .build();
402
403        let token = signer.sign(claims).unwrap();
404        let context = verifier.verify(&token, None, None).await.unwrap();
405
406        assert_eq!(context.subject, "test-subject");
407        assert_eq!(verifier.primary_key_id().await, Some("key-1".to_string()));
408    }
409
410    #[tokio::test]
411    async fn test_key_rotation() {
412        // Create old key pair
413        let old_signing_key = SigningKey::generate();
414        let old_verifying_key = old_signing_key.verifying_key();
415        let old_signer = TokenSigner::new(old_signing_key, "test-issuer");
416
417        // Create new key pair
418        let new_signing_key = SigningKey::generate();
419        let new_verifying_key = new_signing_key.verifying_key();
420        let new_signer = TokenSigner::new(new_signing_key, "test-issuer");
421
422        // Start with old key as primary
423        let old_key = RotationKey::primary(old_verifying_key.clone(), "key-old");
424        let verifier = MultiKeyVerifier::new(vec![old_key], "test-issuer", "test-audience");
425
426        // Sign token with old key
427        let old_claims = SessionClaims::builder("test-issuer", "subject-1", "test-audience")
428            .with_scope("read")
429            .with_metering_key("meter-1")
430            .with_key_class(KeyClass::Publishable)
431            .build();
432        let old_token = old_signer.sign(old_claims).unwrap();
433
434        // Verify old token works
435        let ctx = verifier.verify(&old_token, None, None).await.unwrap();
436        assert_eq!(ctx.subject, "subject-1");
437
438        // Rotate: add new key as primary (old key becomes secondary)
439        let new_key = RotationKey::primary(new_verifying_key, "key-new");
440        verifier.add_key(new_key).await;
441
442        // Verify old token still works (grace period)
443        let ctx = verifier.verify(&old_token, None, None).await.unwrap();
444        assert_eq!(ctx.subject, "subject-1");
445
446        // Sign and verify new token
447        let new_claims = SessionClaims::builder("test-issuer", "subject-2", "test-audience")
448            .with_scope("read")
449            .with_metering_key("meter-2")
450            .with_key_class(KeyClass::Publishable)
451            .build();
452        let new_token = new_signer.sign(new_claims).unwrap();
453
454        let ctx = verifier.verify(&new_token, None, None).await.unwrap();
455        assert_eq!(ctx.subject, "subject-2");
456
457        // Check that new key is now primary
458        assert_eq!(verifier.primary_key_id().await, Some("key-new".to_string()));
459
460        // Both keys should be present
461        let key_ids = verifier.key_ids().await;
462        assert!(key_ids.contains(&"key-old".to_string()));
463        assert!(key_ids.contains(&"key-new".to_string()));
464    }
465
466    #[tokio::test]
467    async fn test_verifier_builder() {
468        let signing_key = SigningKey::generate();
469        let verifying_key = signing_key.verifying_key();
470
471        let verifier = MultiKeyVerifierBuilder::new("test-issuer", "test-audience")
472            .with_primary_key(verifying_key, "key-1")
473            .with_origin_validation()
474            .build();
475
476        let signer = TokenSigner::new(signing_key, "test-issuer");
477        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
478            .with_scope("read")
479            .with_origin("https://trusted.example.com")
480            .with_key_class(KeyClass::Secret)
481            .build();
482
483        let token = signer.sign(claims).unwrap();
484        let ctx = verifier
485            .verify(&token, Some("https://trusted.example.com"), None)
486            .await
487            .unwrap();
488        assert_eq!(ctx.subject, "test-subject");
489    }
490
491    #[tokio::test]
492    async fn test_invalid_signature_with_multiple_keys() {
493        // Create two different key pairs
494        let key1_signing = SigningKey::generate();
495        let key1_verifying = key1_signing.verifying_key();
496
497        let key2_signing = SigningKey::generate();
498        let _key2_verifying = key2_signing.verifying_key();
499
500        let signer = TokenSigner::new(key1_signing, "test-issuer");
501
502        // Create verifier with only key2
503        let verifier = MultiKeyVerifier::from_single_key(
504            key2_signing.verifying_key(),
505            "key-2",
506            "test-issuer",
507            "test-audience",
508        );
509
510        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
511            .with_scope("read")
512            .with_key_class(KeyClass::Publishable)
513            .build();
514
515        let token = signer.sign(claims).unwrap();
516
517        // Should fail because token was signed with key1, verifier only has key2
518        let result = verifier.verify(&token, None, None).await;
519        assert!(matches!(result, Err(VerifyError::InvalidSignature)));
520    }
521
522    #[tokio::test]
523    async fn test_jwks_key_rotation_grace_period() {
524        use crate::token::{Jwk, Jwks};
525        use base64::Engine;
526
527        // Create old key pair with specific key ID
528        let old_signing_key = SigningKey::generate();
529        let old_verifying_key = old_signing_key.verifying_key();
530        let old_kid = old_verifying_key.key_id();
531        let old_signer = TokenSigner::new(old_signing_key, "test-issuer");
532
533        // Create new key pair with specific key ID
534        let new_signing_key = SigningKey::generate();
535        let new_verifying_key = new_signing_key.verifying_key();
536        let new_kid = new_verifying_key.key_id();
537        let new_signer = TokenSigner::new(new_signing_key, "test-issuer");
538
539        // Create JWKS with both keys using their actual key IDs
540        let old_key_b64 =
541            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(old_verifying_key.to_bytes());
542        let new_key_b64 =
543            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(new_verifying_key.to_bytes());
544
545        let jwks = Jwks {
546            keys: vec![
547                Jwk {
548                    kty: "OKP".to_string(),
549                    use_: Some("sig".to_string()),
550                    kid: old_kid,
551                    x: old_key_b64,
552                },
553                Jwk {
554                    kty: "OKP".to_string(),
555                    use_: Some("sig".to_string()),
556                    kid: new_kid,
557                    x: new_key_b64,
558                },
559            ],
560        };
561
562        // Create verifier from JWKS
563        let verifier =
564            crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience");
565
566        // Sign and verify token with old key
567        let old_claims = SessionClaims::builder("test-issuer", "subject-old", "test-audience")
568            .with_scope("read")
569            .with_key_class(KeyClass::Secret)
570            .build();
571        let old_token = old_signer.sign(old_claims).unwrap();
572
573        // Old token should still verify during rotation
574        let ctx = verifier.verify(&old_token, None, None).await.unwrap();
575        assert_eq!(ctx.subject, "subject-old");
576
577        // Sign and verify token with new key
578        let new_claims = SessionClaims::builder("test-issuer", "subject-new", "test-audience")
579            .with_scope("read")
580            .with_key_class(KeyClass::Secret)
581            .build();
582        let new_token = new_signer.sign(new_claims).unwrap();
583
584        // New token should also verify
585        let ctx = verifier.verify(&new_token, None, None).await.unwrap();
586        assert_eq!(ctx.subject, "subject-new");
587    }
588
589    #[tokio::test]
590    async fn test_jwks_key_not_found() {
591        use crate::token::{Jwk, Jwks};
592        use base64::Engine;
593
594        // Create a key pair
595        let signing_key = SigningKey::generate();
596        let _verifying_key = signing_key.verifying_key();
597        let signer = TokenSigner::new(signing_key, "test-issuer");
598
599        // Create JWKS with a different key (not the one used for signing)
600        let different_key = SigningKey::generate();
601        let different_verifying_key = different_key.verifying_key();
602        let different_key_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
603            .encode(different_verifying_key.to_bytes());
604
605        let jwks = Jwks {
606            keys: vec![Jwk {
607                kty: "OKP".to_string(),
608                use_: Some("sig".to_string()),
609                kid: "different-key".to_string(),
610                x: different_key_b64,
611            }],
612        };
613
614        let verifier =
615            crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience");
616
617        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
618            .with_scope("read")
619            .with_key_class(KeyClass::Secret)
620            .build();
621        let token = signer.sign(claims).unwrap();
622
623        // Should fail with key not found
624        let result = verifier.verify(&token, None, None).await;
625        assert!(matches!(result, Err(VerifyError::KeyNotFound(_))));
626    }
627
628    #[tokio::test]
629    async fn test_jwks_with_origin_validation() {
630        use crate::token::{Jwk, Jwks};
631        use base64::Engine;
632
633        let signing_key = SigningKey::generate();
634        let verifying_key = signing_key.verifying_key();
635        let kid = verifying_key.key_id();
636        let signer = TokenSigner::new(signing_key, "test-issuer");
637
638        let key_b64 =
639            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(verifying_key.to_bytes());
640
641        let jwks = Jwks {
642            keys: vec![Jwk {
643                kty: "OKP".to_string(),
644                use_: Some("sig".to_string()),
645                kid,
646                x: key_b64,
647            }],
648        };
649
650        // Create verifier with origin validation
651        let verifier =
652            crate::verifier::AsyncVerifier::with_jwks(jwks, "test-issuer", "test-audience")
653                .with_origin_validation();
654
655        // Token with matching origin
656        let claims = SessionClaims::builder("test-issuer", "test-subject", "test-audience")
657            .with_scope("read")
658            .with_key_class(KeyClass::Secret)
659            .with_origin("https://trusted.example.com")
660            .build();
661        let token = signer.sign(claims).unwrap();
662
663        // Should succeed with matching origin
664        let ctx = verifier
665            .verify(&token, Some("https://trusted.example.com"), None)
666            .await
667            .unwrap();
668        assert_eq!(ctx.subject, "test-subject");
669
670        // Should fail with wrong origin
671        let result = verifier
672            .verify(&token, Some("https://evil.example.com"), None)
673            .await;
674        assert!(matches!(result, Err(VerifyError::OriginMismatch { .. })));
675    }
676}