jwt_simple/algorithms/jwe/
rsa_oaep.rs1#[cfg(any(feature = "pure-rust", target_arch = "wasm32", target_arch = "wasm64"))]
9use superboring as boring;
10
11use boring::pkey::{Private, Public};
12use boring::rsa::{Padding, Rsa};
13use serde::{de::DeserializeOwned, Serialize};
14
15use crate::claims::*;
16use crate::error::*;
17use crate::jwe_header::JWEHeader;
18use crate::jwe_token::{DecryptionOptions, EncryptionOptions, JWEToken, JWETokenMetadata};
19
20const MIN_RSA_MODULUS_BITS: u32 = 2048;
21
22#[derive(Debug, Clone)]
24pub struct RsaOaepEncryptionKey {
25 pk: Rsa<Public>,
26 key_id: Option<String>,
27}
28
29impl RsaOaepEncryptionKey {
30 pub fn from_der(der: &[u8]) -> Result<Self, Error> {
32 let pk = Rsa::<Public>::public_key_from_der(der)
33 .or_else(|_| Rsa::<Public>::public_key_from_der_pkcs1(der))?;
34 Self::validate_key_size(&pk)?;
35 Ok(RsaOaepEncryptionKey { pk, key_id: None })
36 }
37
38 pub fn from_pem(pem: &str) -> Result<Self, Error> {
40 let pem = pem.trim();
41 let pk = Rsa::<Public>::public_key_from_pem(pem.as_bytes())
42 .or_else(|_| Rsa::<Public>::public_key_from_pem_pkcs1(pem.as_bytes()))?;
43 Self::validate_key_size(&pk)?;
44 Ok(RsaOaepEncryptionKey { pk, key_id: None })
45 }
46
47 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
49 self.pk.public_key_to_der().map_err(Into::into)
50 }
51
52 pub fn to_pem(&self) -> Result<String, Error> {
54 let bytes = self.pk.public_key_to_pem()?;
55 Ok(String::from_utf8(bytes)?)
56 }
57
58 pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
60 self.key_id = Some(key_id.into());
61 self
62 }
63
64 pub fn key_id(&self) -> Option<&str> {
66 self.key_id.as_deref()
67 }
68
69 fn validate_key_size(pk: &Rsa<Public>) -> Result<(), Error> {
70 let bits = pk.size() * 8;
71 ensure!(bits >= MIN_RSA_MODULUS_BITS, JWTError::WeakKey);
72 Ok(())
73 }
74
75 fn wrap_key(&self, cek: &[u8]) -> Result<Vec<u8>, Error> {
76 let mut encrypted = vec![0u8; self.pk.size() as usize];
77 let encrypted_len = self
78 .pk
79 .public_encrypt(cek, &mut encrypted, Padding::PKCS1_OAEP)
80 .map_err(|_| JWTError::InvalidEncryptionKey)?;
81 encrypted.truncate(encrypted_len);
82
83 Ok(encrypted)
84 }
85
86 pub fn encrypt<CustomClaims: Serialize>(
88 &self,
89 claims: JWTClaims<CustomClaims>,
90 ) -> Result<String, Error> {
91 self.encrypt_with_options(claims, &EncryptionOptions::default())
92 }
93
94 pub fn encrypt_with_options<CustomClaims: Serialize>(
96 &self,
97 claims: JWTClaims<CustomClaims>,
98 options: &EncryptionOptions,
99 ) -> Result<String, Error> {
100 let content_encryption = options.content_encryption;
101 let mut header = JWEHeader::new("RSA-OAEP", content_encryption.alg_name());
102
103 if let Some(key_id) = &self.key_id {
104 header.key_id = Some(key_id.clone());
105 }
106 if let Some(key_id) = &options.key_id {
107 header.key_id = Some(key_id.clone());
108 }
109 if let Some(cty) = &options.content_type {
110 header.content_type = Some(cty.clone());
111 }
112
113 JWEToken::build_from_claims(&header, &claims, content_encryption, |cek| {
114 self.wrap_key(cek)
115 })
116 }
117}
118
119#[derive(Clone)]
121pub struct RsaOaepDecryptionKey {
122 sk: Rsa<Private>,
123 key_id: Option<String>,
124}
125
126impl std::fmt::Debug for RsaOaepDecryptionKey {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("RsaOaepDecryptionKey")
129 .field("key_id", &self.key_id)
130 .field("modulus_bits", &(self.sk.size() * 8))
131 .finish_non_exhaustive()
132 }
133}
134
135impl RsaOaepDecryptionKey {
136 pub fn from_der(der: &[u8]) -> Result<Self, Error> {
138 let sk = Rsa::<Private>::private_key_from_der(der)?;
139 if !sk.check_key()? {
140 bail!(JWTError::InvalidKeyPair);
141 }
142 Self::validate_key_size(&sk)?;
143 Ok(RsaOaepDecryptionKey { sk, key_id: None })
144 }
145
146 pub fn from_pem(pem: &str) -> Result<Self, Error> {
148 let pem = pem.trim();
149 let sk = Rsa::<Private>::private_key_from_pem(pem.as_bytes())?;
150 if !sk.check_key()? {
151 bail!(JWTError::InvalidKeyPair);
152 }
153 Self::validate_key_size(&sk)?;
154 Ok(RsaOaepDecryptionKey { sk, key_id: None })
155 }
156
157 pub fn generate(modulus_bits: usize) -> Result<Self, Error> {
159 match modulus_bits {
160 2048 | 3072 | 4096 => {}
161 _ => bail!(JWTError::UnsupportedRSAModulus),
162 };
163 let sk = Rsa::<Private>::generate(modulus_bits as u32)?;
164 Ok(RsaOaepDecryptionKey { sk, key_id: None })
165 }
166
167 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
169 self.sk.private_key_to_der().map_err(Into::into)
170 }
171
172 pub fn to_pem(&self) -> Result<String, Error> {
174 let bytes = self.sk.private_key_to_pem()?;
175 Ok(String::from_utf8(bytes)?)
176 }
177
178 pub fn encryption_key(&self) -> RsaOaepEncryptionKey {
180 let pk = Rsa::<Public>::from_public_components(
181 self.sk.n().to_owned().expect("failed to get modulus"),
182 self.sk.e().to_owned().expect("failed to get exponent"),
183 )
184 .expect("failed to create public key");
185 RsaOaepEncryptionKey {
186 pk,
187 key_id: self.key_id.clone(),
188 }
189 }
190
191 pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
193 self.key_id = Some(key_id.into());
194 self
195 }
196
197 pub fn key_id(&self) -> Option<&str> {
199 self.key_id.as_deref()
200 }
201
202 fn validate_key_size(sk: &Rsa<Private>) -> Result<(), Error> {
203 let bits = sk.size() * 8;
204 ensure!(bits >= MIN_RSA_MODULUS_BITS, JWTError::WeakKey);
205 Ok(())
206 }
207
208 fn unwrap_key(&self, encrypted_key: &[u8]) -> Result<Vec<u8>, Error> {
209 let mut cek = vec![0u8; self.sk.size() as usize];
210 let cek_len = self
211 .sk
212 .private_decrypt(encrypted_key, &mut cek, Padding::PKCS1_OAEP)
213 .map_err(|_| JWTError::KeyUnwrapFailed)?;
214 cek.truncate(cek_len);
215
216 Ok(cek)
217 }
218
219 pub fn encrypt<CustomClaims: Serialize>(
221 &self,
222 claims: JWTClaims<CustomClaims>,
223 ) -> Result<String, Error> {
224 self.encryption_key().encrypt(claims)
225 }
226
227 pub fn encrypt_with_options<CustomClaims: Serialize>(
229 &self,
230 claims: JWTClaims<CustomClaims>,
231 options: &EncryptionOptions,
232 ) -> Result<String, Error> {
233 self.encryption_key().encrypt_with_options(claims, options)
234 }
235
236 pub fn decrypt_token<CustomClaims: DeserializeOwned>(
238 &self,
239 token: &str,
240 options: Option<DecryptionOptions>,
241 ) -> Result<JWTClaims<CustomClaims>, Error> {
242 JWEToken::decrypt("RSA-OAEP", token, options, |_header, encrypted_key| {
243 self.unwrap_key(encrypted_key)
244 })
245 }
246
247 pub fn decode_metadata(token: &str) -> Result<JWETokenMetadata, Error> {
249 JWEToken::decode_metadata(token)
250 }
251}