1use blake3::Hasher;
37use curve25519_dalek::{RistrettoPoint, Scalar, constants::RISTRETTO_BASEPOINT_POINT};
38use rand::RngCore;
39use serde::{Deserialize, Serialize};
40use std::fmt;
41use zeroize::{Zeroize, ZeroizeOnDrop};
42
43pub type AdvancedCommitmentResult<T> = Result<T, AdvancedCommitmentError>;
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum AdvancedCommitmentError {
49 VerificationFailed,
51 InvalidOpening,
53 InvalidIndex,
55 SerializationFailed,
57 DeserializationFailed,
59}
60
61impl fmt::Display for AdvancedCommitmentError {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 AdvancedCommitmentError::VerificationFailed => {
65 write!(f, "Commitment verification failed")
66 }
67 AdvancedCommitmentError::InvalidOpening => write!(f, "Invalid opening proof"),
68 AdvancedCommitmentError::InvalidIndex => write!(f, "Invalid index"),
69 AdvancedCommitmentError::SerializationFailed => write!(f, "Serialization failed"),
70 AdvancedCommitmentError::DeserializationFailed => write!(f, "Deserialization failed"),
71 }
72 }
73}
74
75impl std::error::Error for AdvancedCommitmentError {}
76
77#[derive(Clone)]
86pub struct TrapdoorCommitment {
87 #[allow(dead_code)]
89 g: RistrettoPoint,
90 h: RistrettoPoint,
92}
93
94#[derive(Clone, Zeroize, ZeroizeOnDrop)]
96pub struct Trapdoor {
97 #[zeroize(skip)]
98 alpha: Scalar,
99}
100
101#[derive(Clone, Debug, Serialize, Deserialize)]
103pub struct TrapdoorCom {
104 #[serde(with = "serde_ristretto_point")]
105 c: RistrettoPoint,
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
110pub struct TrapdoorOpening {
111 #[serde(with = "serde_scalar")]
112 #[zeroize(skip)]
113 r: Scalar,
114}
115
116impl TrapdoorCommitment {
117 pub fn setup() -> (Self, Trapdoor) {
121 let mut rng = rand::thread_rng();
122 let mut alpha_bytes = [0u8; 32];
123 rng.fill_bytes(&mut alpha_bytes);
124 let alpha = Scalar::from_bytes_mod_order(alpha_bytes);
125
126 let g = RISTRETTO_BASEPOINT_POINT;
127 let h = alpha * g;
128
129 let commitment = Self { g, h };
130 let trapdoor = Trapdoor { alpha };
131
132 (commitment, trapdoor)
133 }
134
135 pub fn setup_without_trapdoor() -> Self {
137 let g = RISTRETTO_BASEPOINT_POINT;
138
139 let mut hasher = Hasher::new();
141 hasher.update(b"TrapdoorCommitment-H-Generator");
142 let hash = hasher.finalize();
143 let h_scalar = Scalar::from_bytes_mod_order(*hash.as_bytes());
144 let h = h_scalar * g;
145
146 Self { g, h }
147 }
148
149 pub fn commit(&self, value: &[u8]) -> (TrapdoorCom, TrapdoorOpening) {
151 let mut rng = rand::thread_rng();
152 let mut r_bytes = [0u8; 32];
153 rng.fill_bytes(&mut r_bytes);
154 let r = Scalar::from_bytes_mod_order(r_bytes);
155
156 let m = hash_to_scalar(value);
157 let c = m * self.h + r * RISTRETTO_BASEPOINT_POINT;
158
159 (TrapdoorCom { c }, TrapdoorOpening { r })
160 }
161
162 pub fn verify(&self, com: &TrapdoorCom, value: &[u8], opening: &TrapdoorOpening) -> bool {
164 let m = hash_to_scalar(value);
165 let expected = m * self.h + opening.r * RISTRETTO_BASEPOINT_POINT;
166 com.c == expected
167 }
168
169 pub fn equivocate(
171 &self,
172 _com: &TrapdoorCom,
173 original_value: &[u8],
174 original_opening: &TrapdoorOpening,
175 new_value: &[u8],
176 trapdoor: &Trapdoor,
177 ) -> TrapdoorOpening {
178 let m_old = hash_to_scalar(original_value);
179 let m_new = hash_to_scalar(new_value);
180
181 let r_new = original_opening.r + (m_old - m_new) * trapdoor.alpha;
184
185 TrapdoorOpening { r: r_new }
186 }
187}
188
189#[derive(Clone)]
198pub struct VectorCommitment {
199 #[allow(dead_code)]
200 tree_depth: usize,
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct VectorCom {
206 root: [u8; 32],
207}
208
209#[derive(Clone, Debug, Serialize, Deserialize)]
211pub struct VectorOpening {
212 index: usize,
213 value: Vec<u8>,
214 proof: Vec<[u8; 32]>,
215}
216
217impl VectorCommitment {
218 pub fn new(max_size: usize) -> Self {
220 let tree_depth = (max_size as f64).log2().ceil() as usize;
221 Self { tree_depth }
222 }
223
224 pub fn commit(&self, values: &[Vec<u8>]) -> VectorCom {
226 let root = build_merkle_root(values);
227 VectorCom { root }
228 }
229
230 pub fn open(
232 &self,
233 values: &[Vec<u8>],
234 index: usize,
235 ) -> AdvancedCommitmentResult<VectorOpening> {
236 if index >= values.len() {
237 return Err(AdvancedCommitmentError::InvalidIndex);
238 }
239
240 let proof = build_merkle_proof(values, index);
241
242 Ok(VectorOpening {
243 index,
244 value: values[index].clone(),
245 proof,
246 })
247 }
248
249 pub fn verify(&self, com: &VectorCom, opening: &VectorOpening) -> bool {
251 verify_merkle_proof(&com.root, &opening.value, opening.index, &opening.proof)
252 }
253
254 pub fn open_batch(
256 &self,
257 values: &[Vec<u8>],
258 indices: &[usize],
259 ) -> AdvancedCommitmentResult<Vec<VectorOpening>> {
260 indices
261 .iter()
262 .map(|&index| self.open(values, index))
263 .collect()
264 }
265}
266
267#[derive(Clone)]
276pub struct ExtractableCommitment {
277 g: RistrettoPoint,
278 h: RistrettoPoint,
279}
280
281#[derive(Clone, Debug, Serialize, Deserialize)]
283pub struct ExtractableCom {
284 #[serde(with = "serde_ristretto_point")]
285 c: RistrettoPoint,
286}
287
288#[derive(Clone, Debug, Serialize, Deserialize)]
290pub struct ExtractableOpening {
291 #[serde(with = "serde_scalar")]
292 r: Scalar,
293 proof: SchnorrProof,
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize)]
298struct SchnorrProof {
299 #[serde(with = "serde_ristretto_point")]
300 t: RistrettoPoint,
301 #[serde(with = "serde_scalar")]
302 s: Scalar,
303}
304
305impl ExtractableCommitment {
306 pub fn setup() -> Self {
308 let g = RISTRETTO_BASEPOINT_POINT;
309
310 let mut hasher = Hasher::new();
311 hasher.update(b"ExtractableCommitment-H");
312 let hash = hasher.finalize();
313 let h_scalar = Scalar::from_bytes_mod_order(*hash.as_bytes());
314 let h = h_scalar * g;
315
316 Self { g, h }
317 }
318
319 pub fn commit(&self, value: &[u8]) -> (ExtractableCom, ExtractableOpening) {
321 let mut rng = rand::thread_rng();
322 let mut r_bytes = [0u8; 32];
323 rng.fill_bytes(&mut r_bytes);
324 let r = Scalar::from_bytes_mod_order(r_bytes);
325
326 let m = hash_to_scalar(value);
327 let c = m * self.g + r * self.h;
328
329 let mut k_bytes = [0u8; 32];
331 rng.fill_bytes(&mut k_bytes);
332 let k = Scalar::from_bytes_mod_order(k_bytes);
333 let t = k * self.g;
334
335 let challenge = compute_challenge(&c, &t);
336 let s = k + challenge * m;
337
338 let proof = SchnorrProof { t, s };
339 let opening = ExtractableOpening { r, proof };
340
341 (ExtractableCom { c }, opening)
342 }
343
344 pub fn verify(&self, com: &ExtractableCom, value: &[u8], opening: &ExtractableOpening) -> bool {
346 let m = hash_to_scalar(value);
347
348 let expected_c = m * self.g + opening.r * self.h;
350 if com.c != expected_c {
351 return false;
352 }
353
354 let challenge = compute_challenge(&com.c, &opening.proof.t);
356 let lhs = opening.proof.s * self.g;
357 let rhs = opening.proof.t + challenge * m * self.g;
358
359 lhs == rhs
360 }
361}
362
363fn hash_to_scalar(value: &[u8]) -> Scalar {
369 let mut hasher = Hasher::new();
370 hasher.update(b"AdvancedCommitment-Hash:");
371 hasher.update(value);
372 let hash = hasher.finalize();
373
374 Scalar::from_bytes_mod_order(*hash.as_bytes())
375}
376
377fn compute_challenge(c: &RistrettoPoint, t: &RistrettoPoint) -> Scalar {
379 let mut hasher = Hasher::new();
380 hasher.update(b"Challenge:");
381 hasher.update(&c.compress().to_bytes());
382 hasher.update(&t.compress().to_bytes());
383 let hash = hasher.finalize();
384
385 Scalar::from_bytes_mod_order(*hash.as_bytes())
386}
387
388fn build_merkle_root(values: &[Vec<u8>]) -> [u8; 32] {
390 if values.is_empty() {
391 return [0u8; 32];
392 }
393
394 let mut layer: Vec<[u8; 32]> = values.iter().map(|v| hash_leaf(v)).collect();
395
396 while layer.len() > 1 {
397 layer = layer
398 .chunks(2)
399 .map(|chunk| {
400 if chunk.len() == 2 {
401 hash_pair(&chunk[0], &chunk[1])
402 } else {
403 chunk[0]
404 }
405 })
406 .collect();
407 }
408
409 layer[0]
410}
411
412fn build_merkle_proof(values: &[Vec<u8>], index: usize) -> Vec<[u8; 32]> {
414 let mut proof = Vec::new();
415 let mut layer: Vec<[u8; 32]> = values.iter().map(|v| hash_leaf(v)).collect();
416 let mut pos = index;
417
418 while layer.len() > 1 {
419 let sibling_pos = if pos % 2 == 0 { pos + 1 } else { pos - 1 };
421
422 if sibling_pos < layer.len() {
423 proof.push(layer[sibling_pos]);
425 } else {
426 proof.push([0u8; 32]);
429 }
430
431 layer = layer
432 .chunks(2)
433 .map(|chunk| {
434 if chunk.len() == 2 {
435 hash_pair(&chunk[0], &chunk[1])
436 } else {
437 chunk[0]
438 }
439 })
440 .collect();
441
442 pos /= 2;
443 }
444
445 proof
446}
447
448fn verify_merkle_proof(root: &[u8; 32], value: &[u8], index: usize, proof: &[[u8; 32]]) -> bool {
450 let mut current = hash_leaf(value);
451 let mut pos = index;
452
453 for sibling in proof {
454 if sibling == &[0u8; 32] {
456 } else {
459 current = if pos % 2 == 0 {
460 hash_pair(¤t, sibling)
461 } else {
462 hash_pair(sibling, ¤t)
463 };
464 }
465 pos /= 2;
466 }
467
468 ¤t == root
469}
470
471fn hash_leaf(value: &[u8]) -> [u8; 32] {
473 let mut hasher = Hasher::new();
474 hasher.update(b"Leaf:");
475 hasher.update(value);
476 *hasher.finalize().as_bytes()
477}
478
479fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
481 let mut hasher = Hasher::new();
482 hasher.update(b"Node:");
483 hasher.update(left);
484 hasher.update(right);
485 *hasher.finalize().as_bytes()
486}
487
488mod serde_ristretto_point {
493 use curve25519_dalek::RistrettoPoint;
494 use serde::{Deserialize, Deserializer, Serialize, Serializer};
495
496 pub fn serialize<S>(point: &RistrettoPoint, serializer: S) -> Result<S::Ok, S::Error>
497 where
498 S: Serializer,
499 {
500 point.compress().to_bytes().serialize(serializer)
501 }
502
503 pub fn deserialize<'de, D>(deserializer: D) -> Result<RistrettoPoint, D::Error>
504 where
505 D: Deserializer<'de>,
506 {
507 let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
508 let compressed = curve25519_dalek::ristretto::CompressedRistretto(bytes);
509 compressed
510 .decompress()
511 .ok_or_else(|| serde::de::Error::custom("Invalid RistrettoPoint"))
512 }
513}
514
515mod serde_scalar {
516 use curve25519_dalek::Scalar;
517 use serde::{Deserialize, Deserializer, Serialize, Serializer};
518
519 pub fn serialize<S>(scalar: &Scalar, serializer: S) -> Result<S::Ok, S::Error>
520 where
521 S: Serializer,
522 {
523 scalar.to_bytes().serialize(serializer)
524 }
525
526 pub fn deserialize<'de, D>(deserializer: D) -> Result<Scalar, D::Error>
527 where
528 D: Deserializer<'de>,
529 {
530 let bytes: [u8; 32] = Deserialize::deserialize(deserializer)?;
531 Ok(Scalar::from_bytes_mod_order(bytes))
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn test_trapdoor_commitment_basic() {
541 let (commitment, _) = TrapdoorCommitment::setup();
542 let value = b"test value";
543
544 let (com, opening) = commitment.commit(value);
545 assert!(commitment.verify(&com, value, &opening));
546 }
547
548 #[test]
549 fn test_trapdoor_commitment_wrong_value() {
550 let (commitment, _) = TrapdoorCommitment::setup();
551 let value = b"test value";
552
553 let (com, opening) = commitment.commit(value);
554 assert!(!commitment.verify(&com, b"wrong value", &opening));
555 }
556
557 #[test]
558 fn test_trapdoor_equivocation() {
559 let (commitment, trapdoor) = TrapdoorCommitment::setup();
560
561 let original = b"original value";
562 let (com, opening) = commitment.commit(original);
563
564 assert!(commitment.verify(&com, original, &opening));
566
567 let fake = b"different value";
569 let fake_opening = commitment.equivocate(&com, original, &opening, fake, &trapdoor);
570
571 assert!(commitment.verify(&com, original, &opening));
573 assert!(commitment.verify(&com, fake, &fake_opening));
574 }
575
576 #[test]
577 fn test_vector_commitment_basic() {
578 let vc = VectorCommitment::new(10);
579 let values = vec![b"value0".to_vec(), b"value1".to_vec(), b"value2".to_vec()];
580
581 let com = vc.commit(&values);
582 let opening = vc.open(&values, 1).unwrap();
583
584 assert!(vc.verify(&com, &opening));
585 assert_eq!(opening.value, b"value1");
586 }
587
588 #[test]
589 fn test_vector_commitment_wrong_index() {
590 let vc = VectorCommitment::new(10);
591 let values = vec![b"value0".to_vec(), b"value1".to_vec()];
592
593 assert!(vc.open(&values, 5).is_err());
594 }
595
596 #[test]
597 fn test_vector_commitment_tampered() {
598 let vc = VectorCommitment::new(10);
599 let values = vec![b"value0".to_vec(), b"value1".to_vec()];
600
601 let com = vc.commit(&values);
602 let mut opening = vc.open(&values, 1).unwrap();
603
604 opening.value = b"tampered".to_vec();
606 assert!(!vc.verify(&com, &opening));
607 }
608
609 #[test]
610 fn test_vector_commitment_batch() {
611 let vc = VectorCommitment::new(10);
612 let values = vec![
613 b"value0".to_vec(),
614 b"value1".to_vec(),
615 b"value2".to_vec(),
616 b"value3".to_vec(),
617 ];
618
619 let com = vc.commit(&values);
620 let openings = vc.open_batch(&values, &[0, 2, 3]).unwrap();
621
622 assert_eq!(openings.len(), 3);
623 for opening in openings {
624 assert!(vc.verify(&com, &opening));
625 }
626 }
627
628 #[test]
629 fn test_extractable_commitment_basic() {
630 let ec = ExtractableCommitment::setup();
631 let value = b"test value";
632
633 let (com, opening) = ec.commit(value);
634 assert!(ec.verify(&com, value, &opening));
635 }
636
637 #[test]
638 fn test_extractable_commitment_wrong_value() {
639 let ec = ExtractableCommitment::setup();
640 let value = b"test value";
641
642 let (com, opening) = ec.commit(value);
643 assert!(!ec.verify(&com, b"wrong value", &opening));
644 }
645
646 #[test]
647 fn test_extractable_commitment_proof_soundness() {
648 let ec = ExtractableCommitment::setup();
649 let value = b"test value";
650
651 let (com, mut opening) = ec.commit(value);
652
653 opening.proof.s += Scalar::ONE;
655 assert!(!ec.verify(&com, value, &opening));
656 }
657
658 #[test]
659 fn test_trapdoor_serialization() {
660 let (commitment, _) = TrapdoorCommitment::setup();
661 let value = b"test";
662
663 let (com, opening) = commitment.commit(value);
664
665 let com_bytes = crate::codec::encode(&com).unwrap();
666 let opening_bytes = crate::codec::encode(&opening).unwrap();
667
668 let com_de: TrapdoorCom = crate::codec::decode(&com_bytes).unwrap();
669 let opening_de: TrapdoorOpening = crate::codec::decode(&opening_bytes).unwrap();
670
671 assert!(commitment.verify(&com_de, value, &opening_de));
672 }
673
674 #[test]
675 fn test_vector_commitment_serialization() {
676 let vc = VectorCommitment::new(10);
677 let values = vec![b"value0".to_vec(), b"value1".to_vec()];
678
679 let com = vc.commit(&values);
680 let opening = vc.open(&values, 0).unwrap();
681
682 let com_bytes = crate::codec::encode(&com).unwrap();
683 let opening_bytes = crate::codec::encode(&opening).unwrap();
684
685 let com_de: VectorCom = crate::codec::decode(&com_bytes).unwrap();
686 let opening_de: VectorOpening = crate::codec::decode(&opening_bytes).unwrap();
687
688 assert!(vc.verify(&com_de, &opening_de));
689 }
690
691 #[test]
692 fn test_extractable_serialization() {
693 let ec = ExtractableCommitment::setup();
694 let value = b"test";
695
696 let (com, opening) = ec.commit(value);
697
698 let com_bytes = crate::codec::encode(&com).unwrap();
699 let opening_bytes = crate::codec::encode(&opening).unwrap();
700
701 let com_de: ExtractableCom = crate::codec::decode(&com_bytes).unwrap();
702 let opening_de: ExtractableOpening = crate::codec::decode(&opening_bytes).unwrap();
703
704 assert!(ec.verify(&com_de, value, &opening_de));
705 }
706
707 #[test]
708 fn test_trapdoor_without_trapdoor() {
709 let commitment = TrapdoorCommitment::setup_without_trapdoor();
710 let value = b"test value";
711
712 let (com, opening) = commitment.commit(value);
713 assert!(commitment.verify(&com, value, &opening));
714 }
715
716 #[test]
717 fn test_vector_commitment_single_element() {
718 let vc = VectorCommitment::new(10);
719 let values = vec![b"single".to_vec()];
720
721 let com = vc.commit(&values);
722 let opening = vc.open(&values, 0).unwrap();
723
724 assert!(vc.verify(&com, &opening));
725 }
726
727 #[test]
728 fn test_vector_commitment_large() {
729 let vc = VectorCommitment::new(100);
730 let values: Vec<Vec<u8>> = (0..50)
731 .map(|i| format!("value{}", i).into_bytes())
732 .collect();
733
734 let com = vc.commit(&values);
735
736 for i in 0..values.len() {
737 let opening = vc.open(&values, i).unwrap();
738 assert!(vc.verify(&com, &opening));
739 }
740 }
741}