use crate::error::{FluxError, Result};
use crate::keys::PublicKey;
use rsa::RsaPrivateKey;
use rsa::traits::{PrivateKeyParts, PublicKeyParts};
use zeroize::ZeroizeOnDrop;
#[derive(Clone, ZeroizeOnDrop)]
pub struct PrivateKey {
inner: RsaPrivateKey,
}
impl std::fmt::Debug for PrivateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrivateKey")
.field("key_size", &(self.inner.size() * 8))
.field("_key_data", &"[REDACTED]")
.finish()
}
}
impl PrivateKey {
pub fn new(inner: RsaPrivateKey) -> Self {
Self { inner }
}
pub fn key_size_bits(&self) -> usize {
self.inner.size() * 8
}
pub fn key_size_bytes(&self) -> usize {
self.inner.size()
}
pub fn inner(&self) -> &RsaPrivateKey {
&self.inner
}
pub fn to_pem(&self) -> Result<String> {
use rsa::pkcs1::EncodeRsaPrivateKey;
self.inner
.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF)
.map(|zeroizing_string| zeroizing_string.to_string())
.map_err(|e| FluxError::crypto(format!("Failed to encode private key as PEM: {}", e)))
}
pub fn to_encrypted_pem(&self, password: &str) -> Result<String> {
use pkcs8::PrivateKeyInfo;
use rand::rngs::OsRng;
use rsa::pkcs8::EncodePrivateKey;
let private_key_der = self.inner.to_pkcs8_der().map_err(|e| {
FluxError::crypto(format!("Failed to encode private key as PKCS#8 DER: {}", e))
})?;
let private_key_info = PrivateKeyInfo::try_from(private_key_der.as_bytes())
.map_err(|e| FluxError::crypto(format!("Failed to parse PKCS#8 DER: {}", e)))?;
let encrypted_private_key = private_key_info
.encrypt(&mut OsRng, password)
.map_err(|e| FluxError::crypto(format!("Failed to encrypt private key: {}", e)))?;
let pem_string = encrypted_private_key
.to_pem("ENCRYPTED PRIVATE KEY", pkcs8::LineEnding::LF)
.map_err(|e| {
FluxError::crypto(format!(
"Failed to encode encrypted private key as PEM: {}",
e
))
})?;
Ok(pem_string.to_string())
}
pub fn to_der(&self) -> Result<Vec<u8>> {
use rsa::pkcs1::EncodeRsaPrivateKey;
self.inner
.to_pkcs1_der()
.map(|der| der.as_bytes().to_vec())
.map_err(|e| FluxError::crypto(format!("Failed to encode private key as DER: {}", e)))
}
pub fn public_key(&self) -> Result<PublicKey> {
Ok(PublicKey::new(self.inner.to_public_key()))
}
pub fn modulus(&self) -> Vec<u8> {
self.inner.n().to_bytes_be()
}
pub fn private_exponent(&self) -> Vec<u8> {
self.inner.d().to_bytes_be()
}
pub fn prime1(&self) -> Vec<u8> {
self.inner.primes()[0].to_bytes_be()
}
pub fn prime2(&self) -> Vec<u8> {
self.inner.primes()[1].to_bytes_be()
}
pub fn crt_coefficient(&self) -> Vec<u8> {
let primes = self.inner.primes();
if primes.len() >= 2 {
let p = &primes[0];
p.to_bytes_be()
} else {
vec![0u8; 32]
}
}
}
#[cfg(test)]
mod tests {
use crate::keys::KeyPair;
#[test]
fn test_private_key_creation() {
let keypair = KeyPair::generate(2048).unwrap();
let private_key = keypair.private_key();
assert_eq!(private_key.key_size_bits(), 2048);
assert_eq!(private_key.key_size_bytes(), 256);
assert!(!private_key.modulus().is_empty());
assert!(!private_key.private_exponent().is_empty());
assert!(!private_key.prime1().is_empty());
assert!(!private_key.prime2().is_empty());
assert!(!private_key.crt_coefficient().is_empty());
}
#[test]
fn test_private_key_debug_format() {
let keypair = KeyPair::generate(2048).unwrap();
let debug_str = format!("{:?}", keypair.private_key());
assert!(debug_str.contains("PrivateKey"));
assert!(debug_str.contains("key_size"));
assert!(debug_str.contains("[REDACTED]"));
}
#[test]
fn test_private_key_clone() {
let keypair = KeyPair::generate(2048).unwrap();
let private_key1 = keypair.private_key();
let private_key2 = private_key1.clone();
assert_eq!(private_key1.key_size_bits(), private_key2.key_size_bits());
assert_eq!(private_key1.modulus(), private_key2.modulus());
assert_eq!(
private_key1.private_exponent(),
private_key2.private_exponent()
);
}
#[test]
fn test_private_key_to_public_key() {
let keypair = KeyPair::generate(2048).unwrap();
let derived_public = keypair.private_key().public_key().unwrap();
assert_eq!(
derived_public.key_size_bits(),
keypair.public_key().key_size_bits()
);
assert_eq!(derived_public.modulus(), keypair.public_key().modulus());
assert_eq!(derived_public.public_exponent(), vec![0x01, 0x00, 0x01]);
}
#[test]
fn test_private_key_pem_export() {
let keypair = KeyPair::generate(2048).unwrap();
let pem = keypair.private_key().to_pem().unwrap();
assert!(pem.starts_with("-----BEGIN RSA PRIVATE KEY-----\n"));
assert!(pem.ends_with("\n-----END RSA PRIVATE KEY-----\n"));
assert!(pem.len() > 100);
}
#[test]
fn test_private_key_encrypted_pem_export() {
let keypair = KeyPair::generate(2048).unwrap();
let password = "test_password_123";
let encrypted_pem = keypair.private_key().to_encrypted_pem(password).unwrap();
assert!(encrypted_pem.starts_with("-----BEGIN ENCRYPTED PRIVATE KEY-----\n"));
assert!(encrypted_pem.ends_with("\n-----END ENCRYPTED PRIVATE KEY-----\n"));
assert!(encrypted_pem.len() > 100);
let regular_pem = keypair.private_key().to_pem().unwrap();
assert_ne!(encrypted_pem, regular_pem);
}
#[test]
fn test_private_key_der_export() {
let keypair = KeyPair::generate(2048).unwrap();
let der = keypair.private_key().to_der().unwrap();
assert!(!der.is_empty());
}
#[test]
fn test_key_component_lengths() {
for &key_size in &[2048, 3072, 4096] {
let keypair = KeyPair::generate(key_size).unwrap();
let private_key = keypair.private_key();
let expected_modulus_len = key_size / 8;
let expected_prime_len = key_size / 16;
assert_eq!(private_key.modulus().len(), expected_modulus_len);
assert_eq!(private_key.private_exponent().len(), expected_modulus_len);
assert_eq!(private_key.prime1().len(), expected_prime_len);
assert_eq!(private_key.prime2().len(), expected_prime_len);
assert_eq!(private_key.crt_coefficient().len(), expected_prime_len);
}
}
#[test]
fn test_memory_zeroization() {
let keypair = KeyPair::generate(2048).unwrap();
let _private_key = keypair.private_key().clone();
drop(_private_key);
let another_private = keypair.private_key().clone();
drop(another_private);
}
}