use aes::Aes256;
use aes::cipher::generic_array::GenericArray;
use aes::cipher::{BlockDecrypt, BlockEncrypt, KeyInit};
use rand::RngCore;
use rsa::traits::PublicKeyParts;
use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
use sha1::Sha1;
use crate::error::{PsrpError, Result};
pub struct ClientSessionKey {
private: RsaPrivateKey,
}
impl std::fmt::Debug for ClientSessionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientSessionKey")
.field("private", &"<redacted>")
.finish()
}
}
impl ClientSessionKey {
pub fn generate() -> Result<Self> {
let mut rng = rand::thread_rng();
let private = RsaPrivateKey::new(&mut rng, 2048)
.map_err(|e| PsrpError::protocol(format!("RSA keygen: {e}")))?;
Ok(Self { private })
}
#[must_use]
pub fn public_blob_hex(&self) -> String {
let public = RsaPublicKey::from(&self.private);
let mut blob = Vec::with_capacity(12 + 12 + 256);
blob.push(0x06);
blob.push(0x02);
blob.push(0x00);
blob.push(0x00);
blob.extend_from_slice(&0xa400u32.to_le_bytes());
blob.extend_from_slice(b"RSA1");
blob.extend_from_slice(&2048u32.to_le_bytes());
let e_bytes = public.e().to_bytes_le();
let mut exp = [0u8; 4];
for (i, b) in e_bytes.iter().take(4).enumerate() {
exp[i] = *b;
}
blob.extend_from_slice(&exp);
let mut modulus = public.n().to_bytes_le();
if modulus.len() > 256 {
modulus.truncate(256);
} else {
modulus.resize(256, 0);
}
blob.extend_from_slice(&modulus);
let mut hex = String::with_capacity(blob.len() * 2);
for b in &blob {
hex.push_str(&format!("{b:02X}"));
}
hex
}
pub fn decrypt_session_key(&self, ciphertext: &[u8]) -> Result<[u8; 32]> {
let padding = Oaep::new::<Sha1>();
let decrypted = self
.private
.decrypt(padding, ciphertext)
.map_err(|e| PsrpError::protocol(format!("session key unwrap: {e}")))?;
if decrypted.len() != 32 {
return Err(PsrpError::protocol(format!(
"session key: expected 32 bytes, got {}",
decrypted.len()
)));
}
let mut out = [0u8; 32];
out.copy_from_slice(&decrypted);
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct SessionKey {
key: [u8; 32],
}
impl SessionKey {
#[must_use]
pub fn from_bytes(key: [u8; 32]) -> Self {
Self { key }
}
#[must_use]
pub fn random() -> Self {
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
Self { key }
}
pub fn encrypt_secure_string(&self, plaintext: &str) -> Vec<u8> {
let mut padded: Vec<u8> = plaintext
.encode_utf16()
.flat_map(u16::to_le_bytes)
.collect();
let pad = 16 - (padded.len() % 16);
padded.extend(std::iter::repeat_n(pad as u8, pad));
let mut iv = [0u8; 16];
rand::thread_rng().fill_bytes(&mut iv);
let cipher = Aes256::new(GenericArray::from_slice(&self.key));
let mut out = Vec::with_capacity(16 + padded.len());
out.extend_from_slice(&iv);
let mut prev: [u8; 16] = iv;
for chunk in padded.chunks_exact(16) {
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = chunk[i] ^ prev[i];
}
let mut ga = GenericArray::clone_from_slice(&block);
cipher.encrypt_block(&mut ga);
prev.copy_from_slice(ga.as_slice());
out.extend_from_slice(&prev);
}
out
}
pub fn decrypt_secure_string(&self, payload: &[u8]) -> Result<String> {
if payload.len() < 32 || (payload.len() - 16) % 16 != 0 {
return Err(PsrpError::protocol("secure string payload malformed"));
}
let (iv, ct) = payload.split_at(16);
let cipher = Aes256::new(GenericArray::from_slice(&self.key));
let mut prev: [u8; 16] = iv.try_into().unwrap();
let mut pt = Vec::with_capacity(ct.len());
for chunk in ct.chunks_exact(16) {
let mut ga = GenericArray::clone_from_slice(chunk);
cipher.decrypt_block(&mut ga);
let mut block = [0u8; 16];
for i in 0..16 {
block[i] = ga[i] ^ prev[i];
}
pt.extend_from_slice(&block);
prev.copy_from_slice(chunk);
}
let pad = *pt
.last()
.ok_or_else(|| PsrpError::protocol("empty plaintext"))? as usize;
if pad == 0 || pad > 16 || pad > pt.len() {
return Err(PsrpError::protocol("invalid PKCS#7 padding"));
}
for &b in &pt[pt.len() - pad..] {
if b as usize != pad {
return Err(PsrpError::protocol("invalid PKCS#7 padding"));
}
}
pt.truncate(pt.len() - pad);
if pt.len() % 2 != 0 {
return Err(PsrpError::protocol(
"secure string plaintext not UTF-16 aligned",
));
}
let units: Vec<u16> = pt
.chunks_exact(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
String::from_utf16(&units)
.map_err(|e| PsrpError::protocol(format!("secure string UTF-16: {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_key_roundtrip_ascii() {
let key = SessionKey::random();
let ct = key.encrypt_secure_string("hello world");
assert!(ct.len() > 16);
let pt = key.decrypt_secure_string(&ct).unwrap();
assert_eq!(pt, "hello world");
}
#[test]
fn session_key_roundtrip_unicode() {
let key = SessionKey::random();
let ct = key.encrypt_secure_string("héllo 🌍");
let pt = key.decrypt_secure_string(&ct).unwrap();
assert_eq!(pt, "héllo 🌍");
}
#[test]
fn session_key_empty_string() {
let key = SessionKey::random();
let ct = key.encrypt_secure_string("");
let pt = key.decrypt_secure_string(&ct).unwrap();
assert_eq!(pt, "");
}
#[test]
fn decrypt_too_short() {
let key = SessionKey::random();
assert!(key.decrypt_secure_string(&[0u8; 4]).is_err());
}
#[test]
fn wrong_key_fails_decrypt() {
let k1 = SessionKey::random();
let k2 = SessionKey::random();
let ct = k1.encrypt_secure_string("x");
assert!(k2.decrypt_secure_string(&ct).is_err());
}
#[test]
fn session_key_from_bytes() {
let key = SessionKey::from_bytes([0u8; 32]);
let ct = key.encrypt_secure_string("abc");
let pt = SessionKey::from_bytes([0u8; 32])
.decrypt_secure_string(&ct)
.unwrap();
assert_eq!(pt, "abc");
}
#[test]
fn client_session_key_generates_blob() {
let k = ClientSessionKey::generate().unwrap();
let blob = k.public_blob_hex();
assert!(blob.len() >= 48);
assert!(blob.starts_with("06020000"));
}
#[test]
fn decrypt_misaligned_payload() {
let key = SessionKey::random();
let bad = vec![0u8; 33];
assert!(key.decrypt_secure_string(&bad).is_err());
}
#[test]
fn decrypt_bad_pkcs7_padding() {
let key = SessionKey::random();
let ct = key.encrypt_secure_string("x");
let mut tampered = ct.clone();
let len = tampered.len();
tampered[len - 1] ^= 0xFF; assert!(key.decrypt_secure_string(&tampered).is_err());
}
#[test]
fn full_rsa_aes_roundtrip() {
let client = ClientSessionKey::generate().unwrap();
let aes = {
let mut k = [0u8; 32];
rand::thread_rng().fill_bytes(&mut k);
k
};
let public = RsaPublicKey::from(&client.private);
let padding = Oaep::new::<Sha1>();
let wrapped = public
.encrypt(&mut rand::thread_rng(), padding, &aes)
.unwrap();
let unwrapped = client.decrypt_session_key(&wrapped).unwrap();
assert_eq!(unwrapped, aes);
let sk = SessionKey::from_bytes(unwrapped);
let ct = sk.encrypt_secure_string("s3cret");
let pt = sk.decrypt_secure_string(&ct).unwrap();
assert_eq!(pt, "s3cret");
}
#[test]
fn client_session_key_debug_redacts_private() {
let key = ClientSessionKey::generate().unwrap();
let dbg = format!("{key:?}");
assert!(dbg.contains("<redacted>"));
assert!(!dbg.contains("BEGIN"));
}
}