authx_core/crypto/
key_store.rs1use std::sync::{Arc, RwLock};
2
3use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
4use tracing::instrument;
5use uuid::Uuid;
6
7use crate::error::{AuthError, Result};
8
9use super::signing::Claims;
10
11struct KeyVersion {
13 kid: String,
15 encoding: EncodingKey,
16 decoding: DecodingKey,
17}
18
19pub struct KeyRotationStore {
39 inner: Arc<RwLock<Inner>>,
40 max_keys: usize,
41}
42
43struct Inner {
44 keys: Vec<KeyVersion>,
45}
46
47impl KeyRotationStore {
48 pub fn new(max_keys: usize) -> Self {
51 let max_keys = max_keys.clamp(1, 16);
52 Self {
53 inner: Arc::new(RwLock::new(Inner { keys: Vec::new() })),
54 max_keys,
55 }
56 }
57
58 pub fn add_key(
60 &self,
61 kid: impl Into<String>,
62 private_pem: &[u8],
63 public_pem: &[u8],
64 ) -> Result<()> {
65 let encoding = EncodingKey::from_ed_pem(private_pem)
66 .map_err(|e| AuthError::Internal(format!("invalid private key: {e}")))?;
67 let decoding = DecodingKey::from_ed_pem(public_pem)
68 .map_err(|e| AuthError::Internal(format!("invalid public key: {e}")))?;
69
70 let version = KeyVersion {
71 kid: kid.into(),
72 encoding,
73 decoding,
74 };
75 let mut inner = match self.inner.write() {
76 Ok(g) => g,
77 Err(e) => {
78 tracing::error!("key store write-lock poisoned — recovering");
79 e.into_inner()
80 }
81 };
82 inner.keys.push(version);
83
84 while inner.keys.len() > self.max_keys {
86 let removed = inner.keys.remove(0);
87 tracing::info!(kid = %removed.kid, "key version evicted");
88 }
89
90 let current_kid = inner.keys.last().map(|k| k.kid.clone()).unwrap_or_default();
91 tracing::info!(kid = %current_kid, total = inner.keys.len(), "key version added");
92 Ok(())
93 }
94
95 pub fn rotate(
97 &self,
98 kid: impl Into<String>,
99 private_pem: &[u8],
100 public_pem: &[u8],
101 ) -> Result<()> {
102 self.add_key(kid, private_pem, public_pem)
103 }
104
105 pub fn prune_oldest(&self) {
107 let mut inner = match self.inner.write() {
108 Ok(g) => g,
109 Err(e) => {
110 tracing::error!("key store write-lock poisoned — recovering");
111 e.into_inner()
112 }
113 };
114 if inner.keys.len() > 1 {
115 let removed = inner.keys.remove(0);
116 tracing::info!(kid = %removed.kid, "oldest key version pruned");
117 }
118 }
119
120 #[instrument(skip(self, extra), fields(sub = %subject))]
122 pub fn sign(
123 &self,
124 subject: Uuid,
125 ttl_seconds: i64,
126 extra: serde_json::Value,
127 ) -> Result<String> {
128 use chrono::Utc;
129
130 let inner = match self.inner.read() {
131 Ok(g) => g,
132 Err(e) => {
133 tracing::error!("key store read-lock poisoned — recovering");
134 e.into_inner()
135 }
136 };
137 let kv = inner
138 .keys
139 .last()
140 .ok_or_else(|| AuthError::Internal("key store is empty — add a key first".into()))?;
141
142 let now = Utc::now().timestamp();
143 let claims = Claims {
144 sub: subject.to_string(),
145 exp: now + ttl_seconds,
146 iat: now,
147 jti: Uuid::new_v4().to_string(),
148 org: None,
149 extra,
150 };
151
152 let mut header = Header::new(Algorithm::EdDSA);
153 header.kid = Some(kv.kid.clone());
154
155 let token = encode(&header, &claims, &kv.encoding)
156 .map_err(|e| AuthError::Internal(format!("jwt sign failed: {e}")))?;
157
158 tracing::debug!(kid = %kv.kid, sub = %subject, "jwt signed");
159 Ok(token)
160 }
161
162 #[instrument(skip(self, token))]
164 pub fn verify(&self, token: &str) -> Result<Claims> {
165 let inner = match self.inner.read() {
166 Ok(g) => g,
167 Err(e) => {
168 tracing::error!("key store read-lock poisoned — recovering");
169 e.into_inner()
170 }
171 };
172
173 let mut validation = Validation::new(Algorithm::EdDSA);
174 validation.validate_exp = true;
175
176 let header = jsonwebtoken::decode_header(token).map_err(|_| AuthError::InvalidToken)?;
178 let preferred_kid = header.kid.as_deref();
179
180 let ordered: Vec<_> = inner.keys.iter().rev().collect();
182 for kv in &ordered {
183 if let Some(kid) = preferred_kid
184 && kv.kid != kid
185 {
186 continue; }
188 if let Ok(data) = decode::<Claims>(token, &kv.decoding, &validation) {
189 tracing::debug!(kid = %kv.kid, sub = %data.claims.sub, "jwt verified");
190 return Ok(data.claims);
191 }
192 }
193
194 for kv in &ordered {
196 if let Ok(data) = decode::<Claims>(token, &kv.decoding, &validation) {
197 tracing::debug!(kid = %kv.kid, sub = %data.claims.sub, "jwt verified (fallback)");
198 return Ok(data.claims);
199 }
200 }
201
202 tracing::warn!("jwt verification failed against all key versions");
203 Err(AuthError::InvalidToken)
204 }
205
206 pub fn key_count(&self) -> usize {
208 match self.inner.read() {
209 Ok(g) => g.keys.len(),
210 Err(e) => {
211 tracing::error!("key store read-lock poisoned — recovering");
212 e.into_inner().keys.len()
213 }
214 }
215 }
216}
217
218impl Clone for KeyRotationStore {
219 fn clone(&self) -> Self {
220 Self {
221 inner: Arc::clone(&self.inner),
222 max_keys: self.max_keys,
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use uuid::Uuid;
231
232 const PRIV_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEIJ+DYDHbiFQiDpMqQR5JN9QOCiIxj7T/XmVbz3Cg+xvL\n-----END PRIVATE KEY-----\n";
235 const PUB_PEM: &[u8] = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAoNFBPj4h5jFITR2XlDqz8qFjNXaXFJF3mJoSBpVwC1E=\n-----END PUBLIC KEY-----\n";
236
237 const PRIV2_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEIBBZj4V3sFR3zIieCbxHnrLoAoEJQHBkJPIJlqMvpO5U\n-----END PRIVATE KEY-----\n";
239 const PUB2_PEM: &[u8] = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEA2YkJaLvQK1gTnYqQB8djQZfPOvXrJTpGE9nO9A4Xbg0=\n-----END PUBLIC KEY-----\n";
240
241 #[test]
242 fn empty_store_sign_fails() {
243 let store = KeyRotationStore::new(2);
244 assert!(
245 store
246 .sign(Uuid::new_v4(), 3600, serde_json::Value::Null)
247 .is_err()
248 );
249 }
250
251 #[test]
252 fn empty_store_verify_fails() {
253 let store = KeyRotationStore::new(2);
254 assert!(store.verify("not.a.token").is_err());
255 }
256
257 #[test]
258 fn key_count_tracks_additions() {
259 let store = KeyRotationStore::new(3);
260 assert_eq!(store.key_count(), 0);
261
262 if store.add_key("v1", PRIV_PEM, PUB_PEM).is_ok() {
263 assert_eq!(store.key_count(), 1);
264 }
265 }
266
267 #[test]
268 fn invalid_pem_rejected() {
269 let store = KeyRotationStore::new(2);
270 let err = store.add_key("bad", b"not-a-pem", b"also-not-a-pem");
271 assert!(err.is_err());
272 }
273
274 #[test]
275 fn clone_shares_state() {
276 let store = KeyRotationStore::new(2);
277 let clone = store.clone();
278 if store.add_key("v1", PRIV_PEM, PUB_PEM).is_ok() {
280 assert_eq!(clone.key_count(), 1);
281 }
282 }
283
284 #[test]
285 fn max_keys_evicts_oldest() {
286 let store = KeyRotationStore::new(1);
287 let r1 = store.add_key("v1", PRIV_PEM, PUB_PEM);
289 let r2 = store.add_key("v2", PRIV2_PEM, PUB2_PEM);
290
291 if r1.is_ok() && r2.is_ok() {
292 assert_eq!(store.key_count(), 1);
293 }
294 }
295}