1use aes_gcm::aead::Aead;
4use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
5use base64ct::{Base64UrlUnpadded, Encoding};
6use chrono::{DateTime, Utc};
7use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
8use rsa::rand_core::{OsRng, RngCore};
9use rsa::traits::PublicKeyParts;
10use rsa::{RsaPrivateKey, RsaPublicKey};
11use serde::Serialize;
12
13use crate::db::Db;
14use crate::error::AuthError;
15use crate::types::SigningKeyId;
16
17#[derive(Debug, Clone, sqlx::FromRow)]
22pub struct SigningKey {
23 pub id: SigningKeyId,
24 pub private_key_enc: Vec<u8>,
25 pub private_key_nonce: Vec<u8>,
26 pub public_key_pem: String,
27 pub algorithm: String,
28 pub is_active: bool,
29 pub created_at: DateTime<Utc>,
30}
31
32pub fn decrypt_private_key(
37 key: &SigningKey,
38 encryption_key: &[u8; 32],
39) -> Result<String, AuthError> {
40 let cipher = Aes256Gcm::new(encryption_key.into());
41 let nonce = Nonce::from_slice(&key.private_key_nonce);
42 let plaintext = cipher
43 .decrypt(nonce, key.private_key_enc.as_slice())
44 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
45 String::from_utf8(plaintext).map_err(|e| AuthError::SigningKey(e.to_string()))
46}
47
48impl Db {
49 pub async fn create_signing_key(
53 &self,
54 encryption_key: &[u8; 32],
55 ) -> Result<SigningKey, AuthError> {
56 let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
57 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
58
59 let private_pem = private_key
60 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
61 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
62 let pem_bytes = private_pem.as_bytes();
63
64 let mut nonce_bytes = [0u8; 12];
65 OsRng.fill_bytes(&mut nonce_bytes);
66 let nonce = Nonce::from_slice(&nonce_bytes);
67
68 let cipher = Aes256Gcm::new(encryption_key.into());
69 let private_key_enc = cipher
70 .encrypt(nonce, pem_bytes)
71 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
72
73 let public_key_pem = RsaPublicKey::from(&private_key)
74 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
75 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
76
77 let id = SigningKeyId::new();
78
79 sqlx::query(
80 "INSERT INTO allowthem_signing_keys \
81 (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
82 VALUES (?, ?, ?, ?, 'RS256', 0)",
83 )
84 .bind(id)
85 .bind(&private_key_enc)
86 .bind(nonce_bytes.as_slice())
87 .bind(&public_key_pem)
88 .execute(self.pool())
89 .await?;
90
91 let key = SigningKey {
92 id,
93 private_key_enc,
94 private_key_nonce: nonce_bytes.to_vec(),
95 public_key_pem,
96 algorithm: "RS256".to_string(),
97 is_active: false,
98 created_at: Utc::now(),
99 };
100
101 Ok(key)
102 }
103
104 pub async fn activate_signing_key(&self, key_id: SigningKeyId) -> Result<(), AuthError> {
108 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
109
110 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
111 .execute(&mut *tx)
112 .await
113 .map_err(AuthError::Database)?;
114
115 let result = sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
116 .bind(key_id)
117 .execute(&mut *tx)
118 .await
119 .map_err(AuthError::Database)?;
120
121 if result.rows_affected() == 0 {
122 tx.rollback().await.map_err(AuthError::Database)?;
123 return Err(AuthError::NotFound);
124 }
125
126 tx.commit().await.map_err(AuthError::Database)?;
127 Ok(())
128 }
129
130 pub async fn rotate_signing_key(
134 &self,
135 encryption_key: &[u8; 32],
136 ) -> Result<SigningKey, AuthError> {
137 let private_key = RsaPrivateKey::new(&mut OsRng, 2048)
138 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
139
140 let private_pem = private_key
141 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
142 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
143 let pem_bytes = private_pem.as_bytes();
144
145 let mut nonce_bytes = [0u8; 12];
146 OsRng.fill_bytes(&mut nonce_bytes);
147 let nonce = Nonce::from_slice(&nonce_bytes);
148
149 let cipher = Aes256Gcm::new(encryption_key.into());
150 let private_key_enc = cipher
151 .encrypt(nonce, pem_bytes)
152 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
153
154 let public_key_pem = RsaPublicKey::from(&private_key)
155 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
156 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
157
158 let id = SigningKeyId::new();
159
160 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
161
162 sqlx::query(
163 "INSERT INTO allowthem_signing_keys \
164 (id, private_key_enc, private_key_nonce, public_key_pem, algorithm, is_active) \
165 VALUES (?, ?, ?, ?, 'RS256', 0)",
166 )
167 .bind(id)
168 .bind(&private_key_enc)
169 .bind(nonce_bytes.as_slice())
170 .bind(&public_key_pem)
171 .execute(&mut *tx)
172 .await
173 .map_err(AuthError::Database)?;
174
175 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 0 WHERE is_active = 1")
176 .execute(&mut *tx)
177 .await
178 .map_err(AuthError::Database)?;
179
180 sqlx::query("UPDATE allowthem_signing_keys SET is_active = 1 WHERE id = ?")
181 .bind(id)
182 .execute(&mut *tx)
183 .await
184 .map_err(AuthError::Database)?;
185
186 tx.commit().await.map_err(AuthError::Database)?;
187
188 let key = SigningKey {
189 id,
190 private_key_enc,
191 private_key_nonce: nonce_bytes.to_vec(),
192 public_key_pem,
193 algorithm: "RS256".to_string(),
194 is_active: true,
195 created_at: Utc::now(),
196 };
197
198 Ok(key)
199 }
200
201 pub async fn get_active_signing_key(&self) -> Result<SigningKey, AuthError> {
205 sqlx::query_as(
206 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
207 algorithm, is_active, created_at \
208 FROM allowthem_signing_keys WHERE is_active = 1 LIMIT 1",
209 )
210 .fetch_optional(self.pool())
211 .await?
212 .ok_or(AuthError::NotFound)
213 }
214
215 pub async fn get_all_signing_keys(&self) -> Result<Vec<SigningKey>, AuthError> {
219 Ok(sqlx::query_as(
220 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
221 algorithm, is_active, created_at \
222 FROM allowthem_signing_keys ORDER BY created_at DESC",
223 )
224 .fetch_all(self.pool())
225 .await?)
226 }
227
228 pub async fn get_signing_key(&self, id: SigningKeyId) -> Result<SigningKey, AuthError> {
232 sqlx::query_as(
233 "SELECT id, private_key_enc, private_key_nonce, public_key_pem, \
234 algorithm, is_active, created_at \
235 FROM allowthem_signing_keys WHERE id = ?",
236 )
237 .bind(id)
238 .fetch_optional(self.pool())
239 .await?
240 .ok_or(AuthError::NotFound)
241 }
242}
243
244#[derive(Debug, Serialize)]
246pub struct JwkEntry {
247 pub kty: &'static str,
248 #[serde(rename = "use")]
249 pub use_: &'static str,
250 pub alg: &'static str,
251 pub kid: String,
252 pub n: String,
253 pub e: String,
254}
255
256#[derive(Debug, Serialize)]
258pub struct JwkSet {
259 pub keys: Vec<JwkEntry>,
260}
261
262pub fn build_jwks(keys: &[SigningKey]) -> Result<JwkSet, AuthError> {
267 let mut entries = Vec::with_capacity(keys.len());
268 for key in keys {
269 let pub_key = RsaPublicKey::from_public_key_pem(&key.public_key_pem)
270 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
271 let n = Base64UrlUnpadded::encode_string(&pub_key.n().to_bytes_be());
272 let e = Base64UrlUnpadded::encode_string(&pub_key.e().to_bytes_be());
273 entries.push(JwkEntry {
274 kty: "RSA",
275 use_: "sig",
276 alg: "RS256",
277 kid: key.id.to_string(),
278 n,
279 e,
280 });
281 }
282 Ok(JwkSet { keys: entries })
283}
284
285#[derive(Debug, Serialize)]
289pub struct OidcDiscovery {
290 pub issuer: String,
291 pub authorization_endpoint: String,
292 pub token_endpoint: String,
293 pub userinfo_endpoint: String,
294 pub jwks_uri: String,
295 pub scopes_supported: Vec<&'static str>,
296 pub response_types_supported: Vec<&'static str>,
297 pub grant_types_supported: Vec<&'static str>,
298 pub subject_types_supported: Vec<&'static str>,
299 pub id_token_signing_alg_values_supported: Vec<&'static str>,
300 pub token_endpoint_auth_methods_supported: Vec<&'static str>,
301 pub code_challenge_methods_supported: Vec<&'static str>,
302}
303
304pub fn build_discovery(issuer: &str) -> OidcDiscovery {
306 OidcDiscovery {
307 authorization_endpoint: format!("{issuer}/oauth/authorize"),
308 token_endpoint: format!("{issuer}/oauth/token"),
309 userinfo_endpoint: format!("{issuer}/oauth/userinfo"),
310 jwks_uri: format!("{issuer}/.well-known/jwks.json"),
311 issuer: issuer.to_string(),
312 scopes_supported: vec!["openid", "profile", "email", "offline_access"],
313 response_types_supported: vec!["code"],
314 grant_types_supported: vec!["authorization_code", "refresh_token"],
315 subject_types_supported: vec!["public"],
316 id_token_signing_alg_values_supported: vec!["RS256"],
317 token_endpoint_auth_methods_supported: vec![
318 "client_secret_post",
319 "client_secret_basic",
320 "none",
321 ],
322 code_challenge_methods_supported: vec!["S256"],
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::db::Db;
330 use rsa::pkcs8::DecodePrivateKey;
331 use sqlx::SqlitePool;
332 use sqlx::sqlite::SqliteConnectOptions;
333 use std::str::FromStr;
334
335 const ENC_KEY_A: [u8; 32] = [0x42; 32];
336 const ENC_KEY_B: [u8; 32] = [0x99; 32];
337
338 async fn test_db() -> Db {
339 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
340 .unwrap()
341 .pragma("foreign_keys", "ON");
342 let pool = SqlitePool::connect_with(opts).await.unwrap();
343 Db::new(pool).await.unwrap()
344 }
345
346 #[test]
347 fn decrypt_round_trip() {
348 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
349 let pem = private_key
350 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
351 .unwrap();
352 let pem_bytes = pem.as_bytes();
353
354 let mut nonce_bytes = [0u8; 12];
355 OsRng.fill_bytes(&mut nonce_bytes);
356 let nonce = Nonce::from_slice(&nonce_bytes);
357 let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
358 let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
359
360 let public_key_pem = RsaPublicKey::from(&private_key)
361 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
362 .unwrap();
363
364 let key = SigningKey {
365 id: SigningKeyId::new(),
366 private_key_enc: ciphertext,
367 private_key_nonce: nonce_bytes.to_vec(),
368 public_key_pem,
369 algorithm: "RS256".to_string(),
370 is_active: false,
371 created_at: Utc::now(),
372 };
373
374 let decrypted = decrypt_private_key(&key, &ENC_KEY_A).unwrap();
375 assert_eq!(decrypted.as_bytes(), pem_bytes);
376 }
377
378 #[test]
379 fn decrypt_wrong_key_fails() {
380 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
381 let pem = private_key
382 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
383 .unwrap();
384 let pem_bytes = pem.as_bytes();
385
386 let mut nonce_bytes = [0u8; 12];
387 OsRng.fill_bytes(&mut nonce_bytes);
388 let nonce = Nonce::from_slice(&nonce_bytes);
389 let cipher = Aes256Gcm::new((&ENC_KEY_A).into());
390 let ciphertext = cipher.encrypt(nonce, pem_bytes).unwrap();
391
392 let public_key_pem = RsaPublicKey::from(&private_key)
393 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
394 .unwrap();
395
396 let key = SigningKey {
397 id: SigningKeyId::new(),
398 private_key_enc: ciphertext,
399 private_key_nonce: nonce_bytes.to_vec(),
400 public_key_pem,
401 algorithm: "RS256".to_string(),
402 is_active: false,
403 created_at: Utc::now(),
404 };
405
406 let result = decrypt_private_key(&key, &ENC_KEY_B);
407 assert!(result.is_err(), "decryption with wrong key must fail");
408 }
409
410 #[tokio::test]
411 async fn create_signing_key_stores_row() {
412 let db = test_db().await;
413 let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
414 assert!(!key.is_active, "new key must not be active");
415
416 let fetched = db.get_signing_key(key.id).await.unwrap();
417 assert_eq!(fetched.id, key.id);
418 assert_eq!(fetched.algorithm, "RS256");
419 assert!(!fetched.is_active);
420 }
421
422 #[tokio::test]
423 async fn activate_signing_key_marks_active() {
424 let db = test_db().await;
425 let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
426 let key2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
427
428 db.activate_signing_key(key2.id).await.unwrap();
429
430 let fetched1 = db.get_signing_key(key1.id).await.unwrap();
431 let fetched2 = db.get_signing_key(key2.id).await.unwrap();
432 assert!(!fetched1.is_active, "first key must be inactive");
433 assert!(fetched2.is_active, "second key must be active");
434 }
435
436 #[tokio::test]
437 async fn activate_nonexistent_returns_not_found() {
438 let db = test_db().await;
439 let fake_id = SigningKeyId::new();
440 let result = db.activate_signing_key(fake_id).await;
441 assert!(matches!(result, Err(AuthError::NotFound)));
442 }
443
444 #[tokio::test]
445 async fn rotate_signing_key_single_active() {
446 let db = test_db().await;
447 let key1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
448 db.activate_signing_key(key1.id).await.unwrap();
449
450 let new_key = db.rotate_signing_key(&ENC_KEY_A).await.unwrap();
451
452 let active = db.get_active_signing_key().await.unwrap();
453 assert_eq!(active.id, new_key.id, "new key must be the active one");
454
455 let old = db.get_signing_key(key1.id).await.unwrap();
456 assert!(!old.is_active, "old key must be inactive after rotation");
457 }
458
459 #[tokio::test]
460 async fn get_all_signing_keys_returns_all() {
461 let db = test_db().await;
462 let k1 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
463 let k2 = db.create_signing_key(&ENC_KEY_A).await.unwrap();
464
465 let all = db.get_all_signing_keys().await.unwrap();
466 assert_eq!(all.len(), 2);
467 let ids: Vec<_> = all.iter().map(|k| k.id).collect();
468 assert!(ids.contains(&k1.id));
469 assert!(ids.contains(&k2.id));
470 }
471
472 #[test]
473 fn build_jwks_output_format() {
474 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
475 let public_key_pem = RsaPublicKey::from(&private_key)
476 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
477 .unwrap();
478
479 let id = SigningKeyId::new();
480 let key = SigningKey {
481 id,
482 private_key_enc: vec![],
483 private_key_nonce: vec![],
484 public_key_pem,
485 algorithm: "RS256".to_string(),
486 is_active: true,
487 created_at: Utc::now(),
488 };
489
490 let jwks = build_jwks(&[key]).unwrap();
491 assert_eq!(jwks.keys.len(), 1);
492 let entry = &jwks.keys[0];
493 assert_eq!(entry.kty, "RSA");
494 assert_eq!(entry.alg, "RS256");
495 assert_eq!(entry.use_, "sig");
496 assert!(!entry.n.is_empty(), "modulus must be non-empty");
497 assert!(!entry.e.is_empty(), "exponent must be non-empty");
498 assert_eq!(entry.kid, id.to_string());
499 }
500
501 #[test]
502 fn build_jwks_empty() {
503 let jwks = build_jwks(&[]).unwrap();
504 assert!(jwks.keys.is_empty(), "empty input yields empty JWKS");
505 }
506
507 #[test]
508 fn build_jwks_use_field_serializes_correctly() {
509 let private_key = RsaPrivateKey::new(&mut OsRng, 2048).unwrap();
512 let public_key_pem = RsaPublicKey::from(&private_key)
513 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
514 .unwrap();
515
516 let key = SigningKey {
517 id: SigningKeyId::new(),
518 private_key_enc: vec![],
519 private_key_nonce: vec![],
520 public_key_pem,
521 algorithm: "RS256".to_string(),
522 is_active: true,
523 created_at: Utc::now(),
524 };
525
526 let jwks = build_jwks(&[key]).unwrap();
527 let json = serde_json::to_string(&jwks).unwrap();
528 assert!(
529 json.contains(r#""use":"sig"#),
530 "JWKS JSON must contain \"use\":\"sig\", got: {json}"
531 );
532 assert!(
533 !json.contains("use_"),
534 "JWKS JSON must not contain Rust field name 'use_', got: {json}"
535 );
536 }
537
538 #[tokio::test]
539 async fn get_active_signing_key_no_active_returns_not_found() {
540 let db = test_db().await;
541 let result = db.get_active_signing_key().await;
543 assert!(
544 matches!(result, Err(AuthError::NotFound)),
545 "expected NotFound, got: {result:?}"
546 );
547
548 db.create_signing_key(&ENC_KEY_A).await.unwrap();
550 let result = db.get_active_signing_key().await;
551 assert!(
552 matches!(result, Err(AuthError::NotFound)),
553 "inactive key must not be returned as active"
554 );
555 }
556
557 #[tokio::test]
558 async fn create_and_decrypt_round_trip_through_db() {
559 let db = test_db().await;
562 let key = db.create_signing_key(&ENC_KEY_A).await.unwrap();
563
564 let fetched = db.get_signing_key(key.id).await.unwrap();
565 let decrypted_pem = decrypt_private_key(&fetched, &ENC_KEY_A).unwrap();
566
567 let reparsed = RsaPrivateKey::from_pkcs8_pem(&decrypted_pem);
569 assert!(
570 reparsed.is_ok(),
571 "decrypted PEM from DB must parse as RsaPrivateKey"
572 );
573
574 let derived_pub = RsaPublicKey::from(&reparsed.unwrap())
576 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
577 .unwrap();
578 assert_eq!(
579 derived_pub, fetched.public_key_pem,
580 "public key derived from decrypted private key must match stored public key"
581 );
582 }
583
584 #[test]
585 fn build_discovery_fields() {
586 let issuer = "https://auth.example.com";
587 let doc = build_discovery(issuer);
588
589 assert_eq!(doc.issuer, issuer);
590 assert_eq!(
591 doc.authorization_endpoint,
592 "https://auth.example.com/oauth/authorize"
593 );
594 assert_eq!(doc.token_endpoint, "https://auth.example.com/oauth/token");
595 assert_eq!(
596 doc.userinfo_endpoint,
597 "https://auth.example.com/oauth/userinfo"
598 );
599 assert_eq!(
600 doc.jwks_uri,
601 "https://auth.example.com/.well-known/jwks.json"
602 );
603 assert!(!doc.scopes_supported.is_empty());
604 assert!(doc.scopes_supported.contains(&"offline_access"));
605 assert!(!doc.response_types_supported.is_empty());
606 assert!(!doc.grant_types_supported.is_empty());
607 assert!(!doc.subject_types_supported.is_empty());
608 assert!(!doc.id_token_signing_alg_values_supported.is_empty());
609 assert!(doc.token_endpoint_auth_methods_supported.contains(&"none"));
610 assert!(!doc.token_endpoint_auth_methods_supported.is_empty());
611 assert!(!doc.code_challenge_methods_supported.is_empty());
612 }
613}