Skip to main content

systemprompt_security/keys/
mod.rs

1//! RSA signing-key infrastructure for systemprompt.io's federated JWT plane.
2//!
3//! Provides an [`RsaSigningKey`] wrapper around an `rsa::RsaPrivateKey` that
4//! can be generated, loaded from PKCS#8 PEM, persisted to PEM, and exposes a
5//! deterministic `kid` (SHA-256 of the DER-encoded `SubjectPublicKeyInfo`,
6//! base64 URL-encoded, no padding). The accompanying [`jwks`] module turns the
7//! public half into a JWKS document.
8
9use std::fs;
10use std::path::Path;
11
12use base64::Engine;
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use pkcs8::LineEnding;
15use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey, EncodePublicKey};
16use rsa::rand_core::OsRng;
17use rsa::{RsaPrivateKey, RsaPublicKey};
18use sha2::{Digest, Sha256};
19
20pub mod authority;
21pub mod jwks;
22pub mod jwks_client;
23
24pub use authority::{TokenAuthorityError, TokenAuthorityResult};
25pub use jwks::{Jwk, Jwks};
26pub use jwks_client::{JwksClient, JwksClientError};
27
28pub const DEFAULT_RSA_BITS: usize = 2048;
29
30#[derive(Debug, thiserror::Error)]
31pub enum KeyError {
32    #[error("RSA key generation failed: {0}")]
33    Generation(#[source] rsa::Error),
34    #[error("PKCS#8 encoding failed: {0}")]
35    Encode(#[source] pkcs8::Error),
36    #[error("SPKI encoding failed: {0}")]
37    EncodeSpki(#[source] pkcs8::spki::Error),
38    #[error("PKCS#8 decoding failed: {0}")]
39    Decode(#[source] pkcs8::Error),
40    #[error("I/O error for {path}: {source}")]
41    Io {
42        path: String,
43        #[source]
44        source: std::io::Error,
45    },
46}
47
48#[derive(Clone)]
49pub struct RsaSigningKey {
50    private_key: RsaPrivateKey,
51    public_key: RsaPublicKey,
52    kid: String,
53}
54
55impl std::fmt::Debug for RsaSigningKey {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("RsaSigningKey")
58            .field("kid", &self.kid)
59            .finish_non_exhaustive()
60    }
61}
62
63impl RsaSigningKey {
64    pub fn generate() -> Result<Self, KeyError> {
65        Self::generate_bits(DEFAULT_RSA_BITS)
66    }
67
68    pub fn generate_bits(bits: usize) -> Result<Self, KeyError> {
69        let mut rng = OsRng;
70        let private_key = RsaPrivateKey::new(&mut rng, bits).map_err(KeyError::Generation)?;
71        Self::from_private(private_key)
72    }
73
74    pub fn from_pkcs8_pem(pem: &str) -> Result<Self, KeyError> {
75        let private_key = RsaPrivateKey::from_pkcs8_pem(pem).map_err(KeyError::Decode)?;
76        Self::from_private(private_key)
77    }
78
79    pub fn load_from_pem_file(path: &Path) -> Result<Self, KeyError> {
80        let pem = fs::read_to_string(path).map_err(|source| KeyError::Io {
81            path: path.display().to_string(),
82            source,
83        })?;
84        Self::from_pkcs8_pem(&pem)
85    }
86
87    pub fn to_pkcs8_pem(&self) -> Result<String, KeyError> {
88        self.private_key
89            .to_pkcs8_pem(LineEnding::LF)
90            .map(|s| s.to_string())
91            .map_err(KeyError::Encode)
92    }
93
94    pub fn write_pem_file(&self, path: &Path) -> Result<(), KeyError> {
95        let pem = self.to_pkcs8_pem()?;
96        fs::write(path, pem).map_err(|source| KeyError::Io {
97            path: path.display().to_string(),
98            source,
99        })
100    }
101
102    pub const fn public_key(&self) -> &RsaPublicKey {
103        &self.public_key
104    }
105
106    pub const fn private_key(&self) -> &RsaPrivateKey {
107        &self.private_key
108    }
109
110    pub fn kid(&self) -> &str {
111        &self.kid
112    }
113
114    pub fn jwk(&self) -> Jwk {
115        Jwk::from_rsa_public_key(&self.public_key, self.kid.clone())
116    }
117
118    pub fn jwks(&self) -> Jwks {
119        Jwks {
120            keys: vec![self.jwk()],
121        }
122    }
123
124    fn from_private(private_key: RsaPrivateKey) -> Result<Self, KeyError> {
125        let public_key = RsaPublicKey::from(&private_key);
126        let kid = compute_kid(&public_key)?;
127        Ok(Self {
128            private_key,
129            public_key,
130            kid,
131        })
132    }
133}
134
135pub fn compute_kid(public_key: &RsaPublicKey) -> Result<String, KeyError> {
136    let der = public_key
137        .to_public_key_der()
138        .map_err(KeyError::EncodeSpki)?;
139    let digest = Sha256::digest(der.as_bytes());
140    Ok(URL_SAFE_NO_PAD.encode(digest))
141}