use crate::error::ProverError;
use crate::policy::Policy;
use crate::trace::TraceBuilder;
use crate::witness::ComplianceWitness;
use serde::{Deserialize, Serialize};
use ves_stark_air::compliance::{ComplianceAir, PublicInputs};
use ves_stark_air::options::ProofOptions;
use ves_stark_air::policies::aml_threshold::AmlThresholdPolicy;
use ves_stark_primitives::public_inputs::witness_commitment_u64_to_hex;
use ves_stark_primitives::rescue::rescue_hash;
use ves_stark_primitives::{Felt, Hash256};
use winter_air::TraceInfo;
use winter_crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree};
use winter_prover::{Prover, Trace, TraceTable};
pub type Hasher = Blake3_256<Felt>;
pub type RandCoin = DefaultRandomCoin<Hasher>;
pub type VectorCommit = MerkleTree<Hasher>;
#[cfg(not(target_arch = "wasm32"))]
type Timer = std::time::Instant;
#[cfg(target_arch = "wasm32")]
type Timer = u64;
#[cfg(not(target_arch = "wasm32"))]
fn start_timer() -> Timer {
std::time::Instant::now()
}
#[cfg(target_arch = "wasm32")]
fn start_timer() -> Timer {
js_sys::Date::now() as u64
}
#[cfg(not(target_arch = "wasm32"))]
fn elapsed_ms(start: Timer) -> u64 {
start.elapsed().as_millis() as u64
}
#[cfg(target_arch = "wasm32")]
fn elapsed_ms(start: Timer) -> u64 {
(js_sys::Date::now() as u64).saturating_sub(start)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceProof {
#[serde(with = "serde_bytes")]
pub proof_bytes: Vec<u8>,
pub proof_hash: String,
pub metadata: ProofMetadata,
pub witness_commitment: [u64; 4],
#[serde(default, skip_serializing_if = "Option::is_none")]
pub witness_commitment_hex: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProofMetadata {
pub proving_time_ms: u64,
pub num_constraints: usize,
pub trace_length: usize,
pub proof_size: usize,
pub prover_version: String,
}
impl ComplianceProof {
pub fn compute_hash(proof_bytes: &[u8]) -> Hash256 {
Hash256::sha256_with_domain(b"STATESET_VES_COMPLIANCE_PROOF_HASH_V1", proof_bytes)
}
}
pub struct ComplianceProver {
policy: Policy,
options: ProofOptions,
}
impl ComplianceProver {
pub fn new(policy: AmlThresholdPolicy) -> Self {
Self {
policy: policy.into(),
options: ProofOptions::default(),
}
}
pub fn with_policy(policy: Policy) -> Self {
Self {
policy,
options: ProofOptions::default(),
}
}
pub fn with_options(mut self, options: ProofOptions) -> Self {
self.options = options;
self
}
pub fn policy(&self) -> &Policy {
&self.policy
}
pub fn prove(&self, witness: &ComplianceWitness) -> Result<ComplianceProof, ProverError> {
let start = start_timer();
if !self.policy.validate_amount(witness.amount) {
return Err(ProverError::policy_validation_failed(format!(
"Amount {} does not satisfy {} policy with limit {}",
witness.amount,
self.policy.policy_id(),
self.policy.limit()
)));
}
let inputs_policy = Policy::from_public_inputs(
&witness.public_inputs.policy_id,
&witness.public_inputs.policy_params,
)
.map_err(|e| ProverError::InvalidPublicInputs(format!("Invalid policy params: {e}")))?;
if inputs_policy != self.policy {
return Err(ProverError::InvalidPublicInputs(format!(
"Policy mismatch: public inputs are for {}, prover configured for {}",
inputs_policy.policy_id(),
self.policy.policy_id()
)));
}
witness.validate(&self.policy)?;
let amount_limbs = witness.amount_limbs();
let hash_output = rescue_hash(&amount_limbs);
let witness_commitment: [Felt; 4] = [
hash_output[0],
hash_output[1],
hash_output[2],
hash_output[3],
];
let commitment_u64: [u64; 4] = [
witness_commitment[0].as_int(),
witness_commitment[1].as_int(),
witness_commitment[2].as_int(),
witness_commitment[3].as_int(),
];
debug_assert_eq!(
witness
.public_inputs
.witness_commitment_u64()
.ok()
.flatten(),
Some(commitment_u64)
);
let trace = TraceBuilder::new(witness.clone(), self.policy.clone()).build()?;
let trace_length = trace.length();
let pub_inputs_felts = witness
.public_inputs
.to_field_elements()
.map_err(|e| ProverError::InvalidPublicInputs(format!("{e}")))?;
let policy_limit = self
.policy
.effective_limit()
.map_err(|e| ProverError::PolicyValidationFailed(format!("{e}")))?;
let pub_inputs = PublicInputs::try_with_commitment(
policy_limit,
pub_inputs_felts.to_vec(),
witness_commitment,
)
.map_err(|e| ProverError::InvalidPublicInputs(format!("{e}")))?;
let prover =
VesComplianceProver::try_new(self.policy.clone(), self.options.clone(), pub_inputs)?;
let proof = prover
.prove(trace)
.map_err(|e| ProverError::ProofGenerationFailed(format!("{:?}", e)))?;
let proof_bytes = proof.to_bytes();
let proof_hash = ComplianceProof::compute_hash(&proof_bytes);
Ok(ComplianceProof {
proof_bytes: proof_bytes.clone(),
proof_hash: proof_hash.to_hex(),
metadata: ProofMetadata {
proving_time_ms: elapsed_ms(start),
num_constraints: ves_stark_air::compliance::NUM_CONSTRAINTS,
trace_length,
proof_size: proof_bytes.len(),
prover_version: env!("CARGO_PKG_VERSION").to_string(),
},
witness_commitment: commitment_u64,
witness_commitment_hex: Some(witness_commitment_u64_to_hex(&commitment_u64)),
})
}
pub fn limit(&self) -> u64 {
self.policy.limit()
}
pub fn threshold(&self) -> u64 {
self.policy.limit()
}
}
struct VesComplianceProver {
#[allow(dead_code)]
policy: Policy,
options: winter_air::ProofOptions,
pub_inputs: PublicInputs,
}
impl VesComplianceProver {
fn try_new(
policy: Policy,
options: ProofOptions,
pub_inputs: PublicInputs,
) -> Result<Self, ProverError> {
let options = options
.try_to_winterfell()
.map_err(|e| ProverError::InvalidPublicInputs(format!("Invalid proof options: {e}")))?;
Ok(Self {
policy,
options,
pub_inputs,
})
}
}
impl Prover for VesComplianceProver {
type BaseField = Felt;
type Air = ComplianceAir;
type Trace = TraceTable<Felt>;
type HashFn = Hasher;
type RandomCoin = RandCoin;
type VC = VectorCommit;
type TraceLde<E: winter_math::FieldElement<BaseField = Felt>> =
winter_prover::DefaultTraceLde<E, Self::HashFn, Self::VC>;
type ConstraintEvaluator<'a, E: winter_math::FieldElement<BaseField = Felt>> =
winter_prover::DefaultConstraintEvaluator<'a, Self::Air, E>;
fn get_pub_inputs(&self, _trace: &Self::Trace) -> PublicInputs {
self.pub_inputs.clone()
}
fn options(&self) -> &winter_air::ProofOptions {
&self.options
}
fn new_trace_lde<E: winter_math::FieldElement<BaseField = Felt>>(
&self,
trace_info: &TraceInfo,
main_trace: &winter_prover::matrix::ColMatrix<Felt>,
domain: &winter_prover::StarkDomain<Felt>,
partition_option: winter_air::PartitionOptions,
) -> (Self::TraceLde<E>, winter_prover::TracePolyTable<E>) {
winter_prover::DefaultTraceLde::new(trace_info, main_trace, domain, partition_option)
}
fn new_evaluator<'a, E: winter_math::FieldElement<BaseField = Felt>>(
&self,
air: &'a Self::Air,
aux_rand_elements: Option<winter_air::AuxRandElements<E>>,
composition_coefficients: winter_air::ConstraintCompositionCoefficients<E>,
) -> Self::ConstraintEvaluator<'a, E> {
winter_prover::DefaultConstraintEvaluator::new(
air,
aux_rand_elements,
composition_coefficients,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
use ves_stark_primitives::public_inputs::{
compute_policy_hash, CompliancePublicInputs, PolicyParams,
};
#[test]
fn test_prover_creation() {
let policy = AmlThresholdPolicy::new(10000);
let prover = ComplianceProver::new(policy);
assert_eq!(prover.threshold(), 10000);
}
#[test]
fn test_prove_benchmark() {
let threshold = 10000u64;
let amount = 5000u64;
let policy = Policy::aml_threshold(threshold);
let policy_id = "aml.threshold";
let params = PolicyParams::threshold(threshold);
let policy_hash = compute_policy_hash(policy_id, ¶ms).unwrap();
let inputs = CompliancePublicInputs {
event_id: Uuid::new_v4(),
tenant_id: Uuid::new_v4(),
store_id: Uuid::new_v4(),
sequence_number: 1,
payload_kind: 1,
payload_plain_hash: "a".repeat(64),
payload_cipher_hash: "b".repeat(64),
event_signing_hash: "c".repeat(64),
policy_id: policy_id.to_string(),
policy_params: params,
policy_hash: policy_hash.to_hex(),
witness_commitment: None,
authorization_receipt_hash: None,
amount_binding_hash: None,
};
let witness = ComplianceWitness::new(amount, inputs);
let prover = ComplianceProver::with_policy(policy);
let _ = prover.prove(&witness).expect("warmup prove failed");
let _ = prover.prove(&witness).expect("warmup prove failed");
let n = 5;
let start = std::time::Instant::now();
let mut total_proof_bytes = 0usize;
for _ in 0..n {
let proof = prover.prove(&witness).expect("prove failed");
total_proof_bytes += proof.proof_bytes.len();
}
let elapsed = start.elapsed();
let avg_ms = elapsed.as_millis() as f64 / n as f64;
let avg_proof_bytes = total_proof_bytes / n;
println!("bench_e2e_ms: {:.2}", avg_ms);
println!("bench_proof_bytes: {}", avg_proof_bytes);
}
#[test]
fn test_prover_rejects_public_inputs_witness_commitment_mismatch() {
let threshold = 10000u64;
let policy = Policy::aml_threshold(threshold);
let policy_id = "aml.threshold";
let params = PolicyParams::threshold(threshold);
let policy_hash = compute_policy_hash(policy_id, ¶ms).unwrap();
let mut inputs = CompliancePublicInputs {
event_id: Uuid::new_v4(),
tenant_id: Uuid::new_v4(),
store_id: Uuid::new_v4(),
sequence_number: 1,
payload_kind: 1,
payload_plain_hash: "0".repeat(64),
payload_cipher_hash: "0".repeat(64),
event_signing_hash: "0".repeat(64),
policy_id: policy_id.to_string(),
policy_params: params,
policy_hash: policy_hash.to_hex(),
witness_commitment: None,
authorization_receipt_hash: None,
amount_binding_hash: None,
};
inputs.witness_commitment = Some(witness_commitment_u64_to_hex(&[0u64; 4]));
let witness = ComplianceWitness::new(5000, inputs);
let prover = ComplianceProver::with_policy(policy);
let err = prover.prove(&witness).unwrap_err();
assert!(matches!(err, ProverError::InvalidPublicInputs(_)));
}
}