use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use thiserror::Error;
fn random_scalar() -> Scalar {
let mut rng = rand::rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
Scalar::from_bytes_mod_order(bytes)
}
fn random_point() -> RistrettoPoint {
RISTRETTO_BASEPOINT_POINT * random_scalar()
}
#[derive(Error, Debug)]
pub enum BulletproofError {
#[error("Invalid proof")]
InvalidProof,
#[error("Invalid commitment")]
InvalidCommitment,
#[error("Value out of range")]
ValueOutOfRange,
#[error("Invalid parameters")]
InvalidParameters,
#[error("Serialization error: {0}")]
SerializationError(String),
}
pub type BulletproofResult<T> = Result<T, BulletproofError>;
#[derive(Clone, Debug)]
pub struct BulletproofParams {
pub bit_length: usize,
g: RistrettoPoint,
h: RistrettoPoint,
generators: Vec<RistrettoPoint>,
}
impl BulletproofParams {
pub fn new(bit_length: usize) -> Self {
let g = random_point();
let h = random_point();
let generators = (0..bit_length).map(|_| random_point()).collect();
Self {
bit_length,
g,
h,
generators,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BulletproofCommitment {
#[serde(with = "serde_ristretto")]
point: RistrettoPoint,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BulletproofRangeProof {
#[serde(with = "serde_ristretto_vec")]
bit_commitments: Vec<RistrettoPoint>,
#[serde(with = "serde_ristretto_vec")]
initial_commitments: Vec<RistrettoPoint>,
#[serde(with = "serde_scalar")]
challenge: Scalar,
#[serde(with = "serde_scalar_vec")]
bit_responses: Vec<Scalar>,
#[serde(with = "serde_scalar_vec")]
blinding_responses: Vec<Scalar>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AggregatedBulletproof {
commitments: Vec<BulletproofCommitment>,
proof: BulletproofRangeProof,
}
pub fn prove_range(
params: &BulletproofParams,
value: u64,
) -> BulletproofResult<(BulletproofCommitment, BulletproofRangeProof)> {
if params.bit_length < 64 && value >= (1u64 << params.bit_length) {
return Err(BulletproofError::ValueOutOfRange);
}
let blinding = random_scalar();
let commitment_point = params.g * Scalar::from(value) + params.h * blinding;
let commitment = BulletproofCommitment {
point: commitment_point,
};
let bits: Vec<bool> = (0..params.bit_length)
.map(|i| (value >> i) & 1 == 1)
.collect();
let bit_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
let bit_commitments: Vec<RistrettoPoint> = bits
.iter()
.zip(&bit_blindings)
.zip(¶ms.generators)
.map(|((bit, blinding), generator)| {
let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
generator * bit_scalar + params.h * blinding
})
.collect();
let initial_bit_values: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
let initial_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
let initial_commitments: Vec<RistrettoPoint> = initial_bit_values
.iter()
.zip(&initial_blindings)
.zip(¶ms.generators)
.map(|((a, t), generator)| generator * a + params.h * t)
.collect();
let challenge =
generate_challenge_full(&commitment_point, &bit_commitments, &initial_commitments);
let bit_responses: Vec<Scalar> = bits
.iter()
.zip(&initial_bit_values)
.map(|(bit, a)| {
let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
a + challenge * bit_scalar
})
.collect();
let blinding_responses: Vec<Scalar> = bit_blindings
.iter()
.zip(&initial_blindings)
.map(|(r, t)| t + challenge * r)
.collect();
let proof = BulletproofRangeProof {
bit_commitments,
initial_commitments,
challenge,
bit_responses,
blinding_responses,
};
Ok((commitment, proof))
}
pub fn verify_range(
params: &BulletproofParams,
commitment: &BulletproofCommitment,
proof: &BulletproofRangeProof,
) -> BulletproofResult<()> {
if proof.bit_commitments.len() != params.bit_length
|| proof.initial_commitments.len() != params.bit_length
|| proof.bit_responses.len() != params.bit_length
|| proof.blinding_responses.len() != params.bit_length
{
return Err(BulletproofError::InvalidProof);
}
let challenge = generate_challenge_full(
&commitment.point,
&proof.bit_commitments,
&proof.initial_commitments,
);
if challenge != proof.challenge {
return Err(BulletproofError::InvalidProof);
}
for i in 0..params.bit_length {
let lhs =
params.generators[i] * proof.bit_responses[i] + params.h * proof.blinding_responses[i];
let rhs = proof.initial_commitments[i] + proof.bit_commitments[i] * challenge;
if lhs != rhs {
return Err(BulletproofError::InvalidProof);
}
}
Ok(())
}
pub fn prove_range_aggregated(
params: &BulletproofParams,
values: &[u64],
) -> BulletproofResult<AggregatedBulletproof> {
if values.is_empty() {
return Err(BulletproofError::InvalidParameters);
}
struct ProofData {
bits: Vec<bool>,
bit_blindings: Vec<Scalar>,
initial_bit_values: Vec<Scalar>,
initial_blindings: Vec<Scalar>,
}
let mut commitments = Vec::new();
let mut all_bit_commitments = Vec::new();
let mut all_initial_commitments = Vec::new();
let mut proof_data_vec = Vec::new();
for value in values {
if params.bit_length < 64 && *value >= (1u64 << params.bit_length) {
return Err(BulletproofError::ValueOutOfRange);
}
let blinding = random_scalar();
let commitment_point = params.g * Scalar::from(*value) + params.h * blinding;
commitments.push(BulletproofCommitment {
point: commitment_point,
});
let bits: Vec<bool> = (0..params.bit_length)
.map(|i| (*value >> i) & 1 == 1)
.collect();
let bit_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
let bit_commitments: Vec<RistrettoPoint> = bits
.iter()
.zip(&bit_blindings)
.zip(¶ms.generators)
.map(|((bit, blinding), generator)| {
let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
generator * bit_scalar + params.h * blinding
})
.collect();
all_bit_commitments.extend(bit_commitments);
let initial_bit_values: Vec<Scalar> =
(0..params.bit_length).map(|_| random_scalar()).collect();
let initial_blindings: Vec<Scalar> =
(0..params.bit_length).map(|_| random_scalar()).collect();
let initial_commitments: Vec<RistrettoPoint> = initial_bit_values
.iter()
.zip(&initial_blindings)
.zip(¶ms.generators)
.map(|((a, t), generator)| generator * a + params.h * t)
.collect();
all_initial_commitments.extend(initial_commitments.clone());
proof_data_vec.push(ProofData {
bits,
bit_blindings,
initial_bit_values,
initial_blindings,
});
}
let all_points: Vec<_> = commitments.iter().map(|c| c.point).collect();
let challenge =
generate_challenge_multi_full(&all_points, &all_bit_commitments, &all_initial_commitments);
let mut all_bit_responses = Vec::new();
let mut all_blinding_responses = Vec::new();
for proof_data in proof_data_vec {
for (bit_idx, bit) in proof_data.bits.iter().enumerate() {
let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
let bit_response = proof_data.initial_bit_values[bit_idx] + challenge * bit_scalar;
all_bit_responses.push(bit_response);
let blinding_response = proof_data.initial_blindings[bit_idx]
+ challenge * proof_data.bit_blindings[bit_idx];
all_blinding_responses.push(blinding_response);
}
}
let proof = BulletproofRangeProof {
bit_commitments: all_bit_commitments,
initial_commitments: all_initial_commitments,
challenge,
bit_responses: all_bit_responses,
blinding_responses: all_blinding_responses,
};
Ok(AggregatedBulletproof { commitments, proof })
}
pub fn verify_aggregated(
params: &BulletproofParams,
aggregated: &AggregatedBulletproof,
) -> BulletproofResult<()> {
if aggregated.commitments.is_empty() {
return Err(BulletproofError::InvalidParameters);
}
let expected_bits = params.bit_length * aggregated.commitments.len();
if aggregated.proof.bit_commitments.len() != expected_bits
|| aggregated.proof.initial_commitments.len() != expected_bits
|| aggregated.proof.bit_responses.len() != expected_bits
|| aggregated.proof.blinding_responses.len() != expected_bits
{
return Err(BulletproofError::InvalidProof);
}
let all_points: Vec<_> = aggregated.commitments.iter().map(|c| c.point).collect();
let challenge = generate_challenge_multi_full(
&all_points,
&aggregated.proof.bit_commitments,
&aggregated.proof.initial_commitments,
);
if challenge != aggregated.proof.challenge {
return Err(BulletproofError::InvalidProof);
}
for i in 0..expected_bits {
let generator_idx = i % params.bit_length;
let lhs = params.generators[generator_idx] * aggregated.proof.bit_responses[i]
+ params.h * aggregated.proof.blinding_responses[i];
let rhs = aggregated.proof.initial_commitments[i]
+ aggregated.proof.bit_commitments[i] * challenge;
if lhs != rhs {
return Err(BulletproofError::InvalidProof);
}
}
Ok(())
}
fn generate_challenge_full(
commitment: &RistrettoPoint,
bit_commitments: &[RistrettoPoint],
initial_commitments: &[RistrettoPoint],
) -> Scalar {
let mut hasher = blake3::Hasher::new();
hasher.update(commitment.compress().as_bytes());
for bc in bit_commitments {
hasher.update(bc.compress().as_bytes());
}
for ic in initial_commitments {
hasher.update(ic.compress().as_bytes());
}
let hash = hasher.finalize();
Scalar::from_bytes_mod_order(*hash.as_bytes())
}
fn generate_challenge_multi_full(
commitments: &[RistrettoPoint],
bit_commitments: &[RistrettoPoint],
initial_commitments: &[RistrettoPoint],
) -> Scalar {
let mut hasher = blake3::Hasher::new();
for c in commitments {
hasher.update(c.compress().as_bytes());
}
for bc in bit_commitments {
hasher.update(bc.compress().as_bytes());
}
for ic in initial_commitments {
hasher.update(ic.compress().as_bytes());
}
let hash = hasher.finalize();
Scalar::from_bytes_mod_order(*hash.as_bytes())
}
pub mod serde_ristretto {
use super::*;
use serde::{Deserializer, Serializer};
pub fn serialize<S>(point: &RistrettoPoint, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(point.compress().as_bytes())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
where
D: Deserializer<'de>,
{
let bytes: Vec<u8> = serde::Deserialize::deserialize(deserializer)?;
let compressed =
CompressedRistretto::from_slice(&bytes).map_err(serde::de::Error::custom)?;
compressed
.decompress()
.ok_or_else(|| serde::de::Error::custom("Invalid Ristretto point"))
}
}
pub mod serde_ristretto_vec {
use super::*;
use serde::{Deserializer, Serializer};
pub fn serialize<S>(points: &[RistrettoPoint], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes: Vec<Vec<u8>> = points
.iter()
.map(|p| p.compress().as_bytes().to_vec())
.collect();
bytes.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<RistrettoPoint>, D::Error>
where
D: Deserializer<'de>,
{
let bytes_vec: Vec<Vec<u8>> = serde::Deserialize::deserialize(deserializer)?;
bytes_vec
.iter()
.map(|bytes| {
let compressed =
CompressedRistretto::from_slice(bytes).map_err(serde::de::Error::custom)?;
compressed
.decompress()
.ok_or_else(|| serde::de::Error::custom("Invalid Ristretto point"))
})
.collect()
}
}
pub mod serde_scalar {
use super::*;
use serde::{Deserializer, Serializer};
pub fn serialize<S>(scalar: &Scalar, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&scalar.to_bytes())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Scalar, D::Error>
where
D: Deserializer<'de>,
{
let bytes: Vec<u8> = serde::Deserialize::deserialize(deserializer)?;
if bytes.len() != 32 {
return Err(serde::de::Error::custom("Invalid scalar length"));
}
let mut array = [0u8; 32];
array.copy_from_slice(&bytes);
Ok(Scalar::from_bytes_mod_order(array))
}
}
pub mod serde_scalar_vec {
use super::*;
use serde::{Deserializer, Serializer};
pub fn serialize<S>(scalars: &[Scalar], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes: Vec<Vec<u8>> = scalars.iter().map(|s| s.to_bytes().to_vec()).collect();
bytes.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<Scalar>, D::Error>
where
D: Deserializer<'de>,
{
let bytes_vec: Vec<Vec<u8>> = serde::Deserialize::deserialize(deserializer)?;
bytes_vec
.iter()
.map(|bytes| {
if bytes.len() != 32 {
return Err(serde::de::Error::custom("Invalid scalar length"));
}
let mut array = [0u8; 32];
array.copy_from_slice(bytes);
Ok(Scalar::from_bytes_mod_order(array))
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bulletproof_basic() {
let params = BulletproofParams::new(32);
let value = 1000u64;
let (commitment, proof) = prove_range(¶ms, value).unwrap();
assert!(verify_range(¶ms, &commitment, &proof).is_ok());
}
#[test]
fn test_bulletproof_zero() {
let params = BulletproofParams::new(32);
let value = 0u64;
let (commitment, proof) = prove_range(¶ms, value).unwrap();
assert!(verify_range(¶ms, &commitment, &proof).is_ok());
}
#[test]
fn test_bulletproof_max_value() {
let params = BulletproofParams::new(8);
let value = 255u64;
let (commitment, proof) = prove_range(¶ms, value).unwrap();
assert!(verify_range(¶ms, &commitment, &proof).is_ok());
}
#[test]
fn test_bulletproof_out_of_range() {
let params = BulletproofParams::new(8);
let value = 256u64;
assert!(prove_range(¶ms, value).is_err());
}
#[test]
fn test_bulletproof_64bit() {
let params = BulletproofParams::new(64);
let value = u64::MAX;
let (commitment, proof) = prove_range(¶ms, value).unwrap();
assert!(verify_range(¶ms, &commitment, &proof).is_ok());
}
#[test]
fn test_bulletproof_aggregated() {
let params = BulletproofParams::new(32);
let values = vec![100u64, 200u64, 300u64];
let aggregated = prove_range_aggregated(¶ms, &values).unwrap();
assert_eq!(aggregated.commitments.len(), 3);
assert!(verify_aggregated(¶ms, &aggregated).is_ok());
}
#[test]
fn test_bulletproof_serialization() {
let params = BulletproofParams::new(32);
let value = 1000u64;
let (commitment, proof) = prove_range(¶ms, value).unwrap();
let commitment_bytes = crate::codec::encode(&commitment).unwrap();
let proof_bytes = crate::codec::encode(&proof).unwrap();
let commitment2: BulletproofCommitment = crate::codec::decode(&commitment_bytes).unwrap();
let proof2: BulletproofRangeProof = crate::codec::decode(&proof_bytes).unwrap();
assert!(verify_range(¶ms, &commitment2, &proof2).is_ok());
}
#[test]
fn test_bulletproof_different_bit_lengths() {
for bit_length in [8, 16, 32, 48] {
let params = BulletproofParams::new(bit_length);
let max_value = (1u64 << bit_length) - 1;
let (commitment, proof) = prove_range(¶ms, max_value).unwrap();
assert!(verify_range(¶ms, &commitment, &proof).is_ok());
}
}
}