1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::workflow::Workflow;
7
8pub trait IntoInputIds {
17 fn into_input_ids(self) -> Vec<String>;
18}
19
20impl IntoInputIds for &str {
21 fn into_input_ids(self) -> Vec<String> {
22 vec![self.to_owned()]
23 }
24}
25
26impl IntoInputIds for String {
27 fn into_input_ids(self) -> Vec<String> {
28 vec![self]
29 }
30}
31
32impl IntoInputIds for &String {
33 fn into_input_ids(self) -> Vec<String> {
34 vec![self.clone()]
35 }
36}
37
38impl<const N: usize> IntoInputIds for [&str; N] {
39 fn into_input_ids(self) -> Vec<String> {
40 self.into_iter().map(|s| s.to_owned()).collect()
41 }
42}
43
44impl<const N: usize> IntoInputIds for [String; N] {
45 fn into_input_ids(self) -> Vec<String> {
46 self.into_iter().collect()
47 }
48}
49
50impl IntoInputIds for Vec<String> {
51 fn into_input_ids(self) -> Vec<String> {
52 self
53 }
54}
55
56impl IntoInputIds for Vec<&str> {
57 fn into_input_ids(self) -> Vec<String> {
58 self.into_iter().map(|s| s.to_owned()).collect()
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
64pub enum FailureMode {
65 Skip,
67 Abort,
69 Fallback(String),
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RetryPolicy {
76 pub max_retries: u32,
78 pub backoff_ms: u64,
80 pub backoff_multiplier: f64,
82 pub on_failure: FailureMode,
84}
85
86impl Default for RetryPolicy {
87 fn default() -> Self {
88 Self {
89 max_retries: 2,
90 backoff_ms: 100,
91 backoff_multiplier: 2.0,
92 on_failure: FailureMode::Abort,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct AgentConfig {
100 pub name: String,
101 pub system_prompt: String,
102 #[serde(default)]
103 pub tools: Vec<String>,
104 #[serde(default)]
105 pub input_from: Vec<String>,
106 pub output_schema: Option<Value>,
107 #[serde(default)]
108 pub skills: Vec<String>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct HumanConfig {
114 pub prompt: String,
116 pub timeout_secs: Option<u64>,
118 #[serde(default)]
120 pub options: Vec<String>,
121 pub timeout_action: Option<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct ConditionBranch {
128 pub path: String,
130 pub op: ConditionOp,
132 pub value: Value,
134 pub goto: String,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
140pub enum ConditionOp {
141 #[serde(rename = "gt")]
143 Gt,
144 #[serde(rename = "gte")]
146 Gte,
147 #[serde(rename = "lt")]
149 Lt,
150 #[serde(rename = "lte")]
152 Lte,
153 #[serde(rename = "eq")]
155 Eq,
156 #[serde(rename = "neq")]
158 Neq,
159}
160
161pub fn evaluate_condition(data: &Value, branch: &ConditionBranch) -> bool {
163 let extracted = data.pointer(&branch.path);
164 let extracted = match extracted {
165 Some(v) => v,
166 None => return false,
167 };
168
169 match &branch.op {
170 ConditionOp::Eq => extracted == &branch.value,
171 ConditionOp::Neq => extracted != &branch.value,
172 ConditionOp::Gt | ConditionOp::Gte | ConditionOp::Lt | ConditionOp::Lte => {
173 compare_numeric(extracted, &branch.value, &branch.op)
174 }
175 }
176}
177
178fn compare_numeric(lhs: &Value, rhs: &Value, op: &ConditionOp) -> bool {
179 let lhs_f = value_as_f64(lhs);
180 let rhs_f = value_as_f64(rhs);
181 match (lhs_f, rhs_f) {
182 (Some(l), Some(r)) => match op {
183 ConditionOp::Gt => l > r,
184 ConditionOp::Gte => l >= r,
185 ConditionOp::Lt => l < r,
186 ConditionOp::Lte => l <= r,
187 _ => false,
188 },
189 _ => false,
190 }
191}
192
193fn value_as_f64(v: &Value) -> Option<f64> {
194 v.as_f64().or_else(|| v.as_i64().map(|i| i as f64))
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ConditionConfig {
200 pub input_from: Vec<String>,
202 pub branches: Vec<ConditionBranch>,
204 pub default_goto: Option<String>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct LoopConfig {
211 pub body: Vec<String>,
213 pub max_iterations: usize,
215 pub until: Option<ConditionBranch>,
217}
218
219pub type TransformFn = Arc<dyn Fn(&Value) -> std::result::Result<Value, String> + Send + Sync>;
221
222pub type AsyncTransformFn = Arc<
224 dyn Fn(Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<Value, String>> + Send>>
225 + Send
226 + Sync,
227>;
228
229#[derive(Clone)]
233pub struct TransformConfig {
234 pub input_from: Vec<String>,
236 pub transform_fn: TransformFn,
238 pub async_fn: Option<AsyncTransformFn>,
240 pub script: Option<String>,
244}
245
246impl std::fmt::Debug for TransformConfig {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 let has_async = self.async_fn.is_some();
249 f.debug_struct("TransformConfig")
250 .field("input_from", &self.input_from)
251 .field("transform_fn", &"<fn>")
252 .field("async_fn", &if has_async { "<async fn>" } else { "<none>" })
253 .field("script", &self.script)
254 .finish()
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct SubWorkflowConfig {
261 pub workflow: Workflow,
263 pub input_from: Vec<String>,
265}
266
267#[derive(Debug, Clone)]
269pub enum NodeKind {
270 Agent(AgentConfig),
271 Human(HumanConfig),
272 Condition(ConditionConfig),
273 Loop(LoopConfig),
274 Transform(TransformConfig),
275 SubWorkflow(SubWorkflowConfig),
276}
277
278#[derive(Debug, Clone)]
280pub struct Node {
281 pub id: String,
282 pub kind: NodeKind,
283 pub retry_policy: Option<RetryPolicy>,
285}
286
287impl Node {
288 pub fn display_name(&self) -> &str {
291 match &self.kind {
292 NodeKind::Agent(c) => &c.name,
293 _ => &self.id,
294 }
295 }
296
297 pub fn kind_str(&self) -> &str {
299 match &self.kind {
300 NodeKind::Agent(_) => "agent",
301 NodeKind::Human(_) => "human",
302 NodeKind::Transform(c) if c.script.is_some() => "script",
303 NodeKind::Transform(_) => "transform",
304 NodeKind::Condition(_) => "condition",
305 NodeKind::Loop(_) => "loop",
306 NodeKind::SubWorkflow(_) => "sub_workflow",
307 }
308 }
309
310 pub fn agent(id: impl Into<String>) -> NodeBuilder {
312 NodeBuilder {
313 id: id.into(),
314 kind: NodeBuilderKind::Agent {
315 name: None,
316 system_prompt: None,
317 tools: vec![],
318 input_from: vec![],
319 output_schema: None,
320 skills: vec![],
321 },
322 retry_policy: None,
323 }
324 }
325
326 pub fn human(id: impl Into<String>) -> NodeBuilder {
328 NodeBuilder {
329 id: id.into(),
330 kind: NodeBuilderKind::Human {
331 prompt: None,
332 timeout_secs: None,
333 options: vec![],
334 timeout_action: None,
335 },
336 retry_policy: None,
337 }
338 }
339
340 pub fn condition(id: impl Into<String>) -> NodeBuilder {
342 NodeBuilder {
343 id: id.into(),
344 kind: NodeBuilderKind::Condition {
345 input_from: vec![],
346 branches: vec![],
347 default_goto: None,
348 },
349 retry_policy: None,
350 }
351 }
352
353 pub fn loop_node(id: impl Into<String>) -> NodeBuilder {
355 NodeBuilder {
356 id: id.into(),
357 kind: NodeBuilderKind::Loop {
358 body: vec![],
359 max_iterations: 10,
360 until: None,
361 },
362 retry_policy: None,
363 }
364 }
365
366 pub fn transform(id: impl Into<String>) -> NodeBuilder {
368 NodeBuilder {
369 id: id.into(),
370 kind: NodeBuilderKind::Transform {
371 input_from: vec![],
372 transform_fn: None,
373 async_fn: None,
374 },
375 retry_policy: None,
376 }
377 }
378
379 pub fn sub_workflow(id: impl Into<String>) -> NodeBuilder {
381 NodeBuilder {
382 id: id.into(),
383 kind: NodeBuilderKind::SubWorkflow {
384 workflow: None,
385 input_from: vec![],
386 },
387 retry_policy: None,
388 }
389 }
390}
391
392enum NodeBuilderKind {
393 Agent {
394 name: Option<String>,
395 system_prompt: Option<String>,
396 tools: Vec<String>,
397 input_from: Vec<String>,
398 output_schema: Option<Value>,
399 skills: Vec<String>,
400 },
401 Human {
402 prompt: Option<String>,
403 timeout_secs: Option<u64>,
404 options: Vec<String>,
405 timeout_action: Option<String>,
406 },
407 Condition {
408 input_from: Vec<String>,
409 branches: Vec<ConditionBranch>,
410 default_goto: Option<String>,
411 },
412 Loop {
413 body: Vec<String>,
414 max_iterations: usize,
415 until: Option<ConditionBranch>,
416 },
417 Transform {
418 input_from: Vec<String>,
419 transform_fn: Option<TransformFn>,
420 async_fn: Option<AsyncTransformFn>,
421 },
422 SubWorkflow {
423 workflow: Option<Workflow>,
424 input_from: Vec<String>,
425 },
426}
427
428pub struct NodeBuilder {
429 id: String,
430 kind: NodeBuilderKind,
431 retry_policy: Option<RetryPolicy>,
432}
433
434impl NodeBuilder {
435 pub fn name(mut self, name: impl Into<String>) -> Self {
438 if let NodeBuilderKind::Agent { name: ref mut n, .. } = self.kind {
439 *n = Some(name.into());
440 }
441 self
442 }
443
444 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
445 if let NodeBuilderKind::Agent {
446 system_prompt: ref mut sp,
447 ..
448 } = self.kind
449 {
450 *sp = Some(prompt.into());
451 }
452 self
453 }
454
455 pub fn tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
456 if let NodeBuilderKind::Agent {
457 tools: ref mut t, ..
458 } = self.kind
459 {
460 *t = tools.into_iter().map(|s| s.into()).collect();
461 }
462 self
463 }
464
465 pub fn input_from(mut self, inputs: impl IntoInputIds) -> Self {
466 let collected: Vec<String> = inputs.into_input_ids();
467 match self.kind {
468 NodeBuilderKind::Agent {
469 ref mut input_from, ..
470 } => {
471 *input_from = collected;
472 }
473 NodeBuilderKind::Transform {
474 ref mut input_from, ..
475 } => {
476 *input_from = collected;
477 }
478 NodeBuilderKind::SubWorkflow {
479 ref mut input_from, ..
480 } => {
481 *input_from = collected;
482 }
483 _ => {}
484 }
485 self
486 }
487
488 pub fn output_schema(mut self, schema: Value) -> Self {
489 if let NodeBuilderKind::Agent {
490 output_schema: ref mut os,
491 ..
492 } = self.kind
493 {
494 *os = Some(schema);
495 }
496 self
497 }
498
499 pub fn skill(mut self, skill: impl Into<String>) -> Self {
501 if let NodeBuilderKind::Agent {
502 skills: ref mut s, ..
503 } = self.kind
504 {
505 s.push(skill.into());
506 }
507 self
508 }
509
510 pub fn skills(mut self, skills: impl IntoIterator<Item = impl Into<String>>) -> Self {
512 if let NodeBuilderKind::Agent {
513 skills: ref mut s, ..
514 } = self.kind
515 {
516 *s = skills.into_iter().map(|v| v.into()).collect();
517 }
518 self
519 }
520
521 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
524 if let NodeBuilderKind::Human {
525 prompt: ref mut p, ..
526 } = self.kind
527 {
528 *p = Some(prompt.into());
529 }
530 self
531 }
532
533 pub fn timeout_secs(mut self, secs: u64) -> Self {
534 if let NodeBuilderKind::Human {
535 timeout_secs: ref mut ts,
536 ..
537 } = self.kind
538 {
539 *ts = Some(secs);
540 }
541 self
542 }
543
544 pub fn options(mut self, options: impl IntoIterator<Item = impl Into<String>>) -> Self {
545 if let NodeBuilderKind::Human {
546 options: ref mut o, ..
547 } = self.kind
548 {
549 *o = options.into_iter().map(|s| s.into()).collect();
550 }
551 self
552 }
553
554 pub fn timeout_action(mut self, action: impl Into<String>) -> Self {
555 if let NodeBuilderKind::Human {
556 timeout_action: ref mut ta,
557 ..
558 } = self.kind
559 {
560 *ta = Some(action.into());
561 }
562 self
563 }
564
565 pub fn condition_input_from(mut self, inputs: impl IntoInputIds) -> Self {
569 if let NodeBuilderKind::Condition {
570 input_from: ref mut i,
571 ..
572 } = self.kind
573 {
574 *i = inputs.into_input_ids();
575 }
576 self
577 }
578
579 pub fn branch(mut self, branch: ConditionBranch) -> Self {
581 if let NodeBuilderKind::Condition {
582 branches: ref mut b,
583 ..
584 } = self.kind
585 {
586 b.push(branch);
587 }
588 self
589 }
590
591 pub fn default_goto(mut self, target: impl Into<String>) -> Self {
593 if let NodeBuilderKind::Condition {
594 default_goto: ref mut d,
595 ..
596 } = self.kind
597 {
598 *d = Some(target.into());
599 }
600 self
601 }
602
603 pub fn body(mut self, nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
607 if let NodeBuilderKind::Loop {
608 body: ref mut b, ..
609 } = self.kind
610 {
611 *b = nodes.into_iter().map(|s| s.into()).collect();
612 }
613 self
614 }
615
616 pub fn max_iterations(mut self, max: usize) -> Self {
618 if let NodeBuilderKind::Loop {
619 max_iterations: ref mut m,
620 ..
621 } = self.kind
622 {
623 *m = max;
624 }
625 self
626 }
627
628 pub fn until(mut self, condition: ConditionBranch) -> Self {
630 if let NodeBuilderKind::Loop {
631 until: ref mut u, ..
632 } = self.kind
633 {
634 *u = Some(condition);
635 }
636 self
637 }
638
639 pub fn transform_input_from(mut self, inputs: impl IntoInputIds) -> Self {
643 if let NodeBuilderKind::Transform {
644 input_from: ref mut i,
645 ..
646 } = self.kind
647 {
648 *i = inputs.into_input_ids();
649 }
650 self
651 }
652
653 pub fn transform_fn(
655 mut self,
656 f: impl Fn(&Value) -> std::result::Result<Value, String> + Send + Sync + 'static,
657 ) -> Self {
658 if let NodeBuilderKind::Transform {
659 transform_fn: ref mut tf,
660 ..
661 } = self.kind
662 {
663 *tf = Some(Arc::new(f));
664 }
665 self
666 }
667
668 pub fn async_transform_fn<F, Fut>(mut self, f: F) -> Self
682 where
683 F: Fn(Value) -> Fut + Send + Sync + 'static,
684 Fut: std::future::Future<Output = std::result::Result<Value, String>> + Send + 'static,
685 {
686 if let NodeBuilderKind::Transform { ref mut async_fn, .. } = self.kind {
687 *async_fn = Some(Arc::new(move |input: Value| {
688 Box::pin(f(input)) as std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<Value, String>> + Send>>
689 }));
690 }
691 self
692 }
693
694 pub fn workflow(mut self, workflow: Workflow) -> Self {
698 if let NodeBuilderKind::SubWorkflow {
699 workflow: ref mut w, ..
700 } = self.kind
701 {
702 *w = Some(workflow);
703 }
704 self
705 }
706
707 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
709 self.retry_policy = Some(policy);
710 self
711 }
712
713 pub fn retry(
715 mut self,
716 max_retries: u32,
717 backoff_ms: u64,
718 backoff_multiplier: f64,
719 on_failure: FailureMode,
720 ) -> Self {
721 self.retry_policy = Some(RetryPolicy {
722 max_retries,
723 backoff_ms,
724 backoff_multiplier,
725 on_failure,
726 });
727 self
728 }
729
730 pub fn build(self) -> Node {
731 let kind = match self.kind {
732 NodeBuilderKind::Agent {
733 name,
734 system_prompt,
735 tools,
736 input_from,
737 output_schema,
738 skills,
739 } => NodeKind::Agent(AgentConfig {
740 name: name.unwrap_or_else(|| self.id.clone()),
741 system_prompt: system_prompt.unwrap_or_default(),
742 tools,
743 input_from,
744 output_schema,
745 skills,
746 }),
747 NodeBuilderKind::Human {
748 prompt,
749 timeout_secs,
750 options,
751 timeout_action,
752 } => NodeKind::Human(HumanConfig {
753 prompt: prompt.unwrap_or_else(|| "Awaiting human input".to_string()),
754 timeout_secs,
755 options,
756 timeout_action,
757 }),
758 NodeBuilderKind::Condition {
759 input_from,
760 branches,
761 default_goto,
762 } => NodeKind::Condition(ConditionConfig {
763 input_from,
764 branches,
765 default_goto,
766 }),
767 NodeBuilderKind::Loop {
768 body,
769 max_iterations,
770 until,
771 } => NodeKind::Loop(LoopConfig {
772 body,
773 max_iterations,
774 until,
775 }),
776 NodeBuilderKind::Transform {
777 input_from,
778 transform_fn,
779 async_fn,
780 } => {
781 let noop_fn: TransformFn = Arc::new(|_| Ok(Value::Null));
782 NodeKind::Transform(TransformConfig {
783 input_from,
784 transform_fn: transform_fn.unwrap_or(noop_fn),
785 async_fn,
786 script: None, })
788 }
789 NodeBuilderKind::SubWorkflow {
790 workflow,
791 input_from,
792 } => NodeKind::SubWorkflow(SubWorkflowConfig {
793 workflow: workflow.expect("sub_workflow node requires a workflow"),
794 input_from,
795 }),
796 };
797
798 Node {
799 id: self.id,
800 kind,
801 retry_policy: self.retry_policy,
802 }
803 }
804}