1use super::graph::{EdgeConfig, WorkflowGraph};
6use super::node::{RetryPolicy, WorkflowNode};
7use super::state::WorkflowValue;
8use crate::llm::LLMAgent;
9use std::collections::HashMap;
10use std::future::Future;
11use std::sync::Arc;
12
13pub struct WorkflowBuilder {
15 graph: WorkflowGraph,
16 current_node: Option<String>,
17}
18
19impl WorkflowBuilder {
20 pub fn new(id: &str, name: &str) -> Self {
22 Self {
23 graph: WorkflowGraph::new(id, name),
24 current_node: None,
25 }
26 }
27
28 pub fn description(mut self, desc: &str) -> Self {
30 self.graph = self.graph.with_description(desc);
31 self
32 }
33
34 pub fn start(mut self) -> Self {
36 let node = WorkflowNode::start("start");
37 self.graph.add_node(node);
38 self.current_node = Some("start".to_string());
39 self
40 }
41
42 pub fn start_with_id(mut self, id: &str) -> Self {
44 let node = WorkflowNode::start(id);
45 self.graph.add_node(node);
46 self.current_node = Some(id.to_string());
47 self
48 }
49
50 pub fn end(mut self) -> Self {
52 let node = WorkflowNode::end("end");
53 self.graph.add_node(node);
54
55 if let Some(ref current) = self.current_node {
57 self.graph.connect(current, "end");
58 }
59
60 self.current_node = Some("end".to_string());
61 self
62 }
63
64 pub fn end_with_id(mut self, id: &str) -> Self {
66 let node = WorkflowNode::end(id);
67 self.graph.add_node(node);
68
69 if let Some(ref current) = self.current_node {
70 self.graph.connect(current, id);
71 }
72
73 self.current_node = Some(id.to_string());
74 self
75 }
76
77 pub fn task<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
79 where
80 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
81 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
82 {
83 let node = WorkflowNode::task(id, name, executor);
84 self.graph.add_node(node);
85
86 if let Some(ref current) = self.current_node {
88 self.graph.connect(current, id);
89 }
90
91 self.current_node = Some(id.to_string());
92 self
93 }
94
95 pub fn task_with_config<F, Fut>(
97 mut self,
98 id: &str,
99 name: &str,
100 executor: F,
101 retry: RetryPolicy,
102 timeout_ms: u64,
103 ) -> Self
104 where
105 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
106 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
107 {
108 let node = WorkflowNode::task(id, name, executor)
109 .with_retry(retry)
110 .with_timeout(timeout_ms);
111 self.graph.add_node(node);
112
113 if let Some(ref current) = self.current_node {
114 self.graph.connect(current, id);
115 }
116
117 self.current_node = Some(id.to_string());
118 self
119 }
120
121 pub fn agent<F, Fut>(mut self, id: &str, name: &str, agent_fn: F) -> Self
123 where
124 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
125 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
126 {
127 let node = WorkflowNode::agent(id, name, agent_fn);
128 self.graph.add_node(node);
129
130 if let Some(ref current) = self.current_node {
131 self.graph.connect(current, id);
132 }
133
134 self.current_node = Some(id.to_string());
135 self
136 }
137
138 pub fn llm_agent(mut self, id: &str, name: &str, agent: Arc<LLMAgent>) -> Self {
158 let node = WorkflowNode::llm_agent(id, name, agent);
159 self.graph.add_node(node);
160
161 if let Some(ref current) = self.current_node {
162 self.graph.connect(current, id);
163 }
164
165 self.current_node = Some(id.to_string());
166 self
167 }
168
169 pub fn llm_agent_with_template(
188 mut self,
189 id: &str,
190 name: &str,
191 agent: Arc<LLMAgent>,
192 prompt_template: String,
193 ) -> Self {
194 let node = WorkflowNode::llm_agent_with_template(id, name, agent, prompt_template);
195 self.graph.add_node(node);
196
197 if let Some(ref current) = self.current_node {
198 self.graph.connect(current, id);
199 }
200
201 self.current_node = Some(id.to_string());
202 self
203 }
204
205 pub fn condition<F, Fut>(mut self, id: &str, name: &str, condition_fn: F) -> ConditionBuilder
207 where
208 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
209 Fut: Future<Output = bool> + Send + 'static,
210 {
211 let node = WorkflowNode::condition(id, name, condition_fn);
212 self.graph.add_node(node);
213
214 if let Some(ref current) = self.current_node {
215 self.graph.connect(current, id);
216 }
217
218 ConditionBuilder {
219 parent: self,
220 condition_node: id.to_string(),
221 true_branch: None,
222 false_branch: None,
223 }
224 }
225
226 pub fn parallel(mut self, id: &str, name: &str) -> ParallelBuilder {
228 let node = WorkflowNode::parallel(id, name, vec![]);
229 self.graph.add_node(node);
230
231 if let Some(ref current) = self.current_node {
232 self.graph.connect(current, id);
233 }
234
235 ParallelBuilder {
236 parent: self,
237 parallel_node: id.to_string(),
238 branches: Vec::new(),
239 }
240 }
241
242 pub fn loop_node<F, Fut, C, CFut>(
244 mut self,
245 id: &str,
246 name: &str,
247 body: F,
248 condition: C,
249 max_iterations: u32,
250 ) -> Self
251 where
252 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
253 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
254 C: Fn(super::state::WorkflowContext, WorkflowValue) -> CFut + Send + Sync + 'static,
255 CFut: Future<Output = bool> + Send + 'static,
256 {
257 let node = WorkflowNode::loop_node(id, name, body, condition, max_iterations);
258 self.graph.add_node(node);
259
260 if let Some(ref current) = self.current_node {
261 self.graph.connect(current, id);
262 }
263
264 self.current_node = Some(id.to_string());
265 self
266 }
267
268 pub fn sub_workflow(mut self, id: &str, name: &str, sub_workflow_id: &str) -> Self {
270 let node = WorkflowNode::sub_workflow(id, name, sub_workflow_id);
271 self.graph.add_node(node);
272
273 if let Some(ref current) = self.current_node {
274 self.graph.connect(current, id);
275 }
276
277 self.current_node = Some(id.to_string());
278 self
279 }
280
281 pub fn wait(mut self, id: &str, name: &str, event_type: &str) -> Self {
283 let node = WorkflowNode::wait(id, name, event_type);
284 self.graph.add_node(node);
285
286 if let Some(ref current) = self.current_node {
287 self.graph.connect(current, id);
288 }
289
290 self.current_node = Some(id.to_string());
291 self
292 }
293
294 pub fn transform<F, Fut>(mut self, id: &str, name: &str, transform_fn: F) -> Self
296 where
297 F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
298 Fut: Future<Output = WorkflowValue> + Send + 'static,
299 {
300 let node = WorkflowNode::transform(id, name, transform_fn);
301 self.graph.add_node(node);
302
303 if let Some(ref current) = self.current_node {
304 self.graph.connect(current, id);
305 }
306
307 self.current_node = Some(id.to_string());
308 self
309 }
310
311 pub fn node(mut self, node: WorkflowNode) -> Self {
313 let node_id = node.id().to_string();
314 self.graph.add_node(node);
315
316 if let Some(ref current) = self.current_node {
317 self.graph.connect(current, &node_id);
318 }
319
320 self.current_node = Some(node_id);
321 self
322 }
323
324 pub fn edge(mut self, from: &str, to: &str) -> Self {
326 self.graph.connect(from, to);
327 self
328 }
329
330 pub fn conditional_edge(mut self, from: &str, to: &str, condition: &str) -> Self {
332 self.graph.connect_conditional(from, to, condition);
333 self
334 }
335
336 pub fn error_edge(mut self, from: &str, to: &str) -> Self {
338 self.graph.add_edge(EdgeConfig::error(from, to));
339 self
340 }
341
342 pub fn goto(mut self, node_id: &str) -> Self {
344 self.current_node = Some(node_id.to_string());
345 self
346 }
347
348 pub fn then(mut self, node_id: &str) -> Self {
350 if let Some(ref current) = self.current_node {
351 self.graph.connect(current, node_id);
352 }
353 self.current_node = Some(node_id.to_string());
354 self
355 }
356
357 pub fn build(self) -> WorkflowGraph {
359 self.graph
360 }
361
362 pub fn build_validated(self) -> Result<WorkflowGraph, Vec<String>> {
364 self.graph.validate()?;
365 Ok(self.graph)
366 }
367}
368
369pub struct ConditionBuilder {
371 parent: WorkflowBuilder,
372 condition_node: String,
373 true_branch: Option<String>,
374 false_branch: Option<String>,
375}
376
377impl ConditionBuilder {
378 pub fn on_true<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
380 where
381 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
382 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
383 {
384 let node = WorkflowNode::task(id, name, executor);
385 self.parent.graph.add_node(node);
386 self.parent
387 .graph
388 .connect_conditional(&self.condition_node, id, "true");
389 self.true_branch = Some(id.to_string());
390 self
391 }
392
393 pub fn on_false<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
395 where
396 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
397 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
398 {
399 let node = WorkflowNode::task(id, name, executor);
400 self.parent.graph.add_node(node);
401 self.parent
402 .graph
403 .connect_conditional(&self.condition_node, id, "false");
404 self.false_branch = Some(id.to_string());
405 self
406 }
407
408 pub fn merge(mut self, id: &str, name: &str) -> WorkflowBuilder {
410 let node = WorkflowNode::join(
411 id,
412 name,
413 vec![
414 self.true_branch.as_deref().unwrap_or(""),
415 self.false_branch.as_deref().unwrap_or(""),
416 ]
417 .into_iter()
418 .filter(|s| !s.is_empty())
419 .collect(),
420 );
421 self.parent.graph.add_node(node);
422
423 if let Some(ref true_branch) = self.true_branch {
424 self.parent.graph.connect(true_branch, id);
425 }
426 if let Some(ref false_branch) = self.false_branch {
427 self.parent.graph.connect(false_branch, id);
428 }
429
430 self.parent.current_node = Some(id.to_string());
431 self.parent
432 }
433
434 pub fn end_condition(mut self) -> WorkflowBuilder {
436 self.parent.current_node = self.true_branch.or(self.false_branch);
438 self.parent
439 }
440}
441
442pub struct ParallelBuilder {
444 parent: WorkflowBuilder,
445 parallel_node: String,
446 branches: Vec<String>,
447}
448
449impl ParallelBuilder {
450 pub fn branch<F, Fut>(mut self, id: &str, name: &str, executor: F) -> Self
452 where
453 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
454 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
455 {
456 let node = WorkflowNode::task(id, name, executor);
457 self.parent.graph.add_node(node);
458 self.parent.graph.connect(&self.parallel_node, id);
459 self.branches.push(id.to_string());
460 self
461 }
462
463 pub fn branch_agent<F, Fut>(mut self, id: &str, name: &str, agent_fn: F) -> Self
465 where
466 F: Fn(super::state::WorkflowContext, WorkflowValue) -> Fut + Send + Sync + 'static,
467 Fut: Future<Output = Result<WorkflowValue, String>> + Send + 'static,
468 {
469 let node = WorkflowNode::agent(id, name, agent_fn);
470 self.parent.graph.add_node(node);
471 self.parent.graph.connect(&self.parallel_node, id);
472 self.branches.push(id.to_string());
473 self
474 }
475
476 pub fn llm_agent_branch(mut self, id: &str, name: &str, agent: Arc<LLMAgent>) -> Self {
493 let node = WorkflowNode::llm_agent(id, name, agent);
494 self.parent.graph.add_node(node);
495 self.parent.graph.connect(&self.parallel_node, id);
496 self.branches.push(id.to_string());
497 self
498 }
499
500 pub fn join(mut self, id: &str, name: &str) -> WorkflowBuilder {
502 let node = WorkflowNode::join(id, name, self.branches.iter().map(|s| s.as_str()).collect());
503 self.parent.graph.add_node(node);
504
505 for branch in &self.branches {
506 self.parent.graph.connect(branch, id);
507 }
508
509 self.parent.current_node = Some(id.to_string());
510 self.parent
511 }
512
513 pub fn join_with_transform<F, Fut>(
515 mut self,
516 id: &str,
517 name: &str,
518 transform: F,
519 ) -> WorkflowBuilder
520 where
521 F: Fn(HashMap<String, WorkflowValue>) -> Fut + Send + Sync + 'static,
522 Fut: Future<Output = WorkflowValue> + Send + 'static,
523 {
524 let node = WorkflowNode::join_with_transform(
525 id,
526 name,
527 self.branches.iter().map(|s| s.as_str()).collect(),
528 transform,
529 );
530 self.parent.graph.add_node(node);
531
532 for branch in &self.branches {
533 self.parent.graph.connect(branch, id);
534 }
535
536 self.parent.current_node = Some(id.to_string());
537 self.parent
538 }
539}
540
541#[macro_export]
543macro_rules! workflow {
544 ($id:expr, $name:expr => {
545 $($body:tt)*
546 }) => {
547 WorkflowBuilder::new($id, $name)
548 $($body)*
549 .build()
550 };
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_workflow_builder() {
559 let graph = WorkflowBuilder::new("test", "Test Workflow")
560 .start()
561 .task("task1", "Task 1", |_ctx, input| async move { Ok(input) })
562 .task("task2", "Task 2", |_ctx, input| async move { Ok(input) })
563 .end()
564 .build();
565
566 assert_eq!(graph.node_count(), 4);
567 assert_eq!(graph.edge_count(), 3);
568 }
569
570 #[test]
571 fn test_condition_builder() {
572 let graph = WorkflowBuilder::new("test", "Conditional Workflow")
573 .start()
574 .condition("check", "Check", |_ctx, input| async move {
575 input.as_i64().unwrap_or(0) > 10
576 })
577 .on_true("high", "High", |_ctx, _input| async move {
578 Ok(WorkflowValue::String("high".to_string()))
579 })
580 .on_false("low", "Low", |_ctx, _input| async move {
581 Ok(WorkflowValue::String("low".to_string()))
582 })
583 .merge("merge", "Merge")
584 .end()
585 .build();
586
587 assert_eq!(graph.node_count(), 6);
588 }
589
590 #[test]
591 fn test_parallel_builder() {
592 let graph = WorkflowBuilder::new("test", "Parallel Workflow")
593 .start()
594 .parallel("fork", "Fork")
595 .branch("a", "Branch A", |_ctx, _input| async move {
596 Ok(WorkflowValue::String("a".to_string()))
597 })
598 .branch("b", "Branch B", |_ctx, _input| async move {
599 Ok(WorkflowValue::String("b".to_string()))
600 })
601 .branch("c", "Branch C", |_ctx, _input| async move {
602 Ok(WorkflowValue::String("c".to_string()))
603 })
604 .join("join", "Join")
605 .end()
606 .build();
607
608 assert_eq!(graph.node_count(), 7);
609 }
610}