1use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
37use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
38use curve25519_dalek::scalar::Scalar;
39use rand::Rng;
40use serde::{Deserialize, Serialize};
41use thiserror::Error;
42
43fn random_scalar() -> Scalar {
45 let mut rng = rand::thread_rng();
46 let mut bytes = [0u8; 32];
47 rng.fill(&mut bytes);
48 Scalar::from_bytes_mod_order(bytes)
49}
50
51fn random_point() -> RistrettoPoint {
53 RISTRETTO_BASEPOINT_POINT * random_scalar()
54}
55
56#[derive(Error, Debug)]
58pub enum BulletproofError {
59 #[error("Invalid proof")]
60 InvalidProof,
61 #[error("Invalid commitment")]
62 InvalidCommitment,
63 #[error("Value out of range")]
64 ValueOutOfRange,
65 #[error("Invalid parameters")]
66 InvalidParameters,
67 #[error("Serialization error: {0}")]
68 SerializationError(String),
69}
70
71pub type BulletproofResult<T> = Result<T, BulletproofError>;
72
73#[derive(Clone, Debug)]
78pub struct BulletproofParams {
79 pub bit_length: usize,
81 g: RistrettoPoint,
83 h: RistrettoPoint,
85 generators: Vec<RistrettoPoint>,
87}
88
89impl BulletproofParams {
90 pub fn new(bit_length: usize) -> Self {
105 let g = random_point();
107 let h = random_point();
108
109 let generators = (0..bit_length).map(|_| random_point()).collect();
111
112 Self {
113 bit_length,
114 g,
115 h,
116 generators,
117 }
118 }
119}
120
121#[derive(Clone, Debug, Serialize, Deserialize)]
123pub struct BulletproofCommitment {
124 #[serde(with = "serde_ristretto")]
126 point: RistrettoPoint,
127}
128
129#[derive(Clone, Debug, Serialize, Deserialize)]
134pub struct BulletproofRangeProof {
135 #[serde(with = "serde_ristretto_vec")]
137 bit_commitments: Vec<RistrettoPoint>,
138 #[serde(with = "serde_ristretto_vec")]
140 initial_commitments: Vec<RistrettoPoint>,
141 #[serde(with = "serde_scalar")]
143 challenge: Scalar,
144 #[serde(with = "serde_scalar_vec")]
146 bit_responses: Vec<Scalar>,
147 #[serde(with = "serde_scalar_vec")]
149 blinding_responses: Vec<Scalar>,
150}
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct AggregatedBulletproof {
158 commitments: Vec<BulletproofCommitment>,
160 proof: BulletproofRangeProof,
162}
163
164pub fn prove_range(
183 params: &BulletproofParams,
184 value: u64,
185) -> BulletproofResult<(BulletproofCommitment, BulletproofRangeProof)> {
186 if params.bit_length < 64 && value >= (1u64 << params.bit_length) {
188 return Err(BulletproofError::ValueOutOfRange);
189 }
190
191 let blinding = random_scalar();
193
194 let commitment_point = params.g * Scalar::from(value) + params.h * blinding;
196 let commitment = BulletproofCommitment {
197 point: commitment_point,
198 };
199
200 let bits: Vec<bool> = (0..params.bit_length)
202 .map(|i| (value >> i) & 1 == 1)
203 .collect();
204
205 let bit_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
207
208 let bit_commitments: Vec<RistrettoPoint> = bits
210 .iter()
211 .zip(&bit_blindings)
212 .zip(¶ms.generators)
213 .map(|((bit, blinding), generator)| {
214 let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
215 generator * bit_scalar + params.h * blinding
216 })
217 .collect();
218
219 let initial_bit_values: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
221 let initial_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
222
223 let initial_commitments: Vec<RistrettoPoint> = initial_bit_values
225 .iter()
226 .zip(&initial_blindings)
227 .zip(¶ms.generators)
228 .map(|((a, t), generator)| generator * a + params.h * t)
229 .collect();
230
231 let challenge =
233 generate_challenge_full(&commitment_point, &bit_commitments, &initial_commitments);
234
235 let bit_responses: Vec<Scalar> = bits
238 .iter()
239 .zip(&initial_bit_values)
240 .map(|(bit, a)| {
241 let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
242 a + challenge * bit_scalar
243 })
244 .collect();
245
246 let blinding_responses: Vec<Scalar> = bit_blindings
247 .iter()
248 .zip(&initial_blindings)
249 .map(|(r, t)| t + challenge * r)
250 .collect();
251
252 let proof = BulletproofRangeProof {
253 bit_commitments,
254 initial_commitments,
255 challenge,
256 bit_responses,
257 blinding_responses,
258 };
259
260 Ok((commitment, proof))
261}
262
263pub fn verify_range(
281 params: &BulletproofParams,
282 commitment: &BulletproofCommitment,
283 proof: &BulletproofRangeProof,
284) -> BulletproofResult<()> {
285 if proof.bit_commitments.len() != params.bit_length
287 || proof.initial_commitments.len() != params.bit_length
288 || proof.bit_responses.len() != params.bit_length
289 || proof.blinding_responses.len() != params.bit_length
290 {
291 return Err(BulletproofError::InvalidProof);
292 }
293
294 let challenge = generate_challenge_full(
296 &commitment.point,
297 &proof.bit_commitments,
298 &proof.initial_commitments,
299 );
300 if challenge != proof.challenge {
301 return Err(BulletproofError::InvalidProof);
302 }
303
304 for i in 0..params.bit_length {
309 let lhs =
310 params.generators[i] * proof.bit_responses[i] + params.h * proof.blinding_responses[i];
311 let rhs = proof.initial_commitments[i] + proof.bit_commitments[i] * challenge;
312
313 if lhs != rhs {
314 return Err(BulletproofError::InvalidProof);
315 }
316 }
317
318 Ok(())
319}
320
321pub fn prove_range_aggregated(
330 params: &BulletproofParams,
331 values: &[u64],
332) -> BulletproofResult<AggregatedBulletproof> {
333 if values.is_empty() {
334 return Err(BulletproofError::InvalidParameters);
335 }
336
337 struct ProofData {
339 bits: Vec<bool>,
340 bit_blindings: Vec<Scalar>,
341 initial_bit_values: Vec<Scalar>,
342 initial_blindings: Vec<Scalar>,
343 }
344
345 let mut commitments = Vec::new();
346 let mut all_bit_commitments = Vec::new();
347 let mut all_initial_commitments = Vec::new();
348 let mut proof_data_vec = Vec::new();
349
350 for value in values {
351 if params.bit_length < 64 && *value >= (1u64 << params.bit_length) {
352 return Err(BulletproofError::ValueOutOfRange);
353 }
354
355 let blinding = random_scalar();
356 let commitment_point = params.g * Scalar::from(*value) + params.h * blinding;
357
358 commitments.push(BulletproofCommitment {
359 point: commitment_point,
360 });
361
362 let bits: Vec<bool> = (0..params.bit_length)
364 .map(|i| (*value >> i) & 1 == 1)
365 .collect();
366
367 let bit_blindings: Vec<Scalar> = (0..params.bit_length).map(|_| random_scalar()).collect();
368
369 let bit_commitments: Vec<RistrettoPoint> = bits
370 .iter()
371 .zip(&bit_blindings)
372 .zip(¶ms.generators)
373 .map(|((bit, blinding), generator)| {
374 let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
375 generator * bit_scalar + params.h * blinding
376 })
377 .collect();
378
379 all_bit_commitments.extend(bit_commitments);
380
381 let initial_bit_values: Vec<Scalar> =
383 (0..params.bit_length).map(|_| random_scalar()).collect();
384 let initial_blindings: Vec<Scalar> =
385 (0..params.bit_length).map(|_| random_scalar()).collect();
386
387 let initial_commitments: Vec<RistrettoPoint> = initial_bit_values
388 .iter()
389 .zip(&initial_blindings)
390 .zip(¶ms.generators)
391 .map(|((a, t), generator)| generator * a + params.h * t)
392 .collect();
393
394 all_initial_commitments.extend(initial_commitments.clone());
395
396 proof_data_vec.push(ProofData {
398 bits,
399 bit_blindings,
400 initial_bit_values,
401 initial_blindings,
402 });
403 }
404
405 let all_points: Vec<_> = commitments.iter().map(|c| c.point).collect();
407 let challenge =
408 generate_challenge_multi_full(&all_points, &all_bit_commitments, &all_initial_commitments);
409
410 let mut all_bit_responses = Vec::new();
412 let mut all_blinding_responses = Vec::new();
413
414 for proof_data in proof_data_vec {
415 for (bit_idx, bit) in proof_data.bits.iter().enumerate() {
416 let bit_scalar = if *bit { Scalar::ONE } else { Scalar::ZERO };
417 let bit_response = proof_data.initial_bit_values[bit_idx] + challenge * bit_scalar;
419 all_bit_responses.push(bit_response);
420
421 let blinding_response = proof_data.initial_blindings[bit_idx]
423 + challenge * proof_data.bit_blindings[bit_idx];
424 all_blinding_responses.push(blinding_response);
425 }
426 }
427
428 let proof = BulletproofRangeProof {
429 bit_commitments: all_bit_commitments,
430 initial_commitments: all_initial_commitments,
431 challenge,
432 bit_responses: all_bit_responses,
433 blinding_responses: all_blinding_responses,
434 };
435
436 Ok(AggregatedBulletproof { commitments, proof })
437}
438
439pub fn verify_aggregated(
441 params: &BulletproofParams,
442 aggregated: &AggregatedBulletproof,
443) -> BulletproofResult<()> {
444 if aggregated.commitments.is_empty() {
445 return Err(BulletproofError::InvalidParameters);
446 }
447
448 let expected_bits = params.bit_length * aggregated.commitments.len();
449
450 if aggregated.proof.bit_commitments.len() != expected_bits
451 || aggregated.proof.initial_commitments.len() != expected_bits
452 || aggregated.proof.bit_responses.len() != expected_bits
453 || aggregated.proof.blinding_responses.len() != expected_bits
454 {
455 return Err(BulletproofError::InvalidProof);
456 }
457
458 let all_points: Vec<_> = aggregated.commitments.iter().map(|c| c.point).collect();
460 let challenge = generate_challenge_multi_full(
461 &all_points,
462 &aggregated.proof.bit_commitments,
463 &aggregated.proof.initial_commitments,
464 );
465
466 if challenge != aggregated.proof.challenge {
467 return Err(BulletproofError::InvalidProof);
468 }
469
470 for i in 0..expected_bits {
472 let generator_idx = i % params.bit_length;
473 let lhs = params.generators[generator_idx] * aggregated.proof.bit_responses[i]
474 + params.h * aggregated.proof.blinding_responses[i];
475 let rhs = aggregated.proof.initial_commitments[i]
476 + aggregated.proof.bit_commitments[i] * challenge;
477
478 if lhs != rhs {
479 return Err(BulletproofError::InvalidProof);
480 }
481 }
482
483 Ok(())
484}
485
486fn generate_challenge_full(
488 commitment: &RistrettoPoint,
489 bit_commitments: &[RistrettoPoint],
490 initial_commitments: &[RistrettoPoint],
491) -> Scalar {
492 let mut hasher = blake3::Hasher::new();
493 hasher.update(commitment.compress().as_bytes());
494
495 for bc in bit_commitments {
496 hasher.update(bc.compress().as_bytes());
497 }
498
499 for ic in initial_commitments {
500 hasher.update(ic.compress().as_bytes());
501 }
502
503 let hash = hasher.finalize();
504 Scalar::from_bytes_mod_order(*hash.as_bytes())
505}
506
507fn generate_challenge_multi_full(
509 commitments: &[RistrettoPoint],
510 bit_commitments: &[RistrettoPoint],
511 initial_commitments: &[RistrettoPoint],
512) -> Scalar {
513 let mut hasher = blake3::Hasher::new();
514
515 for c in commitments {
516 hasher.update(c.compress().as_bytes());
517 }
518
519 for bc in bit_commitments {
520 hasher.update(bc.compress().as_bytes());
521 }
522
523 for ic in initial_commitments {
524 hasher.update(ic.compress().as_bytes());
525 }
526
527 let hash = hasher.finalize();
528 Scalar::from_bytes_mod_order(*hash.as_bytes())
529}
530
531pub mod serde_ristretto {
533 use super::*;
534 use serde::{Deserializer, Serializer};
535
536 pub fn serialize<S>(point: &RistrettoPoint, serializer: S) -> Result<S::Ok, S::Error>
537 where
538 S: Serializer,
539 {
540 serializer.serialize_bytes(point.compress().as_bytes())
541 }
542
543 pub fn deserialize<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
544 where
545 D: Deserializer<'de>,
546 {
547 let bytes: Vec<u8> = serde::Deserialize::deserialize(deserializer)?;
548 let compressed =
549 CompressedRistretto::from_slice(&bytes).map_err(serde::de::Error::custom)?;
550 compressed
551 .decompress()
552 .ok_or_else(|| serde::de::Error::custom("Invalid Ristretto point"))
553 }
554}
555
556pub mod serde_ristretto_vec {
557 use super::*;
558 use serde::{Deserializer, Serializer};
559
560 pub fn serialize<S>(points: &[RistrettoPoint], serializer: S) -> Result<S::Ok, S::Error>
561 where
562 S: Serializer,
563 {
564 let bytes: Vec<Vec<u8>> = points
565 .iter()
566 .map(|p| p.compress().as_bytes().to_vec())
567 .collect();
568 bytes.serialize(serializer)
569 }
570
571 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<RistrettoPoint>, D::Error>
572 where
573 D: Deserializer<'de>,
574 {
575 let bytes_vec: Vec<Vec<u8>> = serde::Deserialize::deserialize(deserializer)?;
576 bytes_vec
577 .iter()
578 .map(|bytes| {
579 let compressed =
580 CompressedRistretto::from_slice(bytes).map_err(serde::de::Error::custom)?;
581 compressed
582 .decompress()
583 .ok_or_else(|| serde::de::Error::custom("Invalid Ristretto point"))
584 })
585 .collect()
586 }
587}
588
589pub mod serde_scalar {
590 use super::*;
591 use serde::{Deserializer, Serializer};
592
593 pub fn serialize<S>(scalar: &Scalar, serializer: S) -> Result<S::Ok, S::Error>
594 where
595 S: Serializer,
596 {
597 serializer.serialize_bytes(&scalar.to_bytes())
598 }
599
600 pub fn deserialize<'de, D>(deserializer: D) -> Result<Scalar, D::Error>
601 where
602 D: Deserializer<'de>,
603 {
604 let bytes: Vec<u8> = serde::Deserialize::deserialize(deserializer)?;
605 if bytes.len() != 32 {
606 return Err(serde::de::Error::custom("Invalid scalar length"));
607 }
608 let mut array = [0u8; 32];
609 array.copy_from_slice(&bytes);
610 Ok(Scalar::from_bytes_mod_order(array))
611 }
612}
613
614pub mod serde_scalar_vec {
615 use super::*;
616 use serde::{Deserializer, Serializer};
617
618 pub fn serialize<S>(scalars: &[Scalar], serializer: S) -> Result<S::Ok, S::Error>
619 where
620 S: Serializer,
621 {
622 let bytes: Vec<Vec<u8>> = scalars.iter().map(|s| s.to_bytes().to_vec()).collect();
623 bytes.serialize(serializer)
624 }
625
626 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<Scalar>, D::Error>
627 where
628 D: Deserializer<'de>,
629 {
630 let bytes_vec: Vec<Vec<u8>> = serde::Deserialize::deserialize(deserializer)?;
631 bytes_vec
632 .iter()
633 .map(|bytes| {
634 if bytes.len() != 32 {
635 return Err(serde::de::Error::custom("Invalid scalar length"));
636 }
637 let mut array = [0u8; 32];
638 array.copy_from_slice(bytes);
639 Ok(Scalar::from_bytes_mod_order(array))
640 })
641 .collect()
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_bulletproof_basic() {
651 let params = BulletproofParams::new(32);
652 let value = 1000u64;
653
654 let (commitment, proof) = prove_range(¶ms, value).unwrap();
655 assert!(verify_range(¶ms, &commitment, &proof).is_ok());
656 }
657
658 #[test]
659 fn test_bulletproof_zero() {
660 let params = BulletproofParams::new(32);
661 let value = 0u64;
662
663 let (commitment, proof) = prove_range(¶ms, value).unwrap();
664 assert!(verify_range(¶ms, &commitment, &proof).is_ok());
665 }
666
667 #[test]
668 fn test_bulletproof_max_value() {
669 let params = BulletproofParams::new(8);
670 let value = 255u64; let (commitment, proof) = prove_range(¶ms, value).unwrap();
673 assert!(verify_range(¶ms, &commitment, &proof).is_ok());
674 }
675
676 #[test]
677 fn test_bulletproof_out_of_range() {
678 let params = BulletproofParams::new(8);
679 let value = 256u64; assert!(prove_range(¶ms, value).is_err());
682 }
683
684 #[test]
685 fn test_bulletproof_64bit() {
686 let params = BulletproofParams::new(64);
687 let value = u64::MAX; let (commitment, proof) = prove_range(¶ms, value).unwrap();
691 assert!(verify_range(¶ms, &commitment, &proof).is_ok());
692 }
693
694 #[test]
695 fn test_bulletproof_aggregated() {
696 let params = BulletproofParams::new(32);
697 let values = vec![100u64, 200u64, 300u64];
698
699 let aggregated = prove_range_aggregated(¶ms, &values).unwrap();
700 assert_eq!(aggregated.commitments.len(), 3);
701 assert!(verify_aggregated(¶ms, &aggregated).is_ok());
702 }
703
704 #[test]
705 fn test_bulletproof_serialization() {
706 let params = BulletproofParams::new(32);
707 let value = 1000u64;
708
709 let (commitment, proof) = prove_range(¶ms, value).unwrap();
710
711 let commitment_bytes = crate::codec::encode(&commitment).unwrap();
713 let proof_bytes = crate::codec::encode(&proof).unwrap();
714
715 let commitment2: BulletproofCommitment = crate::codec::decode(&commitment_bytes).unwrap();
717 let proof2: BulletproofRangeProof = crate::codec::decode(&proof_bytes).unwrap();
718
719 assert!(verify_range(¶ms, &commitment2, &proof2).is_ok());
721 }
722
723 #[test]
724 fn test_bulletproof_different_bit_lengths() {
725 for bit_length in [8, 16, 32, 48] {
726 let params = BulletproofParams::new(bit_length);
727 let max_value = (1u64 << bit_length) - 1;
728
729 let (commitment, proof) = prove_range(¶ms, max_value).unwrap();
730 assert!(verify_range(¶ms, &commitment, &proof).is_ok());
731 }
732 }
733}