Skip to main content

psrp_rs/
crypto.rs

1//! Session-key cryptography for MS-PSRP `SecureString` transport.
2//!
3//! PSRP §3.1.2: to transmit a `SecureString` between the client and
4//! server, the two sides negotiate a symmetric session key protected by
5//! RSA-OAEP(SHA-1).
6//!
7//! 1. The client generates a 2048-bit RSA keypair.
8//! 2. The client sends its public key to the server in a `PublicKey`
9//!    message (CLIXML body: a single `<S>` containing the key
10//!    serialized as the Windows `BLOBHEADER` + modulus + exponent).
11//! 3. The server generates a 256-bit AES session key, encrypts it with
12//!    the client's public key (RSA-OAEP/SHA-1), and sends the ciphertext
13//!    back in an `EncryptedSessionKey` message.
14//! 4. Both sides use that AES key in CBC mode with PKCS#7 padding to
15//!    wrap every `<SS>` (SecureString) element. The IV is a fresh
16//!    random 16-byte value prefixed to the ciphertext.
17//!
18//! This module implements the **pure-Rust** crypto (no OpenSSL) needed
19//! for that exchange plus a [`SessionKey`] helper that encrypts and
20//! decrypts individual `SecureString` values.
21//!
22//! The exchange itself is driven by the runspace pool — see
23//! [`crate::runspace::RunspacePool::request_session_key`].
24
25use aes::Aes256;
26use aes::cipher::generic_array::GenericArray;
27use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
28use rand::RngCore;
29use rsa::traits::PublicKeyParts;
30use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
31use sha1::Sha1;
32
33use crate::error::{PsrpError, Result};
34
35/// Client-side RSA key used for PSRP session-key negotiation.
36pub struct ClientSessionKey {
37    private: RsaPrivateKey,
38}
39
40impl std::fmt::Debug for ClientSessionKey {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("ClientSessionKey")
43            .field("private", &"<redacted>")
44            .finish()
45    }
46}
47
48impl ClientSessionKey {
49    /// Generate a fresh 2048-bit RSA keypair.
50    pub fn generate() -> Result<Self> {
51        let mut rng = rand::thread_rng();
52        let private = RsaPrivateKey::new(&mut rng, 2048)
53            .map_err(|e| PsrpError::protocol(format!("RSA keygen: {e}")))?;
54        Ok(Self { private })
55    }
56
57    /// Return the Windows `PUBLICKEYBLOB` representation of the public
58    /// key that PSRP expects to transport via the `PublicKey` message.
59    ///
60    /// Layout:
61    /// ```text
62    /// BLOBHEADER (12 bytes):
63    ///   bType = 0x06   (PUBLICKEYBLOB)
64    ///   bVersion = 0x02
65    ///   reserved = 0x0000
66    ///   aiKeyAlg = 0xa400 (CALG_RSA_KEYX)
67    /// RSAPUBKEY (12 bytes):
68    ///   magic = "RSA1"
69    ///   bitlen = 2048
70    ///   pubexp = u32 little-endian
71    /// modulus (256 bytes, little-endian)
72    /// ```
73    #[must_use]
74    pub fn public_blob_hex(&self) -> String {
75        let public = RsaPublicKey::from(&self.private);
76        let mut blob = Vec::with_capacity(12 + 12 + 256);
77        // BLOBHEADER
78        blob.push(0x06);
79        blob.push(0x02);
80        blob.push(0x00);
81        blob.push(0x00);
82        blob.extend_from_slice(&0xa400u32.to_le_bytes());
83        // RSAPUBKEY
84        blob.extend_from_slice(b"RSA1");
85        blob.extend_from_slice(&2048u32.to_le_bytes());
86        let e_bytes = public.e().to_bytes_le();
87        let mut exp = [0u8; 4];
88        for (i, b) in e_bytes.iter().take(4).enumerate() {
89            exp[i] = *b;
90        }
91        blob.extend_from_slice(&exp);
92        // Force modulus to exactly 256 bytes (little-endian).
93        let mut modulus = public.n().to_bytes_le();
94        if modulus.len() > 256 {
95            modulus.truncate(256);
96        } else {
97            modulus.resize(256, 0);
98        }
99        blob.extend_from_slice(&modulus);
100
101        let mut hex = String::with_capacity(blob.len() * 2);
102        for b in &blob {
103            hex.push_str(&format!("{b:02X}"));
104        }
105        hex
106    }
107
108    /// Decrypt an RSA-OAEP/SHA-1 wrapped session key and return the raw
109    /// 32-byte AES key.
110    pub fn decrypt_session_key(&self, ciphertext: &[u8]) -> Result<[u8; 32]> {
111        let padding = Oaep::new::<Sha1>();
112        let decrypted = self
113            .private
114            .decrypt(padding, ciphertext)
115            .map_err(|e| PsrpError::protocol(format!("session key unwrap: {e}")))?;
116        if decrypted.len() != 32 {
117            return Err(PsrpError::protocol(format!(
118                "session key: expected 32 bytes, got {}",
119                decrypted.len()
120            )));
121        }
122        let mut out = [0u8; 32];
123        out.copy_from_slice(&decrypted);
124        Ok(out)
125    }
126}
127
128/// A negotiated AES-256-CBC session key ready for `SecureString`
129/// encryption / decryption.
130#[derive(Debug, Clone)]
131pub struct SessionKey {
132    key: [u8; 32],
133}
134
135impl SessionKey {
136    /// Wrap a raw 256-bit key.
137    #[must_use]
138    pub fn from_bytes(key: [u8; 32]) -> Self {
139        Self { key }
140    }
141
142    /// Generate a fresh random 256-bit key (server-side test helper).
143    #[must_use]
144    pub fn random() -> Self {
145        let mut key = [0u8; 32];
146        rand::thread_rng().fill_bytes(&mut key);
147        Self { key }
148    }
149
150    /// Encrypt a plaintext string and return `IV || ciphertext`.
151    ///
152    /// PSRP transports `SecureString` values by UTF-16LE-encoding the
153    /// plaintext, AES-CBC encrypting with PKCS#7 padding, and prefixing
154    /// the 16-byte IV.
155    pub fn encrypt_secure_string(&self, plaintext: &str) -> Vec<u8> {
156        // UTF-16LE encode + PKCS#7 pad to 16 bytes.
157        let mut padded: Vec<u8> = plaintext
158            .encode_utf16()
159            .flat_map(u16::to_le_bytes)
160            .collect();
161        let pad = 16 - (padded.len() % 16);
162        padded.extend(std::iter::repeat_n(pad as u8, pad));
163
164        let mut iv = [0u8; 16];
165        rand::thread_rng().fill_bytes(&mut iv);
166        let cipher = Aes256::new(GenericArray::from_slice(&self.key));
167
168        // CBC: for each block, XOR with previous ciphertext (or IV for
169        // the first block) then encrypt.
170        let mut out = Vec::with_capacity(16 + padded.len());
171        out.extend_from_slice(&iv);
172        let mut prev: [u8; 16] = iv;
173        for chunk in padded.chunks_exact(16) {
174            let mut block = [0u8; 16];
175            for i in 0..16 {
176                block[i] = chunk[i] ^ prev[i];
177            }
178            let mut ga = GenericArray::clone_from_slice(&block);
179            cipher.encrypt_block(&mut ga);
180            prev.copy_from_slice(ga.as_slice());
181            out.extend_from_slice(&prev);
182        }
183        out
184    }
185
186    /// Decrypt `IV || ciphertext` back into the plaintext string.
187    pub fn decrypt_secure_string(&self, payload: &[u8]) -> Result<String> {
188        if payload.len() < 32 || (payload.len() - 16) % 16 != 0 {
189            return Err(PsrpError::protocol("secure string payload malformed"));
190        }
191        let (iv, ct) = payload.split_at(16);
192        let cipher = Aes256::new(GenericArray::from_slice(&self.key));
193
194        let mut prev: [u8; 16] = iv.try_into().unwrap();
195        let mut pt = Vec::with_capacity(ct.len());
196        for chunk in ct.chunks_exact(16) {
197            let mut ga = GenericArray::clone_from_slice(chunk);
198            cipher.decrypt_block(&mut ga);
199            let mut block = [0u8; 16];
200            for i in 0..16 {
201                block[i] = ga[i] ^ prev[i];
202            }
203            pt.extend_from_slice(&block);
204            prev.copy_from_slice(chunk);
205        }
206
207        // Strip PKCS#7 padding.
208        let pad = *pt
209            .last()
210            .ok_or_else(|| PsrpError::protocol("empty plaintext"))? as usize;
211        if pad == 0 || pad > 16 || pad > pt.len() {
212            return Err(PsrpError::protocol("invalid PKCS#7 padding"));
213        }
214        for &b in &pt[pt.len() - pad..] {
215            if b as usize != pad {
216                return Err(PsrpError::protocol("invalid PKCS#7 padding"));
217            }
218        }
219        pt.truncate(pt.len() - pad);
220
221        if pt.len() % 2 != 0 {
222            return Err(PsrpError::protocol(
223                "secure string plaintext not UTF-16 aligned",
224            ));
225        }
226        let units: Vec<u16> = pt
227            .chunks_exact(2)
228            .map(|c| u16::from_le_bytes([c[0], c[1]]))
229            .collect();
230        String::from_utf16(&units)
231            .map_err(|e| PsrpError::protocol(format!("secure string UTF-16: {e}")))
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn session_key_roundtrip_ascii() {
241        let key = SessionKey::random();
242        let ct = key.encrypt_secure_string("hello world");
243        assert!(ct.len() > 16);
244        let pt = key.decrypt_secure_string(&ct).unwrap();
245        assert_eq!(pt, "hello world");
246    }
247
248    #[test]
249    fn session_key_roundtrip_unicode() {
250        let key = SessionKey::random();
251        let ct = key.encrypt_secure_string("héllo 🌍");
252        let pt = key.decrypt_secure_string(&ct).unwrap();
253        assert_eq!(pt, "héllo 🌍");
254    }
255
256    #[test]
257    fn session_key_empty_string() {
258        let key = SessionKey::random();
259        let ct = key.encrypt_secure_string("");
260        let pt = key.decrypt_secure_string(&ct).unwrap();
261        assert_eq!(pt, "");
262    }
263
264    #[test]
265    fn decrypt_too_short() {
266        let key = SessionKey::random();
267        assert!(key.decrypt_secure_string(&[0u8; 4]).is_err());
268    }
269
270    #[test]
271    fn wrong_key_fails_decrypt() {
272        let k1 = SessionKey::random();
273        let k2 = SessionKey::random();
274        let ct = k1.encrypt_secure_string("x");
275        assert!(k2.decrypt_secure_string(&ct).is_err());
276    }
277
278    #[test]
279    fn session_key_from_bytes() {
280        let key = SessionKey::from_bytes([0u8; 32]);
281        let ct = key.encrypt_secure_string("abc");
282        let pt = SessionKey::from_bytes([0u8; 32])
283            .decrypt_secure_string(&ct)
284            .unwrap();
285        assert_eq!(pt, "abc");
286    }
287
288    #[test]
289    fn client_session_key_generates_blob() {
290        // RSA keygen is slow (~500 ms on modest hardware). Keep to one test.
291        let k = ClientSessionKey::generate().unwrap();
292        let blob = k.public_blob_hex();
293        // Header is always 24 bytes (48 hex chars), modulus is 256 bytes.
294        assert!(blob.len() >= 48);
295        assert!(blob.starts_with("06020000"));
296    }
297
298    #[test]
299    fn decrypt_misaligned_payload() {
300        let key = SessionKey::random();
301        // 16 (IV) + 17 (not a multiple of 16)
302        let bad = vec![0u8; 33];
303        assert!(key.decrypt_secure_string(&bad).is_err());
304    }
305
306    #[test]
307    fn decrypt_bad_pkcs7_padding() {
308        let key = SessionKey::random();
309        // Encrypt something valid, then tamper with the last byte (padding)
310        let ct = key.encrypt_secure_string("x");
311        let mut tampered = ct.clone();
312        let len = tampered.len();
313        tampered[len - 1] ^= 0xFF; // flip last byte
314        assert!(key.decrypt_secure_string(&tampered).is_err());
315    }
316
317    #[test]
318    fn full_rsa_aes_roundtrip() {
319        // Local simulation: client generates key, "server" uses its
320        // public key to RSA-OAEP encrypt a random AES key, client
321        // decrypts, both encrypt/decrypt a SecureString.
322        let client = ClientSessionKey::generate().unwrap();
323        let aes = {
324            let mut k = [0u8; 32];
325            rand::thread_rng().fill_bytes(&mut k);
326            k
327        };
328        // Encrypt with the client's public key.
329        let public = RsaPublicKey::from(&client.private);
330        let padding = Oaep::new::<Sha1>();
331        let wrapped = public
332            .encrypt(&mut rand::thread_rng(), padding, &aes)
333            .unwrap();
334        // Client decrypts.
335        let unwrapped = client.decrypt_session_key(&wrapped).unwrap();
336        assert_eq!(unwrapped, aes);
337        let sk = SessionKey::from_bytes(unwrapped);
338        let ct = sk.encrypt_secure_string("s3cret");
339        let pt = sk.decrypt_secure_string(&ct).unwrap();
340        assert_eq!(pt, "s3cret");
341    }
342
343    #[test]
344    fn client_session_key_debug_redacts_private() {
345        let key = ClientSessionKey::generate().unwrap();
346        let dbg = format!("{key:?}");
347        assert!(dbg.contains("<redacted>"));
348        assert!(!dbg.contains("BEGIN"));
349    }
350}