1use aes::Aes256;
26use aes::cipher::generic_array::GenericArray;
27use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
28use rand::RngCore;
29use rsa::traits::PublicKeyParts;
30use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
31use sha1::Sha1;
32
33use crate::error::{PsrpError, Result};
34
35pub struct ClientSessionKey {
37 private: RsaPrivateKey,
38}
39
40impl std::fmt::Debug for ClientSessionKey {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("ClientSessionKey")
43 .field("private", &"<redacted>")
44 .finish()
45 }
46}
47
48impl ClientSessionKey {
49 pub fn generate() -> Result<Self> {
51 let mut rng = rand::thread_rng();
52 let private = RsaPrivateKey::new(&mut rng, 2048)
53 .map_err(|e| PsrpError::protocol(format!("RSA keygen: {e}")))?;
54 Ok(Self { private })
55 }
56
57 #[must_use]
74 pub fn public_blob_hex(&self) -> String {
75 let public = RsaPublicKey::from(&self.private);
76 let mut blob = Vec::with_capacity(12 + 12 + 256);
77 blob.push(0x06);
79 blob.push(0x02);
80 blob.push(0x00);
81 blob.push(0x00);
82 blob.extend_from_slice(&0xa400u32.to_le_bytes());
83 blob.extend_from_slice(b"RSA1");
85 blob.extend_from_slice(&2048u32.to_le_bytes());
86 let e_bytes = public.e().to_bytes_le();
87 let mut exp = [0u8; 4];
88 for (i, b) in e_bytes.iter().take(4).enumerate() {
89 exp[i] = *b;
90 }
91 blob.extend_from_slice(&exp);
92 let mut modulus = public.n().to_bytes_le();
94 if modulus.len() > 256 {
95 modulus.truncate(256);
96 } else {
97 modulus.resize(256, 0);
98 }
99 blob.extend_from_slice(&modulus);
100
101 let mut hex = String::with_capacity(blob.len() * 2);
102 for b in &blob {
103 hex.push_str(&format!("{b:02X}"));
104 }
105 hex
106 }
107
108 pub fn decrypt_session_key(&self, ciphertext: &[u8]) -> Result<[u8; 32]> {
111 let padding = Oaep::new::<Sha1>();
112 let decrypted = self
113 .private
114 .decrypt(padding, ciphertext)
115 .map_err(|e| PsrpError::protocol(format!("session key unwrap: {e}")))?;
116 if decrypted.len() != 32 {
117 return Err(PsrpError::protocol(format!(
118 "session key: expected 32 bytes, got {}",
119 decrypted.len()
120 )));
121 }
122 let mut out = [0u8; 32];
123 out.copy_from_slice(&decrypted);
124 Ok(out)
125 }
126}
127
128#[derive(Debug, Clone)]
131pub struct SessionKey {
132 key: [u8; 32],
133}
134
135impl SessionKey {
136 #[must_use]
138 pub fn from_bytes(key: [u8; 32]) -> Self {
139 Self { key }
140 }
141
142 #[must_use]
144 pub fn random() -> Self {
145 let mut key = [0u8; 32];
146 rand::thread_rng().fill_bytes(&mut key);
147 Self { key }
148 }
149
150 pub fn encrypt_secure_string(&self, plaintext: &str) -> Vec<u8> {
156 let mut padded: Vec<u8> = plaintext
158 .encode_utf16()
159 .flat_map(u16::to_le_bytes)
160 .collect();
161 let pad = 16 - (padded.len() % 16);
162 padded.extend(std::iter::repeat_n(pad as u8, pad));
163
164 let mut iv = [0u8; 16];
165 rand::thread_rng().fill_bytes(&mut iv);
166 let cipher = Aes256::new(GenericArray::from_slice(&self.key));
167
168 let mut out = Vec::with_capacity(16 + padded.len());
171 out.extend_from_slice(&iv);
172 let mut prev: [u8; 16] = iv;
173 for chunk in padded.chunks_exact(16) {
174 let mut block = [0u8; 16];
175 for i in 0..16 {
176 block[i] = chunk[i] ^ prev[i];
177 }
178 let mut ga = GenericArray::clone_from_slice(&block);
179 cipher.encrypt_block(&mut ga);
180 prev.copy_from_slice(ga.as_slice());
181 out.extend_from_slice(&prev);
182 }
183 out
184 }
185
186 pub fn decrypt_secure_string(&self, payload: &[u8]) -> Result<String> {
188 if payload.len() < 32 || (payload.len() - 16) % 16 != 0 {
189 return Err(PsrpError::protocol("secure string payload malformed"));
190 }
191 let (iv, ct) = payload.split_at(16);
192 let cipher = Aes256::new(GenericArray::from_slice(&self.key));
193
194 let mut prev: [u8; 16] = iv.try_into().unwrap();
195 let mut pt = Vec::with_capacity(ct.len());
196 for chunk in ct.chunks_exact(16) {
197 let mut ga = GenericArray::clone_from_slice(chunk);
198 cipher.decrypt_block(&mut ga);
199 let mut block = [0u8; 16];
200 for i in 0..16 {
201 block[i] = ga[i] ^ prev[i];
202 }
203 pt.extend_from_slice(&block);
204 prev.copy_from_slice(chunk);
205 }
206
207 let pad = *pt
209 .last()
210 .ok_or_else(|| PsrpError::protocol("empty plaintext"))? as usize;
211 if pad == 0 || pad > 16 || pad > pt.len() {
212 return Err(PsrpError::protocol("invalid PKCS#7 padding"));
213 }
214 for &b in &pt[pt.len() - pad..] {
215 if b as usize != pad {
216 return Err(PsrpError::protocol("invalid PKCS#7 padding"));
217 }
218 }
219 pt.truncate(pt.len() - pad);
220
221 if pt.len() % 2 != 0 {
222 return Err(PsrpError::protocol(
223 "secure string plaintext not UTF-16 aligned",
224 ));
225 }
226 let units: Vec<u16> = pt
227 .chunks_exact(2)
228 .map(|c| u16::from_le_bytes([c[0], c[1]]))
229 .collect();
230 String::from_utf16(&units)
231 .map_err(|e| PsrpError::protocol(format!("secure string UTF-16: {e}")))
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn session_key_roundtrip_ascii() {
241 let key = SessionKey::random();
242 let ct = key.encrypt_secure_string("hello world");
243 assert!(ct.len() > 16);
244 let pt = key.decrypt_secure_string(&ct).unwrap();
245 assert_eq!(pt, "hello world");
246 }
247
248 #[test]
249 fn session_key_roundtrip_unicode() {
250 let key = SessionKey::random();
251 let ct = key.encrypt_secure_string("héllo 🌍");
252 let pt = key.decrypt_secure_string(&ct).unwrap();
253 assert_eq!(pt, "héllo 🌍");
254 }
255
256 #[test]
257 fn session_key_empty_string() {
258 let key = SessionKey::random();
259 let ct = key.encrypt_secure_string("");
260 let pt = key.decrypt_secure_string(&ct).unwrap();
261 assert_eq!(pt, "");
262 }
263
264 #[test]
265 fn decrypt_too_short() {
266 let key = SessionKey::random();
267 assert!(key.decrypt_secure_string(&[0u8; 4]).is_err());
268 }
269
270 #[test]
271 fn wrong_key_fails_decrypt() {
272 let k1 = SessionKey::random();
273 let k2 = SessionKey::random();
274 let ct = k1.encrypt_secure_string("x");
275 assert!(k2.decrypt_secure_string(&ct).is_err());
276 }
277
278 #[test]
279 fn session_key_from_bytes() {
280 let key = SessionKey::from_bytes([0u8; 32]);
281 let ct = key.encrypt_secure_string("abc");
282 let pt = SessionKey::from_bytes([0u8; 32])
283 .decrypt_secure_string(&ct)
284 .unwrap();
285 assert_eq!(pt, "abc");
286 }
287
288 #[test]
289 fn client_session_key_generates_blob() {
290 let k = ClientSessionKey::generate().unwrap();
292 let blob = k.public_blob_hex();
293 assert!(blob.len() >= 48);
295 assert!(blob.starts_with("06020000"));
296 }
297
298 #[test]
299 fn decrypt_misaligned_payload() {
300 let key = SessionKey::random();
301 let bad = vec![0u8; 33];
303 assert!(key.decrypt_secure_string(&bad).is_err());
304 }
305
306 #[test]
307 fn decrypt_bad_pkcs7_padding() {
308 let key = SessionKey::random();
309 let ct = key.encrypt_secure_string("x");
311 let mut tampered = ct.clone();
312 let len = tampered.len();
313 tampered[len - 1] ^= 0xFF; assert!(key.decrypt_secure_string(&tampered).is_err());
315 }
316
317 #[test]
318 fn full_rsa_aes_roundtrip() {
319 let client = ClientSessionKey::generate().unwrap();
323 let aes = {
324 let mut k = [0u8; 32];
325 rand::thread_rng().fill_bytes(&mut k);
326 k
327 };
328 let public = RsaPublicKey::from(&client.private);
330 let padding = Oaep::new::<Sha1>();
331 let wrapped = public
332 .encrypt(&mut rand::thread_rng(), padding, &aes)
333 .unwrap();
334 let unwrapped = client.decrypt_session_key(&wrapped).unwrap();
336 assert_eq!(unwrapped, aes);
337 let sk = SessionKey::from_bytes(unwrapped);
338 let ct = sk.encrypt_secure_string("s3cret");
339 let pt = sk.decrypt_secure_string(&ct).unwrap();
340 assert_eq!(pt, "s3cret");
341 }
342
343 #[test]
344 fn client_session_key_debug_redacts_private() {
345 let key = ClientSessionKey::generate().unwrap();
346 let dbg = format!("{key:?}");
347 assert!(dbg.contains("<redacted>"));
348 assert!(!dbg.contains("BEGIN"));
349 }
350}