entrenar/monitor/inference/collector/
hash_chain.rs1use super::super::path::DecisionPath;
4use super::super::trace::DecisionTrace;
5use super::traits::TraceCollector;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct ChainEntry<P: DecisionPath> {
12 pub sequence: u64,
14 pub prev_hash: [u8; 32],
16 pub trace: DecisionTrace<P>,
18 pub hash: [u8; 32],
20}
21
22#[derive(Debug, Clone)]
24pub struct ChainVerification {
25 pub valid: bool,
27 pub entries_verified: usize,
29 pub first_break: Option<usize>,
31 pub error: Option<String>,
33}
34
35pub struct HashChainCollector<P: DecisionPath> {
56 pub(crate) entries: Vec<ChainEntry<P>>,
57 prev_hash: [u8; 32],
58 sequence: u64,
59}
60
61impl<P: DecisionPath + Serialize> HashChainCollector<P> {
62 pub fn new() -> Self {
64 Self {
65 entries: Vec::new(),
66 prev_hash: [0u8; 32], sequence: 0,
68 }
69 }
70
71 fn compute_hash(sequence: u64, prev_hash: &[u8; 32], trace: &DecisionTrace<P>) -> [u8; 32] {
73 let mut hasher = Sha256::new();
74 hasher.update(sequence.to_le_bytes());
75 hasher.update(prev_hash);
76 hasher.update(trace.to_bytes());
77 hasher.finalize().into()
78 }
79
80 pub fn verify_chain(&self) -> ChainVerification {
82 if self.entries.is_empty() {
83 return ChainVerification {
84 valid: true,
85 entries_verified: 0,
86 first_break: None,
87 error: None,
88 };
89 }
90
91 let mut prev_hash = [0u8; 32]; for (i, entry) in self.entries.iter().enumerate() {
94 if entry.sequence != i as u64 {
96 return ChainVerification {
97 valid: false,
98 entries_verified: i,
99 first_break: Some(i),
100 error: Some(format!(
101 "Sequence mismatch at index {}: expected {}, got {}",
102 i, i, entry.sequence
103 )),
104 };
105 }
106
107 if entry.prev_hash != prev_hash {
109 return ChainVerification {
110 valid: false,
111 entries_verified: i,
112 first_break: Some(i),
113 error: Some(format!("Previous hash mismatch at index {i}")),
114 };
115 }
116
117 let computed_hash = Self::compute_hash(entry.sequence, &prev_hash, &entry.trace);
119 if entry.hash != computed_hash {
120 return ChainVerification {
121 valid: false,
122 entries_verified: i,
123 first_break: Some(i),
124 error: Some(format!("Hash mismatch at index {i}")),
125 };
126 }
127
128 prev_hash = entry.hash;
129 }
130
131 ChainVerification {
132 valid: true,
133 entries_verified: self.entries.len(),
134 first_break: None,
135 error: None,
136 }
137 }
138
139 pub fn entries(&self) -> &[ChainEntry<P>] {
141 &self.entries
142 }
143
144 pub fn get(&self, sequence: u64) -> Option<&ChainEntry<P>> {
146 self.entries.get(sequence as usize)
147 }
148
149 pub fn latest_hash(&self) -> [u8; 32] {
151 self.entries.last().map_or([0u8; 32], |e| e.hash)
152 }
153
154 pub fn to_json(&self) -> serde_json::Result<String>
156 where
157 P: Serialize,
158 {
159 serde_json::to_string_pretty(&self.entries)
160 }
161}
162
163impl<P: DecisionPath + Serialize> TraceCollector<P> for HashChainCollector<P> {
164 fn record(&mut self, trace: DecisionTrace<P>) {
165 let hash = Self::compute_hash(self.sequence, &self.prev_hash, &trace);
166
167 let entry = ChainEntry { sequence: self.sequence, prev_hash: self.prev_hash, trace, hash };
168
169 self.prev_hash = hash;
170 self.sequence += 1;
171 self.entries.push(entry);
172 }
173
174 fn flush(&mut self) -> std::io::Result<()> {
175 Ok(())
177 }
178
179 fn len(&self) -> usize {
180 self.entries.len()
181 }
182}
183
184impl<P: DecisionPath + Serialize> Default for HashChainCollector<P> {
185 fn default() -> Self {
186 Self::new()
187 }
188}