use super::super::path::DecisionPath;
use super::super::trace::DecisionTrace;
use super::traits::TraceCollector;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChainEntry<P: DecisionPath> {
pub sequence: u64,
pub prev_hash: [u8; 32],
pub trace: DecisionTrace<P>,
pub hash: [u8; 32],
}
#[derive(Debug, Clone)]
pub struct ChainVerification {
pub valid: bool,
pub entries_verified: usize,
pub first_break: Option<usize>,
pub error: Option<String>,
}
pub struct HashChainCollector<P: DecisionPath> {
pub(crate) entries: Vec<ChainEntry<P>>,
prev_hash: [u8; 32],
sequence: u64,
}
impl<P: DecisionPath + Serialize> HashChainCollector<P> {
pub fn new() -> Self {
Self {
entries: Vec::new(),
prev_hash: [0u8; 32], sequence: 0,
}
}
fn compute_hash(sequence: u64, prev_hash: &[u8; 32], trace: &DecisionTrace<P>) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(sequence.to_le_bytes());
hasher.update(prev_hash);
hasher.update(trace.to_bytes());
hasher.finalize().into()
}
pub fn verify_chain(&self) -> ChainVerification {
if self.entries.is_empty() {
return ChainVerification {
valid: true,
entries_verified: 0,
first_break: None,
error: None,
};
}
let mut prev_hash = [0u8; 32];
for (i, entry) in self.entries.iter().enumerate() {
if entry.sequence != i as u64 {
return ChainVerification {
valid: false,
entries_verified: i,
first_break: Some(i),
error: Some(format!(
"Sequence mismatch at index {}: expected {}, got {}",
i, i, entry.sequence
)),
};
}
if entry.prev_hash != prev_hash {
return ChainVerification {
valid: false,
entries_verified: i,
first_break: Some(i),
error: Some(format!("Previous hash mismatch at index {i}")),
};
}
let computed_hash = Self::compute_hash(entry.sequence, &prev_hash, &entry.trace);
if entry.hash != computed_hash {
return ChainVerification {
valid: false,
entries_verified: i,
first_break: Some(i),
error: Some(format!("Hash mismatch at index {i}")),
};
}
prev_hash = entry.hash;
}
ChainVerification {
valid: true,
entries_verified: self.entries.len(),
first_break: None,
error: None,
}
}
pub fn entries(&self) -> &[ChainEntry<P>] {
&self.entries
}
pub fn get(&self, sequence: u64) -> Option<&ChainEntry<P>> {
self.entries.get(sequence as usize)
}
pub fn latest_hash(&self) -> [u8; 32] {
self.entries.last().map_or([0u8; 32], |e| e.hash)
}
pub fn to_json(&self) -> serde_json::Result<String>
where
P: Serialize,
{
serde_json::to_string_pretty(&self.entries)
}
}
impl<P: DecisionPath + Serialize> TraceCollector<P> for HashChainCollector<P> {
fn record(&mut self, trace: DecisionTrace<P>) {
let hash = Self::compute_hash(self.sequence, &self.prev_hash, &trace);
let entry = ChainEntry { sequence: self.sequence, prev_hash: self.prev_hash, trace, hash };
self.prev_hash = hash;
self.sequence += 1;
self.entries.push(entry);
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
fn len(&self) -> usize {
self.entries.len()
}
}
impl<P: DecisionPath + Serialize> Default for HashChainCollector<P> {
fn default() -> Self {
Self::new()
}
}