use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::Identity;
use num_bigint::BigInt;
use sha2::{Digest, Sha512};
use std::sync::Arc;
use crate::group::Group;
#[derive(Debug, Clone)]
pub struct Ristretto255Group {
order: BigInt,
}
impl Ristretto255Group {
pub fn new() -> Arc<Self> {
let order_bytes: [u8; 32] = [
0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7,
0x9c, 0xd6, 0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed,
];
let order = BigInt::from_bytes_be(num_bigint::Sign::Plus, &order_bytes);
Arc::new(Ristretto255Group { order })
}
pub fn order_as_bigint(&self) -> &BigInt {
&self.order
}
pub fn bigint_to_scalar(bigint: &BigInt) -> Scalar {
let (_, bytes_be) = bigint.to_bytes_be();
let mut bytes_le = [0u8; 32];
let len = bytes_be.len().min(32);
for i in 0..len {
bytes_le[i] = bytes_be[len - 1 - i];
}
Scalar::from_bytes_mod_order(bytes_le)
}
pub fn scalar_to_bigint(scalar: &Scalar) -> BigInt {
let bytes_le = scalar.to_bytes();
let actual_len =
bytes_le.iter().rposition(|&b| b != 0).map_or(0, |i| i + 1);
let bytes_be: Vec<u8> =
bytes_le[..actual_len].iter().rev().copied().collect();
if bytes_be.is_empty() {
BigInt::from(0)
} else {
BigInt::from_bytes_be(num_bigint::Sign::Plus, &bytes_be)
}
}
}
impl Group for Ristretto255Group {
type Scalar = Scalar;
type Element = RistrettoPoint;
fn order(&self) -> &Self::Scalar {
static ORDER_PLACEHOLDER: Scalar = Scalar::ONE;
&ORDER_PLACEHOLDER
}
fn subgroup_order(&self) -> &Self::Scalar {
static ORDER_PLACEHOLDER: Scalar = Scalar::ONE;
&ORDER_PLACEHOLDER
}
fn generator(&self) -> Self::Element {
RISTRETTO_BASEPOINT_POINT
}
fn subgroup_generator(&self) -> Self::Element {
RISTRETTO_BASEPOINT_POINT
}
fn identity(&self) -> Self::Element {
RistrettoPoint::identity()
}
fn exp(
&self,
base: &Self::Element,
scalar: &Self::Scalar,
) -> Self::Element {
base * scalar
}
fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
a + b
}
fn scalar_inverse(&self, x: &Self::Scalar) -> Option<Self::Scalar> {
if x == &Scalar::ZERO {
None
} else {
Option::from(x.invert())
}
}
fn element_inverse(&self, x: &Self::Element) -> Option<Self::Element> {
Some(-*x)
}
fn hash_to_scalar(&self, data: &[u8]) -> Self::Scalar {
let hash = Sha512::digest(data);
let mut wide_bytes = [0u8; 64];
let hash_len = hash.len().min(64);
wide_bytes[..hash_len].copy_from_slice(&hash[..hash_len]);
Scalar::from_bytes_mod_order_wide(&wide_bytes)
}
fn element_to_bytes(&self, elem: &Self::Element) -> Vec<u8> {
elem.compress().to_bytes().to_vec()
}
fn bytes_to_element(&self, bytes: &[u8]) -> Option<Self::Element> {
if bytes.len() != 32 {
return None;
}
let mut arr = [0u8; 32];
arr.copy_from_slice(bytes);
CompressedRistretto(arr).decompress()
}
fn scalar_to_bytes(&self, scalar: &Self::Scalar) -> Vec<u8> {
scalar.to_bytes().to_vec()
}
fn generate_private_key(&self) -> Self::Scalar {
let mut bytes = [0u8; 32];
for byte in &mut bytes {
*byte = rand::random::<u8>();
}
Scalar::from_bytes_mod_order(bytes)
}
fn generate_public_key(&self, private_key: &Self::Scalar) -> Self::Element {
RISTRETTO_BASEPOINT_POINT * private_key
}
fn scalar_mul(&self, a: &Self::Scalar, b: &Self::Scalar) -> Self::Scalar {
a * b
}
fn scalar_sub(&self, a: &Self::Scalar, b: &Self::Scalar) -> Self::Scalar {
a - b
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bigint_scalar_detailed_conversion() {
let original = BigInt::from(0x0102030405060708u64);
let (_, _bytes_be) = original.to_bytes_be();
let scalar = Ristretto255Group::bigint_to_scalar(&original);
let _scalar_bytes = scalar.to_bytes();
let recovered = Ristretto255Group::scalar_to_bigint(&scalar);
assert_eq!(
original, recovered,
"Round-trip conversion should work for simple values"
);
}
#[test]
fn test_polynomial_scalar_conversion() {
use num_bigint::RandBigInt;
use num_bigint::ToBigInt;
let group = Ristretto255Group::new();
let g = group.generator();
let order = group.order_as_bigint().clone();
let mut rng = rand::thread_rng();
let a_0: BigInt = rng
.gen_biguint_below(&order.to_biguint().unwrap())
.to_bigint()
.unwrap();
let a_1: BigInt = rng
.gen_biguint_below(&order.to_biguint().unwrap())
.to_bigint()
.unwrap();
let a_2: BigInt = rng
.gen_biguint_below(&order.to_biguint().unwrap())
.to_bigint()
.unwrap();
let a_0_scalar = Ristretto255Group::bigint_to_scalar(&a_0);
let a_1_scalar = Ristretto255Group::bigint_to_scalar(&a_1);
let a_2_scalar = Ristretto255Group::bigint_to_scalar(&a_2);
let c_0 = group.exp(&g, &a_0_scalar);
let c_1 = group.exp(&g, &a_1_scalar);
let c_2 = group.exp(&g, &a_2_scalar);
let x_1_from_commitments = group.mul(&c_0, &group.mul(&c_1, &c_2));
let p_1_bigint = &a_0 + &a_1 + &a_2;
let p_1_mod = &p_1_bigint % ℴ
let p_1_scalar = Ristretto255Group::bigint_to_scalar(&p_1_mod);
let x_1_from_polynomial = group.exp(&g, &p_1_scalar);
let scalar_sum = a_0_scalar + a_1_scalar + a_2_scalar;
let a_0_recovered = Ristretto255Group::scalar_to_bigint(&a_0_scalar);
let a_1_recovered = Ristretto255Group::scalar_to_bigint(&a_1_scalar);
let a_2_recovered = Ristretto255Group::scalar_to_bigint(&a_2_scalar);
assert_eq!(a_0, a_0_recovered, "a_0 round-trip conversion should work");
assert_eq!(a_1, a_1_recovered, "a_1 round-trip conversion should work");
assert_eq!(a_2, a_2_recovered, "a_2 round-trip conversion should work");
assert_eq!(
scalar_sum, p_1_scalar,
"Scalar sum should equal P(1) converted to Scalar"
);
assert_eq!(
x_1_from_commitments, x_1_from_polynomial,
"X_1 from commitments should equal g^P(1)"
);
}
#[test]
fn test_scalar_addition_modular() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint().clone();
let a = &order / 3;
let b = &order / 3;
let c = &order / 3;
let a_scalar = Ristretto255Group::bigint_to_scalar(&a);
let b_scalar = Ristretto255Group::bigint_to_scalar(&b);
let c_scalar = Ristretto255Group::bigint_to_scalar(&c);
let scalar_sum = a_scalar + b_scalar + c_scalar;
let bigint_sum = &a + &b + &c;
let bigint_sum_mod = &bigint_sum % ℴ
let from_bigint = Ristretto255Group::bigint_to_scalar(&bigint_sum_mod);
assert_eq!(
scalar_sum, from_bigint,
"Scalar sum should equal BigInt sum converted to Scalar"
);
}
#[test]
fn test_order_constant() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint();
let two_252 = BigInt::from(1u64) << 252;
let constant =
BigInt::parse_bytes(b"27742317777372353535851937790883648493", 10)
.unwrap();
let expected = two_252 + constant;
let order_hex = format!("{:x}", order);
let expected_hex = format!("{:x}", expected);
let (_, _order_be) = order.to_bytes_be();
let (_, _expected_be) = expected.to_bytes_be();
assert_eq!(order_hex, expected_hex, "Hex representations should match");
}
#[test]
fn test_scalar_sum_overflow() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint().clone();
let half_order: BigInt = &order / 2;
let a = half_order.clone();
let b = half_order.clone();
let c = BigInt::from(1000u64);
let a_scalar = Ristretto255Group::bigint_to_scalar(&a);
let b_scalar = Ristretto255Group::bigint_to_scalar(&b);
let c_scalar = Ristretto255Group::bigint_to_scalar(&c);
let a_recovered = Ristretto255Group::scalar_to_bigint(&a_scalar);
let b_recovered = Ristretto255Group::scalar_to_bigint(&b_scalar);
let c_recovered = Ristretto255Group::scalar_to_bigint(&c_scalar);
assert_eq!(a, a_recovered, "a round-trip should work");
assert_eq!(b, b_recovered, "b round-trip should work");
assert_eq!(c, c_recovered, "c round-trip should work");
let scalar_sum = a_scalar + b_scalar + c_scalar;
let _scalar_sum_recovered =
Ristretto255Group::scalar_to_bigint(&scalar_sum);
let bigint_sum = &a + &b + &c;
let bigint_sum_mod = &bigint_sum % ℴ
let from_bigint = Ristretto255Group::bigint_to_scalar(&bigint_sum_mod);
assert_eq!(
scalar_sum, from_bigint,
"Scalar sum should equal BigInt sum converted to Scalar"
);
}
#[test]
fn test_scalar_sum_large_values() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint().clone();
let a = BigInt::parse_bytes(b"2000000000000000000000000000000000000000000000000000000000000000000000000000", 10).unwrap();
let b = BigInt::parse_bytes(b"2000000000000000000000000000000000000000000000000000000000000000000000000000", 10).unwrap();
let c = BigInt::parse_bytes(b"2000000000000000000000000000000000000000000000000000000000000000000000000000", 10).unwrap();
let a_scalar = Ristretto255Group::bigint_to_scalar(&a);
let b_scalar = Ristretto255Group::bigint_to_scalar(&b);
let c_scalar = Ristretto255Group::bigint_to_scalar(&c);
let a_recovered = Ristretto255Group::scalar_to_bigint(&a_scalar);
let b_recovered = Ristretto255Group::scalar_to_bigint(&b_scalar);
let c_recovered = Ristretto255Group::scalar_to_bigint(&c_scalar);
assert_eq!(a, a_recovered, "a round-trip should work");
assert_eq!(b, b_recovered, "b round-trip should work");
assert_eq!(c, c_recovered, "c round-trip should work");
let scalar_sum = a_scalar + b_scalar + c_scalar;
let _scalar_sum_recovered =
Ristretto255Group::scalar_to_bigint(&scalar_sum);
let bigint_sum = &a + &b + &c;
let bigint_sum_mod = &bigint_sum % ℴ
let from_bigint = Ristretto255Group::bigint_to_scalar(&bigint_sum_mod);
assert_eq!(
scalar_sum, from_bigint,
"Scalar sum should equal BigInt sum converted to Scalar"
);
}
#[test]
fn test_scalar_sum_individual() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint().clone();
let a = BigInt::from(1000u64);
let b = BigInt::from(2000u64);
let c = BigInt::from(3000u64);
let a_scalar = Ristretto255Group::bigint_to_scalar(&a);
let b_scalar = Ristretto255Group::bigint_to_scalar(&b);
let c_scalar = Ristretto255Group::bigint_to_scalar(&c);
let scalar_sum = a_scalar + b_scalar + c_scalar;
let bigint_sum = &a + &b + &c;
let bigint_sum_mod = &bigint_sum % ℴ
let from_bigint = Ristretto255Group::bigint_to_scalar(&bigint_sum_mod);
let sum_recovered = Ristretto255Group::scalar_to_bigint(&scalar_sum);
assert_eq!(
sum_recovered,
BigInt::from(6000u64),
"Scalar sum should be 6000"
);
assert_eq!(
scalar_sum, from_bigint,
"Scalar sum should equal BigInt sum converted to Scalar"
);
}
#[test]
fn test_commitment_sum() {
let group = Ristretto255Group::new();
let g = group.generator();
let a_0 = group.generate_private_key();
let a_1 = group.generate_private_key();
let a_2 = group.generate_private_key();
let c_0 = group.exp(&g, &a_0);
let c_1 = group.exp(&g, &a_1);
let c_2 = group.exp(&g, &a_2);
let c_sum = group.mul(&c_0, &group.mul(&c_1, &c_2));
let a_sum = a_0 + a_1 + a_2;
let g_sum = group.exp(&g, &a_sum);
assert_eq!(
c_sum, g_sum,
"C_0 + C_1 + C_2 should equal g^(a_0 + a_1 + a_2)"
);
}
#[test]
fn test_dleq_verification_math() {
let group = Ristretto255Group::new();
let g = group.generator();
let w = group.generate_private_key();
let alpha = group.generate_private_key();
let c = group.hash_to_scalar(b"test challenge");
let x = group.exp(&g, &alpha);
let alpha_c = group.scalar_mul(&alpha, &c);
let r = group.scalar_sub(&w, &alpha_c);
let g_w = group.exp(&g, &w);
let g_r = group.exp(&g, &r);
let x_c = group.exp(&x, &c);
let verify_a1 = group.mul(&g_r, &x_c);
assert_eq!(g_w, verify_a1, "g^w should equal g^r * X^c");
}
#[test]
fn test_scalar_mult_consistency() {
let group = Ristretto255Group::new();
let g = group.generator();
let a = Scalar::from(5u64);
let b = Scalar::from(7u64);
let a_plus_b = a + b;
let g_a = group.exp(&g, &a);
let g_b = group.exp(&g, &b);
let g_a_plus_b = group.exp(&g, &a_plus_b);
let g_a_mul_g_b = group.mul(&g_a, &g_b);
assert_eq!(g_a_plus_b, g_a_mul_g_b, "g^(a+b) should equal g^a * g^b");
}
#[test]
fn test_ristretto255_group_new() {
let group = Ristretto255Group::new();
let order = group.order_as_bigint();
assert_ne!(*order, num_bigint::BigInt::from(1u32));
assert!(*order > num_bigint::BigInt::from(1u64) << 250);
}
#[test]
fn test_generate_keypair() {
let group = Ristretto255Group::new();
let privkey = group.generate_private_key();
let pubkey = group.generate_public_key(&privkey);
assert_ne!(pubkey, RistrettoPoint::identity());
}
#[test]
fn test_exp() {
let group = Ristretto255Group::new();
let g = group.generator();
let one = Scalar::ONE;
assert_eq!(group.exp(&g, &one), g);
assert_eq!(group.exp(&g, &Scalar::ZERO), group.identity());
}
#[test]
fn test_mul() {
let group = Ristretto255Group::new();
let g = group.generator();
let g_plus_g = group.mul(&g, &g);
let two_g = group.exp(&g, &Scalar::from(2u64));
assert_eq!(g_plus_g, two_g);
}
#[test]
fn test_scalar_inverse() {
let group = Ristretto255Group::new();
let x = Scalar::from(5u64);
let inv = group.scalar_inverse(&x).unwrap();
let result = x * inv;
assert_eq!(result, Scalar::ONE);
}
#[test]
fn test_scalar_inverse_zero() {
let group = Ristretto255Group::new();
let result = group.scalar_inverse(&Scalar::ZERO);
assert!(result.is_none());
}
#[test]
fn test_element_inverse() {
let group = Ristretto255Group::new();
let g = group.generator();
let neg_g = group.element_inverse(&g).unwrap();
let result = group.mul(&g, &neg_g);
assert_eq!(result, RistrettoPoint::identity());
}
#[test]
fn test_hash_to_scalar() {
let group = Ristretto255Group::new();
let data = b"test data";
let scalar = group.hash_to_scalar(data);
assert_ne!(scalar, Scalar::ZERO);
}
#[test]
fn test_serialize_roundtrip() {
let group = Ristretto255Group::new();
let g = group.generator();
let bytes = group.element_to_bytes(&g);
assert_eq!(bytes.len(), 32);
let restored = group.bytes_to_element(&bytes).unwrap();
assert_eq!(g, restored);
}
#[test]
fn test_scalar_serialize_roundtrip() {
let group = Ristretto255Group::new();
let scalar = Scalar::from(42u64);
let bytes = group.scalar_to_bytes(&scalar);
assert_eq!(bytes.len(), 32);
}
#[test]
fn test_bigint_scalar_conversion() {
let original = BigInt::from(123456789u64);
let scalar = Ristretto255Group::bigint_to_scalar(&original);
let recovered = Ristretto255Group::scalar_to_bigint(&scalar);
assert_eq!(original, recovered);
}
#[test]
fn test_bigint_scalar_conversion_large() {
let original = BigInt::from(1u64) << 200;
let scalar = Ristretto255Group::bigint_to_scalar(&original);
let recovered = Ristretto255Group::scalar_to_bigint(&scalar);
assert_eq!(original, recovered);
}
}