use super::output::{RunMetadata, StepTiming, TokenTotals};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use simple_agent_type::message::Message;
use std::collections::BTreeMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub workflow_path: String,
pub failed_node_id: String,
pub completed_trace: Vec<String>,
pub completed_outputs: BTreeMap<String, Value>,
pub globals: BTreeMap<String, Value>,
pub original_messages: Vec<Message>,
pub step_timings: Vec<StepTiming>,
pub token_totals: TokenTotals,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartialWorkflowOutput {
pub workflow_id: String,
pub completed_trace: Vec<String>,
pub completed_outputs: BTreeMap<String, Value>,
pub failed_node_id: String,
pub error: String,
pub checkpoint: WorkflowCheckpoint,
#[serde(skip_serializing_if = "Option::is_none")]
pub nerdstats: Option<RunMetadata>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_serialization_roundtrip() {
let cp = WorkflowCheckpoint {
workflow_path: "test.yaml".into(),
failed_node_id: "node_b".into(),
completed_trace: vec!["node_a".into()],
completed_outputs: BTreeMap::from([(
"node_a".into(),
serde_json::json!({"result": 1}),
)]),
globals: BTreeMap::new(),
original_messages: vec![Message::user("hello")],
step_timings: vec![],
token_totals: TokenTotals::default(),
};
let json = serde_json::to_string(&cp).unwrap();
let parsed: WorkflowCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.failed_node_id, "node_b");
assert_eq!(parsed.completed_trace, vec!["node_a"]);
assert_eq!(parsed.original_messages.len(), 1);
}
}