1use rsa::{
22 Oaep, Pkcs1v15Sign, RsaPrivateKey, pkcs1::DecodeRsaPrivateKey, pkcs8::DecodePrivateKey,
23 traits::PublicKeyParts,
24};
25use sha1::Sha1;
26use sha2::Sha256;
27
28use crate::cek_envelope;
29use crate::encryption::EncryptionError;
30
31pub struct RsaKeyUnwrapper {
33 private_key: RsaPrivateKey,
34}
35
36impl RsaKeyUnwrapper {
37 pub fn from_pem(pem: &str) -> Result<Self, EncryptionError> {
49 let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
51 .or_else(|_| RsaPrivateKey::from_pkcs1_pem(pem))
52 .map_err(|e| {
53 EncryptionError::CmkError(format!("Failed to parse RSA private key: {e}"))
54 })?;
55
56 Ok(Self { private_key })
57 }
58
59 pub fn from_der(der: &[u8]) -> Result<Self, EncryptionError> {
69 let private_key = RsaPrivateKey::from_pkcs8_der(der)
70 .or_else(|_| RsaPrivateKey::from_pkcs1_der(der))
71 .map_err(|e| {
72 EncryptionError::CmkError(format!("Failed to parse RSA private key: {e}"))
73 })?;
74
75 Ok(Self { private_key })
76 }
77
78 pub fn from_key(private_key: RsaPrivateKey) -> Self {
80 Self { private_key }
81 }
82
83 pub fn decrypt_cek(&self, encrypted_cek: &[u8]) -> Result<Vec<u8>, EncryptionError> {
104 let envelope = cek_envelope::parse(encrypted_cek)?;
105
106 let key_size = self.private_key.size();
107 if envelope.ciphertext.len() != key_size {
108 return Err(EncryptionError::CekDecryptionFailed(format!(
109 "CEK ciphertext length {} does not match RSA key size {key_size}",
110 envelope.ciphertext.len()
111 )));
112 }
113 if envelope.signature.len() != key_size {
114 return Err(EncryptionError::CekDecryptionFailed(format!(
115 "CEK signature length {} does not match RSA key size {key_size}",
116 envelope.signature.len()
117 )));
118 }
119
120 self.private_key
121 .to_public_key()
122 .verify(
123 Pkcs1v15Sign::new::<Sha256>(),
124 &envelope.signed_digest(),
125 envelope.signature,
126 )
127 .map_err(|_| {
128 EncryptionError::CekDecryptionFailed(
129 "CEK envelope signature verification failed".into(),
130 )
131 })?;
132
133 let padding = Oaep::new::<Sha1>();
135 let decrypted = self
136 .private_key
137 .decrypt(padding, envelope.ciphertext)
138 .map_err(|e| {
139 EncryptionError::CekDecryptionFailed(format!("RSA-OAEP decryption failed: {e}"))
140 })?;
141
142 Ok(decrypted)
143 }
144
145 pub fn decrypt_raw(&self, ciphertext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
149 let padding = Oaep::new::<Sha1>();
150 self.private_key.decrypt(padding, ciphertext).map_err(|e| {
151 EncryptionError::CekDecryptionFailed(format!("RSA-OAEP decryption failed: {e}"))
152 })
153 }
154
155 pub fn key_bits(&self) -> usize {
157 self.private_key.size() * 8
158 }
159}
160
161#[cfg(test)]
166#[allow(clippy::expect_used)]
167pub fn create_test_encrypted_cek(
168 cmk: &RsaPrivateKey,
169 key_path: &str,
170 ciphertext: &[u8],
171) -> Vec<u8> {
172 use sha2::Digest;
173
174 let mut envelope = cek_envelope::build_signed_portion(key_path, ciphertext);
175 let digest: [u8; 32] = Sha256::digest(&envelope).into();
176 let signature = cmk
177 .sign(Pkcs1v15Sign::new::<Sha256>(), &digest)
178 .expect("test CMK signs");
179 envelope.extend_from_slice(&signature);
180 envelope
181}
182
183#[cfg(test)]
184#[allow(clippy::unwrap_used, clippy::expect_used)]
185mod tests {
186 use super::*;
187 use rsa::{RsaPrivateKey, pkcs8::EncodePrivateKey};
188
189 fn generate_test_key() -> RsaPrivateKey {
190 let mut rng = rand::thread_rng();
191 RsaPrivateKey::new(&mut rng, 2048).unwrap()
192 }
193
194 #[test]
195 fn test_key_unwrapper_from_pem_pkcs8() {
196 let key = generate_test_key();
197 let pem = key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF).unwrap();
198
199 let unwrapper = RsaKeyUnwrapper::from_pem(&pem).unwrap();
200 assert_eq!(unwrapper.key_bits(), 2048);
201 }
202
203 #[test]
204 fn test_decrypt_raw() {
205 let key = generate_test_key();
206 let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
207
208 let test_cek = [0x42u8; 32]; let public_key = key.to_public_key();
211 let padding = Oaep::new::<Sha1>();
212 let mut rng = rand::thread_rng();
213 let ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
214
215 let decrypted = unwrapper.decrypt_raw(&ciphertext).unwrap();
217 assert_eq!(decrypted, test_cek);
218 }
219
220 #[test]
221 fn test_decrypt_cek_full_flow() {
222 let key = generate_test_key();
223 let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
224
225 let test_cek = [0x55u8; 32];
227
228 let public_key = key.to_public_key();
230 let padding = Oaep::new::<Sha1>();
231 let mut rng = rand::thread_rng();
232 let rsa_ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
233
234 let encrypted_cek =
236 create_test_encrypted_cek(&key, "CurrentUser/My/TestCert", &rsa_ciphertext);
237
238 let decrypted = unwrapper.decrypt_cek(&encrypted_cek).unwrap();
240 assert_eq!(decrypted, test_cek);
241 }
242
243 #[test]
244 fn test_decrypt_cek_rejects_tampered_envelope() {
245 let key = generate_test_key();
246 let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
247
248 let test_cek = [0x55u8; 32];
249 let public_key = key.to_public_key();
250 let padding = Oaep::new::<Sha1>();
251 let mut rng = rand::thread_rng();
252 let rsa_ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
253
254 let mut encrypted_cek = create_test_encrypted_cek(&key, "Test", &rsa_ciphertext);
255 encrypted_cek[20] ^= 0x01;
257
258 let err = unwrapper.decrypt_cek(&encrypted_cek).unwrap_err();
259 assert!(err.to_string().contains("signature verification failed"));
260 }
261
262 #[test]
263 fn test_decrypt_cek_rejects_wrong_signer() {
264 let key = generate_test_key();
265 let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
266
267 let test_cek = [0x55u8; 32];
268 let public_key = key.to_public_key();
269 let padding = Oaep::new::<Sha1>();
270 let mut rng = rand::thread_rng();
271 let rsa_ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
272
273 let other_key = generate_test_key();
275 let encrypted_cek = create_test_encrypted_cek(&other_key, "Test", &rsa_ciphertext);
276
277 let err = unwrapper.decrypt_cek(&encrypted_cek).unwrap_err();
278 assert!(err.to_string().contains("signature verification failed"));
279 }
280}