use curve25519_dalek::{
ristretto::RistrettoPoint,
traits::{Identity, MultiscalarMul},
};
#[cfg(feature = "precomputed_tables")]
use crate::ristretto::pedersen::scalar_mul_with_pre_computation_tables;
use crate::{
commitment::{HomomorphicCommitment, HomomorphicCommitmentFactory},
ristretto::{
RistrettoPublicKey,
RistrettoSecretKey,
pedersen::{PedersenCommitment, RISTRETTO_PEDERSEN_G, ristretto_pedersen_h},
},
};
#[derive(Debug, PartialEq, Eq, Clone)]
#[allow(non_snake_case)]
pub struct PedersenCommitmentFactory {
pub(crate) G: RistrettoPoint,
pub(crate) H: RistrettoPoint,
}
impl PedersenCommitmentFactory {
#[allow(non_snake_case)]
pub fn new(G: RistrettoPoint, H: RistrettoPoint) -> PedersenCommitmentFactory {
PedersenCommitmentFactory { G, H }
}
}
impl Default for PedersenCommitmentFactory {
fn default() -> Self {
PedersenCommitmentFactory::new(RISTRETTO_PEDERSEN_G, *ristretto_pedersen_h())
}
}
impl HomomorphicCommitmentFactory for PedersenCommitmentFactory {
type P = RistrettoPublicKey;
#[allow(non_snake_case)]
fn commit(&self, k: &RistrettoSecretKey, v: &RistrettoSecretKey) -> PedersenCommitment {
let c = if (self.G, self.H) == (RISTRETTO_PEDERSEN_G, *ristretto_pedersen_h()) {
#[cfg(feature = "precomputed_tables")]
{
scalar_mul_with_pre_computation_tables(&k.0, &v.0)
}
#[cfg(not(feature = "precomputed_tables"))]
{
RistrettoPoint::multiscalar_mul(&[v.0, k.0], &[self.H, self.G])
}
} else {
RistrettoPoint::multiscalar_mul(&[v.0, k.0], &[self.H, self.G])
};
HomomorphicCommitment(RistrettoPublicKey::new_from_pk(c))
}
fn zero(&self) -> PedersenCommitment {
HomomorphicCommitment(RistrettoPublicKey::new_from_pk(RistrettoPoint::identity()))
}
fn open(&self, k: &RistrettoSecretKey, v: &RistrettoSecretKey, commitment: &PedersenCommitment) -> bool {
let c_test = self.commit(k, v);
commitment.0 == c_test.0
}
fn commit_value(&self, k: &RistrettoSecretKey, value: u64) -> PedersenCommitment {
let v = RistrettoSecretKey::from(value);
self.commit(k, &v)
}
fn open_value(&self, k: &RistrettoSecretKey, v: u64, commitment: &PedersenCommitment) -> bool {
let kv = RistrettoSecretKey::from(v);
self.open(k, &kv, commitment)
}
}
#[cfg(test)]
mod test {
use alloc::vec::Vec;
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
use curve25519_dalek::scalar::Scalar;
use super::*;
use crate::keys::{PublicKey, SecretKey};
#[test]
fn check_default_base() {
let base = PedersenCommitmentFactory::default();
assert_eq!(base.G, RISTRETTO_PEDERSEN_G);
assert_eq!(base.H, *ristretto_pedersen_h())
}
#[test]
fn check_zero() {
let c = RistrettoPoint::multiscalar_mul(&[Scalar::ZERO, Scalar::ZERO], &[
RISTRETTO_PEDERSEN_G,
*ristretto_pedersen_h(),
]);
let factory = PedersenCommitmentFactory::default();
assert_eq!(
HomomorphicCommitment(RistrettoPublicKey::new_from_pk(c)),
PedersenCommitmentFactory::zero(&factory)
);
}
#[test]
#[allow(non_snake_case)]
fn check_open() {
let factory = PedersenCommitmentFactory::default();
let H = *ristretto_pedersen_h();
let mut rng = rand::rng();
for _ in 0..100 {
let v = RistrettoSecretKey::random(&mut rng);
let k = RistrettoSecretKey::random(&mut rng);
let c = factory.commit(&k, &v);
let c_calc: RistrettoPoint = v.0 * H + k.0 * RISTRETTO_PEDERSEN_G;
assert_eq!(RistrettoPoint::from(c.as_public_key()), c_calc);
assert!(factory.open(&k, &v, &c));
assert!(!factory.open(&k, &(&v + &v), &c));
assert!(!factory.open(&(&k + &v), &v, &c));
}
}
#[test]
fn check_homomorphism() {
let mut rng = rand::rng();
for _ in 0..100 {
let v1 = RistrettoSecretKey::random(&mut rng);
let v2 = RistrettoSecretKey::random(&mut rng);
let v_sum = &v1 + &v2;
let k1 = RistrettoSecretKey::random(&mut rng);
let k2 = RistrettoSecretKey::random(&mut rng);
let k_sum = &k1 + &k2;
let factory = PedersenCommitmentFactory::default();
let c1 = factory.commit(&k1, &v1);
let c2 = factory.commit(&k2, &v2);
let c_sum = &c1 + &c2;
let c_sum2 = factory.commit(&k_sum, &v_sum);
assert!(factory.open(&k1, &v1, &c1));
assert!(factory.open(&k2, &v2, &c2));
assert_eq!(c_sum, c_sum2);
assert!(factory.open(&k_sum, &v_sum, &c_sum));
}
}
#[test]
fn check_homomorphism_with_public_key() {
let mut rng = rand::rng();
let v1 = RistrettoSecretKey::random(&mut rng);
let k1 = RistrettoSecretKey::random(&mut rng);
let factory = PedersenCommitmentFactory::default();
let c1 = factory.commit(&k1, &v1);
let (k2, k2_pub) = RistrettoPublicKey::random_keypair(&mut rng);
let c_sum = &c1 + &k2_pub;
let c2 = factory.commit(&(&k1 + &k2), &v1);
assert_eq!(c_sum, c2);
assert!(factory.open(&(&k1 + &k2), &v1, &c2));
}
#[test]
fn sum_commitment_vector() {
let mut rng = rand::rng();
let mut v_sum = RistrettoSecretKey::default();
let mut k_sum = RistrettoSecretKey::default();
let zero = RistrettoSecretKey::default();
let commitment_factory = PedersenCommitmentFactory::default();
let mut c_sum = commitment_factory.commit(&zero, &zero);
let mut commitments = Vec::with_capacity(100);
for _ in 0..100 {
let v = RistrettoSecretKey::random(&mut rng);
v_sum = &v_sum + &v;
let k = RistrettoSecretKey::random(&mut rng);
k_sum = &k_sum + &k;
let c = commitment_factory.commit(&k, &v);
c_sum = &c_sum + &c;
commitments.push(c);
}
assert!(commitment_factory.open(&k_sum, &v_sum, &c_sum));
assert_eq!(c_sum, commitments.iter().sum());
}
#[cfg(feature = "serde")]
#[test]
fn serialize_deserialize() {
use tari_utilities::message_format::MessageFormat;
let mut rng = rand::rng();
let factory = PedersenCommitmentFactory::default();
let k = RistrettoSecretKey::random(&mut rng);
let c = factory.commit_value(&k, 420);
let ser_c = c.to_base64().unwrap();
let c2 = PedersenCommitment::from_base64(&ser_c).unwrap();
assert!(factory.open_value(&k, 420, &c2));
let ser_c = c.to_binary().unwrap();
let c2 = PedersenCommitment::from_binary(&ser_c).unwrap();
assert!(factory.open_value(&k, 420, &c2));
assert!(PedersenCommitment::from_base64("bad@ser$").is_err());
}
#[test]
#[allow(clippy::redundant_clone)]
fn derived_methods() {
let factory = PedersenCommitmentFactory::default();
let k = RistrettoSecretKey::from(1024);
let value = 2048;
let c1 = factory.commit_value(&k, value);
assert_eq!(
format!("{c1:?}"),
"HomomorphicCommitment(601cdc5c97e94bb16ae56f75430f8ab3ef4703c7d89ca9592e8acadc81629f0e)"
);
let c2 = c1.clone();
assert_eq!(c1, c2);
let mut hasher = DefaultHasher::new();
c1.hash(&mut hasher);
let result = format!("{:x}", hasher.finish());
assert_eq!(&result, "699d38210741194e");
let mut values = (value - 100..value).collect::<Vec<_>>();
values.extend((value + 1..value + 101).collect::<Vec<_>>());
let (mut tested_less_than, mut tested_greater_than) = (false, false);
for val in values {
let c3 = factory.commit_value(&k, val);
assert_ne!(c2, c3);
assert_ne!(c2.cmp(&c3), c3.cmp(&c2));
if c2 > c3 {
assert!(c3 < c2);
assert!(matches!(c2.cmp(&c3), std::cmp::Ordering::Greater));
assert!(matches!(c3.cmp(&c2), std::cmp::Ordering::Less));
tested_less_than = true;
}
if c2 < c3 {
assert!(c3 > c2);
assert!(matches!(c2.cmp(&c3), std::cmp::Ordering::Less));
assert!(matches!(c3.cmp(&c2), std::cmp::Ordering::Greater));
tested_greater_than = true;
}
if tested_less_than && tested_greater_than {
break;
}
}
assert!(
tested_less_than && tested_greater_than,
"Try extending the range of values to compare"
);
}
}