Skip to main content

mofa_foundation/workflow/
builder.rs

1//! 工作流构建器
2//!
3//! 提供流式 API 构建工作流
4
5use super::graph::{EdgeConfig, WorkflowGraph};
6use super::node::{RetryPolicy, WorkflowNode};
7use super::state::WorkflowValue;
8use crate::llm::LLMAgent;
9use std::collections::HashMap;
10use std::future::Future;
11use std::sync::Arc;
12
13/// 工作流构建器
14pub struct WorkflowBuilder {
15    graph: WorkflowGraph,
16    current_node: Option<String>,
17}
18
19impl WorkflowBuilder {
20    /// 创建新的工作流构建器
21    pub fn new(id: &str, name: &str) -> Self {
22        Self {
23            graph: WorkflowGraph::new(id, name),
24            current_node: None,
25        }
26    }
27
28    /// 设置描述
29    pub fn description(mut self, desc: &str) -> Self {
30        self.graph = self.graph.with_description(desc);
31        self
32    }
33
34    /// 添加开始节点
35    pub fn start(mut self) -> Self {
36        let node = WorkflowNode::start("start");
37        self.graph.add_node(node);
38        self.current_node = Some("start".to_string());
39        self
40    }
41
42    /// 添加开始节点(自定义 ID)
43    pub fn start_with_id(mut self, id: &str) -> Self {
44        let node = WorkflowNode::start(id);
45        self.graph.add_node(node);
46        self.current_node = Some(id.to_string());
47        self
48    }
49
50    /// 添加结束节点
51    pub fn end(mut self) -> Self {
52        let node = WorkflowNode::end("end");
53        self.graph.add_node(node);
54
55        // 连接当前节点到结束节点
56        if let Some(ref current) = self.current_node {
57            self.graph.connect(current, "end");
58        }
59
60        self.current_node = Some("end".to_string());
61        self
62    }
63
64    /// 添加结束节点(自定义 ID)
65    pub fn end_with_id(mut self, id: &str) -> Self {
66        let node = WorkflowNode::end(id);
67        self.graph.add_node(node);
68
69        if let Some(ref current) = self.current_node {
70            self.graph.connect(current, id);
71        }
72
73        self.current_node = Some(id.to_string());
74        self
75    }
76
77    /// 添加任务节点
78    pub fn task<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
79    where
80        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
81        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
82    {
83        let node = WorkflowNode::task(id, name, executor);
84        self.graph.add_node(node);
85
86        // 连接当前节点
87        if let Some(ref current) = self.current_node {
88            self.graph.connect(current, id);
89        }
90
91        self.current_node = Some(id.to_string());
92        self
93    }
94
95    /// 添加任务节点(带配置)
96    pub fn task_with_config<F, Fut>(
97        mut self,
98        id: &str,
99        name: &str,
100        executor: F,
101        retry: RetryPolicy,
102        timeout_ms: u64,
103    ) -> Self
104    where
105        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
106        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
107    {
108        let node = WorkflowNode::task(id, name, executor)
109            .with_retry(retry)
110            .with_timeout(timeout_ms);
111        self.graph.add_node(node);
112
113        if let Some(ref current) = self.current_node {
114            self.graph.connect(current, id);
115        }
116
117        self.current_node = Some(id.to_string());
118        self
119    }
120
121    /// 添加智能体节点
122    pub fn agent<F, Fut>(mut self, id: &str, name: &str, agent_fn: F) -> Self
123    where
124        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
125        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
126    {
127        let node = WorkflowNode::agent(id, name, agent_fn);
128        self.graph.add_node(node);
129
130        if let Some(ref current) = self.current_node {
131            self.graph.connect(current, id);
132        }
133
134        self.current_node = Some(id.to_string());
135        self
136    }
137
138    /// 添加 LLM 智能体节点(使用 LLMAgent)
139    ///
140    /// 允许在工作流中使用预配置的 LLMAgent。
141    ///
142    /// # 示例
143    ///
144    /// ```rust,ignore
145    /// let agent = LLMAgentBuilder::new()
146    ///     .with_id("my-agent")
147    ///     .with_provider(Arc::new(openai_from_env()?))
148    ///     .with_system_prompt("You are a helpful assistant.")
149    ///     .build()?;
150    ///
151    /// let workflow = WorkflowBuilder::new("test", "Test")
152    ///     .start()
153    ///     .llm_agent("agent1", "LLM Agent", Arc::new(agent))
154    ///     .end()
155    ///     .build();
156    /// ```
157    pub fn llm_agent(mut self, id: &str, name: &str, agent: Arc<LLMAgent>) -> Self {
158        let node = WorkflowNode::llm_agent(id, name, agent);
159        self.graph.add_node(node);
160
161        if let Some(ref current) = self.current_node {
162            self.graph.connect(current, id);
163        }
164
165        self.current_node = Some(id.to_string());
166        self
167    }
168
169    /// 添加 LLM 智能体节点(带 prompt 模板)
170    ///
171    /// 允许使用 Jinja-style 模板格式化输入。
172    ///
173    /// # 示例
174    ///
175    /// ```rust,ignore
176    /// let workflow = WorkflowBuilder::new("test", "Test")
177    ///     .start()
178    ///     .llm_agent_with_template(
179    ///         "agent1",
180    ///         "LLM Agent",
181    ///         Arc::new(agent),
182    ///         "Process this data: {{ input }}".to_string()
183    ///     )
184    ///     .end()
185    ///     .build();
186    /// ```
187    pub fn llm_agent_with_template(
188        mut self,
189        id: &str,
190        name: &str,
191        agent: Arc<LLMAgent>,
192        prompt_template: String,
193    ) -> Self {
194        let node = WorkflowNode::llm_agent_with_template(id, name, agent, prompt_template);
195        self.graph.add_node(node);
196
197        if let Some(ref current) = self.current_node {
198            self.graph.connect(current, id);
199        }
200
201        self.current_node = Some(id.to_string());
202        self
203    }
204
205    /// 添加条件节点
206    pub fn condition<F, Fut>(mut self, id: &str, name: &str, condition_fn: F) -> ConditionBuilder
207    where
208        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
209        Fut: Future<Output = bool> + Send + 'static,
210    {
211        let node = WorkflowNode::condition(id, name, condition_fn);
212        self.graph.add_node(node);
213
214        if let Some(ref current) = self.current_node {
215            self.graph.connect(current, id);
216        }
217
218        ConditionBuilder {
219            parent: self,
220            condition_node: id.to_string(),
221            true_branch: None,
222            false_branch: None,
223        }
224    }
225
226    /// 添加并行节点
227    pub fn parallel(mut self, id: &str, name: &str) -> ParallelBuilder {
228        let node = WorkflowNode::parallel(id, name, vec![]);
229        self.graph.add_node(node);
230
231        if let Some(ref current) = self.current_node {
232            self.graph.connect(current, id);
233        }
234
235        ParallelBuilder {
236            parent: self,
237            parallel_node: id.to_string(),
238            branches: Vec::new(),
239        }
240    }
241
242    /// 添加循环节点
243    pub fn loop_node<F, Fut, C, CFut>(
244        mut self,
245        id: &str,
246        name: &str,
247        body: F,
248        condition: C,
249        max_iterations: u32,
250    ) -> Self
251    where
252        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
253        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
254        C: Fn(super::state::WorkflowContext, WorkflowValue) -> CFut + Send + Sync + 'static,
255        CFut: Future<Output = bool> + Send + 'static,
256    {
257        let node = WorkflowNode::loop_node(id, name, body, condition, max_iterations);
258        self.graph.add_node(node);
259
260        if let Some(ref current) = self.current_node {
261            self.graph.connect(current, id);
262        }
263
264        self.current_node = Some(id.to_string());
265        self
266    }
267
268    /// 添加子工作流节点
269    pub fn sub_workflow(mut self, id: &str, name: &str, sub_workflow_id: &str) -> Self {
270        let node = WorkflowNode::sub_workflow(id, name, sub_workflow_id);
271        self.graph.add_node(node);
272
273        if let Some(ref current) = self.current_node {
274            self.graph.connect(current, id);
275        }
276
277        self.current_node = Some(id.to_string());
278        self
279    }
280
281    /// 添加等待节点
282    pub fn wait(mut self, id: &str, name: &str, event_type: &str) -> Self {
283        let node = WorkflowNode::wait(id, name, event_type);
284        self.graph.add_node(node);
285
286        if let Some(ref current) = self.current_node {
287            self.graph.connect(current, id);
288        }
289
290        self.current_node = Some(id.to_string());
291        self
292    }
293
294    /// 添加数据转换节点
295    pub fn transform<F, Fut>(mut self, id: &str, name: &str, transform_fn: F) -> Self
296    where
297        F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
298        Fut: Future<Output = WorkflowValue> + Send + 'static,
299    {
300        let node = WorkflowNode::transform(id, name, transform_fn);
301        self.graph.add_node(node);
302
303        if let Some(ref current) = self.current_node {
304            self.graph.connect(current, id);
305        }
306
307        self.current_node = Some(id.to_string());
308        self
309    }
310
311    /// 添加自定义节点
312    pub fn node(mut self, node: WorkflowNode) -> Self {
313        let node_id = node.id().to_string();
314        self.graph.add_node(node);
315
316        if let Some(ref current) = self.current_node {
317            self.graph.connect(current, &node_id);
318        }
319
320        self.current_node = Some(node_id);
321        self
322    }
323
324    /// 添加边(不改变当前节点)
325    pub fn edge(mut self, from: &str, to: &str) -> Self {
326        self.graph.connect(from, to);
327        self
328    }
329
330    /// 添加条件边
331    pub fn conditional_edge(mut self, from: &str, to: &str, condition: &str) -> Self {
332        self.graph.connect_conditional(from, to, condition);
333        self
334    }
335
336    /// 添加错误处理边
337    pub fn error_edge(mut self, from: &str, to: &str) -> Self {
338        self.graph.add_edge(EdgeConfig::error(from, to));
339        self
340    }
341
342    /// 跳转到指定节点(设置当前节点)
343    pub fn goto(mut self, node_id: &str) -> Self {
344        self.current_node = Some(node_id.to_string());
345        self
346    }
347
348    /// 从当前节点连接到指定节点
349    pub fn then(mut self, node_id: &str) -> Self {
350        if let Some(ref current) = self.current_node {
351            self.graph.connect(current, node_id);
352        }
353        self.current_node = Some(node_id.to_string());
354        self
355    }
356
357    /// 构建工作流图
358    pub fn build(self) -> WorkflowGraph {
359        self.graph
360    }
361
362    /// 验证并构建
363    pub fn build_validated(self) -> Result<WorkflowGraph, Vec<String>> {
364        self.graph.validate()?;
365        Ok(self.graph)
366    }
367}
368
369/// 条件构建器
370pub struct ConditionBuilder {
371    parent: WorkflowBuilder,
372    condition_node: String,
373    true_branch: Option<String>,
374    false_branch: Option<String>,
375}
376
377impl ConditionBuilder {
378    /// 设置为真时的分支
379    pub fn on_true<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
380    where
381        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
382        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
383    {
384        let node = WorkflowNode::task(id, name, executor);
385        self.parent.graph.add_node(node);
386        self.parent
387            .graph
388            .connect_conditional(&self.condition_node, id, "true");
389        self.true_branch = Some(id.to_string());
390        self
391    }
392
393    /// 设置为假时的分支
394    pub fn on_false<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
395    where
396        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
397        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
398    {
399        let node = WorkflowNode::task(id, name, executor);
400        self.parent.graph.add_node(node);
401        self.parent
402            .graph
403            .connect_conditional(&self.condition_node, id, "false");
404        self.false_branch = Some(id.to_string());
405        self
406    }
407
408    /// 汇聚两个分支
409    pub fn merge(mut self, id: &str, name: &str) -> WorkflowBuilder {
410        let node = WorkflowNode::join(
411            id,
412            name,
413            vec![
414                self.true_branch.as_deref().unwrap_or(""),
415                self.false_branch.as_deref().unwrap_or(""),
416            ]
417            .into_iter()
418            .filter(|s| !s.is_empty())
419            .collect(),
420        );
421        self.parent.graph.add_node(node);
422
423        if let Some(ref true_branch) = self.true_branch {
424            self.parent.graph.connect(true_branch, id);
425        }
426        if let Some(ref false_branch) = self.false_branch {
427            self.parent.graph.connect(false_branch, id);
428        }
429
430        self.parent.current_node = Some(id.to_string());
431        self.parent
432    }
433
434    /// 不汇聚,返回构建器
435    pub fn end_condition(mut self) -> WorkflowBuilder {
436        // 设置当前节点为最后添加的分支
437        self.parent.current_node = self.true_branch.or(self.false_branch);
438        self.parent
439    }
440}
441
442/// 并行构建器
443pub struct ParallelBuilder {
444    parent: WorkflowBuilder,
445    parallel_node: String,
446    branches: Vec<String>,
447}
448
449impl ParallelBuilder {
450    /// 添加分支任务
451    pub fn branch<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
452    where
453        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
454        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
455    {
456        let node = WorkflowNode::task(id, name, executor);
457        self.parent.graph.add_node(node);
458        self.parent.graph.connect(&self.parallel_node, id);
459        self.branches.push(id.to_string());
460        self
461    }
462
463    /// 添加分支智能体
464    pub fn branch_agent<F, Fut>(mut self, id: &str, name: &str, agent_fn: F) -> Self
465    where
466        F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
467        Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
468    {
469        let node = WorkflowNode::agent(id, name, agent_fn);
470        self.parent.graph.add_node(node);
471        self.parent.graph.connect(&self.parallel_node, id);
472        self.branches.push(id.to_string());
473        self
474    }
475
476    /// 添加 LLM 智能体分支
477    ///
478    /// 允许在并行执行中使用预配置的 LLMAgent。
479    ///
480    /// # 示例
481    ///
482    /// ```rust,ignore
483    /// let workflow = WorkflowBuilder::new("test", "Test")
484    ///     .start()
485    ///     .parallel("fork", "Fork")
486    ///     .llm_agent_branch("agent_a", "Agent A", Arc::new(agent_a))
487    ///     .llm_agent_branch("agent_b", "Agent B", Arc::new(agent_b))
488    ///     .join("join", "Join")
489    ///     .end()
490    ///     .build();
491    /// ```
492    pub fn llm_agent_branch(mut self, id: &str, name: &str, agent: Arc<LLMAgent>) -> Self {
493        let node = WorkflowNode::llm_agent(id, name, agent);
494        self.parent.graph.add_node(node);
495        self.parent.graph.connect(&self.parallel_node, id);
496        self.branches.push(id.to_string());
497        self
498    }
499
500    /// 汇聚所有分支
501    pub fn join(mut self, id: &str, name: &str) -> WorkflowBuilder {
502        let node = WorkflowNode::join(id, name, self.branches.iter().map(|s| s.as_str()).collect());
503        self.parent.graph.add_node(node);
504
505        for branch in &self.branches {
506            self.parent.graph.connect(branch, id);
507        }
508
509        self.parent.current_node = Some(id.to_string());
510        self.parent
511    }
512
513    /// 汇聚并转换
514    pub fn join_with_transform<F, Fut>(
515        mut self,
516        id: &str,
517        name: &str,
518        transform: F,
519    ) -> WorkflowBuilder
520    where
521        F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
522        Fut: Future<Output = WorkflowValue> + Send + 'static,
523    {
524        let node = WorkflowNode::join_with_transform(
525            id,
526            name,
527            self.branches.iter().map(|s| s.as_str()).collect(),
528            transform,
529        );
530        self.parent.graph.add_node(node);
531
532        for branch in &self.branches {
533            self.parent.graph.connect(branch, id);
534        }
535
536        self.parent.current_node = Some(id.to_string());
537        self.parent
538    }
539}
540
541/// 简化的工作流构建宏
542#[macro_export]
543macro_rules! workflow {
544    ($id:expr, $name:expr => {
545        $($body:tt)*
546    }) => {
547        WorkflowBuilder::new($id, $name)
548            $($body)*
549            .build()
550    };
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_workflow_builder() {
559        let graph = WorkflowBuilder::new("test", "Test Workflow")
560            .start()
561            .task("task1", "Task 1", |_ctx, input| async move { Ok(input) })
562            .task("task2", "Task 2", |_ctx, input| async move { Ok(input) })
563            .end()
564            .build();
565
566        assert_eq!(graph.node_count(), 4);
567        assert_eq!(graph.edge_count(), 3);
568    }
569
570    #[test]
571    fn test_condition_builder() {
572        let graph = WorkflowBuilder::new("test", "Conditional Workflow")
573            .start()
574            .condition("check", "Check", |_ctx, input| async move {
575                input.as_i64().unwrap_or(0) > 10
576            })
577            .on_true("high", "High", |_ctx, _input| async move {
578                Ok(WorkflowValue::String("high".to_string()))
579            })
580            .on_false("low", "Low", |_ctx, _input| async move {
581                Ok(WorkflowValue::String("low".to_string()))
582            })
583            .merge("merge", "Merge")
584            .end()
585            .build();
586
587        assert_eq!(graph.node_count(), 6);
588    }
589
590    #[test]
591    fn test_parallel_builder() {
592        let graph = WorkflowBuilder::new("test", "Parallel Workflow")
593            .start()
594            .parallel("fork", "Fork")
595            .branch("a", "Branch A", |_ctx, _input| async move {
596                Ok(WorkflowValue::String("a".to_string()))
597            })
598            .branch("b", "Branch B", |_ctx, _input| async move {
599                Ok(WorkflowValue::String("b".to_string()))
600            })
601            .branch("c", "Branch C", |_ctx, _input| async move {
602                Ok(WorkflowValue::String("c".to_string()))
603            })
604            .join("join", "Join")
605            .end()
606            .build();
607
608        assert_eq!(graph.node_count(), 7);
609    }
610}