1use serde::{Deserialize, Serialize};
6use std::any::Any;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(untagged)]
14pub enum WorkflowValue {
15 Null,
16 Bool(bool),
17 Int(i64),
18 Float(f64),
19 String(String),
20 Bytes(Vec<u8>),
21 List(Vec<WorkflowValue>),
22 Map(HashMap<String, WorkflowValue>),
23 Json(serde_json::Value),
24}
25
26impl WorkflowValue {
27 pub fn is_null(&self) -> bool {
28 matches!(self, WorkflowValue::Null)
29 }
30
31 pub fn as_bool(&self) -> Option<bool> {
32 match self {
33 WorkflowValue::Bool(b) => Some(*b),
34 _ => None,
35 }
36 }
37
38 pub fn as_i64(&self) -> Option<i64> {
39 match self {
40 WorkflowValue::Int(i) => Some(*i),
41 _ => None,
42 }
43 }
44
45 pub fn as_f64(&self) -> Option<f64> {
46 match self {
47 WorkflowValue::Float(f) => Some(*f),
48 WorkflowValue::Int(i) => Some(*i as f64),
49 _ => None,
50 }
51 }
52
53 pub fn as_str(&self) -> Option<&str> {
54 match self {
55 WorkflowValue::String(s) => Some(s),
56 _ => None,
57 }
58 }
59
60 pub fn as_bytes(&self) -> Option<&[u8]> {
61 match self {
62 WorkflowValue::Bytes(b) => Some(b),
63 _ => None,
64 }
65 }
66
67 pub fn as_list(&self) -> Option<&Vec<WorkflowValue>> {
68 match self {
69 WorkflowValue::List(l) => Some(l),
70 _ => None,
71 }
72 }
73
74 pub fn as_map(&self) -> Option<&HashMap<String, WorkflowValue>> {
75 match self {
76 WorkflowValue::Map(m) => Some(m),
77 _ => None,
78 }
79 }
80
81 pub fn as_json(&self) -> Option<&serde_json::Value> {
82 match self {
83 WorkflowValue::Json(j) => Some(j),
84 _ => None,
85 }
86 }
87}
88
89impl From<bool> for WorkflowValue {
90 fn from(v: bool) -> Self {
91 WorkflowValue::Bool(v)
92 }
93}
94
95impl From<i64> for WorkflowValue {
96 fn from(v: i64) -> Self {
97 WorkflowValue::Int(v)
98 }
99}
100
101impl From<i32> for WorkflowValue {
102 fn from(v: i32) -> Self {
103 WorkflowValue::Int(v as i64)
104 }
105}
106
107impl From<f64> for WorkflowValue {
108 fn from(v: f64) -> Self {
109 WorkflowValue::Float(v)
110 }
111}
112
113impl From<String> for WorkflowValue {
114 fn from(v: String) -> Self {
115 WorkflowValue::String(v)
116 }
117}
118
119impl From<&str> for WorkflowValue {
120 fn from(v: &str) -> Self {
121 WorkflowValue::String(v.to_string())
122 }
123}
124
125impl From<Vec<u8>> for WorkflowValue {
126 fn from(v: Vec<u8>) -> Self {
127 WorkflowValue::Bytes(v)
128 }
129}
130
131impl From<serde_json::Value> for WorkflowValue {
132 fn from(v: serde_json::Value) -> Self {
133 WorkflowValue::Json(v)
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
139pub enum NodeStatus {
140 Pending,
142 Waiting,
144 Running,
146 Completed,
148 Failed(String),
150 Skipped,
152 Cancelled,
154}
155
156impl NodeStatus {
157 pub fn is_terminal(&self) -> bool {
158 matches!(
159 self,
160 NodeStatus::Completed
161 | NodeStatus::Failed(_)
162 | NodeStatus::Skipped
163 | NodeStatus::Cancelled
164 )
165 }
166
167 pub fn is_success(&self) -> bool {
168 matches!(self, NodeStatus::Completed | NodeStatus::Skipped)
169 }
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174pub enum WorkflowStatus {
175 NotStarted,
177 Running,
179 Paused,
181 Completed,
183 Failed(String),
185 Cancelled,
187}
188
189#[derive(Debug, Clone)]
191pub struct NodeResult {
192 pub node_id: String,
194 pub status: NodeStatus,
196 pub output: WorkflowValue,
198 pub duration_ms: u64,
200 pub retry_count: u32,
202 pub error: Option<String>,
204}
205
206impl NodeResult {
207 pub fn success(node_id: &str, output: WorkflowValue, duration_ms: u64) -> Self {
208 Self {
209 node_id: node_id.to_string(),
210 status: NodeStatus::Completed,
211 output,
212 duration_ms,
213 retry_count: 0,
214 error: None,
215 }
216 }
217
218 pub fn failed(node_id: &str, error: &str, duration_ms: u64) -> Self {
219 Self {
220 node_id: node_id.to_string(),
221 status: NodeStatus::Failed(error.to_string()),
222 output: WorkflowValue::Null,
223 duration_ms,
224 retry_count: 0,
225 error: Some(error.to_string()),
226 }
227 }
228
229 pub fn skipped(node_id: &str) -> Self {
230 Self {
231 node_id: node_id.to_string(),
232 status: NodeStatus::Skipped,
233 output: WorkflowValue::Null,
234 duration_ms: 0,
235 retry_count: 0,
236 error: None,
237 }
238 }
239}
240
241pub struct WorkflowContext {
243 pub workflow_id: String,
245 pub execution_id: String,
247 input: Arc<RwLock<WorkflowValue>>,
249 node_outputs: Arc<RwLock<HashMap<String, WorkflowValue>>>,
251 node_statuses: Arc<RwLock<HashMap<String, NodeStatus>>>,
253 variables: Arc<RwLock<HashMap<String, WorkflowValue>>>,
255 custom_data: Arc<RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>>,
257 checkpoints: Arc<RwLock<Vec<CheckpointData>>>,
259}
260
261impl WorkflowContext {
262 pub fn new(workflow_id: &str) -> Self {
263 Self {
264 workflow_id: workflow_id.to_string(),
265 execution_id: uuid::Uuid::now_v7().to_string(),
266 input: Arc::new(RwLock::new(WorkflowValue::Null)),
267 node_outputs: Arc::new(RwLock::new(HashMap::new())),
268 node_statuses: Arc::new(RwLock::new(HashMap::new())),
269 variables: Arc::new(RwLock::new(HashMap::new())),
270 custom_data: Arc::new(RwLock::new(HashMap::new())),
271 checkpoints: Arc::new(RwLock::new(Vec::new())),
272 }
273 }
274
275 pub async fn set_input(&self, input: WorkflowValue) {
277 let mut i = self.input.write().await;
278 *i = input;
279 }
280
281 pub async fn get_input(&self) -> WorkflowValue {
283 self.input.read().await.clone()
284 }
285
286 pub async fn set_node_output(&self, node_id: &str, output: WorkflowValue) {
288 let mut outputs = self.node_outputs.write().await;
289 outputs.insert(node_id.to_string(), output);
290 }
291
292 pub async fn get_node_output(&self, node_id: &str) -> Option<WorkflowValue> {
294 let outputs = self.node_outputs.read().await;
295 outputs.get(node_id).cloned()
296 }
297
298 pub async fn get_node_outputs(&self, node_ids: &[&str]) -> HashMap<String, WorkflowValue> {
300 let outputs = self.node_outputs.read().await;
301 node_ids
302 .iter()
303 .filter_map(|id| outputs.get(*id).map(|v| (id.to_string(), v.clone())))
304 .collect()
305 }
306
307 pub async fn set_node_status(&self, node_id: &str, status: NodeStatus) {
309 let mut statuses = self.node_statuses.write().await;
310 statuses.insert(node_id.to_string(), status);
311 }
312
313 pub async fn get_node_status(&self, node_id: &str) -> Option<NodeStatus> {
315 let statuses = self.node_statuses.read().await;
316 statuses.get(node_id).cloned()
317 }
318
319 pub async fn get_all_node_statuses(&self) -> HashMap<String, NodeStatus> {
321 self.node_statuses.read().await.clone()
322 }
323
324 pub async fn set_variable(&self, name: &str, value: WorkflowValue) {
326 let mut vars = self.variables.write().await;
327 vars.insert(name.to_string(), value);
328 }
329
330 pub async fn get_variable(&self, name: &str) -> Option<WorkflowValue> {
332 let vars = self.variables.read().await;
333 vars.get(name).cloned()
334 }
335
336 pub async fn set_custom<T: Send + Sync + 'static>(&self, key: &str, value: T) {
338 let mut data = self.custom_data.write().await;
339 data.insert(key.to_string(), Box::new(value));
340 }
341
342 pub async fn get_custom<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<T> {
344 let data = self.custom_data.read().await;
345 data.get(key).and_then(|v| v.downcast_ref::<T>().cloned())
346 }
347
348 pub async fn create_checkpoint(&self, label: &str) {
350 let checkpoint = CheckpointData {
351 label: label.to_string(),
352 timestamp: std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_millis() as u64,
356 node_outputs: self.node_outputs.read().await.clone(),
357 node_statuses: self.node_statuses.read().await.clone(),
358 variables: self.variables.read().await.clone(),
359 };
360 let mut checkpoints = self.checkpoints.write().await;
361 checkpoints.push(checkpoint);
362 }
363
364 pub async fn restore_checkpoint(&self, label: &str) -> bool {
366 let checkpoints = self.checkpoints.read().await;
367 let checkpoint = checkpoints.iter().rev().find(|c| c.label == label).cloned();
368 drop(checkpoints);
369
370 if let Some(checkpoint) = checkpoint {
371 let mut outputs = self.node_outputs.write().await;
372 *outputs = checkpoint.node_outputs.clone();
373 drop(outputs);
374
375 let mut statuses = self.node_statuses.write().await;
376 *statuses = checkpoint.node_statuses.clone();
377 drop(statuses);
378
379 let mut vars = self.variables.write().await;
380 *vars = checkpoint.variables.clone();
381
382 true
383 } else {
384 false
385 }
386 }
387
388 pub async fn list_checkpoints(&self) -> Vec<String> {
390 let checkpoints = self.checkpoints.read().await;
391 checkpoints.iter().map(|c| c.label.clone()).collect()
392 }
393}
394
395impl Clone for WorkflowContext {
396 fn clone(&self) -> Self {
397 Self {
398 workflow_id: self.workflow_id.clone(),
399 execution_id: self.execution_id.clone(),
400 input: self.input.clone(),
401 node_outputs: self.node_outputs.clone(),
402 node_statuses: self.node_statuses.clone(),
403 variables: self.variables.clone(),
404 custom_data: self.custom_data.clone(),
405 checkpoints: self.checkpoints.clone(),
406 }
407 }
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct CheckpointData {
413 pub label: String,
415 pub timestamp: u64,
417 pub node_outputs: HashMap<String, WorkflowValue>,
419 pub node_statuses: HashMap<String, NodeStatus>,
421 pub variables: HashMap<String, WorkflowValue>,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct ExecutionRecord {
428 pub execution_id: String,
430 pub workflow_id: String,
432 pub started_at: u64,
434 pub ended_at: Option<u64>,
436 pub status: WorkflowStatus,
438 pub node_records: Vec<NodeExecutionRecord>,
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct NodeExecutionRecord {
445 pub node_id: String,
447 pub started_at: u64,
449 pub ended_at: u64,
451 pub status: NodeStatus,
453 pub retry_count: u32,
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[tokio::test]
462 async fn test_workflow_context() {
463 let ctx = WorkflowContext::new("test_workflow");
464
465 ctx.set_input(WorkflowValue::String("test input".to_string()))
467 .await;
468 let input = ctx.get_input().await;
469 assert_eq!(input.as_str(), Some("test input"));
470
471 ctx.set_node_output("node1", WorkflowValue::Int(42)).await;
473 let output = ctx.get_node_output("node1").await;
474 assert_eq!(output.unwrap().as_i64(), Some(42));
475
476 ctx.set_variable("counter", WorkflowValue::Int(0)).await;
478 let var = ctx.get_variable("counter").await;
479 assert_eq!(var.unwrap().as_i64(), Some(0));
480
481 ctx.create_checkpoint("before_loop").await;
483 ctx.set_variable("counter", WorkflowValue::Int(10)).await;
484 ctx.restore_checkpoint("before_loop").await;
485 let var = ctx.get_variable("counter").await;
486 assert_eq!(var.unwrap().as_i64(), Some(0));
487 }
488
489 #[test]
490 fn test_workflow_value_conversions() {
491 let v: WorkflowValue = 42i64.into();
492 assert_eq!(v.as_i64(), Some(42));
493
494 let v: WorkflowValue = "hello".into();
495 assert_eq!(v.as_str(), Some("hello"));
496
497 let v: WorkflowValue = true.into();
498 assert_eq!(v.as_bool(), Some(true));
499
500 let v: WorkflowValue = 3.14f64.into();
501 assert_eq!(v.as_f64(), Some(3.14));
502 }
503}