1use aes_gcm::aead::Aead;
4use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
5use base64ct::{Base64UrlUnpadded, Encoding};
6use chrono::{DateTime, Utc};
7use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
8use rsa::rand_core::{OsRng, RngCore};
9use rsa::traits::PublicKeyParts;
10use rsa::{RsaPrivateKey, RsaPublicKey};
11use serde::Serialize;
12
13use crate::db::Db;
14use crate::error::AuthError;
15use crate::types::SigningKeyId;
16
17#[derive(Debug, Clone, sqlx::FromRow)]
22pub struct SigningKey {
23 pub id: SigningKeyId,
24 pub private_key_enc: Vec<u8>,
25 pub private_key_nonce: Vec<u8>,
26 pub public_key_pem: String,
27 pub algorithm: String,
28 pub is_active: bool,
29 pub created_at: DateTime<Utc>,
30}
31
32pub fn decrypt_private_key(
37 key: &SigningKey,
38 encryption_key: &[u8; 32],
39) -> Result<String, AuthError> {
40 let cipher = Aes256Gcm::new(encryption_key.into());
41 let nonce = Nonce::from_slice(&key.private_key_nonce);
42 let plaintext = cipher
43 .decrypt(nonce, key.private_key_enc.as_slice())
44 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
45 String::from_utf8(plaintext).map_err(|e| AuthError::SigningKey(e.to_string()))
46}
47
48impl Db {
49 pub async fn create_signing_key(
53 &self,
54 encryption_key: &[u8; 32],
55 ) -> Result<SigningKey, AuthError> {
56 let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
57 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
58
59 let private_pem = private_key
60 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
61 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
62 let pem_bytes = private_pem.as_bytes();
63
64 let mut nonce_bytes = [0u8; 12];
65 OsRng.fill_bytes(&mut nonce_bytes);
66 let nonce = Nonce::from_slice(&nonce_bytes);
67
68 let cipher = Aes256Gcm::new(encryption_key.into());
69 let private_key_enc = cipher
70 .encrypt(nonce, pem_bytes)
71 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
72
73 let public_key_pem = RsaPublicKey::from(&private_key)
74 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
75 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
76
77 let id = SigningKeyId::new();
78
79 sqlx::query(
80 "INSERT INTO allowthem_signing_keys \
81 (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
82 VALUES (?, ?, ?, ?, 'RS256', 0)",
83 )
84 .bind(id)
85 .bind(&private_key_enc)
86 .bind(nonce_bytes.as_slice())
87 .bind(&public_key_pem)
88 .execute(self.pool())
89 .await?;
90
91 let key = SigningKey {
92 id,
93 private_key_enc,
94 private_key_nonce: nonce_bytes.to_vec(),
95 public_key_pem,
96 algorithm: "RS256".to_string(),
97 is_active: false,
98 created_at: Utc::now(),
99 };
100
101 Ok(key)
102 }
103
104 pub async fn activate_signing_key(&self, key_id: SigningKeyId) -> Result<(), AuthError> {
108 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
109
110 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
111 .execute(&mut *tx)
112 .await
113 .map_err(AuthError::Database)?;
114
115 let result = sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
116 .bind(key_id)
117 .execute(&mut *tx)
118 .await
119 .map_err(AuthError::Database)?;
120
121 if result.rows_affected() == 0 {
122 tx.rollback().await.map_err(AuthError::Database)?;
123 return Err(AuthError::NotFound);
124 }
125
126 tx.commit().await.map_err(AuthError::Database)?;
127 Ok(())
128 }
129
130 pub async fn rotate_signing_key(
134 &self,
135 encryption_key: &[u8; 32],
136 ) -> Result<SigningKey, AuthError> {
137 let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
138 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
139
140 let private_pem = private_key
141 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
142 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
143 let pem_bytes = private_pem.as_bytes();
144
145 let mut nonce_bytes = [0u8; 12];
146 OsRng.fill_bytes(&mut nonce_bytes);
147 let nonce = Nonce::from_slice(&nonce_bytes);
148
149 let cipher = Aes256Gcm::new(encryption_key.into());
150 let private_key_enc = cipher
151 .encrypt(nonce, pem_bytes)
152 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
153
154 let public_key_pem = RsaPublicKey::from(&private_key)
155 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
156 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
157
158 let id = SigningKeyId::new();
159
160 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
161
162 sqlx::query(
163 "INSERT INTO allowthem_signing_keys \
164 (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
165 VALUES (?, ?, ?, ?, 'RS256', 0)",
166 )
167 .bind(id)
168 .bind(&private_key_enc)
169 .bind(nonce_bytes.as_slice())
170 .bind(&public_key_pem)
171 .execute(&mut *tx)
172 .await
173 .map_err(AuthError::Database)?;
174
175 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
176 .execute(&mut *tx)
177 .await
178 .map_err(AuthError::Database)?;
179
180 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
181 .bind(id)
182 .execute(&mut *tx)
183 .await
184 .map_err(AuthError::Database)?;
185
186 tx.commit().await.map_err(AuthError::Database)?;
187
188 let key = SigningKey {
189 id,
190 private_key_enc,
191 private_key_nonce: nonce_bytes.to_vec(),
192 public_key_pem,
193 algorithm: "RS256".to_string(),
194 is_active: true,
195 created_at: Utc::now(),
196 };
197
198 Ok(key)
199 }
200
201 pub async fn get_active_signing_key(&self) -> Result<SigningKey, AuthError> {
205 sqlx::query_as(
206 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
207 algorithm, is_active, created_at \
208 FROM allowthem_signing_keys WHERE is_active = 1 LIMIT 1",
209 )
210 .fetch_optional(self.pool())
211 .await?
212 .ok_or(AuthError::NotFound)
213 }
214
215 pub async fn get_all_signing_keys(&self) -> Result<Vec<SigningKey>, AuthError> {
219 Ok(sqlx::query_as(
220 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
221 algorithm, is_active, created_at \
222 FROM allowthem_signing_keys ORDER BY created_at DESC",
223 )
224 .fetch_all(self.pool())
225 .await?)
226 }
227
228 pub async fn get_signing_key(&self, id: SigningKeyId) -> Result<SigningKey, AuthError> {
232 sqlx::query_as(
233 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
234 algorithm, is_active, created_at \
235 FROM allowthem_signing_keys WHERE id = ?",
236 )
237 .bind(id)
238 .fetch_optional(self.pool())
239 .await?
240 .ok_or(AuthError::NotFound)
241 }
242}
243
244#[derive(Debug, Serialize)]
246pub struct JwkEntry {
247 pub kty: &'static str,
248 #[serde(rename = "use")]
249 pub use_: &'static str,
250 pub alg: &'static str,
251 pub kid: String,
252 pub n: String,
253 pub e: String,
254}
255
256#[derive(Debug, Serialize)]
258pub struct JwkSet {
259 pub keys: Vec<JwkEntry>,
260}
261
262pub fn build_jwks(keys: &[SigningKey]) -> Result<JwkSet, AuthError> {
267 let mut entries = Vec::with_capacity(keys.len());
268 for key in keys {
269 let pub_key = RsaPublicKey::from_public_key_pem(&key.public_key_pem)
270 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
271 let n = Base64UrlUnpadded::encode_string(&pub_key.n().to_bytes_be());
272 let e = Base64UrlUnpadded::encode_string(&pub_key.e().to_bytes_be());
273 entries.push(JwkEntry {
274 kty: "RSA",
275 use_: "sig",
276 alg: "RS256",
277 kid: key.id.to_string(),
278 n,
279 e,
280 });
281 }
282 Ok(JwkSet { keys: entries })
283}
284
285#[derive(Debug, Serialize)]
289pub struct OidcDiscovery {
290 pub issuer: String,
291 pub authorization_endpoint: String,
292 pub token_endpoint: String,
293 pub userinfo_endpoint: String,
294 pub jwks_uri: String,
295 pub scopes_supported: Vec<&'static str>,
296 pub response_types_supported: Vec<&'static str>,
297 pub grant_types_supported: Vec<&'static str>,
298 pub subject_types_supported: Vec<&'static str>,
299 pub id_token_signing_alg_values_supported: Vec<&'static str>,
300 pub token_endpoint_auth_methods_supported: Vec<&'static str>,
301 pub code_challenge_methods_supported: Vec<&'static str>,
302}
303
304pub fn build_discovery(issuer: &str) -> OidcDiscovery {
306 OidcDiscovery {
307 authorization_endpoint: format!("{issuer}/oauth/authorize"),
308 token_endpoint: format!("{issuer}/oauth/token"),
309 userinfo_endpoint: format!("{issuer}/oauth/userinfo"),
310 jwks_uri: format!("{issuer}/.well-known/jwks.json"),
311 issuer: issuer.to_string(),
312 scopes_supported: vec!["openid", "profile", "email"],
313 response_types_supported: vec!["code"],
314 grant_types_supported: vec!["authorization_code", "refresh_token"],
315 subject_types_supported: vec!["public"],
316 id_token_signing_alg_values_supported: vec!["RS256"],
317 token_endpoint_auth_methods_supported: vec!["client_secret_post", "client_secret_basic"],
318 code_challenge_methods_supported: vec!["S256"],
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::db::Db;
326 use rsa::pkcs8::DecodePrivateKey;
327 use sqlx::SqlitePool;
328 use sqlx::sqlite::SqliteConnectOptions;
329 use std::str::FromStr;
330
331 const ENC_KEY_A: [u8; 32] = [0x42; 32];
332 const ENC_KEY_B: [u8; 32] = [0x99; 32];
333
334 async fn test_db() -> Db {
335 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
336 .unwrap()
337 .pragma("foreign_keys", "ON");
338 let pool = SqlitePool::connect_with(opts).await.unwrap();
339 Db::new(pool).await.unwrap()
340 }
341
342 #[test]
343 fn decrypt_round_trip() {
344 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
345 let pem = private_key
346 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
347 .unwrap();
348 let pem_bytes = pem.as_bytes();
349
350 let mut nonce_bytes = [0u8; 12];
351 OsRng.fill_bytes(&mut nonce_bytes);
352 let nonce = Nonce::from_slice(&nonce_bytes);
353 let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
354 let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
355
356 let public_key_pem = RsaPublicKey::from(&private_key)
357 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
358 .unwrap();
359
360 let key = SigningKey {
361 id: SigningKeyId::new(),
362 private_key_enc: ciphertext,
363 private_key_nonce: nonce_bytes.to_vec(),
364 public_key_pem,
365 algorithm: "RS256".to_string(),
366 is_active: false,
367 created_at: Utc::now(),
368 };
369
370 let decrypted = decrypt_private_key(&key, &ENC_KEY_A).unwrap();
371 assert_eq!(decrypted.as_bytes(), pem_bytes);
372 }
373
374 #[test]
375 fn decrypt_wrong_key_fails() {
376 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
377 let pem = private_key
378 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
379 .unwrap();
380 let pem_bytes = pem.as_bytes();
381
382 let mut nonce_bytes = [0u8; 12];
383 OsRng.fill_bytes(&mut nonce_bytes);
384 let nonce = Nonce::from_slice(&nonce_bytes);
385 let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
386 let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
387
388 let public_key_pem = RsaPublicKey::from(&private_key)
389 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
390 .unwrap();
391
392 let key = SigningKey {
393 id: SigningKeyId::new(),
394 private_key_enc: ciphertext,
395 private_key_nonce: nonce_bytes.to_vec(),
396 public_key_pem,
397 algorithm: "RS256".to_string(),
398 is_active: false,
399 created_at: Utc::now(),
400 };
401
402 let result = decrypt_private_key(&key, &ENC_KEY_B);
403 assert!(result.is_err(), "decryption with wrong key must fail");
404 }
405
406 #[tokio::test]
407 async fn create_signing_key_stores_row() {
408 let db = test_db().await;
409 let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
410 assert!(!key.is_active, "new key must not be active");
411
412 let fetched = db.get_signing_key(key.id).await.unwrap();
413 assert_eq!(fetched.id, key.id);
414 assert_eq!(fetched.algorithm, "RS256");
415 assert!(!fetched.is_active);
416 }
417
418 #[tokio::test]
419 async fn activate_signing_key_marks_active() {
420 let db = test_db().await;
421 let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
422 let key2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
423
424 db.activate_signing_key(key2.id).await.unwrap();
425
426 let fetched1 = db.get_signing_key(key1.id).await.unwrap();
427 let fetched2 = db.get_signing_key(key2.id).await.unwrap();
428 assert!(!fetched1.is_active, "first key must be inactive");
429 assert!(fetched2.is_active, "second key must be active");
430 }
431
432 #[tokio::test]
433 async fn activate_nonexistent_returns_not_found() {
434 let db = test_db().await;
435 let fake_id = SigningKeyId::new();
436 let result = db.activate_signing_key(fake_id).await;
437 assert!(matches!(result, Err(AuthError::NotFound)));
438 }
439
440 #[tokio::test]
441 async fn rotate_signing_key_single_active() {
442 let db = test_db().await;
443 let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
444 db.activate_signing_key(key1.id).await.unwrap();
445
446 let new_key = db.rotate_signing_key(&ENC_KEY_A).await.unwrap();
447
448 let active = db.get_active_signing_key().await.unwrap();
449 assert_eq!(active.id, new_key.id, "new key must be the active one");
450
451 let old = db.get_signing_key(key1.id).await.unwrap();
452 assert!(!old.is_active, "old key must be inactive after rotation");
453 }
454
455 #[tokio::test]
456 async fn get_all_signing_keys_returns_all() {
457 let db = test_db().await;
458 let k1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
459 let k2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
460
461 let all = db.get_all_signing_keys().await.unwrap();
462 assert_eq!(all.len(), 2);
463 let ids: Vec<_> = all.iter().map(|k| k.id).collect();
464 assert!(ids.contains(&k1.id));
465 assert!(ids.contains(&k2.id));
466 }
467
468 #[test]
469 fn build_jwks_output_format() {
470 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
471 let public_key_pem = RsaPublicKey::from(&private_key)
472 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
473 .unwrap();
474
475 let id = SigningKeyId::new();
476 let key = SigningKey {
477 id,
478 private_key_enc: vec![],
479 private_key_nonce: vec![],
480 public_key_pem,
481 algorithm: "RS256".to_string(),
482 is_active: true,
483 created_at: Utc::now(),
484 };
485
486 let jwks = build_jwks(&[key]).unwrap();
487 assert_eq!(jwks.keys.len(), 1);
488 let entry = &jwks.keys[0];
489 assert_eq!(entry.kty, "RSA");
490 assert_eq!(entry.alg, "RS256");
491 assert_eq!(entry.use_, "sig");
492 assert!(!entry.n.is_empty(), "modulus must be non-empty");
493 assert!(!entry.e.is_empty(), "exponent must be non-empty");
494 assert_eq!(entry.kid, id.to_string());
495 }
496
497 #[test]
498 fn build_jwks_empty() {
499 let jwks = build_jwks(&[]).unwrap();
500 assert!(jwks.keys.is_empty(), "empty input yields empty JWKS");
501 }
502
503 #[test]
504 fn build_jwks_use_field_serializes_correctly() {
505 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
508 let public_key_pem = RsaPublicKey::from(&private_key)
509 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
510 .unwrap();
511
512 let key = SigningKey {
513 id: SigningKeyId::new(),
514 private_key_enc: vec![],
515 private_key_nonce: vec![],
516 public_key_pem,
517 algorithm: "RS256".to_string(),
518 is_active: true,
519 created_at: Utc::now(),
520 };
521
522 let jwks = build_jwks(&[key]).unwrap();
523 let json = serde_json::to_string(&jwks).unwrap();
524 assert!(
525 json.contains(r#""use":"sig"#),
526 "JWKS JSON must contain \"use\":\"sig\", got: {json}"
527 );
528 assert!(
529 !json.contains("use_"),
530 "JWKS JSON must not contain Rust field name 'use_', got: {json}"
531 );
532 }
533
534 #[tokio::test]
535 async fn get_active_signing_key_no_active_returns_not_found() {
536 let db = test_db().await;
537 let result = db.get_active_signing_key().await;
539 assert!(
540 matches!(result, Err(AuthError::NotFound)),
541 "expected NotFound, got: {result:?}"
542 );
543
544 db.create_signing_key(&ENC_KEY_A).await.unwrap();
546 let result = db.get_active_signing_key().await;
547 assert!(
548 matches!(result, Err(AuthError::NotFound)),
549 "inactive key must not be returned as active"
550 );
551 }
552
553 #[tokio::test]
554 async fn create_and_decrypt_round_trip_through_db() {
555 let db = test_db().await;
558 let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
559
560 let fetched = db.get_signing_key(key.id).await.unwrap();
561 let decrypted_pem = decrypt_private_key(&fetched, &ENC_KEY_A).unwrap();
562
563 let reparsed = RsaPrivateKey::from_pkcs8_pem(&decrypted_pem);
565 assert!(
566 reparsed.is_ok(),
567 "decrypted PEM from DB must parse as RsaPrivateKey"
568 );
569
570 let derived_pub = RsaPublicKey::from(&reparsed.unwrap())
572 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
573 .unwrap();
574 assert_eq!(
575 derived_pub, fetched.public_key_pem,
576 "public key derived from decrypted private key must match stored public key"
577 );
578 }
579
580 #[test]
581 fn build_discovery_fields() {
582 let issuer = "https://auth.example.com";
583 let doc = build_discovery(issuer);
584
585 assert_eq!(doc.issuer, issuer);
586 assert_eq!(
587 doc.authorization_endpoint,
588 "https://auth.example.com/oauth/authorize"
589 );
590 assert_eq!(doc.token_endpoint, "https://auth.example.com/oauth/token");
591 assert_eq!(
592 doc.userinfo_endpoint,
593 "https://auth.example.com/oauth/userinfo"
594 );
595 assert_eq!(
596 doc.jwks_uri,
597 "https://auth.example.com/.well-known/jwks.json"
598 );
599 assert!(!doc.scopes_supported.is_empty());
600 assert!(!doc.response_types_supported.is_empty());
601 assert!(!doc.grant_types_supported.is_empty());
602 assert!(!doc.subject_types_supported.is_empty());
603 assert!(!doc.id_token_signing_alg_values_supported.is_empty());
604 assert!(!doc.token_endpoint_auth_methods_supported.is_empty());
605 assert!(!doc.code_challenge_methods_supported.is_empty());
606 }
607}