use tracing::instrument;
use oxidized_state::storage_traits::{
ContentDigest, RunEvent, RunId, RunLedger, RunStatus as StorageRunStatus,
};
use crate::diff::state_diff::CHECKPOINT_SAVED_KIND;
use crate::domain::{AivcsError, Result};
use crate::metrics::METRICS;
#[derive(Debug, Clone)]
pub struct ReplaySummary {
pub run_id: String,
pub agent_name: String,
pub status: StorageRunStatus,
pub event_count: usize,
pub replay_digest: String,
pub spec_digest: ContentDigest,
}
#[derive(Debug, Clone)]
pub struct ResumePoint {
pub checkpoint_id: String,
pub checkpoint_seq: u64,
pub node_id: String,
pub events_before: Vec<RunEvent>,
}
pub async fn verify_spec_digest(
ledger: &dyn RunLedger,
run_id_str: &str,
expected_spec: &ContentDigest,
) -> Result<()> {
let run_id = RunId(run_id_str.to_string());
let record = ledger
.get_run(&run_id)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
if record.spec_digest != *expected_spec {
return Err(AivcsError::DigestMismatch {
expected: expected_spec.as_str().to_string(),
actual: record.spec_digest.as_str().to_string(),
});
}
Ok(())
}
pub async fn find_resume_point(
ledger: &dyn RunLedger,
run_id_str: &str,
) -> Result<Option<ResumePoint>> {
let run_id = RunId(run_id_str.to_string());
let events = ledger
.get_events(&run_id)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
let checkpoint_pos = events.iter().rposition(|e| e.kind == CHECKPOINT_SAVED_KIND);
let Some(pos) = checkpoint_pos else {
return Ok(None);
};
let cp_event = &events[pos];
let checkpoint_id = cp_event
.payload
.get("checkpoint_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let node_id = cp_event
.payload
.get("node_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let events_before = events[..=pos].to_vec();
Ok(Some(ResumePoint {
checkpoint_id,
checkpoint_seq: cp_event.seq,
node_id,
events_before,
}))
}
#[instrument(skip(ledger), fields(run_id = %run_id_str))]
pub async fn replay_run(
ledger: &dyn RunLedger,
run_id_str: &str,
) -> Result<(Vec<RunEvent>, ReplaySummary)> {
let _span = crate::obs::RunSpan::enter(run_id_str);
METRICS.inc_replays();
let run_id = RunId(run_id_str.to_string());
let record = ledger
.get_run(&run_id)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
let events = ledger
.get_events(&run_id)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
let events_json = serde_json::to_vec(&events).map_err(AivcsError::Serialization)?;
let replay_digest = ContentDigest::from_bytes(&events_json).as_str().to_string();
let summary = ReplaySummary {
run_id: record.run_id.to_string(),
agent_name: record.metadata.agent_name.clone(),
status: record.status,
event_count: events.len(),
replay_digest,
spec_digest: record.spec_digest,
};
Ok((events, summary))
}
#[cfg(test)]
mod tests {
use super::*;
use oxidized_state::fakes::MemoryRunLedger;
use oxidized_state::storage_traits::{ContentDigest, RunMetadata};
use std::sync::Arc;
async fn build_run(
ledger: &dyn RunLedger,
n_nodes: u32,
timestamp: chrono::DateTime<chrono::Utc>,
) -> Result<RunId> {
let spec_digest = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: Some("test_sha".to_string()),
agent_name: "test_agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger
.create_run(&spec_digest, metadata)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
let mut seq = 1u64;
let event = RunEvent {
seq,
kind: "graph_started".to_string(),
payload: serde_json::json!({}),
timestamp,
};
ledger
.append_event(&run_id, event)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
seq += 1;
for i in 0..n_nodes {
let event = RunEvent {
seq,
kind: "node_entered".to_string(),
payload: serde_json::json!({"node_id": format!("node_{}", i)}),
timestamp,
};
ledger
.append_event(&run_id, event)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
seq += 1;
let event = RunEvent {
seq,
kind: "node_exited".to_string(),
payload: serde_json::json!({"node_id": format!("node_{}", i)}),
timestamp,
};
ledger
.append_event(&run_id, event)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
seq += 1;
}
let event = RunEvent {
seq,
kind: "graph_completed".to_string(),
payload: serde_json::json!({}),
timestamp,
};
ledger
.append_event(&run_id, event)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
let summary = oxidized_state::storage_traits::RunSummary {
total_events: seq,
final_state_digest: None,
duration_ms: 1000,
success: true,
};
ledger
.complete_run(&run_id, summary)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))?;
Ok(run_id)
}
#[tokio::test]
async fn test_replay_golden_digest_equality() {
let ledger_a: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let ledger_b: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let fixed_timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")
.expect("parse timestamp")
.with_timezone(&chrono::Utc);
let run_a = build_run(&*ledger_a, 2, fixed_timestamp)
.await
.expect("build_run_a");
let run_b = build_run(&*ledger_b, 2, fixed_timestamp)
.await
.expect("build_run_b");
let (_events_a, summary_a) = replay_run(&*ledger_a, &run_a.0).await.expect("replay_a");
let (_events_b, summary_b) = replay_run(&*ledger_b, &run_b.0).await.expect("replay_b");
assert_eq!(summary_a.replay_digest, summary_b.replay_digest);
assert_eq!(summary_a.event_count, summary_b.event_count);
}
#[tokio::test]
async fn test_replay_missing_run_rejection() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let result = replay_run(&*ledger, "nonexistent-run-id").await;
assert!(result.is_err());
match result.unwrap_err() {
AivcsError::StorageError(_) => { }
other => panic!("Expected StorageError, got {:?}", other),
}
}
#[tokio::test]
async fn test_replay_event_order() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec_digest = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: Some("test_sha".to_string()),
agent_name: "test_agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger
.create_run(&spec_digest, metadata)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("create_run");
let event3 = RunEvent {
seq: 3,
kind: "test_3".to_string(),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
ledger
.append_event(&run_id, event3)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("append");
let event1 = RunEvent {
seq: 1,
kind: "test_1".to_string(),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
ledger
.append_event(&run_id, event1)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("append");
let event2 = RunEvent {
seq: 2,
kind: "test_2".to_string(),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
ledger
.append_event(&run_id, event2)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("append");
let summary = oxidized_state::storage_traits::RunSummary {
total_events: 3,
final_state_digest: None,
duration_ms: 100,
success: true,
};
ledger
.complete_run(&run_id, summary)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("complete");
let (events, _summary) = replay_run(&*ledger, &run_id.0).await.expect("replay");
assert_eq!(events.len(), 3);
assert_eq!(events[0].seq, 1);
assert_eq!(events[1].seq, 2);
assert_eq!(events[2].seq, 3);
}
#[tokio::test]
async fn test_spec_digest_mismatch_rejected() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec_a = ContentDigest::from_bytes(b"spec_a");
let metadata = RunMetadata {
git_sha: None,
agent_name: "test_agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger
.create_run(&spec_a, metadata)
.await
.expect("create_run");
verify_spec_digest(&*ledger, &run_id.0, &spec_a)
.await
.expect("correct spec should pass");
let spec_b = ContentDigest::from_bytes(b"spec_b");
let result = verify_spec_digest(&*ledger, &run_id.0, &spec_b).await;
assert!(result.is_err());
match result.unwrap_err() {
AivcsError::DigestMismatch { expected, actual } => {
assert_eq!(expected, spec_b.as_str());
assert_eq!(actual, spec_a.as_str());
}
other => panic!("Expected DigestMismatch, got {:?}", other),
}
}
#[tokio::test]
async fn test_find_resume_point_returns_latest_checkpoint() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: None,
agent_name: "agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger.create_run(&spec, metadata).await.expect("create");
let ts = chrono::Utc::now();
ledger
.append_event(
&run_id,
RunEvent {
seq: 1,
kind: "checkpoint_saved".to_string(),
payload: serde_json::json!({ "checkpoint_id": "cp1", "node_id": "node_a" }),
timestamp: ts,
},
)
.await
.expect("append cp1");
ledger
.append_event(
&run_id,
RunEvent {
seq: 2,
kind: "node_entered".to_string(),
payload: serde_json::json!({ "node_id": "node_b", "iteration": 1 }),
timestamp: ts,
},
)
.await
.expect("append node");
ledger
.append_event(
&run_id,
RunEvent {
seq: 3,
kind: "checkpoint_saved".to_string(),
payload: serde_json::json!({ "checkpoint_id": "cp2", "node_id": "node_b" }),
timestamp: ts,
},
)
.await
.expect("append cp2");
let resume = find_resume_point(&*ledger, &run_id.0)
.await
.expect("find_resume_point")
.expect("should find a checkpoint");
assert_eq!(resume.checkpoint_id, "cp2");
assert_eq!(resume.node_id, "node_b");
assert_eq!(resume.checkpoint_seq, 3);
assert_eq!(resume.events_before.len(), 3);
}
#[tokio::test]
async fn test_find_resume_point_no_checkpoint_returns_none() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: None,
agent_name: "agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger.create_run(&spec, metadata).await.expect("create");
ledger
.append_event(
&run_id,
RunEvent {
seq: 1,
kind: "graph_started".to_string(),
payload: serde_json::json!({}),
timestamp: chrono::Utc::now(),
},
)
.await
.expect("append");
let resume = find_resume_point(&*ledger, &run_id.0)
.await
.expect("find_resume_point");
assert!(resume.is_none());
}
#[tokio::test]
async fn test_resume_point_events_before_includes_checkpoint() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: None,
agent_name: "agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger.create_run(&spec, metadata).await.expect("create");
let ts = chrono::Utc::now();
ledger
.append_event(
&run_id,
RunEvent {
seq: 1,
kind: "graph_started".to_string(),
payload: serde_json::json!({}),
timestamp: ts,
},
)
.await
.expect("append");
ledger
.append_event(
&run_id,
RunEvent {
seq: 2,
kind: "checkpoint_saved".to_string(),
payload: serde_json::json!({ "checkpoint_id": "cp1", "node_id": "node_x" }),
timestamp: ts,
},
)
.await
.expect("append");
ledger
.append_event(
&run_id,
RunEvent {
seq: 3,
kind: "node_entered".to_string(),
payload: serde_json::json!({ "node_id": "node_y", "iteration": 1 }),
timestamp: ts,
},
)
.await
.expect("append");
let resume = find_resume_point(&*ledger, &run_id.0)
.await
.expect("find")
.expect("some");
assert_eq!(resume.events_before.len(), 2);
let last = resume.events_before.last().expect("last");
assert_eq!(last.kind, "checkpoint_saved");
assert_eq!(last.seq, 2);
}
#[tokio::test]
async fn test_replay_empty_run() {
let ledger: Arc<dyn RunLedger> = Arc::new(MemoryRunLedger::new());
let spec_digest = ContentDigest::from_bytes(b"test_spec");
let metadata = RunMetadata {
git_sha: Some("test_sha".to_string()),
agent_name: "test_agent".to_string(),
tags: serde_json::json!({}),
};
let run_id = ledger
.create_run(&spec_digest, metadata)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("create_run");
let summary = oxidized_state::storage_traits::RunSummary {
total_events: 0,
final_state_digest: None,
duration_ms: 0,
success: true,
};
ledger
.complete_run(&run_id, summary)
.await
.map_err(|e| AivcsError::StorageError(e.to_string()))
.expect("complete");
let (events, summary) = replay_run(&*ledger, &run_id.0).await.expect("replay");
assert_eq!(events.len(), 0);
assert_eq!(summary.event_count, 0);
assert_eq!(summary.replay_digest.len(), 64);
assert!(summary.replay_digest.chars().all(|c| c.is_ascii_hexdigit()));
}
}