use num_bigint_dig::{BigInt, BigUint, Sign, ToBigInt};
use num_integer::Integer;
use num_traits::{One, Zero};
use rand::{rngs::StdRng, RngCore, SeedableRng};
#[derive(Debug, Clone)]
pub struct DsaParams {
pub p: BigUint,
pub q: BigUint,
pub g: BigUint,
}
#[derive(Debug, Clone)]
pub struct DsaKeyPair {
pub private: BigUint,
pub public: BigUint,
}
#[derive(Debug, Clone)]
pub struct DsaSignature {
pub r: BigUint,
pub s: BigUint,
}
pub fn toy_generate_dsa_params(_p_bits: usize, _q_bits: usize, _seed: Option<u64>) -> DsaParams {
let q = BigUint::from(11u64); let p = &q * BigUint::from(2u64) + BigUint::one();
let g = BigUint::from(2u64);
DsaParams { p, q, g }
}
pub fn toy_dsa_generate_keypair(params: &DsaParams, seed: Option<u64>) -> DsaKeyPair {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let q = ¶ms.q;
let p = ¶ms.p;
let g = ¶ms.g;
let mut x = BigUint::zero();
for _ in 0..100 {
let mut scalar_bytes = vec![0u8; q.bits() / 8 + 1];
rng.fill_bytes(&mut scalar_bytes);
x = BigUint::from_bytes_be(&scalar_bytes) % q;
if !x.is_zero() {
break;
}
}
if x.is_zero() {
x = BigUint::one();
}
let y = mod_exp(g, &x, p);
DsaKeyPair {
private: x,
public: y,
}
}
pub fn toy_dsa_sign(
params: &DsaParams,
kp: &DsaKeyPair,
h_m: &BigUint,
seed: Option<u64>,
) -> DsaSignature {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let p = ¶ms.p;
let q = ¶ms.q;
let g = ¶ms.g;
let x = &kp.private;
let mut k = BigUint::zero();
for _ in 0..100 {
let mut buf = vec![0u8; q.bits() / 8 + 1];
rng.fill_bytes(&mut buf);
let k_candidate = BigUint::from_bytes_be(&buf) % q;
if !k_candidate.is_zero() {
let (gcd, _, _) = extended_gcd(k_candidate.clone(), q.clone());
if gcd == BigUint::one() {
k = k_candidate;
break;
}
}
}
if k.is_zero() {
k = BigUint::one();
}
let gk = mod_exp(g, &k, p);
let r = gk % q;
let k_inv = mod_inv(k, q).expect("No inverse for ephemeral k?! (toy DSA)");
let xr = (x * &r) % q;
let sum = (h_m + &xr) % q;
let s = (&k_inv * &sum) % q;
DsaSignature { r, s }
}
pub fn toy_dsa_verify(
params: &DsaParams,
pub_key: &BigUint,
h_m: &BigUint,
sig: &DsaSignature,
) -> bool {
let q = ¶ms.q;
let p = ¶ms.p;
let g = ¶ms.g;
if sig.r.is_zero() || sig.r >= *q {
println!("r out of range: r={}", sig.r);
return false;
}
if sig.s.is_zero() || sig.s >= *q {
println!("s out of range: s={}", sig.s);
return false;
}
if pub_key <= &BigUint::one() || pub_key >= p {
println!("Invalid public key");
return false;
}
let w = match mod_inv(sig.s.clone(), q) {
Some(val) => val,
None => {
println!("no inverse for s: s={}", sig.s);
return false;
}
};
println!("w = {}", w);
let u1 = (h_m * &w) % q;
let u2 = (&sig.r * &w) % q;
println!("u1 = {}", u1);
println!("u2 = {}", u2);
let gu1 = mod_exp(g, &u1, p);
let yu2 = mod_exp(pub_key, &u2, p);
println!("gu1 = {}", gu1);
println!("yu2 = {}", yu2);
let t = (&gu1 * &yu2) % p;
println!("t = {}", t);
let v = t % q;
println!("v = {}", v);
println!("r = {}", sig.r);
v == sig.r
}
fn mod_exp(base: &BigUint, exp: &BigUint, m: &BigUint) -> BigUint {
if m.is_zero() {
panic!("mod_exp with modulus=0");
}
base.modpow(exp, m)
}
fn extended_gcd(a: BigUint, b: BigUint) -> (BigUint, BigInt, BigInt) {
let mut a_int = a.to_bigint().unwrap();
let mut b_int = b.to_bigint().unwrap();
let mut x0 = BigInt::one();
let mut x1 = BigInt::zero();
let mut y0 = BigInt::zero();
let mut y1 = BigInt::one();
while !b_int.is_zero() {
let (q, r) = a_int.div_rem(&b_int);
a_int = b_int;
b_int = r;
let tmpx = x0 - &q * &x1;
x0 = x1;
x1 = tmpx;
let tmpy = y0 - &q * &y1;
y0 = y1;
y1 = tmpy;
}
(a_int.to_biguint().unwrap(), x0, y0)
}
fn mod_inv(x: BigUint, m: &BigUint) -> Option<BigUint> {
let (g, s, _) = extended_gcd(x, m.clone());
if g != BigUint::one() {
None
} else {
let m_int = m.to_bigint().unwrap();
let mut result = s % &m_int;
if result.sign() == Sign::Minus {
result += &m_int;
}
Some(result.to_biguint().unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dsa_toy() {
let p_bits = 256;
let q_bits = 160;
let params = toy_generate_dsa_params(p_bits, q_bits, Some(42));
println!("\nDSA Parameters:");
println!("p = {}", params.p);
println!("q = {}", params.q);
println!("g = {}", params.g);
let kp = toy_dsa_generate_keypair(¶ms, Some(100));
println!("\nKey Pair:");
println!("x (private) = {}", kp.private);
println!("y (public) = {}", kp.public);
let msg_hash = BigUint::parse_bytes(b"123456789ABCDEF", 16).unwrap();
println!("\nMessage Hash:");
println!("h_m = {}", msg_hash);
let sig = toy_dsa_sign(¶ms, &kp, &msg_hash, Some(200));
println!("\nSignature:");
println!("r = {}", sig.r);
println!("s = {}", sig.s);
let valid = toy_dsa_verify(¶ms, &kp.public, &msg_hash, &sig);
println!("\nVerification Result: {}", valid);
assert!(valid, "DSA signature must verify with correct key/msg.");
let mut bad_sig = sig.clone();
bad_sig.r += 1u64;
let invalid = toy_dsa_verify(¶ms, &kp.public, &msg_hash, &bad_sig);
assert!(!invalid, "Tampered signature must fail verification");
}
}