use crate::error::ProverError;
use crate::policy::Policy;
use ves_stark_air::range_check::validate_limbs;
use ves_stark_primitives::public_inputs::CompliancePublicInputs;
use ves_stark_primitives::{felt_from_u64, Felt, FELT_ZERO};
#[derive(Debug, Clone)]
pub struct ComplianceWitness {
pub amount: u64,
pub public_inputs: CompliancePublicInputs,
}
impl ComplianceWitness {
pub fn try_new(
amount: u64,
public_inputs: CompliancePublicInputs,
) -> Result<Self, ProverError> {
let public_inputs = public_inputs
.bind_amount(amount)
.map_err(|e| ProverError::InvalidPublicInputs(format!("{e}")))?;
Ok(Self {
amount,
public_inputs,
})
}
pub fn new(amount: u64, public_inputs: CompliancePublicInputs) -> Self {
Self::try_new(amount, public_inputs).expect("invalid compliance witness public inputs")
}
pub fn validate(&self, policy: &Policy) -> Result<(), ProverError> {
if !policy.validate_amount(self.amount) {
return Err(ProverError::policy_validation_failed(format!(
"Amount {} does not satisfy policy {} with limit {}",
self.amount,
policy.policy_id(),
policy.limit(),
)));
}
let policy_hash_valid = self
.public_inputs
.validate_policy_hash()
.map_err(|e| ProverError::InvalidPublicInputs(format!("{e}")))?;
if !policy_hash_valid {
return Err(ProverError::InvalidPublicInputs(
"Policy hash mismatch".to_string(),
));
}
let inputs_policy = Policy::from_public_inputs(
&self.public_inputs.policy_id,
&self.public_inputs.policy_params,
)
.map_err(|e| ProverError::InvalidPublicInputs(format!("Invalid policy params: {e}")))?;
if &inputs_policy != policy {
return Err(ProverError::InvalidPublicInputs(format!(
"Policy mismatch: public inputs are for {}, witness validated against {}",
inputs_policy.policy_id(),
policy.policy_id()
)));
}
let expected_bound_inputs = self
.public_inputs
.bind_amount(self.amount)
.map_err(|e| ProverError::InvalidPublicInputs(format!("{e}")))?;
if self.public_inputs.witness_commitment != expected_bound_inputs.witness_commitment {
return Err(ProverError::InvalidPublicInputs(
"public inputs witnessCommitment is missing or does not match the witness amount"
.to_string(),
));
}
if self.public_inputs.amount_binding_hash != expected_bound_inputs.amount_binding_hash {
return Err(ProverError::InvalidPublicInputs(
"public inputs amountBindingHash is missing or does not match the witness amount"
.to_string(),
));
}
let amount_limbs = self.amount_limbs();
if !validate_limbs(&amount_limbs) {
return Err(ProverError::invalid_witness(
"Amount limbs contain invalid u32 values",
));
}
Ok(())
}
pub fn amount_limbs(&self) -> [Felt; 8] {
let mut limbs = [FELT_ZERO; 8];
limbs[0] = felt_from_u64(self.amount & 0xFFFFFFFF);
limbs[1] = felt_from_u64(self.amount >> 32);
limbs
}
pub fn amount_u128(&self) -> u128 {
self.amount as u128
}
}
pub struct WitnessBuilder {
amount: Option<u64>,
public_inputs: Option<CompliancePublicInputs>,
}
impl WitnessBuilder {
pub fn new() -> Self {
Self {
amount: None,
public_inputs: None,
}
}
pub fn amount(mut self, amount: u64) -> Self {
self.amount = Some(amount);
self
}
pub fn public_inputs(mut self, inputs: CompliancePublicInputs) -> Self {
self.public_inputs = Some(inputs);
self
}
pub fn build(self) -> Result<ComplianceWitness, ProverError> {
let amount = self
.amount
.ok_or_else(|| ProverError::invalid_witness("Amount is required"))?;
let public_inputs = self
.public_inputs
.ok_or_else(|| ProverError::invalid_witness("Public inputs are required"))?;
ComplianceWitness::try_new(amount, public_inputs)
}
}
impl Default for WitnessBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::Policy;
use uuid::Uuid;
use ves_stark_primitives::public_inputs::{compute_policy_hash, PolicyParams};
fn sample_public_inputs(threshold: u64) -> CompliancePublicInputs {
let policy_id = "aml.threshold";
let params = PolicyParams::threshold(threshold);
let hash = compute_policy_hash(policy_id, ¶ms).unwrap();
CompliancePublicInputs {
event_id: Uuid::new_v4(),
tenant_id: Uuid::new_v4(),
store_id: Uuid::new_v4(),
sequence_number: 1,
payload_kind: 1,
payload_plain_hash: "0".repeat(64),
payload_cipher_hash: "0".repeat(64),
event_signing_hash: "0".repeat(64),
policy_id: policy_id.to_string(),
policy_params: params,
policy_hash: hash.to_hex(),
witness_commitment: None,
authorization_receipt_hash: None,
amount_binding_hash: None,
}
}
#[test]
fn test_witness_validation_valid() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(5000, inputs);
let policy = Policy::aml_threshold(threshold);
assert!(witness.validate(&policy).is_ok());
assert!(witness.public_inputs.witness_commitment.is_some());
assert!(witness.public_inputs.amount_binding_hash.is_some());
}
#[test]
fn test_witness_validation_invalid() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(15000, inputs);
let policy = Policy::aml_threshold(threshold);
assert!(witness.validate(&policy).is_err());
}
#[test]
fn test_witness_builder() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = WitnessBuilder::new()
.amount(5000)
.public_inputs(inputs)
.build()
.unwrap();
assert_eq!(witness.amount, 5000);
}
#[test]
fn test_witness_try_new_rejects_mismatched_binding_fields() {
let threshold = 10000u64;
let mut inputs = sample_public_inputs(threshold);
inputs.witness_commitment = Some("0".repeat(64));
inputs.amount_binding_hash = Some("0".repeat(64));
let err = ComplianceWitness::try_new(5000, inputs).unwrap_err();
assert!(matches!(err, ProverError::InvalidPublicInputs(_)));
}
#[test]
fn test_witness_validate_rejects_missing_amount_binding_hash() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let mut witness = ComplianceWitness::new(5000, inputs);
witness.public_inputs.amount_binding_hash = None;
let err = witness
.validate(&Policy::aml_threshold(threshold))
.unwrap_err();
assert!(matches!(err, ProverError::InvalidPublicInputs(_)));
}
#[test]
fn test_amount_limbs() {
let inputs = sample_public_inputs(10000);
let witness = ComplianceWitness::new(0x1234567890ABCDEF, inputs);
let limbs = witness.amount_limbs();
assert_eq!(limbs[0].as_int(), 0x90ABCDEF);
assert_eq!(limbs[1].as_int(), 0x12345678);
}
#[test]
fn test_witness_builder_missing_amount() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let result = WitnessBuilder::new().public_inputs(inputs).build();
assert!(result.is_err());
}
#[test]
fn test_witness_builder_missing_public_inputs() {
let result = WitnessBuilder::new().amount(5000).build();
assert!(result.is_err());
}
#[test]
fn test_witness_zero_amount() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(0, inputs);
let policy = Policy::aml_threshold(threshold);
assert!(witness.validate(&policy).is_ok());
let limbs = witness.amount_limbs();
assert_eq!(limbs[0].as_int(), 0);
assert_eq!(limbs[1].as_int(), 0);
}
#[test]
fn test_witness_max_valid_amount() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(9999, inputs);
let policy = Policy::aml_threshold(threshold);
assert!(witness.validate(&policy).is_ok());
}
#[test]
fn test_witness_boundary_amount_fails() {
let threshold = 10000u64;
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(10000, inputs);
let policy = Policy::aml_threshold(threshold);
assert!(witness.validate(&policy).is_err());
}
#[test]
fn test_amount_u128_conversion() {
let inputs = sample_public_inputs(10000);
let witness = ComplianceWitness::new(u64::MAX, inputs);
assert_eq!(witness.amount_u128(), u64::MAX as u128);
}
#[test]
fn test_amount_limbs_max_value() {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(u64::MAX, inputs);
let limbs = witness.amount_limbs();
assert_eq!(limbs[0].as_int(), 0xFFFFFFFF);
assert_eq!(limbs[1].as_int(), 0xFFFFFFFF);
for limb in limbs.iter().skip(2) {
assert_eq!(limb.as_int(), 0);
}
}
#[test]
fn test_witness_builder_default() {
let builder = WitnessBuilder::default();
assert!(builder.build().is_err()); }
}
#[cfg(test)]
mod proptests {
use super::*;
use crate::policy::Policy;
use proptest::prelude::*;
use uuid::Uuid;
use ves_stark_primitives::public_inputs::{compute_policy_hash, PolicyParams};
fn sample_public_inputs(threshold: u64) -> CompliancePublicInputs {
let policy_id = "aml.threshold";
let params = PolicyParams::threshold(threshold);
let hash = compute_policy_hash(policy_id, ¶ms).unwrap();
CompliancePublicInputs {
event_id: Uuid::new_v4(),
tenant_id: Uuid::new_v4(),
store_id: Uuid::new_v4(),
sequence_number: 1,
payload_kind: 1,
payload_plain_hash: "0".repeat(64),
payload_cipher_hash: "0".repeat(64),
event_signing_hash: "0".repeat(64),
policy_id: policy_id.to_string(),
policy_params: params,
policy_hash: hash.to_hex(),
witness_commitment: None,
authorization_receipt_hash: None,
amount_binding_hash: None,
}
}
proptest! {
#[test]
fn prop_amount_less_than_threshold_validates(
threshold in 1u64..=u64::MAX,
amount_offset in 1u64..=u64::MAX
) {
let amount = if amount_offset <= threshold {
threshold.saturating_sub(amount_offset)
} else {
0
};
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(amount, inputs);
let policy = Policy::aml_threshold(threshold);
prop_assert!(witness.validate(&policy).is_ok());
}
#[test]
fn prop_amount_gte_threshold_fails(
threshold in 1u64..u64::MAX,
extra in 0u64..1000
) {
let amount = threshold.saturating_add(extra);
let inputs = sample_public_inputs(threshold);
let witness = ComplianceWitness::new(amount, inputs);
let policy = Policy::aml_threshold(threshold);
prop_assert!(witness.validate(&policy).is_err());
}
#[test]
fn prop_amount_limb_decomposition_low(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(amount, inputs);
let limbs = witness.amount_limbs();
let expected_low = amount & 0xFFFFFFFF;
prop_assert_eq!(limbs[0].as_int(), expected_low);
}
#[test]
fn prop_amount_limb_decomposition_high(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(amount, inputs);
let limbs = witness.amount_limbs();
let expected_high = amount >> 32;
prop_assert_eq!(limbs[1].as_int(), expected_high);
}
#[test]
fn prop_limb_recombination(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(amount, inputs);
let limbs = witness.amount_limbs();
let recombined = limbs[0].as_int() | (limbs[1].as_int() << 32);
prop_assert_eq!(recombined, amount);
}
#[test]
fn prop_upper_limbs_zero(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(amount, inputs);
let limbs = witness.amount_limbs();
for (i, limb) in limbs.iter().enumerate().skip(2) {
prop_assert_eq!(limb.as_int(), 0, "Limb {} should be zero", i);
}
}
#[test]
fn prop_amount_u128_preserves(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let witness = ComplianceWitness::new(amount, inputs);
prop_assert_eq!(witness.amount_u128(), amount as u128);
}
#[test]
fn prop_builder_equals_direct(amount in any::<u64>()) {
let inputs = sample_public_inputs(u64::MAX);
let direct = ComplianceWitness::new(amount, inputs.clone());
let built = WitnessBuilder::new()
.amount(amount)
.public_inputs(inputs)
.build()
.unwrap();
prop_assert_eq!(direct.amount, built.amount);
}
}
}