Skip to main content

swink_agent_eval/
audit.rs

1//! Deterministic audit trails with hash chains for tamper detection.
2
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5
6use crate::types::Invocation;
7
8/// An [`Invocation`] wrapped with a hash chain for tamper detection.
9///
10/// Each turn is hashed individually, then the concatenated hashes are hashed
11/// again to produce a single `chain_hash`. Use [`verify`](Self::verify) to
12/// check integrity after deserialization or storage.
13///
14/// **Note:** `serde_json::Value` map key order is insertion-dependent, so audit
15/// trails verify same-instance integrity, not cross-process reproducibility.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AuditedInvocation {
18    /// The original invocation trace.
19    pub invocation: Invocation,
20    /// Hex-encoded SHA-256 of each turn's canonical JSON.
21    pub turn_hashes: Vec<String>,
22    /// Hex-encoded SHA-256 of all `turn_hashes` concatenated.
23    pub chain_hash: String,
24}
25
26impl AuditedInvocation {
27    /// Wrap an [`Invocation`] with computed hash chain.
28    #[must_use]
29    pub fn from_invocation(invocation: Invocation) -> Self {
30        let turn_hashes: Vec<String> = invocation
31            .turns
32            .iter()
33            .map(|turn| {
34                let json = serde_json::to_string(turn).expect("TurnRecord is serializable");
35                hex_sha256(json.as_bytes())
36            })
37            .collect();
38
39        let chain_hash = compute_chain_hash(&turn_hashes);
40
41        Self {
42            invocation,
43            turn_hashes,
44            chain_hash,
45        }
46    }
47
48    /// Recompute all hashes and verify they match the stored values.
49    #[must_use]
50    pub fn verify(&self) -> bool {
51        if self.turn_hashes.len() != self.invocation.turns.len() {
52            return false;
53        }
54
55        for (turn, stored_hash) in self.invocation.turns.iter().zip(&self.turn_hashes) {
56            let json = serde_json::to_string(turn).expect("TurnRecord is serializable");
57            let computed = hex_sha256(json.as_bytes());
58            if &computed != stored_hash {
59                return false;
60            }
61        }
62
63        let computed_chain = compute_chain_hash(&self.turn_hashes);
64        computed_chain == self.chain_hash
65    }
66}
67
68fn hex_sha256(data: &[u8]) -> String {
69    let mut hasher = Sha256::new();
70    hasher.update(data);
71    let hash = hasher.finalize();
72    let mut out = String::with_capacity(hash.len() * 2);
73    for byte in hash {
74        use std::fmt::Write as _;
75        let _ = write!(&mut out, "{byte:02x}");
76    }
77    out
78}
79
80fn compute_chain_hash(turn_hashes: &[String]) -> String {
81    let concatenated: String = turn_hashes.concat();
82    hex_sha256(concatenated.as_bytes())
83}
84
85#[cfg(test)]
86mod tests {
87    use std::time::Duration;
88
89    use swink_agent::{AssistantMessage, Cost, ModelSpec, StopReason, Usage};
90
91    use super::*;
92    use crate::types::TurnRecord;
93
94    fn minimal_invocation(num_turns: usize) -> Invocation {
95        let turns = (0..num_turns)
96            .map(|i| TurnRecord {
97                turn_index: i,
98                assistant_message: AssistantMessage {
99                    content: vec![],
100                    provider: "test".to_string(),
101                    model_id: "test-model".to_string(),
102                    usage: Usage::default(),
103                    cost: Cost::default(),
104                    stop_reason: StopReason::Stop,
105                    error_message: None,
106                    error_kind: None,
107                    timestamp: 0,
108                    cache_hint: None,
109                },
110                tool_calls: vec![],
111                tool_results: vec![],
112                duration: Duration::from_millis(10),
113            })
114            .collect();
115
116        Invocation {
117            turns,
118            total_usage: Usage::default(),
119            total_cost: Cost::default(),
120            total_duration: Duration::from_millis(10 * num_turns as u64),
121            final_response: None,
122            stop_reason: StopReason::Stop,
123            model: ModelSpec::new("test", "test-model"),
124        }
125    }
126
127    #[test]
128    fn roundtrip_verify() {
129        let inv = minimal_invocation(3);
130        let audited = AuditedInvocation::from_invocation(inv);
131
132        assert!(audited.verify());
133        assert_eq!(audited.turn_hashes.len(), 3);
134        for hash in &audited.turn_hashes {
135            assert_eq!(hash.len(), 64);
136        }
137        assert_eq!(audited.chain_hash.len(), 64);
138    }
139
140    #[test]
141    fn tampered_turn_fails_verify() {
142        let inv = minimal_invocation(2);
143        let mut audited = AuditedInvocation::from_invocation(inv);
144
145        audited.turn_hashes[0] = "0".repeat(64);
146
147        assert!(!audited.verify());
148    }
149
150    #[test]
151    fn empty_invocation() {
152        let inv = minimal_invocation(0);
153        let audited = AuditedInvocation::from_invocation(inv);
154
155        assert!(audited.verify());
156        assert!(audited.turn_hashes.is_empty());
157        assert_eq!(audited.chain_hash.len(), 64);
158    }
159}