use crate::primitives::hash::sha2::sha256;
use crate::zkp::error::{Result, ZkpError};
use k256::{
FieldBytes, ProjectivePoint, Scalar, SecretKey, U256,
elliptic_curve::{PrimeField, group::GroupEncoding, ops::Reduce},
};
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
fn fiat_shamir_challenge(
public_key: &[u8; 33],
r_bytes: &[u8; 33],
context: &[u8],
) -> Result<Scalar> {
let label = b"arc-zkp/schnorr-v1";
let curve = b"secp256k1";
let mut buf = Vec::with_capacity(
label
.len()
.saturating_add(curve.len())
.saturating_add(33 * 2)
.saturating_add(context.len()),
);
buf.extend_from_slice(label);
buf.extend_from_slice(curve);
buf.extend_from_slice(public_key);
buf.extend_from_slice(r_bytes);
buf.extend_from_slice(context);
let hash =
sha256(&buf).map_err(|e| ZkpError::SerializationError(format!("SHA-256 failed: {}", e)))?;
Ok(<Scalar as Reduce<U256>>::reduce_bytes(FieldBytes::from_slice(&hash)))
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "zkp-serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "zkp-serde", serde(crate = "serde"))]
pub struct SchnorrProof {
#[cfg_attr(feature = "zkp-serde", serde(with = "serde_with::As::<serde_with::Bytes>"))]
commitment: [u8; 33],
#[cfg_attr(feature = "zkp-serde", serde(with = "serde_with::As::<serde_with::Bytes>"))]
response: [u8; 32],
}
impl std::fmt::Debug for SchnorrProof {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchnorrProof")
.field("commitment", &"[REDACTED]")
.field("response", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for SchnorrProof {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.commitment.ct_eq(&other.commitment) & self.response.ct_eq(&other.response)
}
}
impl SchnorrProof {
#[must_use]
pub fn new(commitment: [u8; 33], response: [u8; 32]) -> Self {
Self { commitment, response }
}
#[must_use]
pub fn commitment(&self) -> &[u8; 33] {
&self.commitment
}
#[must_use]
pub fn response(&self) -> &[u8; 32] {
&self.response
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct SchnorrProver {
secret: [u8; 32],
#[zeroize(skip)]
public_key: [u8; 33],
}
impl std::fmt::Debug for SchnorrProver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchnorrProver")
.field("secret", &"[REDACTED]")
.field("public_key", &"[public]")
.finish()
}
}
impl ConstantTimeEq for SchnorrProver {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.secret.ct_eq(&other.secret) & self.public_key.ct_eq(&other.public_key)
}
}
impl SchnorrProver {
pub fn new() -> Result<(Self, [u8; 33])> {
let secret_bytes = crate::primitives::rand::csprng::random_bytes(32);
let secret_key = SecretKey::from_slice(&secret_bytes)
.map_err(|e| ZkpError::SerializationError(format!("Invalid secret key: {e}")))?;
let public_key = secret_key.public_key();
let secret_bytes: [u8; 32] = secret_key.to_bytes().into();
let public_bytes: [u8; 33] = <[u8; 33]>::try_from(public_key.to_sec1_bytes().as_ref())
.map_err(|e| {
ZkpError::SerializationError(format!("Failed to serialize public key: {}", e))
})?;
let prover = Self { secret: secret_bytes, public_key: public_bytes };
Ok((prover, public_bytes))
}
pub fn from_secret(secret: &[u8; 32]) -> Result<(Self, [u8; 33])> {
let secret_key = SecretKey::from_bytes(secret.into())
.map_err(|e| ZkpError::SerializationError(format!("Invalid secret key format: {e}")))?;
let public_key = secret_key.public_key();
let public_bytes: [u8; 33] = <[u8; 33]>::try_from(public_key.to_sec1_bytes().as_ref())
.map_err(|e| {
ZkpError::SerializationError(format!("Failed to serialize public key: {}", e))
})?;
let prover = Self { secret: *secret, public_key: public_bytes };
Ok((prover, public_bytes))
}
#[allow(clippy::arithmetic_side_effects)] pub fn prove(&self, context: &[u8]) -> Result<SchnorrProof> {
let x: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(&self.secret)).into();
let x = x.ok_or(ZkpError::InvalidScalar)?;
let nonce_bytes = crate::primitives::rand::csprng::random_bytes(32);
let k = <Scalar as Reduce<U256>>::reduce_bytes(FieldBytes::from_slice(&nonce_bytes));
let r_point = ProjectivePoint::GENERATOR * k;
let r_bytes: [u8; 33] = <[u8; 33]>::try_from(r_point.to_affine().to_bytes().as_slice())
.map_err(|e| ZkpError::SerializationError(format!("Failed to serialize R: {}", e)))?;
let c = fiat_shamir_challenge(&self.public_key, &r_bytes, context)?;
let s = k + c * x;
let s_bytes: [u8; 32] = s.to_bytes().into();
Ok(SchnorrProof { commitment: r_bytes, response: s_bytes })
}
#[must_use]
pub fn public_key(&self) -> &[u8; 33] {
&self.public_key
}
}
pub struct SchnorrVerifier {
public_key: [u8; 33],
}
impl SchnorrVerifier {
#[must_use]
pub fn new(public_key: [u8; 33]) -> Self {
Self { public_key }
}
#[allow(clippy::arithmetic_side_effects)] pub fn verify(&self, proof: &SchnorrProof, context: &[u8]) -> Result<bool> {
let p_point = Self::parse_point(&self.public_key)?;
let r_point = Self::parse_point(proof.commitment())?;
let s: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(proof.response())).into();
let s = s.ok_or(ZkpError::InvalidScalar)?;
let c = fiat_shamir_challenge(&self.public_key, proof.commitment(), context)?;
let lhs = ProjectivePoint::GENERATOR * s;
let rhs = r_point + p_point * c;
Ok(bool::from(lhs.ct_eq(&rhs)))
}
fn parse_point(bytes: &[u8; 33]) -> Result<ProjectivePoint> {
use k256::EncodedPoint;
use k256::elliptic_curve::sec1::FromEncodedPoint;
let encoded = EncodedPoint::from_bytes(bytes)
.map_err(|e| ZkpError::SerializationError(format!("Invalid point encoding: {}", e)))?;
let point: Option<ProjectivePoint> = ProjectivePoint::from_encoded_point(&encoded).into();
point.ok_or(ZkpError::InvalidPublicKey)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_schnorr_proof_valid_succeeds() {
let (prover, public_key) = SchnorrProver::new().unwrap();
let context = b"test challenge context";
let proof = prover.prove(context).unwrap();
let verifier = SchnorrVerifier::new(public_key);
assert!(verifier.verify(&proof, context).unwrap());
}
#[test]
fn test_schnorr_proof_wrong_context_fails() {
let (prover, public_key) = SchnorrProver::new().unwrap();
let proof = prover.prove(b"context 1").unwrap();
let verifier = SchnorrVerifier::new(public_key);
assert!(!verifier.verify(&proof, b"context 2").unwrap());
}
#[test]
fn test_schnorr_proof_wrong_public_key_fails() {
let (prover, _) = SchnorrProver::new().unwrap();
let (_, other_public_key) = SchnorrProver::new().unwrap();
let context = b"test";
let proof = prover.prove(context).unwrap();
let verifier = SchnorrVerifier::new(other_public_key);
assert!(!verifier.verify(&proof, context).unwrap());
}
#[test]
fn test_schnorr_from_secret_succeeds() {
let secret = [42u8; 32];
let (prover1, pk1) = SchnorrProver::from_secret(&secret).unwrap();
let (_prover2, pk2) = SchnorrProver::from_secret(&secret).unwrap();
assert_eq!(pk1, pk2);
let proof = prover1.prove(b"test").unwrap();
let verifier = SchnorrVerifier::new(pk2);
assert!(verifier.verify(&proof, b"test").unwrap());
}
}