use super::bigint::BigInt;
use super::rsa::{RsaPublicKey, RsaSecretKey, rsa_decrypt_raw, rsa_encrypt_raw};
use crate::Hasher;
use crate::hash::sha256::Sha256;
const HASH_LEN: usize = 32;
fn sha256(data: &[u8]) -> [u8; HASH_LEN] {
let digest = Sha256::hash(data);
let mut out = [0u8; HASH_LEN];
out.copy_from_slice(&digest);
out
}
fn mgf1_sha256(seed: &[u8], len: usize) -> Vec<u8> {
let mut output = Vec::with_capacity(len);
let mut counter: u32 = 0;
while output.len() < len {
let mut h = Sha256::new();
h.update(seed);
h.update(&counter.to_be_bytes());
let block = h.finalize();
let take = (len - output.len()).min(block.len());
output.extend_from_slice(&block[..take]);
counter += 1;
}
output.truncate(len);
output
}
fn xor_in_place(a: &mut [u8], b: &[u8]) {
for (x, y) in a.iter_mut().zip(b.iter()) {
*x ^= y;
}
}
pub fn oaep_encrypt(pk: &RsaPublicKey, msg: &[u8], label: &[u8], rng: &mut dyn FnMut(&mut [u8])) -> Vec<u8> {
let k = pk.modulus_byte_len();
let max_msg_len = k - 2 * HASH_LEN - 2;
assert!(
msg.len() <= max_msg_len,
"OAEP: message too long (max {} bytes, got {})",
max_msg_len,
msg.len()
);
let l_hash = sha256(label);
let db_len = k - HASH_LEN - 1;
let mut db = vec![0u8; db_len];
db[..HASH_LEN].copy_from_slice(&l_hash);
let ps_len = db_len - HASH_LEN - 1 - msg.len();
db[HASH_LEN + ps_len] = 0x01;
db[HASH_LEN + ps_len + 1..].copy_from_slice(msg);
let mut seed = [0u8; HASH_LEN];
rng(&mut seed);
let db_mask = mgf1_sha256(&seed, db_len);
xor_in_place(&mut db, &db_mask);
let seed_mask = mgf1_sha256(&db, HASH_LEN);
let mut masked_seed = seed;
xor_in_place(&mut masked_seed, &seed_mask);
let mut em = vec![0u8; k];
em[0] = 0x00;
em[1..1 + HASH_LEN].copy_from_slice(&masked_seed);
em[1 + HASH_LEN..].copy_from_slice(&db);
let m = BigInt::from_be_bytes(&em);
let c = rsa_encrypt_raw(pk, &m);
c.to_be_bytes(k)
}
pub fn oaep_decrypt(sk: &RsaSecretKey, ct: &[u8], label: &[u8]) -> Option<Vec<u8>> {
let k = sk.modulus_byte_len();
if ct.len() != k || k < 2 * HASH_LEN + 2 {
return None;
}
let c = BigInt::from_be_bytes(ct);
let m = rsa_decrypt_raw(sk, &c);
let em = m.to_be_bytes(k);
if em[0] != 0x00 {
return None;
}
let masked_seed = &em[1..1 + HASH_LEN];
let masked_db = &em[1 + HASH_LEN..];
let seed_mask = mgf1_sha256(masked_db, HASH_LEN);
let mut seed = [0u8; HASH_LEN];
seed.copy_from_slice(masked_seed);
xor_in_place(&mut seed, &seed_mask);
let db_len = k - HASH_LEN - 1;
let db_mask = mgf1_sha256(&seed, db_len);
let mut db = vec![0u8; db_len];
db.copy_from_slice(masked_db);
xor_in_place(&mut db, &db_mask);
let l_hash = sha256(label);
let mut valid = true;
for i in 0..HASH_LEN {
if db[i] != l_hash[i] {
valid = false;
}
}
let mut sep = None;
for i in HASH_LEN..db.len() {
if db[i] == 0x01 {
sep = Some(i);
break;
} else if db[i] != 0x00 {
valid = false;
break;
}
}
if !valid {
return None;
}
let sep = sep?;
Some(db[sep + 1..].to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_rng() -> impl FnMut(&mut [u8]) {
let mut state: u64 = 0xdeadbeefcafebabe;
move |buf: &mut [u8]| {
for b in buf.iter_mut() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*b = (state >> 33) as u8;
}
}
}
#[test]
fn test_mgf1() {
let mask1 = mgf1_sha256(b"seed", 64);
let mask2 = mgf1_sha256(b"seed", 64);
assert_eq!(mask1.len(), 64);
assert_eq!(mask1, mask2);
let mask3 = mgf1_sha256(b"other", 64);
assert_ne!(mask1, mask3);
}
#[test]
fn test_oaep_encrypt_decrypt_roundtrip() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let msg = b"Hello, OAEP!";
let ct = oaep_encrypt(&pk, msg, b"", &mut rng);
let pt = oaep_decrypt(&sk, &ct, b"").expect("OAEP decryption failed");
assert_eq!(&pt, msg);
}
#[test]
fn test_oaep_wrong_label() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let msg = b"test";
let ct = oaep_encrypt(&pk, msg, b"label_a", &mut rng);
let result = oaep_decrypt(&sk, &ct, b"label_b");
assert!(result.is_none(), "Decryption should fail with wrong label");
}
}