use core::fmt;
use p256::{
elliptic_curve::{
ops::Reduce,
sec1::ToEncodedPoint,
Field,
},
AffinePoint, FieldBytes, ProjectivePoint, Scalar,
};
use rand_core::{CryptoRng, RngCore};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct DleqProof {
pub c: Scalar,
pub s: Scalar,
}
impl fmt::Debug for DleqProof {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DleqProof {{ c: 0x{}, s: 0x{} }}",
hex32(&self.c),
hex32(&self.s)
)
}
}
const DLEQ_DST: &[u8] = b"DLEQ-P256-v1";
fn challenge_scalar(
g: &AffinePoint,
y: &AffinePoint,
a: &AffinePoint,
b: &AffinePoint,
t1: &AffinePoint,
t2: &AffinePoint,
dst: &[u8],
) -> Scalar {
let mut hasher = Sha256::new();
hasher.update(u32::try_from(dst.len()).unwrap_or(0).to_be_bytes());
hasher.update(dst);
for p in [g, y, a, b, t1, t2] {
let enc = p.to_encoded_point(true);
hasher.update(enc.as_bytes());
}
let digest = hasher.finalize();
Scalar::reduce_bytes(&digest)
}
pub fn prove<R: RngCore + CryptoRng>(
k: &Scalar,
g: &AffinePoint,
y: &AffinePoint,
a: &AffinePoint,
b: &AffinePoint,
rng: &mut R,
dst: Option<&[u8]>,
) -> DleqProof {
let r = Scalar::random(rng);
let t1 = (ProjectivePoint::from(*g) * r).to_affine();
let t2 = (ProjectivePoint::from(*a) * r).to_affine();
let mut full_dst = Vec::with_capacity(DLEQ_DST.len() + dst.map_or(0, |d| d.len()));
full_dst.extend_from_slice(DLEQ_DST);
if let Some(extra) = dst {
full_dst.extend_from_slice(extra);
}
let c = challenge_scalar(g, y, a, b, &t1, &t2, &full_dst);
let s = r + c * *k;
DleqProof { c, s }
}
pub fn verify(
g: &AffinePoint,
y: &AffinePoint,
a: &AffinePoint,
b: &AffinePoint,
proof: &DleqProof,
dst: Option<&[u8]>,
) -> bool {
let s_g = ProjectivePoint::from(*g) * proof.s;
let c_y = ProjectivePoint::from(*y) * proof.c;
let t1_prime = (s_g - c_y).to_affine();
let s_a = ProjectivePoint::from(*a) * proof.s;
let c_b = ProjectivePoint::from(*b) * proof.c;
let t2_prime = (s_a - c_b).to_affine();
let mut full_dst = Vec::with_capacity(DLEQ_DST.len() + dst.map_or(0, |d| d.len()));
full_dst.extend_from_slice(DLEQ_DST);
if let Some(extra) = dst {
full_dst.extend_from_slice(extra);
}
let c_check = challenge_scalar(g, y, a, b, &t1_prime, &t2_prime, &full_dst);
bool::from(c_check.to_bytes().ct_eq(&proof.c.to_bytes()))
}
pub fn encode_proof(proof: &DleqProof) -> [u8; 64] {
let mut out = [0u8; 64];
out[..32].copy_from_slice(&proof.c.to_bytes());
out[32..].copy_from_slice(&proof.s.to_bytes());
out
}
pub fn decode_proof(bytes: &[u8; 64]) -> DleqProof {
let c_bytes: [u8; 32] = bytes[..32].try_into().expect("slice is 32 bytes");
let s_bytes: [u8; 32] = bytes[32..].try_into().expect("slice is 32 bytes");
let c = Scalar::reduce_bytes(&FieldBytes::from(c_bytes));
let s = Scalar::reduce_bytes(&FieldBytes::from(s_bytes));
DleqProof { c, s }
}
fn hex32(x: &Scalar) -> String {
let b = x.to_bytes();
b.iter().map(|byte| format!("{:02x}", byte)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use p256::{ProjectivePoint, Scalar};
use rand_core::OsRng;
#[test]
fn round_trip_proof() {
let mut rng = OsRng;
let k = Scalar::random(&mut rng);
let g = AffinePoint::GENERATOR;
let a = (ProjectivePoint::GENERATOR * Scalar::random(&mut rng)).to_affine();
let y = (ProjectivePoint::from(g) * k).to_affine();
let b = (ProjectivePoint::from(a) * k).to_affine();
let proof = prove(&k, &g, &y, &a, &b, &mut rng, Some(b"test-dst"));
assert!(verify(&g, &y, &a, &b, &proof, Some(b"test-dst")));
let enc = encode_proof(&proof);
let dec = decode_proof(&enc);
assert_eq!(proof, dec);
}
#[test]
fn detect_bad_proof() {
let mut rng = OsRng;
let k = Scalar::random(&mut rng);
let g = AffinePoint::GENERATOR;
let a = (ProjectivePoint::GENERATOR * Scalar::random(&mut rng)).to_affine();
let y = (ProjectivePoint::from(g) * k).to_affine();
let b = (ProjectivePoint::from(a) * k).to_affine();
let mut proof = prove(&k, &g, &y, &a, &b, &mut rng, None);
proof.s = proof.s + Scalar::ONE;
assert!(!verify(&g, &y, &a, &b, &proof, None));
}
#[test]
fn test_constant_time_verification() {
let mut rng = OsRng;
let k = Scalar::random(&mut rng);
let g = AffinePoint::GENERATOR;
let a = (ProjectivePoint::GENERATOR * Scalar::random(&mut rng)).to_affine();
let y = (ProjectivePoint::from(g) * k).to_affine();
let b = (ProjectivePoint::from(a) * k).to_affine();
let proof = prove(&k, &g, &y, &a, &b, &mut rng, Some(b"test"));
assert!(verify(&g, &y, &a, &b, &proof, Some(b"test")));
let mut c_bytes = proof.c.to_bytes();
for byte_idx in 0..32 {
for bit_idx in 0..8 {
c_bytes[byte_idx] ^= 1 << bit_idx;
let c_modified = Scalar::reduce_bytes(&FieldBytes::clone_from_slice(&c_bytes));
let modified_proof = DleqProof {
c: c_modified,
s: proof.s,
};
assert!(
!verify(&g, &y, &a, &b, &modified_proof, Some(b"test")),
"Failed to detect bit flip at byte {} bit {}",
byte_idx,
bit_idx
);
c_bytes[byte_idx] ^= 1 << bit_idx;
}
}
}
#[test]
fn test_proof_rejection_patterns() {
let mut rng = OsRng;
let k = Scalar::random(&mut rng);
let g = AffinePoint::GENERATOR;
let a = (ProjectivePoint::GENERATOR * Scalar::random(&mut rng)).to_affine();
let y = (ProjectivePoint::from(g) * k).to_affine();
let b = (ProjectivePoint::from(a) * k).to_affine();
let proof = prove(&k, &g, &y, &a, &b, &mut rng, Some(b"test"));
assert!(verify(&g, &y, &a, &b, &proof, Some(b"test")));
let bad_proof_1 = DleqProof {
c: proof.c + Scalar::ONE,
s: proof.s,
};
assert!(!verify(&g, &y, &a, &b, &bad_proof_1, Some(b"test")));
let bad_proof_2 = DleqProof {
c: proof.c,
s: proof.s + Scalar::ONE,
};
assert!(!verify(&g, &y, &a, &b, &bad_proof_2, Some(b"test")));
assert!(!verify(&g, &y, &a, &b, &proof, Some(b"wrong-dst")));
let bad_proof_3 = DleqProof {
c: proof.s,
s: proof.c,
};
assert!(!verify(&g, &y, &a, &b, &bad_proof_3, Some(b"test")));
let bad_proof_4 = DleqProof {
c: Scalar::ZERO,
s: proof.s,
};
assert!(!verify(&g, &y, &a, &b, &bad_proof_4, Some(b"test")));
}
}