1use curve25519_dalek::{
29 constants::RISTRETTO_BASEPOINT_POINT,
30 ristretto::{CompressedRistretto, RistrettoPoint},
31 scalar::Scalar,
32};
33use rand::Rng;
34use serde::{Deserialize, Serialize};
35use thiserror::Error;
36use zeroize::{Zeroize, ZeroizeOnDrop};
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
40pub struct PedersenCommitment {
41 #[serde(with = "serde_bytes_32")]
43 point: [u8; 32],
44}
45
46#[derive(Clone, Zeroize, ZeroizeOnDrop, Serialize, Deserialize)]
48pub struct PedersenOpening {
49 #[serde(with = "serde_bytes_32")]
51 blinding: [u8; 32],
52}
53
54#[derive(Debug, Error)]
56pub enum PedersenError {
57 #[error("Invalid commitment point")]
59 InvalidCommitment,
60
61 #[error("Verification failed")]
63 VerificationFailed,
64
65 #[error("Invalid blinding factor")]
67 InvalidBlinding,
68}
69
70pub type PedersenResult<T> = Result<T, PedersenError>;
71
72mod serde_bytes_32 {
74 use serde::{Deserialize, Deserializer, Serializer};
75
76 pub fn serialize<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 serializer.serialize_bytes(bytes)
81 }
82
83 pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
84 where
85 D: Deserializer<'de>,
86 {
87 let bytes = <Vec<u8>>::deserialize(deserializer)?;
88 if bytes.len() != 32 {
89 return Err(serde::de::Error::custom("Expected 32 bytes"));
90 }
91 let mut result = [0u8; 32];
92 result.copy_from_slice(&bytes);
93 Ok(result)
94 }
95}
96
97impl PedersenCommitment {
98 pub fn from_bytes(bytes: [u8; 32]) -> PedersenResult<Self> {
100 CompressedRistretto(bytes)
102 .decompress()
103 .ok_or(PedersenError::InvalidCommitment)?;
104 Ok(Self { point: bytes })
105 }
106
107 pub fn as_bytes(&self) -> &[u8; 32] {
109 &self.point
110 }
111
112 pub fn to_bytes(&self) -> [u8; 32] {
114 self.point
115 }
116
117 pub fn add(&self, other: &Self) -> Self {
130 let p1 = CompressedRistretto(self.point).decompress().unwrap();
131 let p2 = CompressedRistretto(other.point).decompress().unwrap();
132 let sum = (p1 + p2).compress();
133 Self {
134 point: sum.to_bytes(),
135 }
136 }
137
138 pub fn sub(&self, other: &Self) -> Self {
142 let p1 = CompressedRistretto(self.point).decompress().unwrap();
143 let p2 = CompressedRistretto(other.point).decompress().unwrap();
144 let diff = (p1 - p2).compress();
145 Self {
146 point: diff.to_bytes(),
147 }
148 }
149
150 pub fn mul(&self, scalar: u64) -> Self {
154 let p = CompressedRistretto(self.point).decompress().unwrap();
155 let s = Scalar::from(scalar);
156 let result = (s * p).compress();
157 Self {
158 point: result.to_bytes(),
159 }
160 }
161}
162
163impl PedersenOpening {
164 pub fn from_bytes(bytes: [u8; 32]) -> Self {
166 Self { blinding: bytes }
167 }
168
169 pub fn as_bytes(&self) -> &[u8; 32] {
171 &self.blinding
172 }
173
174 pub fn to_bytes(&self) -> [u8; 32] {
176 self.blinding
177 }
178
179 pub fn add(&self, other: &Self) -> Self {
192 let s1 = Scalar::from_bytes_mod_order(self.blinding);
193 let s2 = Scalar::from_bytes_mod_order(other.blinding);
194 let sum = s1 + s2;
195 Self {
196 blinding: sum.to_bytes(),
197 }
198 }
199
200 pub fn sub(&self, other: &Self) -> Self {
202 let s1 = Scalar::from_bytes_mod_order(self.blinding);
203 let s2 = Scalar::from_bytes_mod_order(other.blinding);
204 let diff = s1 - s2;
205 Self {
206 blinding: diff.to_bytes(),
207 }
208 }
209
210 pub fn mul(&self, scalar: u64) -> Self {
212 let s1 = Scalar::from_bytes_mod_order(self.blinding);
213 let s2 = Scalar::from(scalar);
214 let product = s1 * s2;
215 Self {
216 blinding: product.to_bytes(),
217 }
218 }
219}
220
221pub fn commit(value: u64) -> (PedersenCommitment, PedersenOpening) {
235 let mut rng = rand::thread_rng();
236 let mut blinding_bytes = [0u8; 32];
237 rng.fill(&mut blinding_bytes);
238 let blinding = Scalar::from_bytes_mod_order(blinding_bytes);
239
240 let opening = PedersenOpening {
241 blinding: blinding.to_bytes(),
242 };
243 let commitment = compute_commitment(value, &opening);
244
245 (commitment, opening)
246}
247
248pub fn commit_with_blinding(value: u64, blinding: &PedersenOpening) -> PedersenCommitment {
257 compute_commitment(value, blinding)
258}
259
260pub fn verify(commitment: &PedersenCommitment, value: u64, opening: &PedersenOpening) -> bool {
279 let expected = compute_commitment(value, opening);
280 expected == *commitment
281}
282
283pub fn verify_batch(
293 commitments: &[PedersenCommitment],
294 values: &[u64],
295 openings: &[PedersenOpening],
296) -> bool {
297 if commitments.len() != values.len() || commitments.len() != openings.len() {
298 return false;
299 }
300
301 commitments
302 .iter()
303 .zip(values.iter())
304 .zip(openings.iter())
305 .all(|((c, v), o)| verify(c, *v, o))
306}
307
308fn compute_commitment(value: u64, opening: &PedersenOpening) -> PedersenCommitment {
310 let g = RISTRETTO_BASEPOINT_POINT;
312
313 let h = get_h_generator();
315
316 let value_scalar = Scalar::from(value);
318 let blinding_scalar = Scalar::from_bytes_mod_order(opening.blinding);
319
320 let commitment_point = value_scalar * g + blinding_scalar * h;
322
323 PedersenCommitment {
324 point: commitment_point.compress().to_bytes(),
325 }
326}
327
328fn get_h_generator() -> RistrettoPoint {
332 let mut hasher = blake3::Hasher::new();
334 hasher.update(b"chie-pedersen-h-generator-v1");
335 let hash = hasher.finalize();
336
337 let scalar = Scalar::from_bytes_mod_order(*hash.as_bytes());
339 scalar * RISTRETTO_BASEPOINT_POINT
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_commit_and_verify() {
348 let (commitment, opening) = commit(1024);
349 assert!(verify(&commitment, 1024, &opening));
350 assert!(!verify(&commitment, 2048, &opening));
351 }
352
353 #[test]
354 fn test_homomorphic_addition() {
355 let (c1, o1) = commit(100);
356 let (c2, o2) = commit(200);
357
358 let sum_commitment = c1.add(&c2);
359 let sum_opening = o1.add(&o2);
360
361 assert!(verify(&sum_commitment, 300, &sum_opening));
362 }
363
364 #[test]
365 fn test_homomorphic_subtraction() {
366 let (c1, o1) = commit(500);
367 let (c2, o2) = commit(200);
368
369 let diff_commitment = c1.sub(&c2);
370 let diff_opening = o1.sub(&o2);
371
372 assert!(verify(&diff_commitment, 300, &diff_opening));
373 }
374
375 #[test]
376 fn test_scalar_multiplication() {
377 let (commitment, opening) = commit(100);
378
379 let scaled_commitment = commitment.mul(3);
380 let scaled_opening = opening.mul(3);
381
382 assert!(verify(&scaled_commitment, 300, &scaled_opening));
383 }
384
385 #[test]
386 fn test_batch_verification() {
387 let (c1, o1) = commit(100);
388 let (c2, o2) = commit(200);
389 let (c3, o3) = commit(300);
390
391 let commitments = vec![c1, c2, c3];
392 let values = vec![100, 200, 300];
393 let openings = vec![o1, o2, o3];
394
395 assert!(verify_batch(&commitments, &values, &openings));
396
397 let wrong_values = vec![100, 200, 400];
399 assert!(!verify_batch(&commitments, &wrong_values, &openings));
400 }
401
402 #[test]
403 fn test_commitment_serialization() {
404 let (commitment, _) = commit(1024);
405
406 let bytes = commitment.to_bytes();
407 let restored = PedersenCommitment::from_bytes(bytes).unwrap();
408
409 assert_eq!(commitment, restored);
410 }
411
412 #[test]
413 fn test_opening_serialization() {
414 let (_, opening) = commit(1024);
415
416 let bytes = opening.to_bytes();
417 let restored = PedersenOpening::from_bytes(bytes);
418
419 let commitment = commit_with_blinding(1024, &restored);
421 assert!(verify(&commitment, 1024, &restored));
422 }
423
424 #[test]
425 fn test_bandwidth_aggregation_scenario() {
426 let (bandwidth1, opening1) = commit(1024); let (bandwidth2, opening2) = commit(2048); let (bandwidth3, opening3) = commit(4096); let total_bandwidth = bandwidth1.add(&bandwidth2).add(&bandwidth3);
433 let total_opening = opening1.add(&opening2).add(&opening3);
434
435 assert!(verify(&total_bandwidth, 7168, &total_opening));
437 }
438
439 #[test]
440 fn test_different_values_different_commitments() {
441 let (c1, _) = commit(100);
442 let (c2, _) = commit(100);
443
444 assert_ne!(c1, c2);
446 }
447
448 #[test]
449 fn test_zero_value_commitment() {
450 let (commitment, opening) = commit(0);
451 assert!(verify(&commitment, 0, &opening));
452 }
453
454 #[test]
455 fn test_large_value() {
456 let large_value = 1_000_000_000u64; let (commitment, opening) = commit(large_value);
458 assert!(verify(&commitment, large_value, &opening));
459 }
460
461 #[test]
462 fn test_commitment_commutativity() {
463 let (c1, o1) = commit(100);
464 let (c2, o2) = commit(200);
465
466 let sum1 = c1.add(&c2);
467 let sum2 = c2.add(&c1);
468
469 assert_eq!(sum1, sum2);
470
471 let opening_sum1 = o1.add(&o2);
472 let opening_sum2 = o2.add(&o1);
473
474 assert!(verify(&sum1, 300, &opening_sum1));
475 assert!(verify(&sum2, 300, &opening_sum2));
476 }
477}