use crate::utils::{from_base64, from_hex, to_base64, to_hex};
use anyhow::Result;
use rand::Rng;
#[derive(Debug, Clone)]
pub struct RSAKeyPair {
pub public_key: RSAPublicKey,
pub private_key: RSAPrivateKey,
}
#[derive(Debug, Clone)]
pub struct RSAPublicKey {
pub n: u64,
pub e: u64,
}
#[derive(Debug, Clone)]
pub struct RSAPrivateKey {
pub n: u64,
pub d: u64,
}
#[derive(Debug)]
pub struct RSAEncryptedData {
pub ciphertext: Vec<u64>,
#[allow(dead_code)]
pub public_key: RSAPublicKey,
}
pub fn encrypt(
data: &str,
key_size_or_pem: &str,
encoding: &str,
privkey_format: &str,
) -> Result<(String, String)> {
let is_pem = key_size_or_pem
.trim_start()
.starts_with("-----BEGIN RSA PUBLIC KEY-----");
let (public_key, private_key_string) = if is_pem {
let lines: Vec<&str> = key_size_or_pem
.lines()
.map(str::trim)
.filter(|l| !l.is_empty())
.collect();
let start = lines
.iter()
.position(|l| l.starts_with("-----BEGIN RSA PUBLIC KEY-----"));
let end = lines
.iter()
.position(|l| l.starts_with("-----END RSA PUBLIC KEY-----"));
if let (Some(start), Some(end)) = (start, end) {
let b64 = lines[start + 1..end].join("");
let decoded =
from_base64(&b64).map_err(|_| anyhow::anyhow!("Invalid base64 in PEM"))?;
let key_str = String::from_utf8(decoded)
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in PEM key data"))?;
let parts: Vec<&str> = key_str.split(':').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid PEM public key format"));
}
let n = parts[0]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid modulus in PEM public key"))?;
let e = parts[1]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid exponent in PEM public key"))?;
(RSAPublicKey { n, e }, String::new())
} else {
return Err(anyhow::anyhow!("Invalid PEM format for RSA public key"));
}
} else {
let key_bits: u32 = key_size_or_pem
.parse()
.map_err(|_| anyhow::anyhow!("Invalid key size: {}", key_size_or_pem))?;
if ![512, 1024, 2048].contains(&key_bits) {
return Err(anyhow::anyhow!("Key size must be 512, 1024, or 2048 bits"));
}
println!("\n🔑 Generating RSA key pair...");
let key_pair = generate_key_pair(key_bits)?;
println!("✅ Key pair generated successfully!");
println!(
"📋 Public Key (n={}, e={})",
key_pair.public_key.n, key_pair.public_key.e
);
let priv_str = match privkey_format.to_lowercase().as_str() {
"pem" => export_private_key_pem(&key_pair.private_key),
_ => format!("{}:{}", key_pair.private_key.n, key_pair.private_key.d),
};
(key_pair.public_key, priv_str)
};
let encrypted_data = rsa_encrypt(data.as_bytes(), &public_key)?;
let encrypted_string = match encoding.to_lowercase().as_str() {
"base64" => {
let bytes: Vec<u8> = encrypted_data
.ciphertext
.iter()
.flat_map(|&num| num.to_be_bytes())
.collect();
to_base64(&bytes)
}
"hex" => {
let bytes: Vec<u8> = encrypted_data
.ciphertext
.iter()
.flat_map(|&num| num.to_be_bytes())
.collect();
to_hex(&bytes)
}
_ => {
return Err(anyhow::anyhow!(
"Unsupported encoding: {}. Use 'base64' or 'hex'",
encoding
))
}
};
Ok((encrypted_string, private_key_string))
}
pub fn decrypt(data: &str, private_key_str: &str, encoding: &str) -> Result<String> {
let private_key = if private_key_str
.trim_start()
.starts_with("-----BEGIN RSA PRIVATE KEY-----")
{
import_private_key_pem(private_key_str)?
} else {
parse_private_key(private_key_str)?
};
let ciphertext_nums = match encoding.to_lowercase().as_str() {
"base64" => {
let bytes = from_base64(data)?;
if !bytes.len().is_multiple_of(8) {
return Err(anyhow::anyhow!("Invalid base64 data length"));
}
bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_be_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect::<Vec<u64>>()
}
"hex" => {
let bytes = from_hex(data)?;
if !bytes.len().is_multiple_of(8) {
return Err(anyhow::anyhow!("Invalid hex data length"));
}
bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_be_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect()
}
_ => {
return Err(anyhow::anyhow!(
"Unsupported encoding: {}. Use 'base64' or 'hex'",
encoding
))
}
};
let decrypted_bytes = rsa_decrypt(&ciphertext_nums, &private_key)?;
String::from_utf8(decrypted_bytes)
.map_err(|e| anyhow::anyhow!("Invalid UTF-8 in decrypted data: {}", e))
}
pub fn export_public_key_ne(public_key: &RSAPublicKey) -> String {
format!("{}:{}", public_key.n, public_key.e)
}
pub fn export_public_key_pem(public_key: &RSAPublicKey) -> String {
let key_data = format!("{}:{}", public_key.n, public_key.e);
let b64 = to_base64(key_data.as_bytes());
format!("-----BEGIN RSA PUBLIC KEY-----\n{b64}\n-----END RSA PUBLIC KEY-----")
}
pub fn export_private_key_pem(private_key: &RSAPrivateKey) -> String {
let key_data = format!("{}:{}", private_key.n, private_key.d);
let b64 = to_base64(key_data.as_bytes());
format!("-----BEGIN RSA PRIVATE KEY-----\n{b64}\n-----END RSA PRIVATE KEY-----")
}
pub fn import_private_key_pem(pem: &str) -> Result<RSAPrivateKey> {
let lines: Vec<&str> = pem
.lines()
.map(str::trim)
.filter(|l| !l.is_empty())
.collect();
let start = lines
.iter()
.position(|l| l.starts_with("-----BEGIN RSA PRIVATE KEY-----"));
let end = lines
.iter()
.position(|l| l.starts_with("-----END RSA PRIVATE KEY-----"));
if let (Some(start), Some(end)) = (start, end) {
let b64 = lines[start + 1..end].join("");
let decoded = from_base64(&b64).map_err(|_| anyhow::anyhow!("Invalid base64 in PEM"))?;
let key_str = String::from_utf8(decoded)
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in PEM key data"))?;
parse_private_key(&key_str)
} else {
Err(anyhow::anyhow!("Invalid PEM format for RSA private key"))
}
}
pub fn generate_key_pair(key_size: u32) -> Result<RSAKeyPair> {
let bit_size = match key_size {
512 => 8, 1024 => 16, 2048 => 20, _ => return Err(anyhow::anyhow!("Unsupported key size")),
};
let p = generate_prime(bit_size)?;
let q = generate_prime(bit_size)?;
if p == q {
return Err(anyhow::anyhow!("Generated primes are identical, try again"));
}
let n = p * q;
let phi = (p - 1) * (q - 1);
let e = find_coprime(phi)?;
let d = mod_inverse(e, phi)?;
let public_key = RSAPublicKey { n, e };
let private_key = RSAPrivateKey { n, d };
Ok(RSAKeyPair {
public_key,
private_key,
})
}
pub fn keygen_and_export(key_size: u32, format: &str) -> Result<(String, String)> {
let key_pair = generate_key_pair(key_size)?;
let (pub_str, priv_str) = match format.to_lowercase().as_str() {
"n:e" => (
export_public_key_ne(&key_pair.public_key),
format!("{}:{}", key_pair.private_key.n, key_pair.private_key.d),
),
"pem" => (
export_public_key_pem(&key_pair.public_key),
export_private_key_pem(&key_pair.private_key),
),
_ => return Err(anyhow::anyhow!("Unsupported key output format")),
};
Ok((pub_str, priv_str))
}
pub fn rsa_encrypt(plaintext: &[u8], public_key: &RSAPublicKey) -> Result<RSAEncryptedData> {
let mut ciphertext = Vec::new();
let block_size = calculate_block_size(public_key.n);
for chunk in plaintext.chunks(block_size) {
let mut m = 0u64;
for &byte in chunk {
m = m * 256 + byte as u64;
}
if m >= public_key.n {
return Err(anyhow::anyhow!("Message block too large for key"));
}
let c = mod_pow(m, public_key.e, public_key.n);
ciphertext.push(c);
}
Ok(RSAEncryptedData {
ciphertext,
public_key: public_key.clone(),
})
}
pub fn rsa_decrypt(ciphertext: &[u64], private_key: &RSAPrivateKey) -> Result<Vec<u8>> {
let mut plaintext = Vec::new();
let block_size = calculate_block_size(private_key.n);
for &c in ciphertext {
let m = mod_pow(c, private_key.d, private_key.n);
let mut bytes = Vec::new();
let mut temp = m;
if temp == 0 {
bytes.push(0);
} else {
while temp > 0 {
bytes.push((temp % 256) as u8);
temp /= 256;
}
bytes.reverse();
}
while bytes.len() < block_size && !bytes.is_empty() && temp != 0 {
bytes.insert(0, 0);
}
plaintext.extend(bytes);
}
while plaintext.first() == Some(&0) && plaintext.len() > 1 {
plaintext.remove(0);
}
while plaintext.last() == Some(&0) && plaintext.len() > 1 {
plaintext.pop();
}
Ok(plaintext)
}
fn parse_private_key(key_str: &str) -> Result<RSAPrivateKey> {
let parts: Vec<&str> = key_str.split(':').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!(
"Invalid private key format. Expected 'n:d'"
));
}
let n = parts[0]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid modulus in private key"))?;
let d = parts[1]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid private exponent in private key"))?;
Ok(RSAPrivateKey { n, d })
}
fn calculate_block_size(n: u64) -> usize {
let bits = 64 - n.leading_zeros();
let bytes = (bits / 8).max(1) as usize;
if bytes > 2 {
bytes - 1
} else {
1
}
}
fn generate_prime(bit_size: u32) -> Result<u64> {
let mut rng = rand::rng();
let min = 1u64 << (bit_size - 1);
let max = (1u64 << bit_size) - 1;
for attempt in 0..10000 {
let candidate = rng.random_range(min + attempt..=max - attempt);
if candidate.is_multiple_of(2) {
continue; }
if is_prime(candidate) {
return Ok(candidate);
}
}
Err(anyhow::anyhow!(
"Failed to generate prime after 10000 attempts"
))
}
fn is_prime(n: u64) -> bool {
if n < 2 {
return false;
}
if n == 2 {
return true;
}
if n.is_multiple_of(2) {
return false;
}
let sqrt_n = (n as f64).sqrt() as u64 + 1;
for i in (3..=sqrt_n).step_by(2) {
if n.is_multiple_of(i) {
return false;
}
}
true
}
fn find_coprime(phi: u64) -> Result<u64> {
for e in [3, 17, 257, 65537] {
if e < phi && gcd(e, phi) == 1 {
return Ok(e);
}
}
for e in 3..phi {
if gcd(e, phi) == 1 {
return Ok(e);
}
}
Err(anyhow::anyhow!("Could not find coprime to phi"))
}
fn gcd(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let temp = b;
b = a % b;
a = temp;
}
a
}
fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
if a == 0 {
return (b, 0, 1);
}
let (gcd, x1, y1) = extended_gcd(b % a, a);
let x = y1 - (b / a) * x1;
let y = x1;
(gcd, x, y)
}
fn mod_inverse(a: u64, m: u64) -> Result<u64> {
let (gcd, x, _) = extended_gcd(a as i64, m as i64);
if gcd != 1 {
return Err(anyhow::anyhow!("Modular inverse does not exist"));
}
let result = ((x % m as i64) + m as i64) % m as i64;
Ok(result as u64)
}
fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
if modulus == 1 {
return 0;
}
let mut result = 1;
base %= modulus;
while exp > 0 {
if exp % 2 == 1 {
result = ((result as u128 * base as u128) % modulus as u128) as u64;
}
exp >>= 1;
base = ((base as u128 * base as u128) % modulus as u128) as u64;
}
result
}
pub fn sign(data: &str, private_key_str: &str, encoding: &str) -> Result<String> {
let private_key = if private_key_str
.trim_start()
.starts_with("-----BEGIN RSA PRIVATE KEY-----")
{
import_private_key_pem(private_key_str)?
} else {
parse_private_key(private_key_str)?
};
let signature_data = rsa_sign(data.as_bytes(), &private_key)?;
let signature_string = match encoding.to_lowercase().as_str() {
"base64" => {
let bytes: Vec<u8> = signature_data
.iter()
.flat_map(|&num| num.to_be_bytes())
.collect();
to_base64(&bytes)
}
"hex" => {
let bytes: Vec<u8> = signature_data
.iter()
.flat_map(|&num| num.to_be_bytes())
.collect();
to_hex(&bytes)
}
_ => {
return Err(anyhow::anyhow!(
"Unsupported encoding: {}. Use 'base64' or 'hex'",
encoding
))
}
};
Ok(signature_string)
}
pub fn verify(data: &str, signature: &str, public_key_str: &str, encoding: &str) -> Result<bool> {
let public_key = if public_key_str
.trim_start()
.starts_with("-----BEGIN RSA PUBLIC KEY-----")
{
let lines: Vec<&str> = public_key_str
.lines()
.map(str::trim)
.filter(|l| !l.is_empty())
.collect();
let start = lines
.iter()
.position(|l| l.starts_with("-----BEGIN RSA PUBLIC KEY-----"));
let end = lines
.iter()
.position(|l| l.starts_with("-----END RSA PUBLIC KEY-----"));
if let (Some(start), Some(end)) = (start, end) {
let b64 = lines[start + 1..end].join("");
let decoded =
from_base64(&b64).map_err(|_| anyhow::anyhow!("Invalid base64 in PEM"))?;
let key_str = String::from_utf8(decoded)
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in PEM key data"))?;
let parts: Vec<&str> = key_str.split(':').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid PEM public key format"));
}
let n = parts[0]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid modulus in PEM public key"))?;
let e = parts[1]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid exponent in PEM public key"))?;
RSAPublicKey { n, e }
} else {
return Err(anyhow::anyhow!("Invalid PEM format for RSA public key"));
}
} else {
let parts: Vec<&str> = public_key_str.split(':').collect();
if parts.len() != 2 {
return Err(anyhow::anyhow!("Invalid public key format. Expected 'n:e'"));
}
let n = parts[0]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid modulus in public key"))?;
let e = parts[1]
.parse::<u64>()
.map_err(|_| anyhow::anyhow!("Invalid exponent in public key"))?;
RSAPublicKey { n, e }
};
let signature_nums = match encoding.to_lowercase().as_str() {
"base64" => {
let bytes = from_base64(signature)?;
if !bytes.len().is_multiple_of(8) {
return Err(anyhow::anyhow!("Invalid base64 signature length"));
}
bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_be_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect::<Vec<u64>>()
}
"hex" => {
let bytes = from_hex(signature)?;
if !bytes.len().is_multiple_of(8) {
return Err(anyhow::anyhow!("Invalid hex signature length"));
}
bytes
.chunks_exact(8)
.map(|chunk| {
u64::from_be_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect()
}
_ => {
return Err(anyhow::anyhow!(
"Unsupported encoding: {}. Use 'base64' or 'hex'",
encoding
))
}
};
rsa_verify(data.as_bytes(), &signature_nums, &public_key)
}
pub fn rsa_sign(message: &[u8], private_key: &RSAPrivateKey) -> Result<Vec<u64>> {
let mut signature = Vec::new();
let block_size = calculate_block_size(private_key.n);
for chunk in message.chunks(block_size) {
let mut m = 0u64;
for &byte in chunk {
m = m * 256 + byte as u64;
}
if m >= private_key.n {
return Err(anyhow::anyhow!("Message block too large for key"));
}
let s = mod_pow(m, private_key.d, private_key.n);
signature.push(s);
}
Ok(signature)
}
pub fn rsa_verify(message: &[u8], signature: &[u64], public_key: &RSAPublicKey) -> Result<bool> {
let mut recovered_message = Vec::new();
let block_size = calculate_block_size(public_key.n);
for &s in signature {
let m = mod_pow(s, public_key.e, public_key.n);
let mut bytes = Vec::new();
let mut temp = m;
if temp == 0 {
bytes.push(0);
} else {
while temp > 0 {
bytes.push((temp % 256) as u8);
temp /= 256;
}
bytes.reverse();
}
while bytes.len() < block_size && !bytes.is_empty() && m != 0 {
bytes.insert(0, 0);
}
recovered_message.extend(bytes);
}
while recovered_message.first() == Some(&0) && recovered_message.len() > 1 {
recovered_message.remove(0);
}
while recovered_message.last() == Some(&0) && recovered_message.len() > 1 {
recovered_message.pop();
}
Ok(recovered_message == message)
}