use crate::primitives::hash::sha2::sha256;
use crate::zkp::error::{Result, ZkpError};
use k256::elliptic_curve::{PrimeField, ops::Reduce};
use subtle::ConstantTimeEq;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[cfg_attr(feature = "zkp-serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SigmaProof {
commitment: Vec<u8>,
challenge: [u8; 32],
response: Vec<u8>,
}
impl SigmaProof {
#[must_use]
pub fn new(commitment: Vec<u8>, challenge: [u8; 32], response: Vec<u8>) -> Self {
Self { commitment, challenge, response }
}
#[must_use]
pub fn commitment(&self) -> &[u8] {
&self.commitment
}
#[must_use]
pub fn challenge(&self) -> &[u8; 32] {
&self.challenge
}
#[must_use]
pub fn response(&self) -> &[u8] {
&self.response
}
#[must_use]
pub fn challenge_mut(&mut self) -> &mut [u8; 32] {
&mut self.challenge
}
}
impl std::fmt::Debug for SigmaProof {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SigmaProof")
.field("commitment", &format!("[{} bytes]", self.commitment.len()))
.field("challenge", &"[REDACTED]")
.field("response", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for SigmaProof {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.commitment.ct_eq(&other.commitment)
& self.challenge.ct_eq(&other.challenge)
& self.response.ct_eq(&other.response)
}
}
pub trait SigmaProtocol {
type Statement;
type Witness;
type Commitment;
type Response;
fn commit(
&self,
statement: &Self::Statement,
witness: &Self::Witness,
) -> Result<(Self::Commitment, Vec<u8>)>;
fn respond(
&self,
witness: &Self::Witness,
commitment_state: Vec<u8>,
challenge: &[u8; 32],
) -> Result<Self::Response>;
fn verify(
&self,
statement: &Self::Statement,
commitment: &Self::Commitment,
challenge: &[u8; 32],
response: &Self::Response,
) -> Result<bool>;
fn serialize_commitment(&self, commitment: &Self::Commitment) -> Vec<u8>;
fn deserialize_commitment(&self, bytes: &[u8]) -> Result<Self::Commitment>;
fn serialize_response(&self, response: &Self::Response) -> Vec<u8>;
fn deserialize_response(&self, bytes: &[u8]) -> Result<Self::Response>;
fn serialize_statement(&self, statement: &Self::Statement) -> Vec<u8>;
}
pub struct FiatShamir<P: SigmaProtocol> {
protocol: P,
domain_separator: Vec<u8>,
}
impl<P: SigmaProtocol> FiatShamir<P> {
#[must_use]
pub fn new(protocol: P, domain_separator: &[u8]) -> Self {
Self { protocol, domain_separator: domain_separator.to_vec() }
}
pub fn prove(
&self,
statement: &P::Statement,
witness: &P::Witness,
context: &[u8],
) -> Result<SigmaProof> {
let (commitment, commit_state) = self.protocol.commit(statement, witness)?;
let commitment_bytes = self.protocol.serialize_commitment(&commitment);
let challenge = self.compute_challenge(statement, &commitment_bytes, context)?;
let response = self.protocol.respond(witness, commit_state, &challenge)?;
let response_bytes = self.protocol.serialize_response(&response);
Ok(SigmaProof::new(commitment_bytes, challenge, response_bytes))
}
pub fn verify(
&self,
statement: &P::Statement,
proof: &SigmaProof,
context: &[u8],
) -> Result<bool> {
let expected_challenge = self.compute_challenge(statement, proof.commitment(), context)?;
if expected_challenge.ct_eq(proof.challenge()).unwrap_u8() == 0 {
return Ok(false);
}
let commitment = self.protocol.deserialize_commitment(proof.commitment())?;
let response = self.protocol.deserialize_response(proof.response())?;
self.protocol.verify(statement, &commitment, proof.challenge(), &response)
}
fn compute_challenge(
&self,
statement: &P::Statement,
commitment: &[u8],
context: &[u8],
) -> Result<[u8; 32]> {
let statement_bytes = self.protocol.serialize_statement(statement);
let statement_len = u32::try_from(statement_bytes.len()).unwrap_or(u32::MAX);
let commitment_len = u32::try_from(commitment.len()).unwrap_or(u32::MAX);
let context_len = u32::try_from(context.len()).unwrap_or(u32::MAX);
let mut buf = Vec::with_capacity(
self.domain_separator
.len()
.saturating_add(4)
.saturating_add(statement_bytes.len())
.saturating_add(4)
.saturating_add(commitment.len())
.saturating_add(4)
.saturating_add(context.len()),
);
buf.extend_from_slice(&self.domain_separator);
buf.extend_from_slice(&statement_len.to_le_bytes());
buf.extend_from_slice(&statement_bytes);
buf.extend_from_slice(&commitment_len.to_le_bytes());
buf.extend_from_slice(commitment);
buf.extend_from_slice(&context_len.to_le_bytes());
buf.extend_from_slice(context);
sha256(&buf).map_err(|e| ZkpError::SerializationError(format!("SHA-256 failed: {}", e)))
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct DlogEqualityProof {
a: [u8; 33],
b: [u8; 33],
challenge: [u8; 32],
response: [u8; 32],
}
impl DlogEqualityProof {
#[must_use]
pub fn new(a: [u8; 33], b: [u8; 33], challenge: [u8; 32], response: [u8; 32]) -> Self {
Self { a, b, challenge, response }
}
#[must_use]
pub fn a(&self) -> &[u8; 33] {
&self.a
}
#[must_use]
pub fn b(&self) -> &[u8; 33] {
&self.b
}
#[must_use]
pub fn challenge(&self) -> &[u8; 32] {
&self.challenge
}
#[must_use]
pub fn response(&self) -> &[u8; 32] {
&self.response
}
}
impl std::fmt::Debug for DlogEqualityProof {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DlogEqualityProof")
.field("a", &"[REDACTED]")
.field("b", &"[REDACTED]")
.field("challenge", &"[REDACTED]")
.field("response", &"[REDACTED]")
.finish()
}
}
impl ConstantTimeEq for DlogEqualityProof {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.a.ct_eq(&other.a)
& self.b.ct_eq(&other.b)
& self.challenge.ct_eq(&other.challenge)
& self.response.ct_eq(&other.response)
}
}
#[derive(Debug, Clone)]
pub struct DlogEqualityStatement {
pub g: [u8; 33],
pub h: [u8; 33],
pub p: [u8; 33],
pub q: [u8; 33],
}
impl DlogEqualityProof {
#[allow(clippy::arithmetic_side_effects)] pub fn prove(
statement: &DlogEqualityStatement,
secret: &[u8; 32],
context: &[u8],
) -> Result<Self> {
use k256::{
FieldBytes, Scalar, U256,
elliptic_curve::{group::GroupEncoding, ops::Reduce},
};
let g = Self::parse_point(&statement.g)?;
let h = Self::parse_point(&statement.h)?;
let x: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(secret)).into();
let x = x.ok_or(ZkpError::InvalidScalar)?;
let nonce_bytes = crate::primitives::rand::csprng::random_bytes(32);
let k = <Scalar as Reduce<U256>>::reduce_bytes(FieldBytes::from_slice(&nonce_bytes));
let a_point = g * k;
let b_point = h * k;
let a_bytes: [u8; 33] = <[u8; 33]>::try_from(a_point.to_affine().to_bytes().as_slice())
.map_err(|e| ZkpError::SerializationError(format!("Failed to serialize A: {}", e)))?;
let b_bytes: [u8; 33] = <[u8; 33]>::try_from(b_point.to_affine().to_bytes().as_slice())
.map_err(|e| ZkpError::SerializationError(format!("Failed to serialize B: {}", e)))?;
let challenge = Self::compute_challenge(statement, &a_bytes, &b_bytes, context)?;
let c = <Scalar as Reduce<U256>>::reduce_bytes(FieldBytes::from_slice(&challenge));
let s = k + c * x;
let response: [u8; 32] = s.to_bytes().into();
Ok(Self { a: a_bytes, b: b_bytes, challenge, response })
}
#[allow(clippy::arithmetic_side_effects)] pub fn verify(&self, statement: &DlogEqualityStatement, context: &[u8]) -> Result<bool> {
use k256::{FieldBytes, Scalar, U256};
let g = Self::parse_point(&statement.g)?;
let h = Self::parse_point(&statement.h)?;
let p = Self::parse_point(&statement.p)?;
let q = Self::parse_point(&statement.q)?;
let a = Self::parse_point(&self.a)?;
let b = Self::parse_point(&self.b)?;
let expected_challenge = Self::compute_challenge(statement, &self.a, &self.b, context)?;
if expected_challenge.ct_eq(&self.challenge).unwrap_u8() == 0 {
return Ok(false);
}
let s: Option<Scalar> = Scalar::from_repr(*FieldBytes::from_slice(&self.response)).into();
let s = s.ok_or(ZkpError::InvalidScalar)?;
let c = <Scalar as Reduce<U256>>::reduce_bytes(FieldBytes::from_slice(&self.challenge));
let lhs1 = g * s;
let rhs1 = a + p * c;
let lhs2 = h * s;
let rhs2 = b + q * c;
Ok(bool::from(lhs1.ct_eq(&rhs1)) & bool::from(lhs2.ct_eq(&rhs2)))
}
fn parse_point(bytes: &[u8; 33]) -> Result<k256::ProjectivePoint> {
use k256::EncodedPoint;
use k256::elliptic_curve::sec1::FromEncodedPoint;
let encoded = EncodedPoint::from_bytes(bytes)
.map_err(|e| ZkpError::SerializationError(format!("Invalid point encoding: {}", e)))?;
let point: Option<k256::ProjectivePoint> =
k256::ProjectivePoint::from_encoded_point(&encoded).into();
point.ok_or(ZkpError::InvalidPublicKey)
}
fn compute_challenge(
statement: &DlogEqualityStatement,
a: &[u8; 33],
b: &[u8; 33],
context: &[u8],
) -> Result<[u8; 32]> {
let label = b"arc-zkp/dlog-equality-v1";
let mut buf = Vec::with_capacity(
label.len().saturating_add(33 * 4).saturating_add(33 * 2).saturating_add(context.len()),
);
buf.extend_from_slice(label);
buf.extend_from_slice(&statement.g);
buf.extend_from_slice(&statement.h);
buf.extend_from_slice(&statement.p);
buf.extend_from_slice(&statement.q);
buf.extend_from_slice(a);
buf.extend_from_slice(b);
buf.extend_from_slice(context);
sha256(&buf).map_err(|e| ZkpError::SerializationError(format!("SHA-256 failed: {}", e)))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use k256::{
FieldBytes, ProjectivePoint, Scalar, SecretKey, elliptic_curve::group::GroupEncoding,
};
#[test]
fn test_dlog_equality_proof_succeeds() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"test").unwrap();
assert!(proof.verify(&statement, b"test").unwrap());
}
#[test]
fn test_dlog_equality_wrong_context_fails_verification_fails() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"context1").unwrap();
assert!(!proof.verify(&statement, b"context2").unwrap());
}
#[test]
fn test_dlog_equality_wrong_secret_fails_verification_fails() {
let x_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = x_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let y_key = SecretKey::random(&mut rand::thread_rng());
let y: [u8; 32] = y_key.to_bytes().into();
let y_scalar = Scalar::from_repr(*FieldBytes::from_slice(&y)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(3u64);
let p = g * x_scalar; let q = h * y_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"test").unwrap();
assert!(!proof.verify(&statement, b"test").unwrap());
}
#[test]
fn test_dlog_equality_tampered_challenge_fails_verification_fails() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let mut proof = DlogEqualityProof::prove(&statement, &x, b"test").unwrap();
proof.challenge[0] ^= 0xFF;
assert!(!proof.verify(&statement, b"test").unwrap());
}
#[test]
fn test_dlog_equality_tampered_response_fails_verification_fails() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let mut proof = DlogEqualityProof::prove(&statement, &x, b"test").unwrap();
proof.response[0] ^= 0xFF;
let result = proof.verify(&statement, b"test");
if let Ok(valid) = result {
assert!(!valid);
}
}
#[test]
fn test_dlog_equality_invalid_point_returns_error() {
let mut invalid_point: [u8; 33] = [0x05; 33]; invalid_point[0] = 0x05;
let valid_g: [u8; 33] =
<[u8; 33]>::try_from(ProjectivePoint::GENERATOR.to_affine().to_bytes().as_slice())
.unwrap();
let statement =
DlogEqualityStatement { g: invalid_point, h: valid_g, p: valid_g, q: valid_g };
let secret = [1u8; 32];
let result = DlogEqualityProof::prove(&statement, &secret, b"test");
assert!(result.is_err());
}
#[test]
fn test_dlog_equality_proof_fields_are_populated_succeeds() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"test").unwrap();
assert_eq!(proof.a.len(), 33);
assert_eq!(proof.b.len(), 33);
assert_eq!(proof.challenge.len(), 32);
assert_eq!(proof.response.len(), 32);
let proof2 = proof.clone();
assert_eq!(proof.challenge, proof2.challenge);
let debug = format!("{:?}", proof);
assert!(debug.contains("DlogEqualityProof"));
}
#[test]
fn test_dlog_equality_statement_clone_debug_succeeds() {
let g: [u8; 33] =
<[u8; 33]>::try_from(ProjectivePoint::GENERATOR.to_affine().to_bytes().as_slice())
.unwrap();
let statement = DlogEqualityStatement { g, h: g, p: g, q: g };
let stmt2 = statement.clone();
assert_eq!(statement.g, stmt2.g);
let debug = format!("{:?}", statement);
assert!(debug.contains("DlogEqualityStatement"));
}
#[test]
fn test_sigma_proof_fields_are_populated_succeeds() {
let proof = SigmaProof::new(vec![1, 2, 3], [0u8; 32], vec![4, 5, 6]);
let proof2 = proof.clone();
assert_eq!(proof.commitment(), proof2.commitment());
assert_eq!(proof.challenge(), proof2.challenge());
assert_eq!(proof.response(), proof2.response());
let debug = format!("{:?}", proof);
assert!(debug.contains("SigmaProof"));
}
#[test]
fn test_dlog_equality_different_generators_succeeds() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(7u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"ctx").unwrap();
assert!(proof.verify(&statement, b"ctx").unwrap());
}
struct MockSigmaProtocol;
impl SigmaProtocol for MockSigmaProtocol {
type Statement = Vec<u8>;
type Witness = Vec<u8>;
type Commitment = Vec<u8>;
type Response = Vec<u8>;
fn commit(
&self,
_statement: &Self::Statement,
witness: &Self::Witness,
) -> Result<(Self::Commitment, Vec<u8>)> {
let mut buf = Vec::with_capacity(b"mock-commit".len() + witness.len());
buf.extend_from_slice(b"mock-commit");
buf.extend_from_slice(witness);
let commitment = sha256(&buf).unwrap().to_vec();
Ok((commitment, witness.clone()))
}
fn respond(
&self,
_witness: &Self::Witness,
commitment_state: Vec<u8>,
challenge: &[u8; 32],
) -> Result<Self::Response> {
let mut buf = Vec::with_capacity(
b"mock-response".len() + commitment_state.len() + challenge.len(),
);
buf.extend_from_slice(b"mock-response");
buf.extend_from_slice(&commitment_state);
buf.extend_from_slice(challenge);
Ok(sha256(&buf).unwrap().to_vec())
}
fn verify(
&self,
_statement: &Self::Statement,
commitment: &Self::Commitment,
challenge: &[u8; 32],
response: &Self::Response,
) -> Result<bool> {
Ok(commitment.len() == 32 && challenge.len() == 32 && response.len() == 32)
}
fn serialize_commitment(&self, commitment: &Self::Commitment) -> Vec<u8> {
commitment.clone()
}
fn deserialize_commitment(&self, bytes: &[u8]) -> Result<Self::Commitment> {
Ok(bytes.to_vec())
}
fn serialize_response(&self, response: &Self::Response) -> Vec<u8> {
response.clone()
}
fn deserialize_response(&self, bytes: &[u8]) -> Result<Self::Response> {
Ok(bytes.to_vec())
}
fn serialize_statement(&self, statement: &Self::Statement) -> Vec<u8> {
statement.clone()
}
}
#[test]
fn test_fiat_shamir_prove_verify_roundtrip_succeeds() {
let fs = FiatShamir::new(MockSigmaProtocol, b"test-domain");
let statement = vec![1u8; 32];
let witness = vec![42u8; 16];
let proof = fs.prove(&statement, &witness, b"context").unwrap();
assert_eq!(proof.commitment().len(), 32);
assert_eq!(proof.challenge().len(), 32);
assert_eq!(proof.response().len(), 32);
assert!(fs.verify(&statement, &proof, b"context").unwrap());
}
#[test]
fn test_fiat_shamir_wrong_context_fails() {
let fs = FiatShamir::new(MockSigmaProtocol, b"test-domain");
let statement = vec![1u8; 32];
let witness = vec![42u8; 16];
let proof = fs.prove(&statement, &witness, b"context-a").unwrap();
assert!(!fs.verify(&statement, &proof, b"context-b").unwrap());
}
#[test]
fn test_fiat_shamir_tampered_challenge_fails() {
let fs = FiatShamir::new(MockSigmaProtocol, b"test-domain");
let statement = vec![1u8; 32];
let witness = vec![42u8; 16];
let mut proof = fs.prove(&statement, &witness, b"ctx").unwrap();
proof.challenge_mut()[0] ^= 0xFF;
assert!(!fs.verify(&statement, &proof, b"ctx").unwrap());
}
#[test]
fn test_fiat_shamir_different_domain_separators_produce_different_proofs_succeeds() {
let fs1 = FiatShamir::new(MockSigmaProtocol, b"domain-1");
let fs2 = FiatShamir::new(MockSigmaProtocol, b"domain-2");
let statement = vec![1u8; 32];
let witness = vec![42u8; 16];
let proof = fs1.prove(&statement, &witness, b"ctx").unwrap();
assert!(!fs2.verify(&statement, &proof, b"ctx").unwrap());
}
#[test]
fn test_fiat_shamir_different_statements_produce_different_proofs_succeeds() {
let fs = FiatShamir::new(MockSigmaProtocol, b"domain");
let statement1 = vec![1u8; 32];
let statement2 = vec![2u8; 32];
let witness = vec![42u8; 16];
let proof = fs.prove(&statement1, &witness, b"ctx").unwrap();
assert!(!fs.verify(&statement2, &proof, b"ctx").unwrap());
}
#[test]
fn test_fiat_shamir_empty_domain_and_context_succeeds() {
let fs = FiatShamir::new(MockSigmaProtocol, b"");
let statement = vec![0u8; 32];
let witness = vec![0u8; 8];
let proof = fs.prove(&statement, &witness, b"").unwrap();
assert!(fs.verify(&statement, &proof, b"").unwrap());
}
#[test]
fn test_dlog_equality_empty_context_succeeds() {
let secret_key = SecretKey::random(&mut rand::thread_rng());
let x: [u8; 32] = secret_key.to_bytes().into();
let x_scalar = Scalar::from_repr(*FieldBytes::from_slice(&x)).unwrap();
let g = ProjectivePoint::GENERATOR;
let h = g * Scalar::from(2u64);
let p = g * x_scalar;
let q = h * x_scalar;
let g_bytes: [u8; 33] = <[u8; 33]>::try_from(g.to_affine().to_bytes().as_slice()).unwrap();
let h_bytes: [u8; 33] = <[u8; 33]>::try_from(h.to_affine().to_bytes().as_slice()).unwrap();
let p_bytes: [u8; 33] = <[u8; 33]>::try_from(p.to_affine().to_bytes().as_slice()).unwrap();
let q_bytes: [u8; 33] = <[u8; 33]>::try_from(q.to_affine().to_bytes().as_slice()).unwrap();
let statement = DlogEqualityStatement { g: g_bytes, h: h_bytes, p: p_bytes, q: q_bytes };
let proof = DlogEqualityProof::prove(&statement, &x, b"").unwrap();
assert!(proof.verify(&statement, b"").unwrap());
}
}