1use crate::{
7 ApprovalConfig, Condition, Edge, FormConfig, LlmConfig, LoopConfig, McpConfig, Node, NodeId,
8 NodeKind, ParallelConfig, RetryConfig, ScriptConfig, SubWorkflowConfig, SwitchConfig,
9 TimeoutConfig, TryCatchConfig, VectorConfig, Workflow,
10};
11
12pub struct WorkflowBuilder {
14 workflow: Workflow,
15 last_node_id: Option<NodeId>,
16}
17
18impl WorkflowBuilder {
19 pub fn new(name: impl Into<String>) -> Self {
21 Self {
22 workflow: Workflow::new(name.into()),
23 last_node_id: None,
24 }
25 }
26
27 pub fn description(mut self, description: impl Into<String>) -> Self {
29 self.workflow.metadata.description = Some(description.into());
30 self
31 }
32
33 pub fn version(mut self, version: impl Into<String>) -> Self {
35 self.workflow.metadata.version = version.into();
36 self
37 }
38
39 pub fn tag(mut self, tag: impl Into<String>) -> Self {
41 self.workflow.metadata.tags.push(tag.into());
42 self
43 }
44
45 pub fn tags(mut self, tags: Vec<String>) -> Self {
47 self.workflow.metadata.tags.extend(tags);
48 self
49 }
50
51 pub fn start(mut self, name: impl Into<String>) -> Self {
53 let node = Node::new(name.into(), NodeKind::Start);
54 self.last_node_id = Some(node.id);
55 self.workflow.add_node(node);
56 self
57 }
58
59 pub fn end(mut self, name: impl Into<String>) -> Self {
61 let node = Node::new(name.into(), NodeKind::End);
62 let node_id = node.id;
63 self.workflow.add_node(node);
64
65 if let Some(from_id) = self.last_node_id {
67 self.workflow.add_edge(Edge::new(from_id, node_id));
68 }
69
70 self.last_node_id = Some(node_id);
71 self
72 }
73
74 pub fn llm(mut self, name: impl Into<String>, config: LlmConfig) -> Self {
76 let node = Node::new(name.into(), NodeKind::LLM(config));
77 let node_id = node.id;
78 self.workflow.add_node(node);
79
80 if let Some(from_id) = self.last_node_id {
82 self.workflow.add_edge(Edge::new(from_id, node_id));
83 }
84
85 self.last_node_id = Some(node_id);
86 self
87 }
88
89 pub fn code(mut self, name: impl Into<String>, config: ScriptConfig) -> Self {
91 let node = Node::new(name.into(), NodeKind::Code(config));
92 let node_id = node.id;
93 self.workflow.add_node(node);
94
95 if let Some(from_id) = self.last_node_id {
97 self.workflow.add_edge(Edge::new(from_id, node_id));
98 }
99
100 self.last_node_id = Some(node_id);
101 self
102 }
103
104 pub fn retriever(mut self, name: impl Into<String>, config: VectorConfig) -> Self {
106 let node = Node::new(name.into(), NodeKind::Retriever(config));
107 let node_id = node.id;
108 self.workflow.add_node(node);
109
110 if let Some(from_id) = self.last_node_id {
112 self.workflow.add_edge(Edge::new(from_id, node_id));
113 }
114
115 self.last_node_id = Some(node_id);
116 self
117 }
118
119 pub fn if_else(mut self, name: impl Into<String>, condition: Condition) -> Self {
121 let node = Node::new(name.into(), NodeKind::IfElse(condition));
122 let node_id = node.id;
123 self.workflow.add_node(node);
124
125 if let Some(from_id) = self.last_node_id {
127 self.workflow.add_edge(Edge::new(from_id, node_id));
128 }
129
130 self.last_node_id = Some(node_id);
131 self
132 }
133
134 pub fn tool(mut self, name: impl Into<String>, config: McpConfig) -> Self {
136 let node = Node::new(name.into(), NodeKind::Tool(config));
137 let node_id = node.id;
138 self.workflow.add_node(node);
139
140 if let Some(from_id) = self.last_node_id {
142 self.workflow.add_edge(Edge::new(from_id, node_id));
143 }
144
145 self.last_node_id = Some(node_id);
146 self
147 }
148
149 pub fn loop_node(mut self, name: impl Into<String>, config: LoopConfig) -> Self {
151 let node = Node::new(name.into(), NodeKind::Loop(config));
152 let node_id = node.id;
153 self.workflow.add_node(node);
154
155 if let Some(from_id) = self.last_node_id {
157 self.workflow.add_edge(Edge::new(from_id, node_id));
158 }
159
160 self.last_node_id = Some(node_id);
161 self
162 }
163
164 pub fn try_catch(mut self, name: impl Into<String>, config: TryCatchConfig) -> Self {
166 let node = Node::new(name.into(), NodeKind::TryCatch(config));
167 let node_id = node.id;
168 self.workflow.add_node(node);
169
170 if let Some(from_id) = self.last_node_id {
172 self.workflow.add_edge(Edge::new(from_id, node_id));
173 }
174
175 self.last_node_id = Some(node_id);
176 self
177 }
178
179 pub fn sub_workflow(mut self, name: impl Into<String>, config: SubWorkflowConfig) -> Self {
181 let node = Node::new(name.into(), NodeKind::SubWorkflow(config));
182 let node_id = node.id;
183 self.workflow.add_node(node);
184
185 if let Some(from_id) = self.last_node_id {
187 self.workflow.add_edge(Edge::new(from_id, node_id));
188 }
189
190 self.last_node_id = Some(node_id);
191 self
192 }
193
194 pub fn switch(mut self, name: impl Into<String>, config: SwitchConfig) -> Self {
196 let node = Node::new(name.into(), NodeKind::Switch(config));
197 let node_id = node.id;
198 self.workflow.add_node(node);
199
200 if let Some(from_id) = self.last_node_id {
202 self.workflow.add_edge(Edge::new(from_id, node_id));
203 }
204
205 self.last_node_id = Some(node_id);
206 self
207 }
208
209 pub fn parallel(mut self, name: impl Into<String>, config: ParallelConfig) -> Self {
211 let node = Node::new(name.into(), NodeKind::Parallel(config));
212 let node_id = node.id;
213 self.workflow.add_node(node);
214
215 if let Some(from_id) = self.last_node_id {
217 self.workflow.add_edge(Edge::new(from_id, node_id));
218 }
219
220 self.last_node_id = Some(node_id);
221 self
222 }
223
224 pub fn approval(mut self, name: impl Into<String>, config: ApprovalConfig) -> Self {
226 let node = Node::new(name.into(), NodeKind::Approval(config));
227 let node_id = node.id;
228 self.workflow.add_node(node);
229
230 if let Some(from_id) = self.last_node_id {
232 self.workflow.add_edge(Edge::new(from_id, node_id));
233 }
234
235 self.last_node_id = Some(node_id);
236 self
237 }
238
239 pub fn form(mut self, name: impl Into<String>, config: FormConfig) -> Self {
241 let node = Node::new(name.into(), NodeKind::Form(config));
242 let node_id = node.id;
243 self.workflow.add_node(node);
244
245 if let Some(from_id) = self.last_node_id {
247 self.workflow.add_edge(Edge::new(from_id, node_id));
248 }
249
250 self.last_node_id = Some(node_id);
251 self
252 }
253
254 pub fn node(mut self, node: Node) -> Self {
256 let node_id = node.id;
257 self.workflow.add_node(node);
258
259 if let Some(from_id) = self.last_node_id {
261 self.workflow.add_edge(Edge::new(from_id, node_id));
262 }
263
264 self.last_node_id = Some(node_id);
265 self
266 }
267
268 pub fn connect(mut self, from_index: usize, to_index: usize) -> Self {
270 if from_index < self.workflow.nodes.len() && to_index < self.workflow.nodes.len() {
271 let from_id = self.workflow.nodes[from_index].id;
272 let to_id = self.workflow.nodes[to_index].id;
273 self.workflow.add_edge(Edge::new(from_id, to_id));
274 }
275 self
276 }
277
278 pub fn connect_ids(mut self, from_id: NodeId, to_id: NodeId) -> Self {
280 self.workflow.add_edge(Edge::new(from_id, to_id));
281 self
282 }
283
284 pub fn last_node_id(&self) -> Option<NodeId> {
286 self.last_node_id
287 }
288
289 pub fn node_id_at(&self, index: usize) -> Option<NodeId> {
291 self.workflow.nodes.get(index).map(|n| n.id)
292 }
293
294 pub fn build(self) -> Workflow {
296 self.workflow
297 }
298}
299
300pub struct NodeBuilder {
302 node: Node,
303}
304
305impl NodeBuilder {
306 pub fn new(name: impl Into<String>, kind: NodeKind) -> Self {
308 Self {
309 node: Node::new(name.into(), kind),
310 }
311 }
312
313 pub fn retry(mut self, config: RetryConfig) -> Self {
315 self.node.retry_config = Some(config);
316 self
317 }
318
319 pub fn timeout(mut self, config: TimeoutConfig) -> Self {
321 self.node.timeout_config = Some(config);
322 self
323 }
324
325 pub fn position(mut self, x: f64, y: f64) -> Self {
327 self.node.position = Some((x, y));
328 self
329 }
330
331 pub fn build(self) -> Node {
333 self.node
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_workflow_builder_basic() {
343 let workflow = WorkflowBuilder::new("Test Workflow")
344 .description("A test workflow")
345 .version("1.0.0")
346 .tag("test")
347 .start("Start")
348 .end("End")
349 .build();
350
351 assert_eq!(workflow.metadata.name, "Test Workflow");
352 assert_eq!(
353 workflow.metadata.description,
354 Some("A test workflow".to_string())
355 );
356 assert_eq!(workflow.metadata.version, "1.0.0");
357 assert_eq!(workflow.metadata.tags, vec!["test"]);
358 assert_eq!(workflow.nodes.len(), 2);
359 assert_eq!(workflow.edges.len(), 1);
360 }
361
362 #[test]
363 fn test_workflow_builder_with_llm() {
364 let llm_config = LlmConfig {
365 provider: "openai".to_string(),
366 model: "gpt-4".to_string(),
367 system_prompt: None,
368 prompt_template: "Hello {{input}}".to_string(),
369 temperature: Some(0.7),
370 max_tokens: Some(100),
371 tools: vec![],
372 images: vec![],
373 extra_params: serde_json::json!({}),
374 };
375
376 let workflow = WorkflowBuilder::new("LLM Workflow")
377 .start("Start")
378 .llm("Generate", llm_config)
379 .end("End")
380 .build();
381
382 assert_eq!(workflow.nodes.len(), 3);
383 assert_eq!(workflow.edges.len(), 2);
384
385 let llm_node = &workflow.nodes[1];
387 assert_eq!(llm_node.name, "Generate");
388 assert!(matches!(llm_node.kind, NodeKind::LLM(_)));
389 }
390
391 #[test]
392 fn test_workflow_builder_with_code() {
393 let script_config = ScriptConfig {
394 runtime: "rust".to_string(),
395 code: "println!(\"Hello\");".to_string(),
396 inputs: vec![],
397 output: "result".to_string(),
398 };
399
400 let workflow = WorkflowBuilder::new("Code Workflow")
401 .start("Start")
402 .code("Execute", script_config)
403 .end("End")
404 .build();
405
406 assert_eq!(workflow.nodes.len(), 3);
407 assert_eq!(workflow.edges.len(), 2);
408 }
409
410 #[test]
411 fn test_workflow_builder_custom_connections() {
412 let workflow = WorkflowBuilder::new("Custom Connections")
413 .start("Start")
414 .end("End")
415 .connect(0, 1) .build();
417
418 assert_eq!(workflow.edges.len(), 2); }
420
421 #[test]
422 fn test_node_builder() {
423 let retry_config = RetryConfig {
424 max_retries: 3,
425 initial_delay_ms: 1000,
426 backoff_multiplier: 2.0,
427 max_delay_ms: 30000,
428 };
429
430 let timeout_config = TimeoutConfig {
431 execution_timeout_ms: 60000,
432 idle_timeout_ms: None,
433 timeout_action: crate::TimeoutAction::Fail,
434 };
435
436 let node = NodeBuilder::new("Test Node", NodeKind::Start)
437 .retry(retry_config)
438 .timeout(timeout_config)
439 .position(100.0, 200.0)
440 .build();
441
442 assert_eq!(node.name, "Test Node");
443 assert!(node.retry_config.is_some());
444 assert!(node.timeout_config.is_some());
445 assert_eq!(node.position, Some((100.0, 200.0)));
446 }
447
448 #[test]
449 fn test_workflow_builder_multiple_tags() {
450 let workflow = WorkflowBuilder::new("Tagged Workflow")
451 .tags(vec!["tag1".to_string(), "tag2".to_string()])
452 .tag("tag3")
453 .build();
454
455 assert_eq!(workflow.metadata.tags.len(), 3);
456 assert!(workflow.metadata.tags.contains(&"tag1".to_string()));
457 assert!(workflow.metadata.tags.contains(&"tag2".to_string()));
458 assert!(workflow.metadata.tags.contains(&"tag3".to_string()));
459 }
460
461 #[test]
462 fn test_workflow_builder_get_node_ids() {
463 let builder = WorkflowBuilder::new("Test").start("Start").end("End");
464
465 assert!(builder.last_node_id().is_some());
466 assert!(builder.node_id_at(0).is_some());
467 assert!(builder.node_id_at(1).is_some());
468 assert!(builder.node_id_at(2).is_none());
469 }
470
471 #[test]
472 fn test_workflow_builder_if_else() {
473 use uuid::Uuid;
474
475 let true_branch_id = Uuid::new_v4();
476 let false_branch_id = Uuid::new_v4();
477
478 let condition = Condition {
479 expression: "{{value}} > 10".to_string(),
480 true_branch: true_branch_id,
481 false_branch: false_branch_id,
482 };
483
484 let workflow = WorkflowBuilder::new("Conditional Workflow")
485 .start("Start")
486 .if_else("Check Value", condition)
487 .end("End")
488 .build();
489
490 assert_eq!(workflow.nodes.len(), 3);
491 assert!(matches!(workflow.nodes[1].kind, NodeKind::IfElse(_)));
492 }
493
494 #[test]
495 fn test_workflow_builder_auto_connect() {
496 let llm_config = LlmConfig {
497 provider: "openai".to_string(),
498 model: "gpt-4".to_string(),
499 system_prompt: None,
500 prompt_template: "test".to_string(),
501 temperature: None,
502 max_tokens: None,
503 tools: vec![],
504 images: vec![],
505 extra_params: serde_json::json!({}),
506 };
507
508 let workflow = WorkflowBuilder::new("Auto Connect Test")
509 .start("Start")
510 .llm("LLM1", llm_config.clone())
511 .llm("LLM2", llm_config)
512 .end("End")
513 .build();
514
515 assert_eq!(workflow.nodes.len(), 4);
517 assert_eq!(workflow.edges.len(), 3);
518 }
519}