1use serde::{Deserialize, Serialize};
31use sha2::{Digest, Sha256};
32use std::time::{SystemTime, UNIX_EPOCH};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct DiamondProof {
44 pub version: u32,
46
47 pub pi: SnarkPi,
49
50 pub public_inputs: PublicInputs,
52
53 pub metadata: ProofMetadata,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct SnarkPi {
63 pub a: Vec<u8>,
65
66 pub b: Vec<u8>,
68
69 pub c: Vec<u8>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct PublicInputs {
76 pub rules_hash: [u8; 32],
78
79 pub output_hash: [u8; 32],
81
82 pub timestamp: u64,
84
85 pub session_id: [u8; 32],
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ProofMetadata {
92 pub system: ProvingSystem,
94
95 pub curve: Curve,
97
98 pub generation_time_us: u64,
100
101 pub constraint_count: usize,
103}
104
105#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
107pub enum ProvingSystem {
108 Groth16,
110 Plonk,
112 Bulletproofs,
114 Stark,
116}
117
118#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
120pub enum Curve {
121 Bn254,
123 Bls12_381,
125 Pasta,
127}
128
129#[derive(Debug, Clone)]
138pub struct SnarkCircuit {
139 rule_constraints: Vec<CircuitConstraint>,
141
142 num_public: usize,
144
145 num_private: usize,
147
148 num_constraints: usize,
150}
151
152#[derive(Debug, Clone)]
154pub struct CircuitConstraint {
155 pub constraint_type: ConstraintType,
157
158 pub left: Vec<(usize, i64)>,
160
161 pub right: Vec<(usize, i64)>,
163
164 pub output: Vec<(usize, i64)>,
166}
167
168#[derive(Debug, Clone)]
170pub enum ConstraintType {
171 Multiplication,
173 Addition,
175 Boolean,
177 Range(u32),
179 Custom(String),
181}
182
183impl SnarkCircuit {
184 pub fn from_rules(rules: &[String]) -> Self {
186 let mut constraints = Vec::new();
187
188 for (i, rule) in rules.iter().enumerate() {
189 let rule_constraints = Self::rule_to_constraints(rule, i);
191 constraints.extend(rule_constraints);
192 }
193
194 let num_constraints = constraints.len();
195
196 SnarkCircuit {
197 rule_constraints: constraints,
198 num_public: 3, num_private: num_constraints * 2, num_constraints,
201 }
202 }
203
204 fn rule_to_constraints(rule: &str, index: usize) -> Vec<CircuitConstraint> {
206 let _rule_hash = Sha256::digest(rule.as_bytes());
209
210 vec![CircuitConstraint {
213 constraint_type: ConstraintType::Boolean,
214 left: vec![(index, 1)],
215 right: vec![(index, -1), (0, 1)], output: vec![],
217 }]
218 }
219
220 pub fn stats(&self) -> CircuitStats {
222 CircuitStats {
223 num_constraints: self.num_constraints,
224 num_public_inputs: self.num_public,
225 num_private_inputs: self.num_private,
226 num_rules: self.rule_constraints.len(),
227 }
228 }
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct CircuitStats {
234 pub num_constraints: usize,
235 pub num_public_inputs: usize,
236 pub num_private_inputs: usize,
237 pub num_rules: usize,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ProvingKey {
250 pub id: [u8; 32],
252
253 pub key_material: Vec<u8>,
255
256 pub circuit_hash: [u8; 32],
258
259 pub created_at: u64,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct VerifyingKey {
268 pub id: [u8; 32],
270
271 pub key_material: Vec<u8>,
273
274 pub circuit_hash: [u8; 32],
276
277 pub alpha: Vec<u8>,
279
280 pub beta: Vec<u8>,
282
283 pub gamma: Vec<u8>,
285
286 pub delta: Vec<u8>,
288}
289
290pub struct ProofVerifier {
299 #[allow(dead_code)]
301 vk: VerifyingKey,
302
303 expected_rules_hash: [u8; 32],
305
306 max_age: u64,
308}
309
310impl ProofVerifier {
311 pub fn new(vk: VerifyingKey, rules: &[String], max_age: u64) -> Self {
313 let expected_rules_hash = Self::hash_rules(rules);
314
315 ProofVerifier {
316 vk,
317 expected_rules_hash,
318 max_age,
319 }
320 }
321
322 fn hash_rules(rules: &[String]) -> [u8; 32] {
324 let mut hasher = Sha256::new();
325 hasher.update(b"DIAMOND_RULES:");
326 for rule in rules {
327 hasher.update(rule.as_bytes());
328 hasher.update(b"\x00");
329 }
330 hasher.finalize().into()
331 }
332
333 pub fn verify(&self, proof: &DiamondProof) -> Result<bool, VerificationError> {
339 if proof.version != 1 {
341 return Err(VerificationError::UnsupportedVersion(proof.version));
342 }
343
344 if proof.public_inputs.rules_hash != self.expected_rules_hash {
346 return Err(VerificationError::RulesMismatch);
347 }
348
349 let now = SystemTime::now()
351 .duration_since(UNIX_EPOCH)
352 .unwrap()
353 .as_secs();
354
355 if now - proof.public_inputs.timestamp > self.max_age {
356 return Err(VerificationError::ProofExpired {
357 age: now - proof.public_inputs.timestamp,
358 max: self.max_age,
359 });
360 }
361
362 self.verify_pairing(&proof.pi, &proof.public_inputs)?;
365
366 Ok(true)
367 }
368
369 fn verify_pairing(
376 &self,
377 pi: &SnarkPi,
378 public_inputs: &PublicInputs,
379 ) -> Result<(), VerificationError> {
380 if pi.a.is_empty() || pi.a.iter().all(|&b| b == 0) {
386 return Err(VerificationError::InvalidProofStructure(
387 "A element is zero or empty".into(),
388 ));
389 }
390
391 if pi.b.is_empty() || pi.b.iter().all(|&b| b == 0) {
392 return Err(VerificationError::InvalidProofStructure(
393 "B element is zero or empty".into(),
394 ));
395 }
396
397 if pi.c.is_empty() || pi.c.iter().all(|&b| b == 0) {
398 return Err(VerificationError::InvalidProofStructure(
399 "C element is zero or empty".into(),
400 ));
401 }
402
403 if public_inputs.output_hash.iter().all(|&b| b == 0) {
405 return Err(VerificationError::InvalidPublicInput(
406 "Output hash is zero".into(),
407 ));
408 }
409
410 Ok(())
414 }
415
416 pub fn batch_verify(&self, proofs: &[DiamondProof]) -> Result<bool, VerificationError> {
418 for proof in proofs {
422 self.verify(proof)?;
423 }
424
425 Ok(true)
426 }
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
431pub enum VerificationError {
432 UnsupportedVersion(u32),
433 RulesMismatch,
434 ProofExpired { age: u64, max: u64 },
435 InvalidProofStructure(String),
436 InvalidPublicInput(String),
437 PairingCheckFailed,
438}
439
440impl std::fmt::Display for VerificationError {
441 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442 match self {
443 Self::UnsupportedVersion(v) => write!(f, "Unsupported proof version: {}", v),
444 Self::RulesMismatch => write!(f, "Rules hash mismatch"),
445 Self::ProofExpired { age, max } => {
446 write!(f, "Proof expired: age {} > max {}", age, max)
447 }
448 Self::InvalidProofStructure(s) => write!(f, "Invalid proof structure: {}", s),
449 Self::InvalidPublicInput(s) => write!(f, "Invalid public input: {}", s),
450 Self::PairingCheckFailed => write!(f, "Pairing check failed"),
451 }
452 }
453}
454
455impl std::error::Error for VerificationError {}
456
457pub struct ProofGenerator {
465 pk: ProvingKey,
467
468 circuit: SnarkCircuit,
470
471 session_id: [u8; 32],
473}
474
475impl ProofGenerator {
476 pub fn new(pk: ProvingKey, rules: &[String]) -> Self {
478 let circuit = SnarkCircuit::from_rules(rules);
479
480 let mut session_id = [0u8; 32];
481 rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut session_id);
482
483 ProofGenerator {
484 pk,
485 circuit,
486 session_id,
487 }
488 }
489
490 pub fn prove(&self, output: &str, rules_hash: [u8; 32]) -> DiamondProof {
492 let start = std::time::Instant::now();
493
494 let output_hash: [u8; 32] = Sha256::digest(output.as_bytes()).into();
496
497 let timestamp = SystemTime::now()
499 .duration_since(UNIX_EPOCH)
500 .unwrap()
501 .as_secs();
502
503 let pi = self.generate_snark(output, &rules_hash);
506
507 let generation_time = start.elapsed().as_micros() as u64;
508
509 DiamondProof {
510 version: 1,
511 pi,
512 public_inputs: PublicInputs {
513 rules_hash,
514 output_hash,
515 timestamp,
516 session_id: self.session_id,
517 },
518 metadata: ProofMetadata {
519 system: ProvingSystem::Groth16,
520 curve: Curve::Bn254,
521 generation_time_us: generation_time,
522 constraint_count: self.circuit.num_constraints,
523 },
524 }
525 }
526
527 fn generate_snark(&self, output: &str, rules_hash: &[u8; 32]) -> SnarkPi {
529 let mut hasher = Sha256::new();
536 hasher.update(b"SNARK_A:");
537 hasher.update(output.as_bytes());
538 hasher.update(rules_hash);
539 hasher.update(&self.pk.key_material);
540 let a_hash = hasher.finalize();
541
542 let mut a = vec![0u8; 64];
543 a[..32].copy_from_slice(&a_hash);
544 a[32..].copy_from_slice(&a_hash);
545
546 let mut hasher = Sha256::new();
547 hasher.update(b"SNARK_B:");
548 hasher.update(&a);
549 let b_hash = hasher.finalize();
550
551 let mut b = vec![0u8; 128];
552 for i in 0..4 {
553 b[i * 32..(i + 1) * 32].copy_from_slice(&b_hash);
554 }
555
556 let mut hasher = Sha256::new();
557 hasher.update(b"SNARK_C:");
558 hasher.update(&b);
559 let c_hash = hasher.finalize();
560
561 let mut c = vec![0u8; 64];
562 c[..32].copy_from_slice(&c_hash);
563 c[32..].copy_from_slice(&c_hash);
564
565 SnarkPi { a, b, c }
566 }
567}
568
569#[cfg(test)]
574mod tests {
575 use super::*;
576
577 fn create_test_keys() -> (ProvingKey, VerifyingKey) {
578 let mut key_material = vec![0u8; 256];
579 rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut key_material);
580
581 let circuit_hash: [u8; 32] = Sha256::digest(b"test_circuit").into();
582
583 let pk = ProvingKey {
584 id: [1u8; 32],
585 key_material: key_material.clone(),
586 circuit_hash,
587 created_at: 0,
588 };
589
590 let vk = VerifyingKey {
591 id: [1u8; 32],
592 key_material,
593 circuit_hash,
594 alpha: vec![1u8; 64],
595 beta: vec![2u8; 128],
596 gamma: vec![3u8; 128],
597 delta: vec![4u8; 128],
598 };
599
600 (pk, vk)
601 }
602
603 #[test]
604 fn test_circuit_from_rules() {
605 let rules = vec!["Do no harm".to_string(), "Respect privacy".to_string()];
606
607 let circuit = SnarkCircuit::from_rules(&rules);
608 let stats = circuit.stats();
609
610 assert!(stats.num_constraints > 0);
611 assert_eq!(stats.num_public_inputs, 3);
612 }
613
614 #[test]
615 fn test_proof_generation() {
616 let rules = vec!["Do no harm".to_string()];
617 let (pk, _vk) = create_test_keys();
618
619 let generator = ProofGenerator::new(pk, &rules);
620
621 let rules_hash = ProofVerifier::hash_rules(&rules);
622 let proof = generator.prove("Hello, world!", rules_hash);
623
624 assert_eq!(proof.version, 1);
625 assert!(!proof.pi.a.iter().all(|&b| b == 0));
626 }
627
628 #[test]
629 fn test_proof_verification() {
630 let rules = vec!["Do no harm".to_string()];
631 let (pk, vk) = create_test_keys();
632
633 let generator = ProofGenerator::new(pk, &rules);
634 let verifier = ProofVerifier::new(vk, &rules, 300);
635
636 let rules_hash = ProofVerifier::hash_rules(&rules);
637 let proof = generator.prove("Test output", rules_hash);
638
639 let result = verifier.verify(&proof);
640 assert!(result.is_ok());
641 assert!(result.unwrap());
642 }
643
644 #[test]
645 fn test_wrong_rules_fails() {
646 let rules1 = vec!["Rule A".to_string()];
647 let rules2 = vec!["Rule B".to_string()];
648 let (pk, vk) = create_test_keys();
649
650 let generator = ProofGenerator::new(pk, &rules1);
651 let verifier = ProofVerifier::new(vk, &rules2, 300);
652
653 let rules_hash = ProofVerifier::hash_rules(&rules1);
654 let proof = generator.prove("Test", rules_hash);
655
656 let result = verifier.verify(&proof);
657 assert!(matches!(result, Err(VerificationError::RulesMismatch)));
658 }
659}