Skip to main content

allowthem_core/
signing_keys.rs

1//! RS256 signing key management — key generation, encrypted storage, JWKS, and OIDC discovery.
2
3use aes_gcm::aead::Aead;
4use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
5use base64ct::{Base64UrlUnpadded, Encoding};
6use chrono::{DateTime, Utc};
7use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
8use rsa::rand_core::{OsRng, RngCore};
9use rsa::traits::PublicKeyParts;
10use rsa::{RsaPrivateKey, RsaPublicKey};
11use serde::Serialize;
12
13use crate::db::Db;
14use crate::error::AuthError;
15use crate::types::SigningKeyId;
16
17/// An RS256 signing key pair stored in the database.
18///
19/// The private key is AES-256-GCM encrypted at rest. Call `decrypt_private_key`
20/// to recover the PKCS#8 PEM for signing.
21#[derive(Debug, Clone, sqlx::FromRow)]
22pub struct SigningKey {
23    pub id: SigningKeyId,
24    pub private_key_enc: Vec<u8>,
25    pub private_key_nonce: Vec<u8>,
26    pub public_key_pem: String,
27    pub algorithm: String,
28    pub is_active: bool,
29    pub created_at: DateTime<Utc>,
30}
31
32/// Decrypt a signing key's private key PEM from its encrypted storage.
33///
34/// Uses AES-256-GCM with the stored nonce. Returns the PKCS#8 PEM string.
35/// This is a free function — decryption is pure computation, not a `Db` method.
36pub fn decrypt_private_key(
37    key: &SigningKey,
38    encryption_key: &[u8; 32],
39) -> Result<String, AuthError> {
40    let cipher = Aes256Gcm::new(encryption_key.into());
41    let nonce = Nonce::from_slice(&key.private_key_nonce);
42    let plaintext = cipher
43        .decrypt(nonce, key.private_key_enc.as_slice())
44        .map_err(|e| AuthError::SigningKey(e.to_string()))?;
45    String::from_utf8(plaintext).map_err(|e| AuthError::SigningKey(e.to_string()))
46}
47
48impl Db {
49    /// Generate an RS256 key pair, encrypt the private key, and store both in the database.
50    ///
51    /// The new key is NOT automatically marked active — call `activate_signing_key` separately.
52    pub async fn create_signing_key(
53        &self,
54        encryption_key: &[u8; 32],
55    ) -> Result<SigningKey, AuthError> {
56        let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
57            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
58
59        let private_pem = private_key
60            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
61            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
62        let pem_bytes = private_pem.as_bytes();
63
64        let mut nonce_bytes = [0u8; 12];
65        OsRng.fill_bytes(&mut nonce_bytes);
66        let nonce = Nonce::from_slice(&nonce_bytes);
67
68        let cipher = Aes256Gcm::new(encryption_key.into());
69        let private_key_enc = cipher
70            .encrypt(nonce, pem_bytes)
71            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
72
73        let public_key_pem = RsaPublicKey::from(&private_key)
74            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
75            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
76
77        let id = SigningKeyId::new();
78
79        sqlx::query(
80            "INSERT INTO allowthem_signing_keys \
81             (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
82             VALUES (?, ?, ?, ?, 'RS256', 0)",
83        )
84        .bind(id)
85        .bind(&private_key_enc)
86        .bind(nonce_bytes.as_slice())
87        .bind(&public_key_pem)
88        .execute(self.pool())
89        .await?;
90
91        let key = SigningKey {
92            id,
93            private_key_enc,
94            private_key_nonce: nonce_bytes.to_vec(),
95            public_key_pem,
96            algorithm: "RS256".to_string(),
97            is_active: false,
98            created_at: Utc::now(),
99        };
100
101        Ok(key)
102    }
103
104    /// Mark a key as the active signing key. Deactivates all other keys in a single transaction.
105    ///
106    /// Returns `AuthError::NotFound` if the key ID does not exist.
107    pub async fn activate_signing_key(&self, key_id: SigningKeyId) -> Result<(), AuthError> {
108        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
109
110        sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
111            .execute(&mut *tx)
112            .await
113            .map_err(AuthError::Database)?;
114
115        let result = sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
116            .bind(key_id)
117            .execute(&mut *tx)
118            .await
119            .map_err(AuthError::Database)?;
120
121        if result.rows_affected() == 0 {
122            tx.rollback().await.map_err(AuthError::Database)?;
123            return Err(AuthError::NotFound);
124        }
125
126        tx.commit().await.map_err(AuthError::Database)?;
127        Ok(())
128    }
129
130    /// Generate a new key and activate it, deactivating the current active key.
131    ///
132    /// Combines creation and activation in a single transaction.
133    pub async fn rotate_signing_key(
134        &self,
135        encryption_key: &[u8; 32],
136    ) -> Result<SigningKey, AuthError> {
137        let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
138            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
139
140        let private_pem = private_key
141            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
142            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
143        let pem_bytes = private_pem.as_bytes();
144
145        let mut nonce_bytes = [0u8; 12];
146        OsRng.fill_bytes(&mut nonce_bytes);
147        let nonce = Nonce::from_slice(&nonce_bytes);
148
149        let cipher = Aes256Gcm::new(encryption_key.into());
150        let private_key_enc = cipher
151            .encrypt(nonce, pem_bytes)
152            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
153
154        let public_key_pem = RsaPublicKey::from(&private_key)
155            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
156            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
157
158        let id = SigningKeyId::new();
159
160        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
161
162        sqlx::query(
163            "INSERT INTO allowthem_signing_keys \
164             (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
165             VALUES (?, ?, ?, ?, 'RS256', 0)",
166        )
167        .bind(id)
168        .bind(&private_key_enc)
169        .bind(nonce_bytes.as_slice())
170        .bind(&public_key_pem)
171        .execute(&mut *tx)
172        .await
173        .map_err(AuthError::Database)?;
174
175        sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
176            .execute(&mut *tx)
177            .await
178            .map_err(AuthError::Database)?;
179
180        sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
181            .bind(id)
182            .execute(&mut *tx)
183            .await
184            .map_err(AuthError::Database)?;
185
186        tx.commit().await.map_err(AuthError::Database)?;
187
188        let key = SigningKey {
189            id,
190            private_key_enc,
191            private_key_nonce: nonce_bytes.to_vec(),
192            public_key_pem,
193            algorithm: "RS256".to_string(),
194            is_active: true,
195            created_at: Utc::now(),
196        };
197
198        Ok(key)
199    }
200
201    /// Get the currently active signing key.
202    ///
203    /// Returns `AuthError::NotFound` if no key is active.
204    pub async fn get_active_signing_key(&self) -> Result<SigningKey, AuthError> {
205        sqlx::query_as(
206            "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
207             algorithm, is_active, created_at \
208             FROM allowthem_signing_keys WHERE is_active = 1 LIMIT 1",
209        )
210        .fetch_optional(self.pool())
211        .await?
212        .ok_or(AuthError::NotFound)
213    }
214
215    /// Get all signing keys ordered by creation date descending.
216    ///
217    /// Used by the JWKS endpoint to serve all public keys (active + rotated-out).
218    pub async fn get_all_signing_keys(&self) -> Result<Vec<SigningKey>, AuthError> {
219        Ok(sqlx::query_as(
220            "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
221             algorithm, is_active, created_at \
222             FROM allowthem_signing_keys ORDER BY created_at DESC",
223        )
224        .fetch_all(self.pool())
225        .await?)
226    }
227
228    /// Get a specific signing key by ID.
229    ///
230    /// Returns `AuthError::NotFound` if no key matches the ID.
231    pub async fn get_signing_key(&self, id: SigningKeyId) -> Result<SigningKey, AuthError> {
232        sqlx::query_as(
233            "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
234             algorithm, is_active, created_at \
235             FROM allowthem_signing_keys WHERE id = ?",
236        )
237        .bind(id)
238        .fetch_optional(self.pool())
239        .await?
240        .ok_or(AuthError::NotFound)
241    }
242}
243
244/// A single JWK entry for the JWKS endpoint.
245#[derive(Debug, Serialize)]
246pub struct JwkEntry {
247    pub kty: &'static str,
248    #[serde(rename = "use")]
249    pub use_: &'static str,
250    pub alg: &'static str,
251    pub kid: String,
252    pub n: String,
253    pub e: String,
254}
255
256/// The full JWKS document.
257#[derive(Debug, Serialize)]
258pub struct JwkSet {
259    pub keys: Vec<JwkEntry>,
260}
261
262/// Build a JWKS document from all signing keys.
263///
264/// Parses each key's public PEM to extract the RSA modulus and exponent,
265/// base64url-encoding them per RFC 7518 Section 6.3.1.
266pub fn build_jwks(keys: &[SigningKey]) -> Result<JwkSet, AuthError> {
267    let mut entries = Vec::with_capacity(keys.len());
268    for key in keys {
269        let pub_key = RsaPublicKey::from_public_key_pem(&key.public_key_pem)
270            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
271        let n = Base64UrlUnpadded::encode_string(&pub_key.n().to_bytes_be());
272        let e = Base64UrlUnpadded::encode_string(&pub_key.e().to_bytes_be());
273        entries.push(JwkEntry {
274            kty: "RSA",
275            use_: "sig",
276            alg: "RS256",
277            kid: key.id.to_string(),
278            n,
279            e,
280        });
281    }
282    Ok(JwkSet { keys: entries })
283}
284
285/// OpenID Connect discovery metadata.
286///
287/// See: <https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata>
288#[derive(Debug, Serialize)]
289pub struct OidcDiscovery {
290    pub issuer: String,
291    pub authorization_endpoint: String,
292    pub token_endpoint: String,
293    pub userinfo_endpoint: String,
294    pub jwks_uri: String,
295    pub scopes_supported: Vec<&'static str>,
296    pub response_types_supported: Vec<&'static str>,
297    pub grant_types_supported: Vec<&'static str>,
298    pub subject_types_supported: Vec<&'static str>,
299    pub id_token_signing_alg_values_supported: Vec<&'static str>,
300    pub token_endpoint_auth_methods_supported: Vec<&'static str>,
301    pub code_challenge_methods_supported: Vec<&'static str>,
302}
303
304/// Build the OIDC discovery document for the given issuer URL.
305pub fn build_discovery(issuer: &str) -> OidcDiscovery {
306    OidcDiscovery {
307        authorization_endpoint: format!("{issuer}/oauth/authorize"),
308        token_endpoint: format!("{issuer}/oauth/token"),
309        userinfo_endpoint: format!("{issuer}/oauth/userinfo"),
310        jwks_uri: format!("{issuer}/.well-known/jwks.json"),
311        issuer: issuer.to_string(),
312        scopes_supported: vec!["openid", "profile", "email"],
313        response_types_supported: vec!["code"],
314        grant_types_supported: vec!["authorization_code", "refresh_token"],
315        subject_types_supported: vec!["public"],
316        id_token_signing_alg_values_supported: vec!["RS256"],
317        token_endpoint_auth_methods_supported: vec!["client_secret_post", "client_secret_basic"],
318        code_challenge_methods_supported: vec!["S256"],
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use crate::db::Db;
326    use rsa::pkcs8::DecodePrivateKey;
327    use sqlx::SqlitePool;
328    use sqlx::sqlite::SqliteConnectOptions;
329    use std::str::FromStr;
330
331    const ENC_KEY_A: [u8; 32] = [0x42; 32];
332    const ENC_KEY_B: [u8; 32] = [0x99; 32];
333
334    async fn test_db() -> Db {
335        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
336            .unwrap()
337            .pragma("foreign_keys", "ON");
338        let pool = SqlitePool::connect_with(opts).await.unwrap();
339        Db::new(pool).await.unwrap()
340    }
341
342    #[test]
343    fn decrypt_round_trip() {
344        let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
345        let pem = private_key
346            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
347            .unwrap();
348        let pem_bytes = pem.as_bytes();
349
350        let mut nonce_bytes = [0u8; 12];
351        OsRng.fill_bytes(&mut nonce_bytes);
352        let nonce = Nonce::from_slice(&nonce_bytes);
353        let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
354        let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
355
356        let public_key_pem = RsaPublicKey::from(&private_key)
357            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
358            .unwrap();
359
360        let key = SigningKey {
361            id: SigningKeyId::new(),
362            private_key_enc: ciphertext,
363            private_key_nonce: nonce_bytes.to_vec(),
364            public_key_pem,
365            algorithm: "RS256".to_string(),
366            is_active: false,
367            created_at: Utc::now(),
368        };
369
370        let decrypted = decrypt_private_key(&key, &ENC_KEY_A).unwrap();
371        assert_eq!(decrypted.as_bytes(), pem_bytes);
372    }
373
374    #[test]
375    fn decrypt_wrong_key_fails() {
376        let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
377        let pem = private_key
378            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
379            .unwrap();
380        let pem_bytes = pem.as_bytes();
381
382        let mut nonce_bytes = [0u8; 12];
383        OsRng.fill_bytes(&mut nonce_bytes);
384        let nonce = Nonce::from_slice(&nonce_bytes);
385        let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
386        let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
387
388        let public_key_pem = RsaPublicKey::from(&private_key)
389            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
390            .unwrap();
391
392        let key = SigningKey {
393            id: SigningKeyId::new(),
394            private_key_enc: ciphertext,
395            private_key_nonce: nonce_bytes.to_vec(),
396            public_key_pem,
397            algorithm: "RS256".to_string(),
398            is_active: false,
399            created_at: Utc::now(),
400        };
401
402        let result = decrypt_private_key(&key, &ENC_KEY_B);
403        assert!(result.is_err(), "decryption with wrong key must fail");
404    }
405
406    #[tokio::test]
407    async fn create_signing_key_stores_row() {
408        let db = test_db().await;
409        let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
410        assert!(!key.is_active, "new key must not be active");
411
412        let fetched = db.get_signing_key(key.id).await.unwrap();
413        assert_eq!(fetched.id, key.id);
414        assert_eq!(fetched.algorithm, "RS256");
415        assert!(!fetched.is_active);
416    }
417
418    #[tokio::test]
419    async fn activate_signing_key_marks_active() {
420        let db = test_db().await;
421        let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
422        let key2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
423
424        db.activate_signing_key(key2.id).await.unwrap();
425
426        let fetched1 = db.get_signing_key(key1.id).await.unwrap();
427        let fetched2 = db.get_signing_key(key2.id).await.unwrap();
428        assert!(!fetched1.is_active, "first key must be inactive");
429        assert!(fetched2.is_active, "second key must be active");
430    }
431
432    #[tokio::test]
433    async fn activate_nonexistent_returns_not_found() {
434        let db = test_db().await;
435        let fake_id = SigningKeyId::new();
436        let result = db.activate_signing_key(fake_id).await;
437        assert!(matches!(result, Err(AuthError::NotFound)));
438    }
439
440    #[tokio::test]
441    async fn rotate_signing_key_single_active() {
442        let db = test_db().await;
443        let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
444        db.activate_signing_key(key1.id).await.unwrap();
445
446        let new_key = db.rotate_signing_key(&ENC_KEY_A).await.unwrap();
447
448        let active = db.get_active_signing_key().await.unwrap();
449        assert_eq!(active.id, new_key.id, "new key must be the active one");
450
451        let old = db.get_signing_key(key1.id).await.unwrap();
452        assert!(!old.is_active, "old key must be inactive after rotation");
453    }
454
455    #[tokio::test]
456    async fn get_all_signing_keys_returns_all() {
457        let db = test_db().await;
458        let k1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
459        let k2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
460
461        let all = db.get_all_signing_keys().await.unwrap();
462        assert_eq!(all.len(), 2);
463        let ids: Vec<_> = all.iter().map(|k| k.id).collect();
464        assert!(ids.contains(&k1.id));
465        assert!(ids.contains(&k2.id));
466    }
467
468    #[test]
469    fn build_jwks_output_format() {
470        let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
471        let public_key_pem = RsaPublicKey::from(&private_key)
472            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
473            .unwrap();
474
475        let id = SigningKeyId::new();
476        let key = SigningKey {
477            id,
478            private_key_enc: vec![],
479            private_key_nonce: vec![],
480            public_key_pem,
481            algorithm: "RS256".to_string(),
482            is_active: true,
483            created_at: Utc::now(),
484        };
485
486        let jwks = build_jwks(&[key]).unwrap();
487        assert_eq!(jwks.keys.len(), 1);
488        let entry = &jwks.keys[0];
489        assert_eq!(entry.kty, "RSA");
490        assert_eq!(entry.alg, "RS256");
491        assert_eq!(entry.use_, "sig");
492        assert!(!entry.n.is_empty(), "modulus must be non-empty");
493        assert!(!entry.e.is_empty(), "exponent must be non-empty");
494        assert_eq!(entry.kid, id.to_string());
495    }
496
497    #[test]
498    fn build_jwks_empty() {
499        let jwks = build_jwks(&[]).unwrap();
500        assert!(jwks.keys.is_empty(), "empty input yields empty JWKS");
501    }
502
503    #[test]
504    fn build_jwks_use_field_serializes_correctly() {
505        // The #[serde(rename = "use")] on JwkEntry is load-bearing for OIDC relying parties.
506        // Test that the JSON output contains "use":"sig", not "use_":"sig".
507        let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
508        let public_key_pem = RsaPublicKey::from(&private_key)
509            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
510            .unwrap();
511
512        let key = SigningKey {
513            id: SigningKeyId::new(),
514            private_key_enc: vec![],
515            private_key_nonce: vec![],
516            public_key_pem,
517            algorithm: "RS256".to_string(),
518            is_active: true,
519            created_at: Utc::now(),
520        };
521
522        let jwks = build_jwks(&[key]).unwrap();
523        let json = serde_json::to_string(&jwks).unwrap();
524        assert!(
525            json.contains(r#""use":"sig"#),
526            "JWKS JSON must contain \"use\":\"sig\", got: {json}"
527        );
528        assert!(
529            !json.contains("use_"),
530            "JWKS JSON must not contain Rust field name 'use_', got: {json}"
531        );
532    }
533
534    #[tokio::test]
535    async fn get_active_signing_key_no_active_returns_not_found() {
536        let db = test_db().await;
537        // DB is empty — no active key.
538        let result = db.get_active_signing_key().await;
539        assert!(
540            matches!(result, Err(AuthError::NotFound)),
541            "expected NotFound, got: {result:?}"
542        );
543
544        // Also verify: key created but never activated.
545        db.create_signing_key(&ENC_KEY_A).await.unwrap();
546        let result = db.get_active_signing_key().await;
547        assert!(
548            matches!(result, Err(AuthError::NotFound)),
549            "inactive key must not be returned as active"
550        );
551    }
552
553    #[tokio::test]
554    async fn create_and_decrypt_round_trip_through_db() {
555        // Exercise the full create→store→fetch→decrypt path including the BLOB round-trip
556        // through SQLite. The manual decrypt_round_trip test does not touch the DB.
557        let db = test_db().await;
558        let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
559
560        let fetched = db.get_signing_key(key.id).await.unwrap();
561        let decrypted_pem = decrypt_private_key(&fetched, &ENC_KEY_A).unwrap();
562
563        // Verify the decrypted bytes parse as a valid RSA private key PEM.
564        let reparsed = RsaPrivateKey::from_pkcs8_pem(&decrypted_pem);
565        assert!(
566            reparsed.is_ok(),
567            "decrypted PEM from DB must parse as RsaPrivateKey"
568        );
569
570        // Cross-check: public key derived from decrypted private key matches stored public key.
571        let derived_pub = RsaPublicKey::from(&reparsed.unwrap())
572            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
573            .unwrap();
574        assert_eq!(
575            derived_pub, fetched.public_key_pem,
576            "public key derived from decrypted private key must match stored public key"
577        );
578    }
579
580    #[test]
581    fn build_discovery_fields() {
582        let issuer = "https://auth.example.com";
583        let doc = build_discovery(issuer);
584
585        assert_eq!(doc.issuer, issuer);
586        assert_eq!(
587            doc.authorization_endpoint,
588            "https://auth.example.com/oauth/authorize"
589        );
590        assert_eq!(doc.token_endpoint, "https://auth.example.com/oauth/token");
591        assert_eq!(
592            doc.userinfo_endpoint,
593            "https://auth.example.com/oauth/userinfo"
594        );
595        assert_eq!(
596            doc.jwks_uri,
597            "https://auth.example.com/.well-known/jwks.json"
598        );
599        assert!(!doc.scopes_supported.is_empty());
600        assert!(!doc.response_types_supported.is_empty());
601        assert!(!doc.grant_types_supported.is_empty());
602        assert!(!doc.subject_types_supported.is_empty());
603        assert!(!doc.id_token_signing_alg_values_supported.is_empty());
604        assert!(!doc.token_endpoint_auth_methods_supported.is_empty());
605        assert!(!doc.code_challenge_methods_supported.is_empty());
606    }
607}