1use serde::{Deserialize, Serialize};
30use sha2::{Digest, Sha256};
31use std::time::{SystemTime, UNIX_EPOCH};
32
33use super::zk_snark::DiamondProof;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProofChain {
42 pub session_id: [u8; 32],
44
45 pub proofs: Vec<ChainedProof>,
47
48 pub merkle_root: [u8; 32],
50
51 pub started_at: u64,
53
54 pub last_updated: u64,
56
57 pub status: ChainStatus,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ChainedProof {
64 pub sequence: u64,
66
67 pub proof: DiamondProof,
69
70 pub prev_hash: [u8; 32],
72
73 pub proof_hash: [u8; 32],
75
76 pub input_hash: [u8; 32],
78
79 pub output_hash: [u8; 32],
81
82 pub timestamp: u64,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SessionProof {
89 pub session_id: [u8; 32],
91
92 pub merkle_root: [u8; 32],
94
95 pub proof_count: u64,
97
98 pub started_at: u64,
100 pub ended_at: u64,
101
102 pub rules_hash: [u8; 32],
104
105 pub signature: Vec<u8>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct GlobalProofRoot {
112 pub root_hash: [u8; 32],
114
115 pub session_count: u64,
117
118 pub from: u64,
120 pub to: u64,
121
122 pub signature: Vec<u8>,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
128pub enum ChainStatus {
129 Active,
131
132 Finalized,
134
135 Corrupted,
137}
138
139impl ProofChain {
144 pub fn new() -> Self {
146 let mut session_id = [0u8; 32];
147 rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut session_id);
148
149 let now = SystemTime::now()
150 .duration_since(UNIX_EPOCH)
151 .unwrap()
152 .as_secs();
153
154 ProofChain {
155 session_id,
156 proofs: Vec::new(),
157 merkle_root: [0u8; 32],
158 started_at: now,
159 last_updated: now,
160 status: ChainStatus::Active,
161 }
162 }
163
164 pub fn add_proof(
166 &mut self,
167 proof: DiamondProof,
168 input: &str,
169 output: &str,
170 ) -> Result<ChainedProof, ChainError> {
171 if self.status != ChainStatus::Active {
172 return Err(ChainError::ChainNotActive);
173 }
174
175 let sequence = self.proofs.len() as u64;
176
177 let prev_hash = self
179 .proofs
180 .last()
181 .map(|p| p.proof_hash)
182 .unwrap_or([0u8; 32]);
183
184 let input_hash: [u8; 32] = Sha256::digest(input.as_bytes()).into();
186 let output_hash: [u8; 32] = Sha256::digest(output.as_bytes()).into();
187
188 let timestamp = SystemTime::now()
189 .duration_since(UNIX_EPOCH)
190 .unwrap()
191 .as_secs();
192
193 let proof_hash = Self::compute_proof_hash(
195 sequence,
196 &prev_hash,
197 &proof,
198 &input_hash,
199 &output_hash,
200 timestamp,
201 );
202
203 let chained = ChainedProof {
204 sequence,
205 proof,
206 prev_hash,
207 proof_hash,
208 input_hash,
209 output_hash,
210 timestamp,
211 };
212
213 self.proofs.push(chained.clone());
214 self.last_updated = timestamp;
215
216 self.merkle_root = self.compute_merkle_root();
218
219 Ok(chained)
220 }
221
222 pub fn finalize(&mut self) -> SessionProof {
224 self.status = ChainStatus::Finalized;
225
226 let ended_at = SystemTime::now()
227 .duration_since(UNIX_EPOCH)
228 .unwrap()
229 .as_secs();
230
231 let rules_hash = self
232 .proofs
233 .first()
234 .map(|p| p.proof.public_inputs.rules_hash)
235 .unwrap_or([0u8; 32]);
236
237 let signature = self.generate_session_signature(ended_at);
239
240 SessionProof {
241 session_id: self.session_id,
242 merkle_root: self.merkle_root,
243 proof_count: self.proofs.len() as u64,
244 started_at: self.started_at,
245 ended_at,
246 rules_hash,
247 signature,
248 }
249 }
250
251 pub fn verify_integrity(&self) -> Result<bool, ChainError> {
253 if self.proofs.is_empty() {
254 return Ok(true);
255 }
256
257 for (i, proof) in self.proofs.iter().enumerate() {
259 if proof.sequence != i as u64 {
261 return Err(ChainError::SequenceMismatch {
262 expected: i as u64,
263 got: proof.sequence,
264 });
265 }
266
267 if i == 0 {
269 if proof.prev_hash != [0u8; 32] {
270 return Err(ChainError::InvalidFirstLink);
271 }
272 } else {
273 let expected_prev = self.proofs[i - 1].proof_hash;
274 if proof.prev_hash != expected_prev {
275 return Err(ChainError::BrokenLink {
276 at_sequence: proof.sequence,
277 });
278 }
279 }
280
281 let computed_hash = Self::compute_proof_hash(
283 proof.sequence,
284 &proof.prev_hash,
285 &proof.proof,
286 &proof.input_hash,
287 &proof.output_hash,
288 proof.timestamp,
289 );
290
291 if proof.proof_hash != computed_hash {
292 return Err(ChainError::HashMismatch {
293 at_sequence: proof.sequence,
294 });
295 }
296 }
297
298 let computed_root = self.compute_merkle_root();
300 if computed_root != self.merkle_root {
301 return Err(ChainError::MerkleRootMismatch);
302 }
303
304 Ok(true)
305 }
306
307 pub fn get_proof(&self, sequence: u64) -> Option<&ChainedProof> {
309 self.proofs.get(sequence as usize)
310 }
311
312 pub fn get_merkle_proof(&self, sequence: u64) -> Option<MerkleProof> {
314 if sequence >= self.proofs.len() as u64 {
315 return None;
316 }
317
318 let path = self.compute_merkle_path(sequence as usize);
319
320 Some(MerkleProof {
321 sequence,
322 proof_hash: self.proofs[sequence as usize].proof_hash,
323 path,
324 root: self.merkle_root,
325 })
326 }
327
328 fn compute_proof_hash(
333 sequence: u64,
334 prev_hash: &[u8; 32],
335 proof: &DiamondProof,
336 input_hash: &[u8; 32],
337 output_hash: &[u8; 32],
338 timestamp: u64,
339 ) -> [u8; 32] {
340 let mut hasher = Sha256::new();
341 hasher.update(b"CHAIN_PROOF:");
342 hasher.update(sequence.to_le_bytes());
343 hasher.update(prev_hash);
344 hasher.update(proof.public_inputs.rules_hash);
345 hasher.update(proof.public_inputs.output_hash);
346 hasher.update(input_hash);
347 hasher.update(output_hash);
348 hasher.update(timestamp.to_le_bytes());
349 hasher.finalize().into()
350 }
351
352 fn compute_merkle_root(&self) -> [u8; 32] {
353 if self.proofs.is_empty() {
354 return [0u8; 32];
355 }
356
357 let mut level: Vec<[u8; 32]> = self.proofs.iter().map(|p| p.proof_hash).collect();
358
359 while level.len() > 1 {
360 let mut next_level = Vec::new();
361
362 for chunk in level.chunks(2) {
363 let mut hasher = Sha256::new();
364 hasher.update(b"MERKLE:");
365 hasher.update(chunk[0]);
366 if chunk.len() > 1 {
367 hasher.update(chunk[1]);
368 } else {
369 hasher.update(chunk[0]); }
371 next_level.push(hasher.finalize().into());
372 }
373
374 level = next_level;
375 }
376
377 level[0]
378 }
379
380 fn compute_merkle_path(&self, index: usize) -> Vec<MerklePathNode> {
381 let mut path = Vec::new();
382 let mut level: Vec<[u8; 32]> = self.proofs.iter().map(|p| p.proof_hash).collect();
383 let mut idx = index;
384
385 while level.len() > 1 {
386 let sibling_idx = if idx.is_multiple_of(2) {
387 idx + 1
388 } else {
389 idx - 1
390 };
391 let sibling = if sibling_idx < level.len() {
392 level[sibling_idx]
393 } else {
394 level[idx] };
396
397 path.push(MerklePathNode {
398 hash: sibling,
399 is_left: idx % 2 == 1,
400 });
401
402 let mut next_level = Vec::new();
404 for chunk in level.chunks(2) {
405 let mut hasher = Sha256::new();
406 hasher.update(b"MERKLE:");
407 hasher.update(chunk[0]);
408 if chunk.len() > 1 {
409 hasher.update(chunk[1]);
410 } else {
411 hasher.update(chunk[0]);
412 }
413 next_level.push(hasher.finalize().into());
414 }
415
416 level = next_level;
417 idx /= 2;
418 }
419
420 path
421 }
422
423 fn generate_session_signature(&self, ended_at: u64) -> Vec<u8> {
424 let mut hasher = Sha256::new();
426 hasher.update(b"SESSION_SIG:");
427 hasher.update(self.session_id);
428 hasher.update(self.merkle_root);
429 hasher.update((self.proofs.len() as u64).to_le_bytes());
430 hasher.update(self.started_at.to_le_bytes());
431 hasher.update(ended_at.to_le_bytes());
432 hasher.finalize().to_vec()
433 }
434}
435
436impl Default for ProofChain {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct MerkleProof {
445 pub sequence: u64,
447
448 pub proof_hash: [u8; 32],
450
451 pub path: Vec<MerklePathNode>,
453
454 pub root: [u8; 32],
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct MerklePathNode {
461 pub hash: [u8; 32],
463
464 pub is_left: bool,
466}
467
468impl MerkleProof {
469 pub fn verify(&self) -> bool {
471 let mut current = self.proof_hash;
472
473 for node in &self.path {
474 let mut hasher = Sha256::new();
475 hasher.update(b"MERKLE:");
476
477 if node.is_left {
478 hasher.update(node.hash);
479 hasher.update(current);
480 } else {
481 hasher.update(current);
482 hasher.update(node.hash);
483 }
484
485 current = hasher.finalize().into();
486 }
487
488 current == self.root
489 }
490}
491
492#[derive(Debug, Clone, Serialize, Deserialize)]
494pub enum ChainError {
495 ChainNotActive,
496 SequenceMismatch { expected: u64, got: u64 },
497 InvalidFirstLink,
498 BrokenLink { at_sequence: u64 },
499 HashMismatch { at_sequence: u64 },
500 MerkleRootMismatch,
501}
502
503impl std::fmt::Display for ChainError {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 match self {
506 Self::ChainNotActive => write!(f, "Chain is not active"),
507 Self::SequenceMismatch { expected, got } => {
508 write!(f, "Sequence mismatch: expected {}, got {}", expected, got)
509 }
510 Self::InvalidFirstLink => write!(f, "First link has invalid prev_hash"),
511 Self::BrokenLink { at_sequence } => {
512 write!(f, "Broken link at sequence {}", at_sequence)
513 }
514 Self::HashMismatch { at_sequence } => {
515 write!(f, "Hash mismatch at sequence {}", at_sequence)
516 }
517 Self::MerkleRootMismatch => write!(f, "Merkle root mismatch"),
518 }
519 }
520}
521
522impl std::error::Error for ChainError {}
523
524#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::diamond::zk_snark::{Curve, ProofMetadata, ProvingSystem, PublicInputs, SnarkPi};
532
533 fn create_test_proof(seq: u64) -> DiamondProof {
534 DiamondProof {
535 version: 1,
536 pi: SnarkPi {
537 a: vec![seq as u8; 64],
538 b: vec![seq as u8; 128],
539 c: vec![seq as u8; 64],
540 },
541 public_inputs: PublicInputs {
542 rules_hash: [1u8; 32],
543 output_hash: [2u8; 32],
544 timestamp: 12345 + seq,
545 session_id: [3u8; 32],
546 },
547 metadata: ProofMetadata {
548 system: ProvingSystem::Groth16,
549 curve: Curve::Bn254,
550 generation_time_us: 1000,
551 constraint_count: 100,
552 },
553 }
554 }
555
556 #[test]
557 fn test_chain_creation() {
558 let chain = ProofChain::new();
559 assert_eq!(chain.status, ChainStatus::Active);
560 assert!(chain.proofs.is_empty());
561 }
562
563 #[test]
564 fn test_add_proof() {
565 let mut chain = ProofChain::new();
566 let proof = create_test_proof(0);
567
568 let result = chain.add_proof(proof, "input", "output");
569 assert!(result.is_ok());
570 assert_eq!(chain.proofs.len(), 1);
571 }
572
573 #[test]
574 fn test_chain_integrity() {
575 let mut chain = ProofChain::new();
576
577 for i in 0..5 {
578 let proof = create_test_proof(i);
579 chain
580 .add_proof(proof, &format!("input{}", i), &format!("output{}", i))
581 .unwrap();
582 }
583
584 let result = chain.verify_integrity();
585 assert!(result.is_ok());
586 assert!(result.unwrap());
587 }
588
589 #[test]
590 fn test_finalize() {
591 let mut chain = ProofChain::new();
592 let proof = create_test_proof(0);
593 chain.add_proof(proof, "in", "out").unwrap();
594
595 let session = chain.finalize();
596
597 assert_eq!(chain.status, ChainStatus::Finalized);
598 assert_eq!(session.proof_count, 1);
599 }
600
601 #[test]
602 fn test_cannot_add_after_finalize() {
603 let mut chain = ProofChain::new();
604 chain.finalize();
605
606 let proof = create_test_proof(0);
607 let result = chain.add_proof(proof, "in", "out");
608
609 assert!(matches!(result, Err(ChainError::ChainNotActive)));
610 }
611
612 #[test]
613 fn test_merkle_proof_verification() {
614 let mut chain = ProofChain::new();
615
616 for i in 0..4 {
617 let proof = create_test_proof(i);
618 chain
619 .add_proof(proof, &format!("in{}", i), &format!("out{}", i))
620 .unwrap();
621 }
622
623 let merkle_proof = chain.get_merkle_proof(2).unwrap();
624 assert!(merkle_proof.verify());
625 }
626}