1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
20
21use std::collections::HashMap;
22use std::sync::{Mutex, MutexGuard};
23
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26
27#[derive(Debug, thiserror::Error)]
33pub enum SessionJournalError {
34 #[error("session journal lock poisoned")]
36 LockPoisoned,
37
38 #[error("hash chain integrity violation at entry {index}: expected {expected}, got {actual}")]
40 IntegrityViolation {
41 index: usize,
42 expected: String,
43 actual: String,
44 },
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct JournalEntry {
57 pub sequence: u64,
59 pub prev_hash: String,
62 pub entry_hash: String,
64 pub timestamp_secs: u64,
66 pub tool_name: String,
68 pub server_id: String,
70 pub agent_id: String,
72 pub bytes_read: u64,
74 pub bytes_written: u64,
76 pub delegation_depth: u32,
78 pub allowed: bool,
80}
81
82const ZERO_HASH: &str = "0000000000000000000000000000000000000000000000000000000000000000";
84
85fn compute_entry_hash(entry: &JournalEntry) -> String {
87 let mut hasher = Sha256::new();
88 hasher.update(entry.sequence.to_le_bytes());
89 hasher.update(entry.prev_hash.as_bytes());
90 hasher.update(entry.timestamp_secs.to_le_bytes());
91 hasher.update(entry.tool_name.as_bytes());
92 hasher.update(entry.server_id.as_bytes());
93 hasher.update(entry.agent_id.as_bytes());
94 hasher.update(entry.bytes_read.to_le_bytes());
95 hasher.update(entry.bytes_written.to_le_bytes());
96 hasher.update(entry.delegation_depth.to_le_bytes());
97 hasher.update(if entry.allowed { &[1u8] } else { &[0u8] });
98 hex::encode(hasher.finalize())
99}
100
101#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct CumulativeDataFlow {
108 pub total_bytes_read: u64,
110 pub total_bytes_written: u64,
112 pub total_invocations: u64,
114 pub max_delegation_depth: u32,
116}
117
118#[derive(Debug)]
124struct JournalInner {
125 entries: Vec<JournalEntry>,
127 data_flow: CumulativeDataFlow,
129 tool_sequence: Vec<String>,
131 tool_counts: HashMap<String, u64>,
133}
134
135impl JournalInner {
136 fn new() -> Self {
137 Self {
138 entries: Vec::new(),
139 data_flow: CumulativeDataFlow::default(),
140 tool_sequence: Vec::new(),
141 tool_counts: HashMap::new(),
142 }
143 }
144
145 fn last_hash(&self) -> &str {
146 self.entries
147 .last()
148 .map(|e| e.entry_hash.as_str())
149 .unwrap_or(ZERO_HASH)
150 }
151}
152
153pub struct SessionJournal {
162 inner: Mutex<JournalInner>,
163 session_id: String,
164}
165
166impl std::fmt::Debug for SessionJournal {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("SessionJournal")
169 .field("session_id", &self.session_id)
170 .finish()
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct RecordParams {
177 pub tool_name: String,
179 pub server_id: String,
181 pub agent_id: String,
183 pub bytes_read: u64,
185 pub bytes_written: u64,
187 pub delegation_depth: u32,
189 pub allowed: bool,
191}
192
193impl SessionJournal {
194 fn lock_inner(&self) -> Result<MutexGuard<'_, JournalInner>, SessionJournalError> {
195 self.inner
196 .lock()
197 .map_err(|_| SessionJournalError::LockPoisoned)
198 }
199
200 pub fn new(session_id: String) -> Self {
202 Self {
203 inner: Mutex::new(JournalInner::new()),
204 session_id,
205 }
206 }
207
208 pub fn session_id(&self) -> &str {
210 &self.session_id
211 }
212
213 pub fn record(&self, params: RecordParams) -> Result<u64, SessionJournalError> {
218 let mut inner = self.lock_inner()?;
219
220 let sequence = inner.entries.len() as u64;
221 let prev_hash = inner.last_hash().to_string();
222 let timestamp_secs = std::time::SystemTime::now()
223 .duration_since(std::time::UNIX_EPOCH)
224 .map(|d| d.as_secs())
225 .unwrap_or(0);
226
227 let tool_name = params.tool_name;
228 let mut entry = JournalEntry {
229 sequence,
230 prev_hash,
231 entry_hash: String::new(),
232 timestamp_secs,
233 tool_name: tool_name.clone(),
234 server_id: params.server_id,
235 agent_id: params.agent_id,
236 bytes_read: params.bytes_read,
237 bytes_written: params.bytes_written,
238 delegation_depth: params.delegation_depth,
239 allowed: params.allowed,
240 };
241 entry.entry_hash = compute_entry_hash(&entry);
242
243 inner.data_flow.total_bytes_read = inner
245 .data_flow
246 .total_bytes_read
247 .saturating_add(params.bytes_read);
248 inner.data_flow.total_bytes_written = inner
249 .data_flow
250 .total_bytes_written
251 .saturating_add(params.bytes_written);
252 inner.data_flow.total_invocations = inner.data_flow.total_invocations.saturating_add(1);
253 inner.data_flow.max_delegation_depth = inner
254 .data_flow
255 .max_delegation_depth
256 .max(params.delegation_depth);
257
258 inner.tool_sequence.push(tool_name.clone());
260 let count = inner.tool_counts.entry(tool_name).or_insert(0);
261 *count = count.saturating_add(1);
262
263 inner.entries.push(entry);
264
265 Ok(sequence)
266 }
267
268 pub fn data_flow(&self) -> Result<CumulativeDataFlow, SessionJournalError> {
270 let inner = self.lock_inner()?;
271 Ok(inner.data_flow.clone())
272 }
273
274 pub fn tool_sequence(&self) -> Result<Vec<String>, SessionJournalError> {
276 let inner = self.lock_inner()?;
277 Ok(inner.tool_sequence.clone())
278 }
279
280 pub fn tool_counts(&self) -> Result<HashMap<String, u64>, SessionJournalError> {
282 let inner = self.lock_inner()?;
283 Ok(inner.tool_counts.clone())
284 }
285
286 pub fn len(&self) -> Result<usize, SessionJournalError> {
288 let inner = self.lock_inner()?;
289 Ok(inner.entries.len())
290 }
291
292 pub fn is_empty(&self) -> Result<bool, SessionJournalError> {
294 Ok(self.len()? == 0)
295 }
296
297 pub fn entries(&self) -> Result<Vec<JournalEntry>, SessionJournalError> {
299 let inner = self.lock_inner()?;
300 Ok(inner.entries.clone())
301 }
302
303 pub fn recent_entries(&self, n: usize) -> Result<Vec<JournalEntry>, SessionJournalError> {
305 let inner = self.lock_inner()?;
306 let start = inner.entries.len().saturating_sub(n);
307 Ok(inner.entries[start..].to_vec())
308 }
309
310 pub fn verify_integrity(&self) -> Result<(), SessionJournalError> {
315 let inner = self.lock_inner()?;
316
317 for (index, entry) in inner.entries.iter().enumerate() {
318 let expected_prev = if index == 0 {
320 ZERO_HASH
321 } else {
322 inner.entries[index - 1].entry_hash.as_str()
323 };
324
325 if entry.prev_hash != expected_prev {
326 return Err(SessionJournalError::IntegrityViolation {
327 index,
328 expected: expected_prev.to_string(),
329 actual: entry.prev_hash.clone(),
330 });
331 }
332
333 let recomputed = compute_entry_hash(entry);
335 if entry.entry_hash != recomputed {
336 return Err(SessionJournalError::IntegrityViolation {
337 index,
338 expected: recomputed,
339 actual: entry.entry_hash.clone(),
340 });
341 }
342 }
343
344 Ok(())
345 }
346
347 pub fn head_hash(&self) -> Result<String, SessionJournalError> {
349 let inner = self.lock_inner()?;
350 Ok(inner.last_hash().to_string())
351 }
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361
362 fn test_params(tool: &str) -> RecordParams {
363 RecordParams {
364 tool_name: tool.to_string(),
365 server_id: "srv-1".to_string(),
366 agent_id: "agent-1".to_string(),
367 bytes_read: 100,
368 bytes_written: 50,
369 delegation_depth: 0,
370 allowed: true,
371 }
372 }
373
374 #[test]
375 fn empty_journal() {
376 let journal = SessionJournal::new("sess-1".to_string());
377 assert_eq!(journal.len().unwrap(), 0);
378 assert!(journal.is_empty().unwrap());
379 assert_eq!(journal.head_hash().unwrap(), ZERO_HASH);
380 }
381
382 #[test]
383 fn single_entry() {
384 let journal = SessionJournal::new("sess-1".to_string());
385 let seq = journal.record(test_params("read_file")).unwrap();
386 assert_eq!(seq, 0);
387 assert_eq!(journal.len().unwrap(), 1);
388 assert!(!journal.is_empty().unwrap());
389
390 let entries = journal.entries().unwrap();
391 assert_eq!(entries[0].prev_hash, ZERO_HASH);
392 assert!(!entries[0].entry_hash.is_empty());
393 assert_eq!(entries[0].tool_name, "read_file");
394 }
395
396 #[test]
397 fn hash_chain_links() {
398 let journal = SessionJournal::new("sess-chain".to_string());
399 journal.record(test_params("read_file")).unwrap();
400 journal.record(test_params("write_file")).unwrap();
401 journal.record(test_params("bash")).unwrap();
402
403 let entries = journal.entries().unwrap();
404 assert_eq!(entries[0].prev_hash, ZERO_HASH);
405 assert_eq!(entries[1].prev_hash, entries[0].entry_hash);
406 assert_eq!(entries[2].prev_hash, entries[1].entry_hash);
407 }
408
409 #[test]
410 fn integrity_check_passes() {
411 let journal = SessionJournal::new("sess-integrity".to_string());
412 for tool in &["read_file", "write_file", "bash", "http_request"] {
413 journal.record(test_params(tool)).unwrap();
414 }
415 assert!(journal.verify_integrity().is_ok());
416 }
417
418 #[test]
419 fn cumulative_data_flow() {
420 let journal = SessionJournal::new("sess-flow".to_string());
421 journal
422 .record(RecordParams {
423 tool_name: "read_file".to_string(),
424 server_id: "srv".to_string(),
425 agent_id: "agent".to_string(),
426 bytes_read: 200,
427 bytes_written: 0,
428 delegation_depth: 0,
429 allowed: true,
430 })
431 .unwrap();
432 journal
433 .record(RecordParams {
434 tool_name: "write_file".to_string(),
435 server_id: "srv".to_string(),
436 agent_id: "agent".to_string(),
437 bytes_read: 0,
438 bytes_written: 300,
439 delegation_depth: 1,
440 allowed: true,
441 })
442 .unwrap();
443
444 let flow = journal.data_flow().unwrap();
445 assert_eq!(flow.total_bytes_read, 200);
446 assert_eq!(flow.total_bytes_written, 300);
447 assert_eq!(flow.total_invocations, 2);
448 assert_eq!(flow.max_delegation_depth, 1);
449 }
450
451 #[test]
452 fn tool_sequence_tracking() {
453 let journal = SessionJournal::new("sess-seq".to_string());
454 journal.record(test_params("read_file")).unwrap();
455 journal.record(test_params("bash")).unwrap();
456 journal.record(test_params("read_file")).unwrap();
457
458 let seq = journal.tool_sequence().unwrap();
459 assert_eq!(seq, vec!["read_file", "bash", "read_file"]);
460
461 let counts = journal.tool_counts().unwrap();
462 assert_eq!(counts.get("read_file"), Some(&2));
463 assert_eq!(counts.get("bash"), Some(&1));
464 }
465
466 #[test]
467 fn recent_entries_subset() {
468 let journal = SessionJournal::new("sess-recent".to_string());
469 for i in 0..10 {
470 journal.record(test_params(&format!("tool_{i}"))).unwrap();
471 }
472
473 let recent = journal.recent_entries(3).unwrap();
474 assert_eq!(recent.len(), 3);
475 assert_eq!(recent[0].tool_name, "tool_7");
476 assert_eq!(recent[1].tool_name, "tool_8");
477 assert_eq!(recent[2].tool_name, "tool_9");
478 }
479
480 #[test]
481 fn recent_entries_all_when_fewer() {
482 let journal = SessionJournal::new("sess-few".to_string());
483 journal.record(test_params("tool_a")).unwrap();
484 journal.record(test_params("tool_b")).unwrap();
485
486 let recent = journal.recent_entries(10).unwrap();
487 assert_eq!(recent.len(), 2);
488 }
489
490 #[test]
491 fn session_id_accessible() {
492 let journal = SessionJournal::new("my-session-42".to_string());
493 assert_eq!(journal.session_id(), "my-session-42");
494 }
495
496 #[test]
497 fn denied_invocations_tracked() {
498 let journal = SessionJournal::new("sess-denied".to_string());
499 journal
500 .record(RecordParams {
501 tool_name: "bash".to_string(),
502 server_id: "srv".to_string(),
503 agent_id: "agent".to_string(),
504 bytes_read: 0,
505 bytes_written: 0,
506 delegation_depth: 0,
507 allowed: false,
508 })
509 .unwrap();
510
511 let entries = journal.entries().unwrap();
512 assert!(!entries[0].allowed);
513 let flow = journal.data_flow().unwrap();
515 assert_eq!(flow.total_invocations, 1);
516 }
517
518 #[test]
519 fn entry_hash_determinism() {
520 let e1 = JournalEntry {
522 sequence: 0,
523 prev_hash: ZERO_HASH.to_string(),
524 entry_hash: String::new(),
525 timestamp_secs: 1700000000,
526 tool_name: "read_file".to_string(),
527 server_id: "srv".to_string(),
528 agent_id: "agent".to_string(),
529 bytes_read: 100,
530 bytes_written: 0,
531 delegation_depth: 0,
532 allowed: true,
533 };
534 let e2 = e1.clone();
535 assert_eq!(compute_entry_hash(&e1), compute_entry_hash(&e2));
536 }
537
538 #[test]
539 fn entry_hash_changes_with_content() {
540 let e1 = JournalEntry {
541 sequence: 0,
542 prev_hash: ZERO_HASH.to_string(),
543 entry_hash: String::new(),
544 timestamp_secs: 1700000000,
545 tool_name: "read_file".to_string(),
546 server_id: "srv".to_string(),
547 agent_id: "agent".to_string(),
548 bytes_read: 100,
549 bytes_written: 0,
550 delegation_depth: 0,
551 allowed: true,
552 };
553 let mut e2 = e1.clone();
554 e2.bytes_read = 999;
555 assert_ne!(compute_entry_hash(&e1), compute_entry_hash(&e2));
556 }
557
558 #[test]
559 fn serde_roundtrip() {
560 let journal = SessionJournal::new("sess-serde".to_string());
561 journal.record(test_params("read_file")).unwrap();
562
563 let entries = journal.entries().unwrap();
564 let json = serde_json::to_string_pretty(&entries).unwrap();
565 let restored: Vec<JournalEntry> = serde_json::from_str(&json).unwrap();
566 assert_eq!(entries.len(), restored.len());
567 assert_eq!(entries[0].entry_hash, restored[0].entry_hash);
568 assert_eq!(entries[0].tool_name, restored[0].tool_name);
569 }
570}