use crate::primitives::hash::sha2::sha256;
use crate::zkp::error::{Result, ZkpError};
use k256::{
FieldBytes, ProjectivePoint, Scalar, SecretKey,
elliptic_curve::{PrimeField, group::GroupEncoding},
};
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
fn fiat_shamir_challenge(
public_key: &[u8; 33],
r_bytes: &[u8; 33],
context: &[u8],
) -> Result<Scalar> {
let label = b"arc-zkp/schnorr-v2";
let curve = b"secp256k1";
let mut counter: u32 = 0;
loop {
let mut buf = Vec::with_capacity(
label
.len()
.saturating_add(curve.len())
.saturating_add(33 * 2)
.saturating_add(context.len())
.saturating_add(4),
);
buf.extend_from_slice(label);
buf.extend_from_slice(curve);
buf.extend_from_slice(public_key);
buf.extend_from_slice(r_bytes);
buf.extend_from_slice(context);
buf.extend_from_slice(&counter.to_be_bytes());
let hash = sha256(&buf)
.map_err(|e| ZkpError::SerializationError(format!("SHA-256 failed: {}", e)))?;
let cand: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(&hash)).into();
if let Some(s) = cand
&& !bool::from(s.ct_eq(&Scalar::ZERO))
{
return Ok(s);
}
counter = counter.checked_add(1).ok_or_else(|| {
ZkpError::SerializationError(
"schnorr challenge derivation: counter overflow (statistically impossible)"
.to_string(),
)
})?;
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "zkp-serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "zkp-serde", serde(crate = "serde"))]
pub struct SchnorrProof {
#[cfg_attr(feature = "zkp-serde", serde(with = "serde_with::As::<serde_with::Bytes>"))]
commitment: [u8; 33],
#[cfg_attr(feature = "zkp-serde", serde(with = "serde_with::As::<serde_with::Bytes>"))]
response: [u8; 32],
}
impl std::fmt::Debug for SchnorrProof {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchnorrProof")
.field("commitment", &"[REDACTED]")
.field("response", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for SchnorrProof {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.commitment.ct_eq(&other.commitment) & self.response.ct_eq(&other.response)
}
}
impl SchnorrProof {
#[must_use]
pub fn new(commitment: [u8; 33], response: [u8; 32]) -> Self {
Self { commitment, response }
}
#[must_use]
pub fn clone_for_transmission(&self) -> Self {
Self { commitment: self.commitment, response: self.response }
}
#[must_use]
pub fn commitment(&self) -> &[u8; 33] {
&self.commitment
}
#[must_use]
pub fn response(&self) -> &[u8; 32] {
&self.response
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct SchnorrProver {
secret: [u8; 32],
#[zeroize(skip)]
public_key: [u8; 33],
}
impl std::fmt::Debug for SchnorrProver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchnorrProver")
.field("secret", &"[REDACTED]")
.field("public_key", &"[public]")
.finish()
}
}
impl ConstantTimeEq for SchnorrProver {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.secret.ct_eq(&other.secret) & self.public_key.ct_eq(&other.public_key)
}
}
impl SchnorrProver {
pub fn new() -> Result<(Self, [u8; 33])> {
let initial_bytes =
zeroize::Zeroizing::new(crate::primitives::rand::csprng::random_bytes(32));
let secret_key = SecretKey::from_slice(&initial_bytes)
.map_err(|e| ZkpError::SerializationError(format!("Invalid secret key: {e}")))?;
let public_key = secret_key.public_key();
let secret_bytes_zeroizing: zeroize::Zeroizing<[u8; 32]> =
zeroize::Zeroizing::new(secret_key.to_bytes().into());
let public_bytes: [u8; 33] = <[u8; 33]>::try_from(public_key.to_sec1_bytes().as_ref())
.map_err(|e| {
ZkpError::SerializationError(format!("Failed to serialize public key: {}", e))
})?;
let prover = Self { secret: *secret_bytes_zeroizing, public_key: public_bytes };
Ok((prover, public_bytes))
}
pub fn from_secret(secret: &[u8; 32]) -> Result<(Self, [u8; 33])> {
let secret_key = SecretKey::from_bytes(secret.into())
.map_err(|e| ZkpError::SerializationError(format!("Invalid secret key format: {e}")))?;
let public_key = secret_key.public_key();
let public_bytes: [u8; 33] = <[u8; 33]>::try_from(public_key.to_sec1_bytes().as_ref())
.map_err(|e| {
ZkpError::SerializationError(format!("Failed to serialize public key: {}", e))
})?;
let prover = Self { secret: *secret, public_key: public_bytes };
Ok((prover, public_bytes))
}
#[expect(
clippy::arithmetic_side_effects,
reason = "EC scalar math is modular, cannot overflow"
)]
pub fn prove(&self, context: &[u8]) -> Result<SchnorrProof> {
use zeroize::Zeroizing;
let x: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(&self.secret)).into();
let x = Zeroizing::new(x.ok_or(ZkpError::InvalidScalar)?);
let k_scalar: Scalar = loop {
let nonce_bytes = Zeroizing::new(crate::primitives::rand::csprng::random_bytes(32));
let candidate: Option<Scalar> =
Scalar::from_repr(*FieldBytes::from_slice(&nonce_bytes)).into();
if let Some(s) = candidate
&& !bool::from(s.ct_eq(&Scalar::ZERO))
{
break s;
}
};
let k = Zeroizing::new(k_scalar);
let r_point = ProjectivePoint::GENERATOR * *k;
let r_bytes: [u8; 33] = <[u8; 33]>::try_from(r_point.to_affine().to_bytes().as_slice())
.map_err(|e| ZkpError::SerializationError(format!("Failed to serialize R: {}", e)))?;
let c = fiat_shamir_challenge(&self.public_key, &r_bytes, context)?;
let s = Zeroizing::new(*k + c * *x);
let mut tmp_bytes = Zeroizing::new(<[u8; 32]>::from(s.to_bytes()));
let response: [u8; 32] = *tmp_bytes;
tmp_bytes.zeroize();
Ok(SchnorrProof { commitment: r_bytes, response })
}
#[must_use]
pub fn public_key(&self) -> &[u8; 33] {
&self.public_key
}
}
pub struct SchnorrVerifier {
public_key: [u8; 33],
}
impl SchnorrVerifier {
#[must_use]
pub fn new(public_key: [u8; 33]) -> Self {
Self { public_key }
}
#[expect(clippy::arithmetic_side_effects, reason = "EC math is modular, cannot overflow")]
pub fn verify(&self, proof: &SchnorrProof, context: &[u8]) -> Result<bool> {
let p_point = Self::parse_point(&self.public_key)?;
let r_point = Self::parse_point(proof.commitment())?;
let s: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(proof.response())).into();
let s = s.ok_or(ZkpError::InvalidScalar)?;
let c = fiat_shamir_challenge(&self.public_key, proof.commitment(), context)?;
let lhs = ProjectivePoint::GENERATOR * s;
let rhs = r_point + p_point * c;
Ok(bool::from(lhs.ct_eq(&rhs)))
}
fn parse_point(bytes: &[u8; 33]) -> Result<ProjectivePoint> {
use k256::EncodedPoint;
use k256::elliptic_curve::Group;
use k256::elliptic_curve::sec1::FromEncodedPoint;
let encoded = EncodedPoint::from_bytes(bytes)
.map_err(|e| ZkpError::SerializationError(format!("Invalid point encoding: {}", e)))?;
let point: Option<ProjectivePoint> = ProjectivePoint::from_encoded_point(&encoded).into();
let p = point.ok_or(ZkpError::InvalidPublicKey)?;
if bool::from(p.is_identity()) {
return Err(ZkpError::InvalidPublicKey);
}
Ok(p)
}
}
#[cfg(test)]
#[expect(
clippy::unwrap_used,
reason = "test/bench code: unwrap is acceptable when inputs are statically known"
)]
mod tests {
use super::*;
#[test]
fn test_schnorr_proof_valid_succeeds() {
let (prover, public_key) = SchnorrProver::new().unwrap();
let context = b"test challenge context";
let proof = prover.prove(context).unwrap();
let verifier = SchnorrVerifier::new(public_key);
assert!(verifier.verify(&proof, context).unwrap());
}
#[test]
fn test_schnorr_proof_wrong_context_fails() {
let (prover, public_key) = SchnorrProver::new().unwrap();
let proof = prover.prove(b"context 1").unwrap();
let verifier = SchnorrVerifier::new(public_key);
assert!(!verifier.verify(&proof, b"context 2").unwrap());
}
#[test]
fn test_schnorr_proof_wrong_public_key_fails() {
let (prover, _) = SchnorrProver::new().unwrap();
let (_, other_public_key) = SchnorrProver::new().unwrap();
let context = b"test";
let proof = prover.prove(context).unwrap();
let verifier = SchnorrVerifier::new(other_public_key);
assert!(!verifier.verify(&proof, context).unwrap());
}
#[test]
fn test_schnorr_from_secret_succeeds() {
let secret = [42u8; 32];
let (prover1, pk1) = SchnorrProver::from_secret(&secret).unwrap();
let (_prover2, pk2) = SchnorrProver::from_secret(&secret).unwrap();
assert_eq!(pk1, pk2);
let proof = prover1.prove(b"test").unwrap();
let verifier = SchnorrVerifier::new(pk2);
assert!(verifier.verify(&proof, b"test").unwrap());
}
#[test]
fn test_schnorr_proof_clone_for_transmission_independent_storage() {
let proof = SchnorrProof::new([0xAA; 33], [0xBB; 32]);
let mut cloned = proof.clone_for_transmission();
Zeroize::zeroize(&mut cloned);
assert_eq!(*proof.commitment(), [0xAA; 33]);
assert_eq!(*proof.response(), [0xBB; 32]);
assert_eq!(*cloned.commitment(), [0u8; 33]);
assert_eq!(*cloned.response(), [0u8; 32]);
}
#[test]
fn test_schnorr_proof_zeroize_wipes_all_fields() {
let mut proof = SchnorrProof::new([0xAA; 33], [0xBB; 32]);
Zeroize::zeroize(&mut proof);
assert_eq!(*proof.commitment(), [0u8; 33]);
assert_eq!(*proof.response(), [0u8; 32]);
}
}