use async_trait::async_trait;
use dagrs::graph::event::GraphEvent;
use dagrs::node::action::Action;
use dagrs::node::default_node::DefaultNode;
use dagrs::{
Checkpoint, CheckpointConfig, CheckpointStore, EnvVar, Graph, InChannels,
MemoryCheckpointStore, Node, NodeTable, OutChannels, Output,
};
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Clone)]
struct CounterAction {
count: Arc<Mutex<usize>>,
}
#[async_trait]
impl Action for CounterAction {
async fn run(&self, _: &mut InChannels, _: &mut OutChannels, _: Arc<EnvVar>) -> Output {
let mut c = self.count.lock().unwrap();
*c += 1;
Output::new(*c)
}
}
#[derive(Clone)]
struct SlowAction {
name: String,
executed: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl Action for SlowAction {
async fn run(&self, _: &mut InChannels, _: &mut OutChannels, _: Arc<EnvVar>) -> Output {
tokio::time::sleep(Duration::from_millis(10)).await;
self.executed.lock().unwrap().push(self.name.clone());
Output::empty()
}
}
#[tokio::test]
async fn test_checkpoint_store_configuration() {
let mut graph = Graph::new();
let store = MemoryCheckpointStore::new();
graph.set_checkpoint_store(Box::new(store));
let config = CheckpointConfig::enabled()
.with_node_interval(5)
.with_max_checkpoints(3);
graph.set_checkpoint_config(config);
}
#[tokio::test]
async fn test_manual_checkpoint_save_and_load() {
let mut graph = Graph::new();
let mut table = NodeTable::new();
let node_a = DefaultNode::new("A".to_string(), &mut table);
let node_b = DefaultNode::new("B".to_string(), &mut table);
let id_a = node_a.id();
let id_b = node_b.id();
graph.add_node(node_a);
graph.add_node(node_b);
graph.add_edge(id_a, vec![id_b]);
graph.set_checkpoint_store(Box::new(MemoryCheckpointStore::new()));
graph.async_start().await.unwrap();
let checkpoints = graph.list_checkpoints().await.unwrap();
assert!(checkpoints.is_empty());
}
#[tokio::test]
async fn test_checkpoint_list_and_delete() {
let store = MemoryCheckpointStore::new();
let cp1 = Checkpoint::with_id("test_1", 1, 0);
let cp2 = Checkpoint::with_id("test_2", 2, 1);
let cp3 = Checkpoint::with_id("test_3", 3, 2);
store.save(&cp1).await.unwrap();
store.save(&cp2).await.unwrap();
store.save(&cp3).await.unwrap();
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 3);
store.delete(&"test_2".to_string()).await.unwrap();
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 2);
assert!(!ids.contains(&"test_2".to_string()));
store.clear().await.unwrap();
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 0);
}
#[tokio::test]
async fn test_get_latest_checkpoint() {
let store = MemoryCheckpointStore::new();
let mut cp1 = Checkpoint::with_id("old", 1, 0);
cp1.timestamp = 1000;
let mut cp2 = Checkpoint::with_id("newer", 2, 1);
cp2.timestamp = 2000;
let mut cp3 = Checkpoint::with_id("newest", 3, 2);
cp3.timestamp = 3000;
store.save(&cp1).await.unwrap();
store.save(&cp3).await.unwrap();
store.save(&cp2).await.unwrap();
let latest = store.latest().await.unwrap().unwrap();
assert_eq!(latest.id, "newest");
assert_eq!(latest.pc, 3);
}
#[tokio::test]
async fn test_checkpoint_with_node_states() {
use dagrs::NodeState;
let mut checkpoint = Checkpoint::new(5, 2);
checkpoint.add_node_state(NodeState::completed(1, true));
checkpoint.add_node_state(NodeState::completed(2, false));
checkpoint.add_node_state(NodeState::pending(3));
assert_eq!(checkpoint.node_states.len(), 3);
assert!(checkpoint.node_states.get(&1).unwrap().success);
assert!(!checkpoint.node_states.get(&2).unwrap().success);
assert!(!checkpoint.node_states.get(&3).unwrap().completed);
}
#[tokio::test]
async fn test_checkpoint_metadata() {
let mut checkpoint = Checkpoint::new(0, 0);
checkpoint.add_metadata("graph_name", "test_graph");
checkpoint.add_metadata("version", "1.0");
checkpoint.add_metadata("environment", "test");
assert_eq!(
checkpoint.metadata.get("graph_name"),
Some(&"test_graph".to_string())
);
assert_eq!(checkpoint.metadata.get("version"), Some(&"1.0".to_string()));
assert_eq!(
checkpoint.metadata.get("environment"),
Some(&"test".to_string())
);
}
#[tokio::test]
async fn test_checkpoint_events() {
let mut graph = Graph::new();
let mut table = NodeTable::new();
let count = Arc::new(Mutex::new(0));
let node_a = DefaultNode::with_action(
"A".to_string(),
CounterAction {
count: count.clone(),
},
&mut table,
);
let node_b = DefaultNode::with_action(
"B".to_string(),
CounterAction {
count: count.clone(),
},
&mut table,
);
let id_a = node_a.id();
let id_b = node_b.id();
graph.add_node(node_a);
graph.add_node(node_b);
graph.add_edge(id_a, vec![id_b]);
graph.set_checkpoint_store(Box::new(MemoryCheckpointStore::new()));
let config = CheckpointConfig::enabled()
.with_node_interval(1)
.with_max_checkpoints(10);
graph.set_checkpoint_config(config);
let mut receiver = graph.subscribe();
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let collector = tokio::spawn(async move {
let mut collected = Vec::new();
while let Ok(Ok(event)) =
tokio::time::timeout(Duration::from_millis(200), receiver.recv()).await
{
let is_finished = matches!(event, GraphEvent::GraphFinished);
collected.push(event);
if is_finished {
break;
}
}
collected
});
graph.async_start().await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let collected = collector.await.unwrap();
*events_clone.lock().unwrap() = collected;
let events_list = events.lock().unwrap();
let has_node_start = events_list
.iter()
.any(|e| matches!(e, GraphEvent::NodeStart { .. }));
let has_node_success = events_list
.iter()
.any(|e| matches!(e, GraphEvent::NodeSuccess { .. }));
let has_finished = events_list
.iter()
.any(|e| matches!(e, GraphEvent::GraphFinished));
assert!(has_node_start, "Should have NodeStart events");
assert!(has_node_success, "Should have NodeSuccess events");
assert!(has_finished, "Should have GraphFinished event");
}
#[tokio::test]
async fn test_checkpoint_config_builder() {
let default_config = CheckpointConfig::default();
assert!(!default_config.enabled);
assert_eq!(default_config.interval_nodes, Some(10));
assert!(default_config.on_loop_iteration);
assert!(default_config.before_conditional);
assert_eq!(default_config.max_checkpoints, 5);
let custom_config = CheckpointConfig::enabled()
.with_node_interval(20)
.with_time_interval(300)
.with_loop_checkpoint(false)
.with_max_checkpoints(10);
assert!(custom_config.enabled);
assert_eq!(custom_config.interval_nodes, Some(20));
assert_eq!(custom_config.interval_seconds, Some(300));
assert!(!custom_config.on_loop_iteration);
assert_eq!(custom_config.max_checkpoints, 10);
}
#[tokio::test]
async fn test_checkpoint_serialization() {
use dagrs::NodeState;
let mut checkpoint = Checkpoint::new(5, 3);
checkpoint.active_nodes.insert(1);
checkpoint.active_nodes.insert(2);
checkpoint.active_nodes.insert(3);
checkpoint.add_metadata("key", "value");
checkpoint.add_node_state(NodeState::completed(1, true));
let json = serde_json::to_string(&checkpoint).unwrap();
assert!(json.contains("\"pc\":5"));
assert!(json.contains("\"loop_count\":3"));
let restored: Checkpoint = serde_json::from_str(&json).unwrap();
assert_eq!(restored.pc, 5);
assert_eq!(restored.loop_count, 3);
assert_eq!(restored.active_nodes.len(), 3);
assert_eq!(restored.metadata.get("key"), Some(&"value".to_string()));
}
#[tokio::test]
async fn test_resume_execution_basic() {
let mut graph = Graph::new();
let mut table = NodeTable::new();
let executed = Arc::new(Mutex::new(Vec::new()));
let node_a = DefaultNode::with_action(
"A".to_string(),
SlowAction {
name: "A".to_string(),
executed: executed.clone(),
},
&mut table,
);
let node_b = DefaultNode::with_action(
"B".to_string(),
SlowAction {
name: "B".to_string(),
executed: executed.clone(),
},
&mut table,
);
let id_a = node_a.id();
let id_b = node_b.id();
graph.add_node(node_a);
graph.add_node(node_b);
graph.add_edge(id_a, vec![id_b]);
let store = MemoryCheckpointStore::new();
graph.set_checkpoint_store(Box::new(store));
graph.async_start().await.unwrap();
let exec_log = executed.lock().unwrap();
assert!(exec_log.contains(&"A".to_string()), "Node A should execute");
assert!(exec_log.contains(&"B".to_string()), "Node B should execute");
}
#[tokio::test]
async fn test_file_checkpoint_store_basic() {
use dagrs::FileCheckpointStore;
use std::env::temp_dir;
let test_dir = temp_dir().join("dagrs_checkpoint_test");
let store = FileCheckpointStore::new(&test_dir);
let _ = store.clear().await;
let mut checkpoint = Checkpoint::with_id("test_file_cp", 5, 2);
checkpoint.add_metadata("test_key", "test_value");
checkpoint.active_nodes.insert(1);
store.save(&checkpoint).await.unwrap();
let loaded = store.load(&"test_file_cp".to_string()).await.unwrap();
assert_eq!(loaded.pc, 5);
assert_eq!(loaded.loop_count, 2);
assert!(loaded.active_nodes.contains(&1));
let ids = store.list().await.unwrap();
assert!(ids.contains(&"test_file_cp".to_string()));
let latest = store.latest().await.unwrap();
assert!(latest.is_some());
store.clear().await.unwrap();
let _ = std::fs::remove_dir_all(&test_dir);
}
#[tokio::test]
async fn test_file_checkpoint_store_path_traversal_prevention() {
use dagrs::FileCheckpointStore;
use std::env::temp_dir;
let test_dir = temp_dir().join("dagrs_checkpoint_security_test");
let store = FileCheckpointStore::new(&test_dir);
let malicious_checkpoint = Checkpoint::with_id("../../../etc/passwd", 0, 0);
let result = store.save(&malicious_checkpoint).await;
assert!(
result.is_err(),
"Should reject path traversal attempts with .."
);
let malicious_checkpoint2 = Checkpoint::with_id("foo/bar", 0, 0);
let result2 = store.save(&malicious_checkpoint2).await;
assert!(result2.is_err(), "Should reject checkpoint IDs with /");
let result3 = store.load(&"../secret".to_string()).await;
assert!(result3.is_err(), "Should reject path traversal in load");
let _ = std::fs::remove_dir_all(&test_dir);
}
#[tokio::test]
async fn test_checkpoint_id_generation() {
let cp1 = Checkpoint::new(0, 0);
let cp2 = Checkpoint::new(10, 5);
assert!(cp1.id.starts_with("ckpt_"));
assert!(cp2.id.starts_with("ckpt_"));
assert!(!cp1.id.contains('/'));
assert!(!cp1.id.contains('\\'));
assert!(!cp1.id.contains(".."));
}
#[tokio::test]
async fn test_node_state_builder_api() {
use dagrs::NodeState;
let state = NodeState::completed(1, true)
.with_summary("String(hello)")
.with_output_data(b"\"hello\"".to_vec());
assert_eq!(state.node_id, 1);
assert!(state.completed);
assert!(state.success);
assert_eq!(state.output_summary, Some("String(hello)".to_string()));
assert_eq!(state.output_data, Some(b"\"hello\"".to_vec()));
let state = NodeState::completed(2, false).with_summary("Error: connection timeout");
assert_eq!(state.node_id, 2);
assert!(state.completed);
assert!(!state.success);
assert_eq!(
state.output_summary,
Some("Error: connection timeout".to_string())
);
let state = NodeState::pending(3);
assert_eq!(state.node_id, 3);
assert!(!state.completed);
assert!(!state.success);
}
#[tokio::test]
async fn test_checkpoint_output_data_serialization() {
use dagrs::NodeState;
let mut checkpoint = Checkpoint::new(0, 0);
let output_json = serde_json::to_vec(&"test_output").unwrap();
let state = NodeState::completed(1, true)
.with_output_data(output_json.clone())
.with_summary("String(test_output)");
checkpoint.add_node_state(state);
let json = serde_json::to_string(&checkpoint).unwrap();
let restored: Checkpoint = serde_json::from_str(&json).unwrap();
let restored_state = restored.node_states.get(&1).unwrap();
assert_eq!(restored_state.output_data, Some(output_json));
assert_eq!(
restored_state.output_summary,
Some("String(test_output)".to_string())
);
let output: String =
serde_json::from_slice(restored_state.output_data.as_ref().unwrap()).unwrap();
assert_eq!(output, "test_output");
}