use crate::errors::{Error, Result};
use crypto::digest::Digest;
use crypto::sha2::Sha256;
use num_bigint::{BigUint, RandPrime};
use num_bigint_dig::traits::ModInverse;
use num_traits::{FromPrimitive, Num, One, Zero};
use rand::Rng;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use zeroize::Zeroize;
lazy_static! {
static ref MIN_PUB_EXPONENT: BigUint = BigUint::from_u64(2).unwrap();
static ref MAX_PUB_EXPONENT: BigUint = BigUint::from_u64(1 << (31 - 1)).unwrap();
}
const EXP: u64 = 65537;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct PublicKey {
n: BigUint,
e: BigUint,
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct PrivateKey {
n: BigUint,
e: BigUint,
d: BigUint,
primes: Vec<BigUint>,
}
pub struct Signature {
pub a: String,
pub c: BigUint,
pub s: BigUint,
}
impl PartialEq for PrivateKey {
#[inline]
fn eq(&self, other: &PrivateKey) -> bool {
self.n == other.n && self.e == other.e && self.d == other.d && self.primes == other.primes
}
}
impl Eq for PrivateKey {}
impl Zeroize for PrivateKey {
fn zeroize(&mut self) {
self.d.zeroize();
for prime in self.primes.iter_mut() {
prime.zeroize();
}
self.primes.clear();
}
}
impl Drop for PrivateKey {
fn drop(&mut self) {
self.zeroize();
}
}
impl PublicKey {
pub fn new(n: BigUint, e: BigUint) -> Result<Self> {
let k = PublicKey { n, e };
check_public(&k)?;
Ok(k)
}
pub fn verify(&self, message: String, sig: &Signature) -> Result<()> {
let mut hasher = Sha256::new();
hasher.input_str(&message);
let m = hasher.result_str();
let m = BigUint::from_str_radix(&m, 16).unwrap();
hasher.reset();
hasher.input_str(&sig.a);
let a = hasher.result_str();
let a = BigUint::from_str_radix(&a, 16).unwrap();
let left = sig.s.modpow(self.e(), self.n());
let mid_val = sig.c.modpow(&BigUint::from_u64(2).unwrap(), self.n()) + BigUint::one();
let mid_val = mid_val.modpow(&BigUint::from_u64(2).unwrap(), self.n());
let mut right = a * m.modpow(&BigUint::from_u64(2).unwrap(), self.n()) * mid_val;
right %= self.n();
if left != right {
return Err(Error::Verification);
}
Ok(())
}
pub fn n(&self) -> &BigUint {
&self.n
}
pub fn e(&self) -> &BigUint {
&self.e
}
}
impl PrivateKey {
pub fn new<R: Rng>(rng: &mut R, bit_size: usize) -> Result<PrivateKey> {
let nprimes = 2;
if bit_size < 64 {
let prime_limit = (1u64 << (bit_size / nprimes) as u64) as f64;
let mut pi = prime_limit / (prime_limit.ln() - 1f64);
pi /= 4f64;
pi /= 2f64;
if pi < nprimes as f64 {
return Err(Error::TooFewPrimes);
}
}
let mut primes = vec![BigUint::zero(); nprimes];
let n_final: BigUint;
let d_final: BigUint;
'next: loop {
let mut todo = bit_size;
if nprimes >= 7 {
todo += (nprimes - 2) / 5;
}
for (i, prime) in primes.iter_mut().enumerate() {
*prime = rng.gen_prime(todo / (nprimes - i));
todo -= prime.bits();
}
for (i, prime1) in primes.iter().enumerate() {
for prime2 in primes.iter().take(i) {
if prime1 == prime2 {
continue 'next;
}
}
}
let mut n = BigUint::one();
let mut totient = BigUint::one();
for prime in &primes {
n *= prime;
totient *= prime - BigUint::one();
}
if n.bits() != bit_size {
continue 'next;
}
let exp = BigUint::from_u64(EXP).expect("invalid static exponent");
if let Some(d) = exp.mod_inverse(totient) {
n_final = n;
d_final = d.to_biguint().unwrap();
break;
}
}
Ok(PrivateKey::from_components(
n_final,
BigUint::from_u64(EXP).expect("invalid static exponent"),
d_final,
primes,
))
}
pub fn from_components(n: BigUint, e: BigUint, d: BigUint, primes: Vec<BigUint>) -> PrivateKey {
PrivateKey { n, e, d, primes }
}
pub fn n(&self) -> &BigUint {
&self.n
}
pub fn e(&self) -> &BigUint {
&self.e
}
pub fn d(&self) -> &BigUint {
&self.d
}
pub fn primes(&self) -> &[BigUint] {
&self.primes
}
pub fn sign(&self, a: String, alpha: BigUint, beta: BigUint, x: BigUint) -> (BigUint, BigUint) {
let beta_invert = beta.mod_inverse(self.n()).unwrap();
let beta_invert = beta_invert.to_biguint().unwrap();
let mut hasher = Sha256::new();
hasher.input_str(&a);
let a = hasher.result_str();
let a = BigUint::from_str_radix(&a, 16).unwrap();
let mut mid_val = x.modpow(&BigUint::from_u64(2).unwrap(), self.n()) + BigUint::one();
mid_val *= beta_invert.modpow(&BigUint::from_u64(2).unwrap(), self.n());
mid_val *= alpha;
mid_val = mid_val.modpow(&BigUint::from_u64(2).unwrap(), self.n());
mid_val *= a;
let d_1 = self.d() - BigUint::one();
let t = mid_val.modpow(&d_1, self.n());
(beta_invert, t)
}
}
impl From<PrivateKey> for PublicKey {
fn from(private_key: PrivateKey) -> Self {
(&private_key).into()
}
}
impl From<&PrivateKey> for PublicKey {
fn from(private_key: &PrivateKey) -> Self {
let n = private_key.n.clone();
let e = private_key.e.clone();
PublicKey { n, e }
}
}
#[inline]
pub fn check_public(public_key: &PublicKey) -> Result<()> {
if public_key.e() < &*MIN_PUB_EXPONENT {
return Err(Error::PublicExponentTooSmall);
}
if public_key.e() > &*MAX_PUB_EXPONENT {
return Err(Error::PublicExponentTooLarge);
}
Ok(())
}