use crate::error::Result;
use crate::primitives::ec::{PrivateKey, PublicKey};
use crate::primitives::hash::sha256;
use crate::primitives::BigNumber;
#[derive(Clone, Debug)]
pub struct SchnorrProof {
pub r: PublicKey,
pub s_prime: PublicKey,
pub z: BigNumber,
}
pub struct Schnorr;
impl Schnorr {
pub fn generate_proof(
a: &PrivateKey,
big_a: &PublicKey,
big_b: &PublicKey,
big_s: &PublicKey,
) -> Result<SchnorrProof> {
let r = PrivateKey::random();
let big_r = r.public_key();
let s_prime = big_b.mul_scalar(&r.to_bytes())?;
let e = Self::compute_challenge(big_a, big_b, big_s, &s_prime, &big_r);
let order = BigNumber::secp256k1_order();
let r_bn = BigNumber::from_bytes_be(&r.to_bytes());
let a_bn = BigNumber::from_bytes_be(&a.to_bytes());
let z = r_bn.add(&e.mul(&a_bn)).modulo(&order);
Ok(SchnorrProof {
r: big_r,
s_prime,
z,
})
}
pub fn verify_proof(
big_a: &PublicKey,
big_b: &PublicKey,
big_s: &PublicKey,
proof: &SchnorrProof,
) -> bool {
let e = Self::compute_challenge(big_a, big_b, big_s, &proof.s_prime, &proof.r);
let e_bytes = e.to_bytes_be(32);
let e_bytes_arr: [u8; 32] = e_bytes.try_into().expect("e should be 32 bytes");
let z_bytes = proof.z.to_bytes_be(32);
let z_bytes_arr: [u8; 32] = z_bytes.try_into().expect("z should be 32 bytes");
let z_g = match PublicKey::from_scalar_mul_generator(&z_bytes_arr) {
Ok(p) => p,
Err(_) => return false,
};
let e_a = match big_a.mul_scalar(&e_bytes_arr) {
Ok(p) => p,
Err(_) => return false,
};
let r_plus_ea = match proof.r.add(&e_a) {
Ok(p) => p,
Err(_) => return false,
};
if z_g != r_plus_ea {
return false;
}
let z_b = match big_b.mul_scalar(&z_bytes_arr) {
Ok(p) => p,
Err(_) => return false,
};
let e_s = match big_s.mul_scalar(&e_bytes_arr) {
Ok(p) => p,
Err(_) => return false,
};
let s_prime_plus_es = match proof.s_prime.add(&e_s) {
Ok(p) => p,
Err(_) => return false,
};
z_b == s_prime_plus_es
}
fn compute_challenge(
a: &PublicKey,
b: &PublicKey,
s: &PublicKey,
s_prime: &PublicKey,
r: &PublicKey,
) -> BigNumber {
let mut msg = Vec::with_capacity(33 * 5);
msg.extend_from_slice(&a.to_compressed());
msg.extend_from_slice(&b.to_compressed());
msg.extend_from_slice(&s.to_compressed());
msg.extend_from_slice(&s_prime.to_compressed());
msg.extend_from_slice(&r.to_compressed());
let hash = sha256(&msg);
BigNumber::from_bytes_be(&hash).modulo(&BigNumber::secp256k1_order())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schnorr_roundtrip() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let proof =
Schnorr::generate_proof(&alice, &alice.public_key(), &bob.public_key(), &shared)
.unwrap();
assert!(Schnorr::verify_proof(
&alice.public_key(),
&bob.public_key(),
&shared,
&proof
));
}
#[test]
fn test_schnorr_wrong_secret_fails() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let proof =
Schnorr::generate_proof(&alice, &alice.public_key(), &bob.public_key(), &shared)
.unwrap();
let carol = PrivateKey::random();
let wrong_shared = alice.derive_shared_secret(&carol.public_key()).unwrap();
assert!(!Schnorr::verify_proof(
&alice.public_key(),
&bob.public_key(),
&wrong_shared,
&proof
));
}
#[test]
fn test_schnorr_wrong_public_key_fails() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let proof =
Schnorr::generate_proof(&alice, &alice.public_key(), &bob.public_key(), &shared)
.unwrap();
let wrong_pubkey = PrivateKey::random().public_key();
assert!(!Schnorr::verify_proof(
&wrong_pubkey,
&bob.public_key(),
&shared,
&proof
));
}
#[test]
fn test_schnorr_wrong_bob_pubkey_fails() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let proof =
Schnorr::generate_proof(&alice, &alice.public_key(), &bob.public_key(), &shared)
.unwrap();
let wrong_bob_pubkey = PrivateKey::random().public_key();
assert!(!Schnorr::verify_proof(
&alice.public_key(),
&wrong_bob_pubkey,
&shared,
&proof
));
}
#[test]
fn test_schnorr_mutual_verification() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let alice_shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let bob_shared = bob.derive_shared_secret(&alice.public_key()).unwrap();
assert_eq!(alice_shared.to_compressed(), bob_shared.to_compressed());
let alice_proof = Schnorr::generate_proof(
&alice,
&alice.public_key(),
&bob.public_key(),
&alice_shared,
)
.unwrap();
let bob_proof =
Schnorr::generate_proof(&bob, &bob.public_key(), &alice.public_key(), &bob_shared)
.unwrap();
assert!(Schnorr::verify_proof(
&alice.public_key(),
&bob.public_key(),
&alice_shared,
&alice_proof
));
assert!(Schnorr::verify_proof(
&bob.public_key(),
&alice.public_key(),
&bob_shared,
&bob_proof
));
}
#[test]
fn test_schnorr_deterministic_challenge() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let nonce = PrivateKey::random();
let r = nonce.public_key();
let s_prime = bob.public_key().mul_scalar(&nonce.to_bytes()).unwrap();
let e1 = Schnorr::compute_challenge(
&alice.public_key(),
&bob.public_key(),
&shared,
&s_prime,
&r,
);
let e2 = Schnorr::compute_challenge(
&alice.public_key(),
&bob.public_key(),
&shared,
&s_prime,
&r,
);
assert_eq!(e1, e2);
}
#[test]
fn test_schnorr_challenge_changes_with_different_inputs() {
let alice = PrivateKey::random();
let bob = PrivateKey::random();
let shared = alice.derive_shared_secret(&bob.public_key()).unwrap();
let nonce1 = PrivateKey::random();
let r1 = nonce1.public_key();
let s_prime1 = bob.public_key().mul_scalar(&nonce1.to_bytes()).unwrap();
let nonce2 = PrivateKey::random();
let r2 = nonce2.public_key();
let s_prime2 = bob.public_key().mul_scalar(&nonce2.to_bytes()).unwrap();
let e1 = Schnorr::compute_challenge(
&alice.public_key(),
&bob.public_key(),
&shared,
&s_prime1,
&r1,
);
let e2 = Schnorr::compute_challenge(
&alice.public_key(),
&bob.public_key(),
&shared,
&s_prime2,
&r2,
);
assert_ne!(e1, e2);
}
}