Skip to main content

mnemo_core/
hash.rs

1use sha2::{Digest, Sha256};
2use subtle::ConstantTimeEq;
3
4/// Constant-time comparison for hash values to prevent timing side-channels.
5fn hashes_equal(a: &[u8], b: &[u8]) -> bool {
6    a.ct_eq(b).into()
7}
8
9pub fn compute_content_hash(content: &str, agent_id: &str, timestamp: &str) -> Vec<u8> {
10    let mut hasher = Sha256::new();
11    hasher.update(content.as_bytes());
12    hasher.update(agent_id.as_bytes());
13    hasher.update(timestamp.as_bytes());
14    hasher.finalize().to_vec()
15}
16
17pub fn compute_chain_hash(content_hash: &[u8], prev_hash: Option<&[u8]>) -> Vec<u8> {
18    let mut hasher = Sha256::new();
19    hasher.update(content_hash);
20    if let Some(prev) = prev_hash {
21        hasher.update(prev);
22    }
23    hasher.finalize().to_vec()
24}
25
26use serde::{Deserialize, Serialize};
27use uuid::Uuid;
28
29use crate::model::event::AgentEvent;
30use crate::model::memory::MemoryRecord;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ChainVerificationResult {
34    pub valid: bool,
35    pub total_records: usize,
36    pub verified_records: usize,
37    pub first_broken_at: Option<Uuid>,
38    pub error_message: Option<String>,
39}
40
41pub fn verify_chain(records: &[MemoryRecord]) -> ChainVerificationResult {
42    if records.is_empty() {
43        return ChainVerificationResult {
44            valid: true,
45            total_records: 0,
46            verified_records: 0,
47            first_broken_at: None,
48            error_message: None,
49        };
50    }
51
52    let mut verified = 0;
53
54    for (i, record) in records.iter().enumerate() {
55        // Verify content hash (constant-time comparison)
56        let expected_hash =
57            compute_content_hash(&record.content, &record.agent_id, &record.created_at);
58        if !hashes_equal(&expected_hash, &record.content_hash) {
59            return ChainVerificationResult {
60                valid: false,
61                total_records: records.len(),
62                verified_records: verified,
63                first_broken_at: Some(record.id),
64                error_message: Some(format!("content hash mismatch at record {}", record.id)),
65            };
66        }
67
68        // Verify chain linking (prev_hash)
69        if i > 0 {
70            let prev_record = &records[i - 1];
71            let expected_chain =
72                compute_chain_hash(&record.content_hash, Some(&prev_record.content_hash));
73            if let Some(ref prev_hash) = record.prev_hash
74                && !hashes_equal(prev_hash, &expected_chain)
75            {
76                return ChainVerificationResult {
77                    valid: false,
78                    total_records: records.len(),
79                    verified_records: verified,
80                    first_broken_at: Some(record.id),
81                    error_message: Some(format!("chain hash mismatch at record {}", record.id)),
82                };
83            }
84        }
85
86        verified += 1;
87    }
88
89    ChainVerificationResult {
90        valid: true,
91        total_records: records.len(),
92        verified_records: verified,
93        first_broken_at: None,
94        error_message: None,
95    }
96}
97
98/// Verify the integrity of an ordered list of agent events.
99/// Verifies that content_hash fields are non-empty and that
100/// prev_hash chain linkage between consecutive events is valid.
101/// Note: event content_hash is computed from the operation's source data
102/// (memory content or query string), not from the event payload JSON,
103/// so we verify it is present but do not recompute it.
104pub fn verify_event_chain(events: &[AgentEvent]) -> ChainVerificationResult {
105    if events.is_empty() {
106        return ChainVerificationResult {
107            valid: true,
108            total_records: 0,
109            verified_records: 0,
110            first_broken_at: None,
111            error_message: None,
112        };
113    }
114
115    let mut verified = 0;
116
117    for (i, event) in events.iter().enumerate() {
118        // Verify content hash is present (non-empty)
119        if event.content_hash.is_empty() {
120            return ChainVerificationResult {
121                valid: false,
122                total_records: events.len(),
123                verified_records: verified,
124                first_broken_at: Some(event.id),
125                error_message: Some(format!("event content hash is empty at {}", event.id)),
126            };
127        }
128
129        // Verify chain linking (prev_hash)
130        if i > 0 {
131            let prev_event = &events[i - 1];
132            let expected_chain =
133                compute_chain_hash(&event.content_hash, Some(&prev_event.content_hash));
134            if let Some(ref prev_hash) = event.prev_hash
135                && !hashes_equal(prev_hash, &expected_chain)
136            {
137                return ChainVerificationResult {
138                    valid: false,
139                    total_records: events.len(),
140                    verified_records: verified,
141                    first_broken_at: Some(event.id),
142                    error_message: Some(format!("event chain hash mismatch at {}", event.id)),
143                };
144            }
145        }
146
147        verified += 1;
148    }
149
150    ChainVerificationResult {
151        valid: true,
152        total_records: events.len(),
153        verified_records: verified,
154        first_broken_at: None,
155        error_message: None,
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_content_hash_deterministic() {
165        let h1 = compute_content_hash("hello", "agent-1", "2025-01-01T00:00:00Z");
166        let h2 = compute_content_hash("hello", "agent-1", "2025-01-01T00:00:00Z");
167        assert_eq!(h1, h2);
168        assert_eq!(h1.len(), 32); // SHA-256 = 32 bytes
169    }
170
171    #[test]
172    fn test_content_hash_differs_with_different_input() {
173        let h1 = compute_content_hash("hello", "agent-1", "2025-01-01T00:00:00Z");
174        let h2 = compute_content_hash("world", "agent-1", "2025-01-01T00:00:00Z");
175        let h3 = compute_content_hash("hello", "agent-2", "2025-01-01T00:00:00Z");
176        let h4 = compute_content_hash("hello", "agent-1", "2025-01-02T00:00:00Z");
177        assert_ne!(h1, h2);
178        assert_ne!(h1, h3);
179        assert_ne!(h1, h4);
180    }
181
182    #[test]
183    fn test_chain_hash_without_prev() {
184        let content_hash = compute_content_hash("test", "a", "t");
185        let chain = compute_chain_hash(&content_hash, None);
186        assert_eq!(chain.len(), 32);
187    }
188
189    #[test]
190    fn test_chain_hash_with_prev() {
191        let h1 = compute_content_hash("first", "a", "t1");
192        let h2 = compute_content_hash("second", "a", "t2");
193        let chain1 = compute_chain_hash(&h1, None);
194        let chain2 = compute_chain_hash(&h2, Some(&chain1));
195        assert_ne!(chain1, chain2);
196    }
197
198    #[test]
199    fn test_verify_chain_valid() {
200        use crate::model::memory::*;
201
202        let mut records: Vec<MemoryRecord> = Vec::new();
203        let agent_id = "agent-1";
204
205        for i in 0..5 {
206            let content = format!("memory content {i}");
207            let timestamp = format!("2025-01-0{:01}T00:00:00Z", i + 1);
208            let content_hash = compute_content_hash(&content, agent_id, &timestamp);
209            let prev_hash = if i == 0 {
210                Some(compute_chain_hash(&content_hash, None))
211            } else {
212                let prev_record = &records[i - 1];
213                Some(compute_chain_hash(
214                    &content_hash,
215                    Some(&prev_record.content_hash),
216                ))
217            };
218
219            records.push(MemoryRecord {
220                id: uuid::Uuid::now_v7(),
221                agent_id: agent_id.to_string(),
222                content,
223                memory_type: MemoryType::Episodic,
224                scope: Scope::Private,
225                importance: 0.5,
226                tags: vec![],
227                metadata: serde_json::json!({}),
228                embedding: None,
229                content_hash,
230                prev_hash,
231                source_type: SourceType::Agent,
232                source_id: None,
233                consolidation_state: ConsolidationState::Raw,
234                access_count: 0,
235                org_id: None,
236                thread_id: None,
237                created_at: timestamp,
238                updated_at: "2025-01-01T00:00:00Z".to_string(),
239                last_accessed_at: None,
240                expires_at: None,
241                deleted_at: None,
242                decay_rate: None,
243                created_by: None,
244                version: 1,
245                prev_version_id: None,
246                quarantined: false,
247                quarantine_reason: None,
248                decay_function: None,
249            });
250        }
251
252        let result = verify_chain(&records);
253        assert!(result.valid);
254        assert_eq!(result.total_records, 5);
255        assert_eq!(result.verified_records, 5);
256        assert!(result.first_broken_at.is_none());
257    }
258
259    #[test]
260    fn test_verify_chain_tampered() {
261        use crate::model::memory::*;
262
263        let mut records: Vec<MemoryRecord> = Vec::new();
264        let agent_id = "agent-1";
265
266        for i in 0..3 {
267            let content = format!("memory content {i}");
268            let timestamp = format!("2025-01-0{:01}T00:00:00Z", i + 1);
269            let content_hash = compute_content_hash(&content, agent_id, &timestamp);
270            let prev_hash = if i == 0 {
271                Some(compute_chain_hash(&content_hash, None))
272            } else {
273                let prev_record = &records[i - 1];
274                Some(compute_chain_hash(
275                    &content_hash,
276                    Some(&prev_record.content_hash),
277                ))
278            };
279
280            records.push(MemoryRecord {
281                id: uuid::Uuid::now_v7(),
282                agent_id: agent_id.to_string(),
283                content,
284                memory_type: MemoryType::Episodic,
285                scope: Scope::Private,
286                importance: 0.5,
287                tags: vec![],
288                metadata: serde_json::json!({}),
289                embedding: None,
290                content_hash,
291                prev_hash,
292                source_type: SourceType::Agent,
293                source_id: None,
294                consolidation_state: ConsolidationState::Raw,
295                access_count: 0,
296                org_id: None,
297                thread_id: None,
298                created_at: timestamp,
299                updated_at: "2025-01-01T00:00:00Z".to_string(),
300                last_accessed_at: None,
301                expires_at: None,
302                deleted_at: None,
303                decay_rate: None,
304                created_by: None,
305                version: 1,
306                prev_version_id: None,
307                quarantined: false,
308                quarantine_reason: None,
309                decay_function: None,
310            });
311        }
312
313        // Tamper with the second record's content (but not its hash)
314        records[1].content = "TAMPERED CONTENT".to_string();
315
316        let result = verify_chain(&records);
317        assert!(!result.valid);
318        assert_eq!(result.first_broken_at, Some(records[1].id));
319        assert!(
320            result
321                .error_message
322                .unwrap()
323                .contains("content hash mismatch")
324        );
325    }
326}