Skip to main content

mofa_foundation/workflow/
state.rs

1//! 工作流状态管理
2//!
3//! 管理工作流执行过程中的状态和数据传递
4
5use serde::{Deserialize, Serialize};
6use std::any::Any;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// 工作流数据值
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(untagged)]
14pub enum WorkflowValue {
15    Null,
16    Bool(bool),
17    Int(i64),
18    Float(f64),
19    String(String),
20    Bytes(Vec<u8>),
21    List(Vec<WorkflowValue>),
22    Map(HashMap<String, WorkflowValue>),
23    Json(serde_json::Value),
24}
25
26impl WorkflowValue {
27    pub fn is_null(&self) -> bool {
28        matches!(self, WorkflowValue::Null)
29    }
30
31    pub fn as_bool(&self) -> Option<bool> {
32        match self {
33            WorkflowValue::Bool(b) => Some(*b),
34            _ => None,
35        }
36    }
37
38    pub fn as_i64(&self) -> Option<i64> {
39        match self {
40            WorkflowValue::Int(i) => Some(*i),
41            _ => None,
42        }
43    }
44
45    pub fn as_f64(&self) -> Option<f64> {
46        match self {
47            WorkflowValue::Float(f) => Some(*f),
48            WorkflowValue::Int(i) => Some(*i as f64),
49            _ => None,
50        }
51    }
52
53    pub fn as_str(&self) -> Option<&str> {
54        match self {
55            WorkflowValue::String(s) => Some(s),
56            _ => None,
57        }
58    }
59
60    pub fn as_bytes(&self) -> Option<&[u8]> {
61        match self {
62            WorkflowValue::Bytes(b) => Some(b),
63            _ => None,
64        }
65    }
66
67    pub fn as_list(&self) -> Option<&Vec<WorkflowValue>> {
68        match self {
69            WorkflowValue::List(l) => Some(l),
70            _ => None,
71        }
72    }
73
74    pub fn as_map(&self) -> Option<&HashMap<String, WorkflowValue>> {
75        match self {
76            WorkflowValue::Map(m) => Some(m),
77            _ => None,
78        }
79    }
80
81    pub fn as_json(&self) -> Option<&serde_json::Value> {
82        match self {
83            WorkflowValue::Json(j) => Some(j),
84            _ => None,
85        }
86    }
87}
88
89impl From<bool> for WorkflowValue {
90    fn from(v: bool) -> Self {
91        WorkflowValue::Bool(v)
92    }
93}
94
95impl From<i64> for WorkflowValue {
96    fn from(v: i64) -> Self {
97        WorkflowValue::Int(v)
98    }
99}
100
101impl From<i32> for WorkflowValue {
102    fn from(v: i32) -> Self {
103        WorkflowValue::Int(v as i64)
104    }
105}
106
107impl From<f64> for WorkflowValue {
108    fn from(v: f64) -> Self {
109        WorkflowValue::Float(v)
110    }
111}
112
113impl From<String> for WorkflowValue {
114    fn from(v: String) -> Self {
115        WorkflowValue::String(v)
116    }
117}
118
119impl From<&str> for WorkflowValue {
120    fn from(v: &str) -> Self {
121        WorkflowValue::String(v.to_string())
122    }
123}
124
125impl From<Vec<u8>> for WorkflowValue {
126    fn from(v: Vec<u8>) -> Self {
127        WorkflowValue::Bytes(v)
128    }
129}
130
131impl From<serde_json::Value> for WorkflowValue {
132    fn from(v: serde_json::Value) -> Self {
133        WorkflowValue::Json(v)
134    }
135}
136
137/// 节点执行状态
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
139pub enum NodeStatus {
140    /// 等待执行
141    Pending,
142    /// 等待依赖完成
143    Waiting,
144    /// 正在执行
145    Running,
146    /// 执行成功
147    Completed,
148    /// 执行失败
149    Failed(String),
150    /// 已跳过(条件不满足)
151    Skipped,
152    /// 已取消
153    Cancelled,
154}
155
156impl NodeStatus {
157    pub fn is_terminal(&self) -> bool {
158        matches!(
159            self,
160            NodeStatus::Completed
161                | NodeStatus::Failed(_)
162                | NodeStatus::Skipped
163                | NodeStatus::Cancelled
164        )
165    }
166
167    pub fn is_success(&self) -> bool {
168        matches!(self, NodeStatus::Completed | NodeStatus::Skipped)
169    }
170}
171
172/// 工作流执行状态
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174pub enum WorkflowStatus {
175    /// 未开始
176    NotStarted,
177    /// 正在运行
178    Running,
179    /// 已暂停
180    Paused,
181    /// 已完成
182    Completed,
183    /// 失败
184    Failed(String),
185    /// 已取消
186    Cancelled,
187}
188
189/// 节点执行结果
190#[derive(Debug, Clone)]
191pub struct NodeResult {
192    /// 节点 ID
193    pub node_id: String,
194    /// 执行状态
195    pub status: NodeStatus,
196    /// 输出数据
197    pub output: WorkflowValue,
198    /// 执行时长(毫秒)
199    pub duration_ms: u64,
200    /// 重试次数
201    pub retry_count: u32,
202    /// 错误信息
203    pub error: Option<String>,
204}
205
206impl NodeResult {
207    pub fn success(node_id: &str, output: WorkflowValue, duration_ms: u64) -> Self {
208        Self {
209            node_id: node_id.to_string(),
210            status: NodeStatus::Completed,
211            output,
212            duration_ms,
213            retry_count: 0,
214            error: None,
215        }
216    }
217
218    pub fn failed(node_id: &str, error: &str, duration_ms: u64) -> Self {
219        Self {
220            node_id: node_id.to_string(),
221            status: NodeStatus::Failed(error.to_string()),
222            output: WorkflowValue::Null,
223            duration_ms,
224            retry_count: 0,
225            error: Some(error.to_string()),
226        }
227    }
228
229    pub fn skipped(node_id: &str) -> Self {
230        Self {
231            node_id: node_id.to_string(),
232            status: NodeStatus::Skipped,
233            output: WorkflowValue::Null,
234            duration_ms: 0,
235            retry_count: 0,
236            error: None,
237        }
238    }
239}
240
241/// 工作流上下文 - 在节点间传递数据
242pub struct WorkflowContext {
243    /// 工作流 ID
244    pub workflow_id: String,
245    /// 执行 ID(每次运行唯一)
246    pub execution_id: String,
247    /// 输入数据
248    input: Arc<RwLock<WorkflowValue>>,
249    /// 节点输出存储
250    node_outputs: Arc<RwLock<HashMap<String, WorkflowValue>>>,
251    /// 节点状态
252    node_statuses: Arc<RwLock<HashMap<String, NodeStatus>>>,
253    /// 全局变量
254    variables: Arc<RwLock<HashMap<String, WorkflowValue>>>,
255    /// 自定义数据存储
256    custom_data: Arc<RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>>,
257    /// 检查点数据
258    checkpoints: Arc<RwLock<Vec<CheckpointData>>>,
259}
260
261impl WorkflowContext {
262    pub fn new(workflow_id: &str) -> Self {
263        Self {
264            workflow_id: workflow_id.to_string(),
265            execution_id: uuid::Uuid::now_v7().to_string(),
266            input: Arc::new(RwLock::new(WorkflowValue::Null)),
267            node_outputs: Arc::new(RwLock::new(HashMap::new())),
268            node_statuses: Arc::new(RwLock::new(HashMap::new())),
269            variables: Arc::new(RwLock::new(HashMap::new())),
270            custom_data: Arc::new(RwLock::new(HashMap::new())),
271            checkpoints: Arc::new(RwLock::new(Vec::new())),
272        }
273    }
274
275    /// 设置工作流输入
276    pub async fn set_input(&self, input: WorkflowValue) {
277        let mut i = self.input.write().await;
278        *i = input;
279    }
280
281    /// 获取工作流输入
282    pub async fn get_input(&self) -> WorkflowValue {
283        self.input.read().await.clone()
284    }
285
286    /// 设置节点输出
287    pub async fn set_node_output(&self, node_id: &str, output: WorkflowValue) {
288        let mut outputs = self.node_outputs.write().await;
289        outputs.insert(node_id.to_string(), output);
290    }
291
292    /// 获取节点输出
293    pub async fn get_node_output(&self, node_id: &str) -> Option<WorkflowValue> {
294        let outputs = self.node_outputs.read().await;
295        outputs.get(node_id).cloned()
296    }
297
298    /// 获取多个节点的输出
299    pub async fn get_node_outputs(&self, node_ids: &[&str]) -> HashMap<String, WorkflowValue> {
300        let outputs = self.node_outputs.read().await;
301        node_ids
302            .iter()
303            .filter_map(|id| outputs.get(*id).map(|v| (id.to_string(), v.clone())))
304            .collect()
305    }
306
307    /// 设置节点状态
308    pub async fn set_node_status(&self, node_id: &str, status: NodeStatus) {
309        let mut statuses = self.node_statuses.write().await;
310        statuses.insert(node_id.to_string(), status);
311    }
312
313    /// 获取节点状态
314    pub async fn get_node_status(&self, node_id: &str) -> Option<NodeStatus> {
315        let statuses = self.node_statuses.read().await;
316        statuses.get(node_id).cloned()
317    }
318
319    /// 获取所有节点状态
320    pub async fn get_all_node_statuses(&self) -> HashMap<String, NodeStatus> {
321        self.node_statuses.read().await.clone()
322    }
323
324    /// 设置变量
325    pub async fn set_variable(&self, name: &str, value: WorkflowValue) {
326        let mut vars = self.variables.write().await;
327        vars.insert(name.to_string(), value);
328    }
329
330    /// 获取变量
331    pub async fn get_variable(&self, name: &str) -> Option<WorkflowValue> {
332        let vars = self.variables.read().await;
333        vars.get(name).cloned()
334    }
335
336    /// 设置自定义数据
337    pub async fn set_custom<T: Send + Sync + 'static>(&self, key: &str, value: T) {
338        let mut data = self.custom_data.write().await;
339        data.insert(key.to_string(), Box::new(value));
340    }
341
342    /// 获取自定义数据
343    pub async fn get_custom<T: Clone + Send + Sync + 'static>(&self, key: &str) -> Option<T> {
344        let data = self.custom_data.read().await;
345        data.get(key).and_then(|v| v.downcast_ref::<T>().cloned())
346    }
347
348    /// 创建检查点
349    pub async fn create_checkpoint(&self, label: &str) {
350        let checkpoint = CheckpointData {
351            label: label.to_string(),
352            timestamp: std::time::SystemTime::now()
353                .duration_since(std::time::UNIX_EPOCH)
354                .unwrap_or_default()
355                .as_millis() as u64,
356            node_outputs: self.node_outputs.read().await.clone(),
357            node_statuses: self.node_statuses.read().await.clone(),
358            variables: self.variables.read().await.clone(),
359        };
360        let mut checkpoints = self.checkpoints.write().await;
361        checkpoints.push(checkpoint);
362    }
363
364    /// 恢复到检查点
365    pub async fn restore_checkpoint(&self, label: &str) -> bool {
366        let checkpoints = self.checkpoints.read().await;
367        let checkpoint = checkpoints.iter().rev().find(|c| c.label == label).cloned();
368        drop(checkpoints);
369
370        if let Some(checkpoint) = checkpoint {
371            let mut outputs = self.node_outputs.write().await;
372            *outputs = checkpoint.node_outputs.clone();
373            drop(outputs);
374
375            let mut statuses = self.node_statuses.write().await;
376            *statuses = checkpoint.node_statuses.clone();
377            drop(statuses);
378
379            let mut vars = self.variables.write().await;
380            *vars = checkpoint.variables.clone();
381
382            true
383        } else {
384            false
385        }
386    }
387
388    /// 获取所有检查点标签
389    pub async fn list_checkpoints(&self) -> Vec<String> {
390        let checkpoints = self.checkpoints.read().await;
391        checkpoints.iter().map(|c| c.label.clone()).collect()
392    }
393}
394
395impl Clone for WorkflowContext {
396    fn clone(&self) -> Self {
397        Self {
398            workflow_id: self.workflow_id.clone(),
399            execution_id: self.execution_id.clone(),
400            input: self.input.clone(),
401            node_outputs: self.node_outputs.clone(),
402            node_statuses: self.node_statuses.clone(),
403            variables: self.variables.clone(),
404            custom_data: self.custom_data.clone(),
405            checkpoints: self.checkpoints.clone(),
406        }
407    }
408}
409
410/// 检查点数据
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct CheckpointData {
413    /// 检查点标签
414    pub label: String,
415    /// 创建时间戳
416    pub timestamp: u64,
417    /// 节点输出快照
418    pub node_outputs: HashMap<String, WorkflowValue>,
419    /// 节点状态快照
420    pub node_statuses: HashMap<String, NodeStatus>,
421    /// 变量快照
422    pub variables: HashMap<String, WorkflowValue>,
423}
424
425/// 工作流执行历史记录
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct ExecutionRecord {
428    /// 执行 ID
429    pub execution_id: String,
430    /// 工作流 ID
431    pub workflow_id: String,
432    /// 开始时间
433    pub started_at: u64,
434    /// 结束时间
435    pub ended_at: Option<u64>,
436    /// 最终状态
437    pub status: WorkflowStatus,
438    /// 节点执行记录
439    pub node_records: Vec<NodeExecutionRecord>,
440}
441
442/// 节点执行记录
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct NodeExecutionRecord {
445    /// 节点 ID
446    pub node_id: String,
447    /// 开始时间
448    pub started_at: u64,
449    /// 结束时间
450    pub ended_at: u64,
451    /// 执行状态
452    pub status: NodeStatus,
453    /// 重试次数
454    pub retry_count: u32,
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[tokio::test]
462    async fn test_workflow_context() {
463        let ctx = WorkflowContext::new("test_workflow");
464
465        // 测试输入
466        ctx.set_input(WorkflowValue::String("test input".to_string()))
467            .await;
468        let input = ctx.get_input().await;
469        assert_eq!(input.as_str(), Some("test input"));
470
471        // 测试节点输出
472        ctx.set_node_output("node1", WorkflowValue::Int(42)).await;
473        let output = ctx.get_node_output("node1").await;
474        assert_eq!(output.unwrap().as_i64(), Some(42));
475
476        // 测试变量
477        ctx.set_variable("counter", WorkflowValue::Int(0)).await;
478        let var = ctx.get_variable("counter").await;
479        assert_eq!(var.unwrap().as_i64(), Some(0));
480
481        // 测试检查点
482        ctx.create_checkpoint("before_loop").await;
483        ctx.set_variable("counter", WorkflowValue::Int(10)).await;
484        ctx.restore_checkpoint("before_loop").await;
485        let var = ctx.get_variable("counter").await;
486        assert_eq!(var.unwrap().as_i64(), Some(0));
487    }
488
489    #[test]
490    fn test_workflow_value_conversions() {
491        let v: WorkflowValue = 42i64.into();
492        assert_eq!(v.as_i64(), Some(42));
493
494        let v: WorkflowValue = "hello".into();
495        assert_eq!(v.as_str(), Some("hello"));
496
497        let v: WorkflowValue = true.into();
498        assert_eq!(v.as_bool(), Some(true));
499
500        let v: WorkflowValue = 3.14f64.into();
501        assert_eq!(v.as_f64(), Some(3.14));
502    }
503}