Skip to main content

authx_core/crypto/
key_store.rs

1use std::sync::{Arc, RwLock};
2
3use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
4use tracing::instrument;
5use uuid::Uuid;
6
7use crate::error::{AuthError, Result};
8
9use super::signing::Claims;
10
11/// A versioned key pair entry.
12struct KeyVersion {
13    /// Unique version identifier — included in the JWT `kid` header.
14    kid: String,
15    encoding: EncodingKey,
16    decoding: DecodingKey,
17}
18
19/// Zero-downtime key rotation store.
20///
21/// Holds up to `max_keys` Ed25519 key pairs simultaneously. The *latest*
22/// added key is used for signing; all retained keys can verify tokens.
23/// During rotation:
24///
25/// 1. Call `rotate(private_pem, public_pem)` to add the new key.
26/// 2. New tokens are signed with the new key immediately.
27/// 3. Old tokens remain verifiable until their natural expiry.
28/// 4. Call `prune()` to remove the oldest key once all old tokens have expired.
29///
30/// # Example
31/// ```rust,ignore
32/// let mut store = KeyRotationStore::new(3);
33/// store.add_key("v1", PRIVATE_PEM, PUBLIC_PEM)?;
34///
35/// // Later, on rotation:
36/// store.rotate("v2", NEW_PRIVATE_PEM, NEW_PUBLIC_PEM)?;
37/// ```
38pub struct KeyRotationStore {
39    inner: Arc<RwLock<Inner>>,
40    max_keys: usize,
41}
42
43struct Inner {
44    keys: Vec<KeyVersion>,
45}
46
47impl KeyRotationStore {
48    /// Create a new store. `max_keys` caps how many key versions are retained
49    /// simultaneously (minimum 1, maximum 16).
50    pub fn new(max_keys: usize) -> Self {
51        let max_keys = max_keys.clamp(1, 16);
52        Self {
53            inner: Arc::new(RwLock::new(Inner { keys: Vec::new() })),
54            max_keys,
55        }
56    }
57
58    /// Load the initial key pair. `kid` is a human-readable version tag.
59    pub fn add_key(
60        &self,
61        kid: impl Into<String>,
62        private_pem: &[u8],
63        public_pem: &[u8],
64    ) -> Result<()> {
65        let encoding = EncodingKey::from_ed_pem(private_pem)
66            .map_err(|e| AuthError::Internal(format!("invalid private key: {e}")))?;
67        let decoding = DecodingKey::from_ed_pem(public_pem)
68            .map_err(|e| AuthError::Internal(format!("invalid public key: {e}")))?;
69
70        let version = KeyVersion {
71            kid: kid.into(),
72            encoding,
73            decoding,
74        };
75        let mut inner = match self.inner.write() {
76            Ok(g) => g,
77            Err(e) => {
78                tracing::error!("key store write-lock poisoned — recovering");
79                e.into_inner()
80            }
81        };
82        inner.keys.push(version);
83
84        // Enforce max_keys by evicting the oldest.
85        while inner.keys.len() > self.max_keys {
86            let removed = inner.keys.remove(0);
87            tracing::info!(kid = %removed.kid, "key version evicted");
88        }
89
90        let current_kid = inner.keys.last().map(|k| k.kid.clone()).unwrap_or_default();
91        tracing::info!(kid = %current_kid, total = inner.keys.len(), "key version added");
92        Ok(())
93    }
94
95    /// Convenience alias — same as `add_key` but semantically signals rotation.
96    pub fn rotate(
97        &self,
98        kid: impl Into<String>,
99        private_pem: &[u8],
100        public_pem: &[u8],
101    ) -> Result<()> {
102        self.add_key(kid, private_pem, public_pem)
103    }
104
105    /// Drop the oldest key version (call after old tokens have expired).
106    pub fn prune_oldest(&self) {
107        let mut inner = match self.inner.write() {
108            Ok(g) => g,
109            Err(e) => {
110                tracing::error!("key store write-lock poisoned — recovering");
111                e.into_inner()
112            }
113        };
114        if inner.keys.len() > 1 {
115            let removed = inner.keys.remove(0);
116            tracing::info!(kid = %removed.kid, "oldest key version pruned");
117        }
118    }
119
120    /// Sign a JWT with the current (newest) key.
121    #[instrument(skip(self, extra), fields(sub = %subject))]
122    pub fn sign(
123        &self,
124        subject: Uuid,
125        ttl_seconds: i64,
126        extra: serde_json::Value,
127    ) -> Result<String> {
128        use chrono::Utc;
129
130        let inner = match self.inner.read() {
131            Ok(g) => g,
132            Err(e) => {
133                tracing::error!("key store read-lock poisoned — recovering");
134                e.into_inner()
135            }
136        };
137        let kv = inner
138            .keys
139            .last()
140            .ok_or_else(|| AuthError::Internal("key store is empty — add a key first".into()))?;
141
142        let now = Utc::now().timestamp();
143        let claims = Claims {
144            sub: subject.to_string(),
145            exp: now + ttl_seconds,
146            iat: now,
147            jti: Uuid::new_v4().to_string(),
148            org: None,
149            extra,
150        };
151
152        let mut header = Header::new(Algorithm::EdDSA);
153        header.kid = Some(kv.kid.clone());
154
155        let token = encode(&header, &claims, &kv.encoding)
156            .map_err(|e| AuthError::Internal(format!("jwt sign failed: {e}")))?;
157
158        tracing::debug!(kid = %kv.kid, sub = %subject, "jwt signed");
159        Ok(token)
160    }
161
162    /// Verify a JWT against *all* retained key versions (newest first).
163    #[instrument(skip(self, token))]
164    pub fn verify(&self, token: &str) -> Result<Claims> {
165        let inner = match self.inner.read() {
166            Ok(g) => g,
167            Err(e) => {
168                tracing::error!("key store read-lock poisoned — recovering");
169                e.into_inner()
170            }
171        };
172
173        let mut validation = Validation::new(Algorithm::EdDSA);
174        validation.validate_exp = true;
175
176        // Extract `kid` from the header to try the right key first.
177        let header = jsonwebtoken::decode_header(token).map_err(|_| AuthError::InvalidToken)?;
178        let preferred_kid = header.kid.as_deref();
179
180        // Try keys newest-first, preferring the kid match.
181        let ordered: Vec<_> = inner.keys.iter().rev().collect();
182        for kv in &ordered {
183            if let Some(kid) = preferred_kid
184                && kv.kid != kid
185            {
186                continue; // skip non-matching first pass
187            }
188            if let Ok(data) = decode::<Claims>(token, &kv.decoding, &validation) {
189                tracing::debug!(kid = %kv.kid, sub = %data.claims.sub, "jwt verified");
190                return Ok(data.claims);
191            }
192        }
193
194        // Fallback: try all keys (handles tokens without kid or mismatched kid).
195        for kv in &ordered {
196            if let Ok(data) = decode::<Claims>(token, &kv.decoding, &validation) {
197                tracing::debug!(kid = %kv.kid, sub = %data.claims.sub, "jwt verified (fallback)");
198                return Ok(data.claims);
199            }
200        }
201
202        tracing::warn!("jwt verification failed against all key versions");
203        Err(AuthError::InvalidToken)
204    }
205
206    /// Number of currently retained key versions.
207    pub fn key_count(&self) -> usize {
208        match self.inner.read() {
209            Ok(g) => g.keys.len(),
210            Err(e) => {
211                tracing::error!("key store read-lock poisoned — recovering");
212                e.into_inner().keys.len()
213            }
214        }
215    }
216}
217
218impl Clone for KeyRotationStore {
219    fn clone(&self) -> Self {
220        Self {
221            inner: Arc::clone(&self.inner),
222            max_keys: self.max_keys,
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use uuid::Uuid;
231
232    // Minimal Ed25519 PEM pair for testing (generated offline).
233    // These are test-only keys — never use in production.
234    const PRIV_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEIJ+DYDHbiFQiDpMqQR5JN9QOCiIxj7T/XmVbz3Cg+xvL\n-----END PRIVATE KEY-----\n";
235    const PUB_PEM: &[u8] = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAoNFBPj4h5jFITR2XlDqz8qFjNXaXFJF3mJoSBpVwC1E=\n-----END PUBLIC KEY-----\n";
236
237    // Second key pair.
238    const PRIV2_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEIBBZj4V3sFR3zIieCbxHnrLoAoEJQHBkJPIJlqMvpO5U\n-----END PRIVATE KEY-----\n";
239    const PUB2_PEM: &[u8] = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEA2YkJaLvQK1gTnYqQB8djQZfPOvXrJTpGE9nO9A4Xbg0=\n-----END PUBLIC KEY-----\n";
240
241    #[test]
242    fn empty_store_sign_fails() {
243        let store = KeyRotationStore::new(2);
244        assert!(
245            store
246                .sign(Uuid::new_v4(), 3600, serde_json::Value::Null)
247                .is_err()
248        );
249    }
250
251    #[test]
252    fn empty_store_verify_fails() {
253        let store = KeyRotationStore::new(2);
254        assert!(store.verify("not.a.token").is_err());
255    }
256
257    #[test]
258    fn key_count_tracks_additions() {
259        let store = KeyRotationStore::new(3);
260        assert_eq!(store.key_count(), 0);
261
262        if store.add_key("v1", PRIV_PEM, PUB_PEM).is_ok() {
263            assert_eq!(store.key_count(), 1);
264        }
265    }
266
267    #[test]
268    fn invalid_pem_rejected() {
269        let store = KeyRotationStore::new(2);
270        let err = store.add_key("bad", b"not-a-pem", b"also-not-a-pem");
271        assert!(err.is_err());
272    }
273
274    #[test]
275    fn clone_shares_state() {
276        let store = KeyRotationStore::new(2);
277        let clone = store.clone();
278        // Mutations through one are visible from the other.
279        if store.add_key("v1", PRIV_PEM, PUB_PEM).is_ok() {
280            assert_eq!(clone.key_count(), 1);
281        }
282    }
283
284    #[test]
285    fn max_keys_evicts_oldest() {
286        let store = KeyRotationStore::new(1);
287        // Add two keys — only the second should remain.
288        let r1 = store.add_key("v1", PRIV_PEM, PUB_PEM);
289        let r2 = store.add_key("v2", PRIV2_PEM, PUB2_PEM);
290
291        if r1.is_ok() && r2.is_ok() {
292            assert_eq!(store.key_count(), 1);
293        }
294    }
295}