use rsa::{RsaPrivateKey, RsaPublicKey, BigUint, Pkcs1v15Encrypt};
use rsa::traits::{PublicKeyParts, PrivateKeyParts};
use rand::rngs::OsRng;
use log::debug;
#[derive(Debug)]
pub enum RsaError {
KeyGenerationFailed(String),
SigningFailed(String),
DecryptionFailed(String),
InvalidKey(String),
InvalidData(String),
}
pub struct RsaOperations;
impl RsaOperations {
pub fn generate_keypair(bits: usize) -> Result<(Vec<u8>, Vec<u8>), RsaError> {
debug!("Generating RSA-{} keypair", bits);
let private_key = RsaPrivateKey::new(&mut OsRng, bits)
.map_err(|e| RsaError::KeyGenerationFailed(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
let private_data = Self::encode_private_key(&private_key)?;
let public_data = Self::encode_public_key(&public_key);
Ok((private_data, public_data))
}
fn encode_private_key(key: &RsaPrivateKey) -> Result<Vec<u8>, RsaError> {
let e = key.e().to_bytes_be();
let primes = key.primes();
if primes.len() < 2 {
return Err(RsaError::InvalidKey("Missing prime factors".to_string()));
}
let p = primes[0].to_bytes_be();
let q = primes[1].to_bytes_be();
let mut data = Vec::new();
data.extend_from_slice(&(e.len() as u16).to_be_bytes());
data.extend_from_slice(&e);
data.extend_from_slice(&(p.len() as u16).to_be_bytes());
data.extend_from_slice(&p);
data.extend_from_slice(&(q.len() as u16).to_be_bytes());
data.extend_from_slice(&q);
Ok(data)
}
fn encode_public_key(key: &RsaPublicKey) -> Vec<u8> {
let n = key.n().to_bytes_be();
let e = key.e().to_bytes_be();
let mut data = Vec::new();
data.extend_from_slice(&(n.len() as u16).to_be_bytes());
data.extend_from_slice(&n);
data.extend_from_slice(&(e.len() as u16).to_be_bytes());
data.extend_from_slice(&e);
data
}
pub fn decode_private_key(data: &[u8], n_bytes: &[u8]) -> Result<RsaPrivateKey, RsaError> {
if data.len() < 6 {
return Err(RsaError::InvalidKey("Data too short".to_string()));
}
let mut offset = 0;
let e_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
offset += 2;
if offset + e_len > data.len() {
return Err(RsaError::InvalidKey("Invalid e length".to_string()));
}
let e = BigUint::from_bytes_be(&data[offset..offset + e_len]);
offset += e_len;
if offset + 2 > data.len() {
return Err(RsaError::InvalidKey("Missing p length".to_string()));
}
let p_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
offset += 2;
if offset + p_len > data.len() {
return Err(RsaError::InvalidKey("Invalid p length".to_string()));
}
let p = BigUint::from_bytes_be(&data[offset..offset + p_len]);
offset += p_len;
if offset + 2 > data.len() {
return Err(RsaError::InvalidKey("Missing q length".to_string()));
}
let q_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
offset += 2;
if offset + q_len > data.len() {
return Err(RsaError::InvalidKey("Invalid q length".to_string()));
}
let q = BigUint::from_bytes_be(&data[offset..offset + q_len]);
let n = BigUint::from_bytes_be(n_bytes);
let one = BigUint::from(1u64);
let p_minus_1 = &p - &one;
let q_minus_1 = &q - &one;
let phi_n = &p_minus_1 * &q_minus_1;
let d = Self::mod_inverse(&e, &phi_n)
.ok_or_else(|| RsaError::InvalidKey("Cannot compute private exponent d".to_string()))?;
RsaPrivateKey::from_components(n, e, d, vec![p, q])
.map_err(|e| RsaError::InvalidKey(e.to_string()))
}
pub fn sign_pkcs1v15(private_key: &RsaPrivateKey, digest_info: &[u8]) -> Result<Vec<u8>, RsaError> {
let key_size = private_key.size();
if digest_info.len() + 11 > key_size {
return Err(RsaError::InvalidData("DigestInfo too large for key size".to_string()));
}
let padding_len = key_size - digest_info.len() - 3;
let mut padded = Vec::with_capacity(key_size);
padded.push(0x00);
padded.push(0x01);
padded.extend(std::iter::repeat_n(0xFF, padding_len));
padded.push(0x00);
padded.extend_from_slice(digest_info);
let m = BigUint::from_bytes_be(&padded);
let d = private_key.d();
let n = private_key.n();
let signature = m.modpow(d, n);
let mut sig_bytes = signature.to_bytes_be();
while sig_bytes.len() < key_size {
sig_bytes.insert(0, 0);
}
Ok(sig_bytes)
}
#[allow(dead_code)]
pub fn raw_sign(private_key: &RsaPrivateKey, padded_data: &[u8]) -> Result<Vec<u8>, RsaError> {
let m = BigUint::from_bytes_be(padded_data);
let d = private_key.d();
let n = private_key.n();
let signature = m.modpow(d, n);
let key_size = private_key.size().div_ceil(8);
let mut sig_bytes = signature.to_bytes_be();
while sig_bytes.len() < key_size {
sig_bytes.insert(0, 0);
}
Ok(sig_bytes)
}
pub fn decrypt(private_key: &RsaPrivateKey, ciphertext: &[u8]) -> Result<Vec<u8>, RsaError> {
private_key.decrypt(Pkcs1v15Encrypt, ciphertext)
.map_err(|e| RsaError::DecryptionFailed(e.to_string()))
}
pub fn get_modulus(public_key_data: &[u8]) -> Option<Vec<u8>> {
if public_key_data.len() < 4 {
return None;
}
let n_len = u16::from_be_bytes([public_key_data[0], public_key_data[1]]) as usize;
if public_key_data.len() < 2 + n_len {
return None;
}
Some(public_key_data[2..2 + n_len].to_vec())
}
pub fn get_exponent(public_key_data: &[u8]) -> Option<Vec<u8>> {
if public_key_data.len() < 4 {
return None;
}
let n_len = u16::from_be_bytes([public_key_data[0], public_key_data[1]]) as usize;
let e_offset = 2 + n_len;
if public_key_data.len() < e_offset + 2 {
return None;
}
let e_len = u16::from_be_bytes([public_key_data[e_offset], public_key_data[e_offset + 1]]) as usize;
if public_key_data.len() < e_offset + 2 + e_len {
return None;
}
Some(public_key_data[e_offset + 2..e_offset + 2 + e_len].to_vec())
}
fn mod_inverse(a: &BigUint, m: &BigUint) -> Option<BigUint> {
let one = BigUint::from(1u64);
let zero = BigUint::from(0u64);
let mut old_r = m.clone();
let mut r = a.clone();
let mut old_s = zero.clone();
let mut s = one.clone();
let mut old_s_neg = false;
let mut s_neg = false;
while r != zero {
let quotient = &old_r / &r;
let temp_r = old_r;
old_r = r.clone();
r = temp_r - "ient * &r;
let (new_s, new_s_neg) = {
let qs = "ient * &s;
if old_s_neg == s_neg {
if old_s >= qs {
(old_s.clone() - &qs, old_s_neg)
} else {
(qs - &old_s, !old_s_neg)
}
} else {
(old_s.clone() + &qs, old_s_neg)
}
};
old_s = s;
old_s_neg = s_neg;
s = new_s;
s_neg = new_s_neg;
}
if old_r != one {
return None;
}
let result = if old_s_neg {
m - &old_s
} else {
old_s
};
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_keypair() {
let (private, public) = RsaOperations::generate_keypair(2048).unwrap();
assert!(!private.is_empty());
assert!(!public.is_empty());
}
#[test]
fn test_get_modulus_exponent() {
let (_, public) = RsaOperations::generate_keypair(2048).unwrap();
let n = RsaOperations::get_modulus(&public).unwrap();
let e = RsaOperations::get_exponent(&public).unwrap();
assert_eq!(n.len(), 256); assert!(!e.is_empty());
}
}