systemprompt_security/keys/
mod.rs1use std::fs;
10use std::path::Path;
11
12use base64::Engine;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use pkcs8::LineEnding;
15use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey, EncodePublicKey};
16use rsa::rand_core::OsRng;
17use rsa::{RsaPrivateKey, RsaPublicKey};
18use sha2::{Digest, Sha256};
19
20pub mod authority;
21pub mod jwks;
22pub mod jwks_client;
23
24pub use authority::{TokenAuthorityError, TokenAuthorityResult};
25pub use jwks::{Jwk, Jwks};
26pub use jwks_client::{JwksClient, JwksClientError};
27
28pub const DEFAULT_RSA_BITS: usize = 2048;
29
30#[derive(Debug, thiserror::Error)]
31pub enum KeyError {
32 #[error("RSA key generation failed: {0}")]
33 Generation(#[source] rsa::Error),
34 #[error("PKCS#8 encoding failed: {0}")]
35 Encode(#[source] pkcs8::Error),
36 #[error("SPKI encoding failed: {0}")]
37 EncodeSpki(#[source] pkcs8::spki::Error),
38 #[error("PKCS#8 decoding failed: {0}")]
39 Decode(#[source] pkcs8::Error),
40 #[error("I/O error for {path}: {source}")]
41 Io {
42 path: String,
43 #[source]
44 source: std::io::Error,
45 },
46}
47
48#[derive(Clone)]
49pub struct RsaSigningKey {
50 private_key: RsaPrivateKey,
51 public_key: RsaPublicKey,
52 kid: String,
53}
54
55impl std::fmt::Debug for RsaSigningKey {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("RsaSigningKey")
58 .field("kid", &self.kid)
59 .finish_non_exhaustive()
60 }
61}
62
63impl RsaSigningKey {
64 pub fn generate() -> Result<Self, KeyError> {
65 Self::generate_bits(DEFAULT_RSA_BITS)
66 }
67
68 pub fn generate_bits(bits: usize) -> Result<Self, KeyError> {
69 let mut rng = OsRng;
70 let private_key = RsaPrivateKey::new(&mut rng, bits).map_err(KeyError::Generation)?;
71 Self::from_private(private_key)
72 }
73
74 pub fn from_pkcs8_pem(pem: &str) -> Result<Self, KeyError> {
75 let private_key = RsaPrivateKey::from_pkcs8_pem(pem).map_err(KeyError::Decode)?;
76 Self::from_private(private_key)
77 }
78
79 pub fn load_from_pem_file(path: &Path) -> Result<Self, KeyError> {
80 let pem = fs::read_to_string(path).map_err(|source| KeyError::Io {
81 path: path.display().to_string(),
82 source,
83 })?;
84 Self::from_pkcs8_pem(&pem)
85 }
86
87 pub fn to_pkcs8_pem(&self) -> Result<String, KeyError> {
88 self.private_key
89 .to_pkcs8_pem(LineEnding::LF)
90 .map(|s| s.to_string())
91 .map_err(KeyError::Encode)
92 }
93
94 pub fn write_pem_file(&self, path: &Path) -> Result<(), KeyError> {
95 let pem = self.to_pkcs8_pem()?;
96 fs::write(path, pem).map_err(|source| KeyError::Io {
97 path: path.display().to_string(),
98 source,
99 })
100 }
101
102 pub const fn public_key(&self) -> &RsaPublicKey {
103 &self.public_key
104 }
105
106 pub const fn private_key(&self) -> &RsaPrivateKey {
107 &self.private_key
108 }
109
110 pub fn kid(&self) -> &str {
111 &self.kid
112 }
113
114 pub fn jwk(&self) -> Jwk {
115 Jwk::from_rsa_public_key(&self.public_key, self.kid.clone())
116 }
117
118 pub fn jwks(&self) -> Jwks {
119 Jwks {
120 keys: vec![self.jwk()],
121 }
122 }
123
124 fn from_private(private_key: RsaPrivateKey) -> Result<Self, KeyError> {
125 let public_key = RsaPublicKey::from(&private_key);
126 let kid = compute_kid(&public_key)?;
127 Ok(Self {
128 private_key,
129 public_key,
130 kid,
131 })
132 }
133}
134
135pub fn compute_kid(public_key: &RsaPublicKey) -> Result<String, KeyError> {
136 let der = public_key
137 .to_public_key_der()
138 .map_err(KeyError::EncodeSpki)?;
139 let digest = Sha256::digest(der.as_bytes());
140 Ok(URL_SAFE_NO_PAD.encode(digest))
141}