Skip to main content

motosan_agent_workflow/
node.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::workflow::Workflow;
7
8/// Trait that allows `input_from()` to accept both a single string and a collection of strings.
9///
10/// # Examples
11/// ```ignore
12/// .input_from("single_node")      // single &str
13/// .input_from(["a", "b"])         // array of &str
14/// .input_from(vec!["a".into()])   // Vec<String>
15/// ```
16pub trait IntoInputIds {
17    fn into_input_ids(self) -> Vec<String>;
18}
19
20impl IntoInputIds for &str {
21    fn into_input_ids(self) -> Vec<String> {
22        vec![self.to_owned()]
23    }
24}
25
26impl IntoInputIds for String {
27    fn into_input_ids(self) -> Vec<String> {
28        vec![self]
29    }
30}
31
32impl IntoInputIds for &String {
33    fn into_input_ids(self) -> Vec<String> {
34        vec![self.clone()]
35    }
36}
37
38impl<const N: usize> IntoInputIds for [&str; N] {
39    fn into_input_ids(self) -> Vec<String> {
40        self.into_iter().map(|s| s.to_owned()).collect()
41    }
42}
43
44impl<const N: usize> IntoInputIds for [String; N] {
45    fn into_input_ids(self) -> Vec<String> {
46        self.into_iter().collect()
47    }
48}
49
50impl IntoInputIds for Vec<String> {
51    fn into_input_ids(self) -> Vec<String> {
52        self
53    }
54}
55
56impl IntoInputIds for Vec<&str> {
57    fn into_input_ids(self) -> Vec<String> {
58        self.into_iter().map(|s| s.to_owned()).collect()
59    }
60}
61
62/// What to do when a node exhausts all retries.
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
64pub enum FailureMode {
65    /// Skip this node and continue the workflow (store null output).
66    Skip,
67    /// Abort the entire workflow with an error.
68    Abort,
69    /// Use the given fallback string as the node output and continue.
70    Fallback(String),
71}
72
73/// Configurable retry behavior for a node.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RetryPolicy {
76    /// Maximum number of retries (0 means no retries, just the initial attempt).
77    pub max_retries: u32,
78    /// Initial backoff in milliseconds before the first retry.
79    pub backoff_ms: u64,
80    /// Multiplier applied to backoff after each retry (exponential backoff).
81    pub backoff_multiplier: f64,
82    /// What to do when all retries are exhausted.
83    pub on_failure: FailureMode,
84}
85
86impl Default for RetryPolicy {
87    fn default() -> Self {
88        Self {
89            max_retries: 2,
90            backoff_ms: 100,
91            backoff_multiplier: 2.0,
92            on_failure: FailureMode::Abort,
93        }
94    }
95}
96
97/// Configuration for an Agent node.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct AgentConfig {
100    pub name: String,
101    pub system_prompt: String,
102    #[serde(default)]
103    pub tools: Vec<String>,
104    #[serde(default)]
105    pub input_from: Vec<String>,
106    pub output_schema: Option<Value>,
107    #[serde(default)]
108    pub skills: Vec<String>,
109}
110
111/// Configuration for a Human gate node.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct HumanConfig {
114    /// Prompt to show to the human.
115    pub prompt: String,
116    /// Timeout in seconds. If None, waits indefinitely.
117    pub timeout_secs: Option<u64>,
118    /// Available options (e.g., ["approve", "reject", "edit"]).
119    #[serde(default)]
120    pub options: Vec<String>,
121    /// Default action on timeout (e.g., "approve" or "reject").
122    pub timeout_action: Option<String>,
123}
124
125/// A single branch in a ConditionNode: a JSON pointer path, comparison operator, and value.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct ConditionBranch {
128    /// JSON pointer path to extract from upstream output (e.g., "/confidence").
129    pub path: String,
130    /// Comparison operator: "gt", "gte", "lt", "lte", "eq", "neq".
131    pub op: ConditionOp,
132    /// The value to compare against.
133    pub value: Value,
134    /// Target node ID to route to when this branch matches.
135    pub goto: String,
136}
137
138/// Comparison operators for condition evaluation.
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
140pub enum ConditionOp {
141    /// Greater than (>)
142    #[serde(rename = "gt")]
143    Gt,
144    /// Greater than or equal (>=)
145    #[serde(rename = "gte")]
146    Gte,
147    /// Less than (<)
148    #[serde(rename = "lt")]
149    Lt,
150    /// Less than or equal (<=)
151    #[serde(rename = "lte")]
152    Lte,
153    /// Equal (==)
154    #[serde(rename = "eq")]
155    Eq,
156    /// Not equal (!=)
157    #[serde(rename = "neq")]
158    Neq,
159}
160
161/// Evaluate a condition: extract `path` from `data`, compare with `op` against `value`.
162pub fn evaluate_condition(data: &Value, branch: &ConditionBranch) -> bool {
163    let extracted = data.pointer(&branch.path);
164    let extracted = match extracted {
165        Some(v) => v,
166        None => return false,
167    };
168
169    match &branch.op {
170        ConditionOp::Eq => extracted == &branch.value,
171        ConditionOp::Neq => extracted != &branch.value,
172        ConditionOp::Gt | ConditionOp::Gte | ConditionOp::Lt | ConditionOp::Lte => {
173            compare_numeric(extracted, &branch.value, &branch.op)
174        }
175    }
176}
177
178fn compare_numeric(lhs: &Value, rhs: &Value, op: &ConditionOp) -> bool {
179    let lhs_f = value_as_f64(lhs);
180    let rhs_f = value_as_f64(rhs);
181    match (lhs_f, rhs_f) {
182        (Some(l), Some(r)) => match op {
183            ConditionOp::Gt => l > r,
184            ConditionOp::Gte => l >= r,
185            ConditionOp::Lt => l < r,
186            ConditionOp::Lte => l <= r,
187            _ => false,
188        },
189        _ => false,
190    }
191}
192
193fn value_as_f64(v: &Value) -> Option<f64> {
194    v.as_f64().or_else(|| v.as_i64().map(|i| i as f64))
195}
196
197/// Configuration for a Condition node.
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ConditionConfig {
200    /// IDs of upstream nodes whose output is used for evaluation.
201    pub input_from: Vec<String>,
202    /// Ordered list of branches; first match wins.
203    pub branches: Vec<ConditionBranch>,
204    /// Optional default target if no branch matches (instead of error).
205    pub default_goto: Option<String>,
206}
207
208/// Configuration for a Loop node.
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct LoopConfig {
211    /// Node IDs that form the loop body, executed in order each iteration.
212    pub body: Vec<String>,
213    /// Maximum number of iterations (safety limit).
214    pub max_iterations: usize,
215    /// Until condition: JSON pointer path + op + value checked against the last body node's output.
216    pub until: Option<ConditionBranch>,
217}
218
219/// The type of a synchronous transform function.
220pub type TransformFn = Arc<dyn Fn(&Value) -> std::result::Result<Value, String> + Send + Sync>;
221
222/// The type of an async transform function — supports HTTP calls, file I/O, etc.
223pub type AsyncTransformFn = Arc<
224    dyn Fn(Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<Value, String>> + Send>>
225        + Send
226        + Sync,
227>;
228
229/// Configuration for a Transform node (pure function, async function, or external script).
230///
231/// Execution priority: `script` > `async_fn` > `transform_fn`.
232#[derive(Clone)]
233pub struct TransformConfig {
234    /// IDs of upstream nodes whose output is used as input.
235    pub input_from: Vec<String>,
236    /// The sync transform function (lowest priority fallback).
237    pub transform_fn: TransformFn,
238    /// Optional async transform function.
239    pub async_fn: Option<AsyncTransformFn>,
240    /// Optional external script path. When set, the runtime spawns the script as a subprocess,
241    /// pipes input JSON via stdin, and captures output JSON from stdout.
242    /// Supports: .mjs, .js (node), .py (python3), .sh (bash).
243    pub script: Option<String>,
244}
245
246impl std::fmt::Debug for TransformConfig {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        let has_async = self.async_fn.is_some();
249        f.debug_struct("TransformConfig")
250            .field("input_from", &self.input_from)
251            .field("transform_fn", &"<fn>")
252            .field("async_fn", &if has_async { "<async fn>" } else { "<none>" })
253            .field("script", &self.script)
254            .finish()
255    }
256}
257
258/// Configuration for a SubWorkflow node (nested workflow composition).
259#[derive(Debug, Clone)]
260pub struct SubWorkflowConfig {
261    /// The nested workflow to execute.
262    pub workflow: Workflow,
263    /// IDs of upstream nodes whose output is used as input.
264    pub input_from: Vec<String>,
265}
266
267/// The kind of a workflow node.
268#[derive(Debug, Clone)]
269pub enum NodeKind {
270    Agent(AgentConfig),
271    Human(HumanConfig),
272    Condition(ConditionConfig),
273    Loop(LoopConfig),
274    Transform(TransformConfig),
275    SubWorkflow(SubWorkflowConfig),
276}
277
278/// A single node in a workflow DAG.
279#[derive(Debug, Clone)]
280pub struct Node {
281    pub id: String,
282    pub kind: NodeKind,
283    /// Optional retry policy for this node.
284    pub retry_policy: Option<RetryPolicy>,
285}
286
287impl Node {
288    /// Returns a human-readable display name for this node.
289    /// For agent nodes, returns the agent name; for others, returns the node ID.
290    pub fn display_name(&self) -> &str {
291        match &self.kind {
292            NodeKind::Agent(c) => &c.name,
293            _ => &self.id,
294        }
295    }
296
297    /// Returns a string label for the node kind.
298    pub fn kind_str(&self) -> &str {
299        match &self.kind {
300            NodeKind::Agent(_) => "agent",
301            NodeKind::Human(_) => "human",
302            NodeKind::Transform(c) if c.script.is_some() => "script",
303            NodeKind::Transform(_) => "transform",
304            NodeKind::Condition(_) => "condition",
305            NodeKind::Loop(_) => "loop",
306            NodeKind::SubWorkflow(_) => "sub_workflow",
307        }
308    }
309
310    /// Create a new Agent node builder.
311    pub fn agent(id: impl Into<String>) -> NodeBuilder {
312        NodeBuilder {
313            id: id.into(),
314            kind: NodeBuilderKind::Agent {
315                name: None,
316                system_prompt: None,
317                tools: vec![],
318                input_from: vec![],
319                output_schema: None,
320                skills: vec![],
321            },
322            retry_policy: None,
323        }
324    }
325
326    /// Create a new Human node builder.
327    pub fn human(id: impl Into<String>) -> NodeBuilder {
328        NodeBuilder {
329            id: id.into(),
330            kind: NodeBuilderKind::Human {
331                prompt: None,
332                timeout_secs: None,
333                options: vec![],
334                timeout_action: None,
335            },
336            retry_policy: None,
337        }
338    }
339
340    /// Create a new Condition node builder.
341    pub fn condition(id: impl Into<String>) -> NodeBuilder {
342        NodeBuilder {
343            id: id.into(),
344            kind: NodeBuilderKind::Condition {
345                input_from: vec![],
346                branches: vec![],
347                default_goto: None,
348            },
349            retry_policy: None,
350        }
351    }
352
353    /// Create a new Loop node builder.
354    pub fn loop_node(id: impl Into<String>) -> NodeBuilder {
355        NodeBuilder {
356            id: id.into(),
357            kind: NodeBuilderKind::Loop {
358                body: vec![],
359                max_iterations: 10,
360                until: None,
361            },
362            retry_policy: None,
363        }
364    }
365
366    /// Create a new Transform node builder.
367    pub fn transform(id: impl Into<String>) -> NodeBuilder {
368        NodeBuilder {
369            id: id.into(),
370            kind: NodeBuilderKind::Transform {
371                input_from: vec![],
372                transform_fn: None,
373                async_fn: None,
374            },
375            retry_policy: None,
376        }
377    }
378
379    /// Create a new SubWorkflow node builder.
380    pub fn sub_workflow(id: impl Into<String>) -> NodeBuilder {
381        NodeBuilder {
382            id: id.into(),
383            kind: NodeBuilderKind::SubWorkflow {
384                workflow: None,
385                input_from: vec![],
386            },
387            retry_policy: None,
388        }
389    }
390}
391
392enum NodeBuilderKind {
393    Agent {
394        name: Option<String>,
395        system_prompt: Option<String>,
396        tools: Vec<String>,
397        input_from: Vec<String>,
398        output_schema: Option<Value>,
399        skills: Vec<String>,
400    },
401    Human {
402        prompt: Option<String>,
403        timeout_secs: Option<u64>,
404        options: Vec<String>,
405        timeout_action: Option<String>,
406    },
407    Condition {
408        input_from: Vec<String>,
409        branches: Vec<ConditionBranch>,
410        default_goto: Option<String>,
411    },
412    Loop {
413        body: Vec<String>,
414        max_iterations: usize,
415        until: Option<ConditionBranch>,
416    },
417    Transform {
418        input_from: Vec<String>,
419        transform_fn: Option<TransformFn>,
420        async_fn: Option<AsyncTransformFn>,
421    },
422    SubWorkflow {
423        workflow: Option<Workflow>,
424        input_from: Vec<String>,
425    },
426}
427
428pub struct NodeBuilder {
429    id: String,
430    kind: NodeBuilderKind,
431    retry_policy: Option<RetryPolicy>,
432}
433
434impl NodeBuilder {
435    // --- Agent methods ---
436
437    pub fn name(mut self, name: impl Into<String>) -> Self {
438        if let NodeBuilderKind::Agent { name: ref mut n, .. } = self.kind {
439            *n = Some(name.into());
440        }
441        self
442    }
443
444    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
445        if let NodeBuilderKind::Agent {
446            system_prompt: ref mut sp,
447            ..
448        } = self.kind
449        {
450            *sp = Some(prompt.into());
451        }
452        self
453    }
454
455    pub fn tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
456        if let NodeBuilderKind::Agent {
457            tools: ref mut t, ..
458        } = self.kind
459        {
460            *t = tools.into_iter().map(|s| s.into()).collect();
461        }
462        self
463    }
464
465    pub fn input_from(mut self, inputs: impl IntoInputIds) -> Self {
466        let collected: Vec<String> = inputs.into_input_ids();
467        match self.kind {
468            NodeBuilderKind::Agent {
469                ref mut input_from, ..
470            } => {
471                *input_from = collected;
472            }
473            NodeBuilderKind::Transform {
474                ref mut input_from, ..
475            } => {
476                *input_from = collected;
477            }
478            NodeBuilderKind::SubWorkflow {
479                ref mut input_from, ..
480            } => {
481                *input_from = collected;
482            }
483            _ => {}
484        }
485        self
486    }
487
488    pub fn output_schema(mut self, schema: Value) -> Self {
489        if let NodeBuilderKind::Agent {
490            output_schema: ref mut os,
491            ..
492        } = self.kind
493        {
494            *os = Some(schema);
495        }
496        self
497    }
498
499    /// Add a single skill (pushes to the skills vec). Can be called multiple times.
500    pub fn skill(mut self, skill: impl Into<String>) -> Self {
501        if let NodeBuilderKind::Agent {
502            skills: ref mut s, ..
503        } = self.kind
504        {
505            s.push(skill.into());
506        }
507        self
508    }
509
510    /// Set multiple skills at once (replaces any previously added skills).
511    pub fn skills(mut self, skills: impl IntoIterator<Item = impl Into<String>>) -> Self {
512        if let NodeBuilderKind::Agent {
513            skills: ref mut s, ..
514        } = self.kind
515        {
516            *s = skills.into_iter().map(|v| v.into()).collect();
517        }
518        self
519    }
520
521    // --- Human methods ---
522
523    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
524        if let NodeBuilderKind::Human {
525            prompt: ref mut p, ..
526        } = self.kind
527        {
528            *p = Some(prompt.into());
529        }
530        self
531    }
532
533    pub fn timeout_secs(mut self, secs: u64) -> Self {
534        if let NodeBuilderKind::Human {
535            timeout_secs: ref mut ts,
536            ..
537        } = self.kind
538        {
539            *ts = Some(secs);
540        }
541        self
542    }
543
544    pub fn options(mut self, options: impl IntoIterator<Item = impl Into<String>>) -> Self {
545        if let NodeBuilderKind::Human {
546            options: ref mut o, ..
547        } = self.kind
548        {
549            *o = options.into_iter().map(|s| s.into()).collect();
550        }
551        self
552    }
553
554    pub fn timeout_action(mut self, action: impl Into<String>) -> Self {
555        if let NodeBuilderKind::Human {
556            timeout_action: ref mut ta,
557            ..
558        } = self.kind
559        {
560            *ta = Some(action.into());
561        }
562        self
563    }
564
565    // --- Condition methods ---
566
567    /// Add an upstream input source for condition evaluation.
568    pub fn condition_input_from(mut self, inputs: impl IntoInputIds) -> Self {
569        if let NodeBuilderKind::Condition {
570            input_from: ref mut i,
571            ..
572        } = self.kind
573        {
574            *i = inputs.into_input_ids();
575        }
576        self
577    }
578
579    /// Add a condition branch.
580    pub fn branch(mut self, branch: ConditionBranch) -> Self {
581        if let NodeBuilderKind::Condition {
582            branches: ref mut b,
583            ..
584        } = self.kind
585        {
586            b.push(branch);
587        }
588        self
589    }
590
591    /// Set the default goto target when no branch matches.
592    pub fn default_goto(mut self, target: impl Into<String>) -> Self {
593        if let NodeBuilderKind::Condition {
594            default_goto: ref mut d,
595            ..
596        } = self.kind
597        {
598            *d = Some(target.into());
599        }
600        self
601    }
602
603    // --- Loop methods ---
604
605    /// Set the body node IDs for the loop.
606    pub fn body(mut self, nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
607        if let NodeBuilderKind::Loop {
608            body: ref mut b, ..
609        } = self.kind
610        {
611            *b = nodes.into_iter().map(|s| s.into()).collect();
612        }
613        self
614    }
615
616    /// Set the maximum iterations for the loop.
617    pub fn max_iterations(mut self, max: usize) -> Self {
618        if let NodeBuilderKind::Loop {
619            max_iterations: ref mut m,
620            ..
621        } = self.kind
622        {
623            *m = max;
624        }
625        self
626    }
627
628    /// Set the until condition for the loop.
629    pub fn until(mut self, condition: ConditionBranch) -> Self {
630        if let NodeBuilderKind::Loop {
631            until: ref mut u, ..
632        } = self.kind
633        {
634            *u = Some(condition);
635        }
636        self
637    }
638
639    // --- Transform methods ---
640
641    /// Set the input sources for a Transform node.
642    pub fn transform_input_from(mut self, inputs: impl IntoInputIds) -> Self {
643        if let NodeBuilderKind::Transform {
644            input_from: ref mut i,
645            ..
646        } = self.kind
647        {
648            *i = inputs.into_input_ids();
649        }
650        self
651    }
652
653    /// Set the transform function for a Transform node.
654    pub fn transform_fn(
655        mut self,
656        f: impl Fn(&Value) -> std::result::Result<Value, String> + Send + Sync + 'static,
657    ) -> Self {
658        if let NodeBuilderKind::Transform {
659            transform_fn: ref mut tf,
660            ..
661        } = self.kind
662        {
663            *tf = Some(Arc::new(f));
664        }
665        self
666    }
667
668    /// Set an async transform function. When set, the runtime uses this instead of the sync transform_fn.
669    /// Useful for HTTP calls, file I/O, or any async operation.
670    ///
671    /// # Example
672    ///
673    /// ```rust,ignore
674    /// Node::transform("fetch")
675    ///     .transform_input_from(["upstream"])
676    ///     .async_transform_fn(|input| async move {
677    ///         Ok(serde_json::json!({"result": "ok"}))
678    ///     })
679    ///     .build()
680    /// ```
681    pub fn async_transform_fn<F, Fut>(mut self, f: F) -> Self
682    where
683        F: Fn(Value) -> Fut + Send + Sync + 'static,
684        Fut: std::future::Future<Output = std::result::Result<Value, String>> + Send + 'static,
685    {
686        if let NodeBuilderKind::Transform { ref mut async_fn, .. } = self.kind {
687            *async_fn = Some(Arc::new(move |input: Value| {
688                Box::pin(f(input)) as std::pin::Pin<Box<dyn std::future::Future<Output = std::result::Result<Value, String>> + Send>>
689            }));
690        }
691        self
692    }
693
694    // --- SubWorkflow methods ---
695
696    /// Set the nested workflow for a SubWorkflow node.
697    pub fn workflow(mut self, workflow: Workflow) -> Self {
698        if let NodeBuilderKind::SubWorkflow {
699            workflow: ref mut w, ..
700        } = self.kind
701        {
702            *w = Some(workflow);
703        }
704        self
705    }
706
707    /// Set a retry policy for this node.
708    pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
709        self.retry_policy = Some(policy);
710        self
711    }
712
713    /// Set retry parameters directly (convenience method).
714    pub fn retry(
715        mut self,
716        max_retries: u32,
717        backoff_ms: u64,
718        backoff_multiplier: f64,
719        on_failure: FailureMode,
720    ) -> Self {
721        self.retry_policy = Some(RetryPolicy {
722            max_retries,
723            backoff_ms,
724            backoff_multiplier,
725            on_failure,
726        });
727        self
728    }
729
730    pub fn build(self) -> Node {
731        let kind = match self.kind {
732            NodeBuilderKind::Agent {
733                name,
734                system_prompt,
735                tools,
736                input_from,
737                output_schema,
738                skills,
739            } => NodeKind::Agent(AgentConfig {
740                name: name.unwrap_or_else(|| self.id.clone()),
741                system_prompt: system_prompt.unwrap_or_default(),
742                tools,
743                input_from,
744                output_schema,
745                skills,
746            }),
747            NodeBuilderKind::Human {
748                prompt,
749                timeout_secs,
750                options,
751                timeout_action,
752            } => NodeKind::Human(HumanConfig {
753                prompt: prompt.unwrap_or_else(|| "Awaiting human input".to_string()),
754                timeout_secs,
755                options,
756                timeout_action,
757            }),
758            NodeBuilderKind::Condition {
759                input_from,
760                branches,
761                default_goto,
762            } => NodeKind::Condition(ConditionConfig {
763                input_from,
764                branches,
765                default_goto,
766            }),
767            NodeBuilderKind::Loop {
768                body,
769                max_iterations,
770                until,
771            } => NodeKind::Loop(LoopConfig {
772                body,
773                max_iterations,
774                until,
775            }),
776            NodeBuilderKind::Transform {
777                input_from,
778                transform_fn,
779                async_fn,
780            } => {
781                let noop_fn: TransformFn = Arc::new(|_| Ok(Value::Null));
782                NodeKind::Transform(TransformConfig {
783                    input_from,
784                    transform_fn: transform_fn.unwrap_or(noop_fn),
785                    async_fn,
786                    script: None, // Script is set via YAML loader, not builder
787                })
788            }
789            NodeBuilderKind::SubWorkflow {
790                workflow,
791                input_from,
792            } => NodeKind::SubWorkflow(SubWorkflowConfig {
793                workflow: workflow.expect("sub_workflow node requires a workflow"),
794                input_from,
795            }),
796        };
797
798        Node {
799            id: self.id,
800            kind,
801            retry_policy: self.retry_policy,
802        }
803    }
804}