Skip to main content

jamjet_core/
node.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4pub type NodeId = String;
5
6/// The lifecycle status of a single node within an execution.
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8#[serde(rename_all = "snake_case")]
9pub enum NodeStatus {
10    Pending,
11    Scheduled,
12    Running,
13    Completed,
14    Failed,
15    Skipped,
16    Cancelled,
17}
18
19impl NodeStatus {
20    pub fn is_terminal(&self) -> bool {
21        matches!(
22            self,
23            Self::Completed | Self::Failed | Self::Skipped | Self::Cancelled
24        )
25    }
26}
27
28/// All node kinds supported by the JamJet runtime.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum NodeKind {
32    /// LLM call with a prompt and structured output.
33    Model {
34        model_ref: String,
35        prompt_ref: String,
36        output_schema: String,
37        system_prompt: Option<String>,
38    },
39
40    /// Python function, HTTP endpoint, or gRPC tool.
41    Tool {
42        tool_ref: String,
43        input_mapping: HashMap<String, String>,
44        output_schema: String,
45    },
46
47    /// Arbitrary Python function executed by a Python worker.
48    PythonFn {
49        module: String,
50        function: String,
51        output_schema: String,
52    },
53
54    /// Router — evaluates expressions and branches.
55    Condition { branches: Vec<ConditionalBranch> },
56
57    /// Fan-out to multiple branches concurrently.
58    Parallel { branches: Vec<NodeId> },
59
60    /// Waits for all parallel branches to complete.
61    Join {
62        wait_for: Vec<NodeId>,
63        merge_strategy: MergeStrategy,
64    },
65
66    /// Pauses workflow for human decision.
67    HumanApproval {
68        description: String,
69        timeout_secs: Option<u64>,
70        fallback_node: Option<NodeId>,
71    },
72
73    /// Suspends until a timer fires or external event arrives.
74    Wait {
75        condition: WaitCondition,
76        correlation_key: Option<String>,
77        timeout_secs: Option<u64>,
78    },
79
80    /// Executes a child workflow.
81    Subgraph {
82        workflow_ref: String,
83        workflow_version: Option<String>,
84        input_mapping: HashMap<String, String>,
85        output_mapping: HashMap<String, String>,
86    },
87
88    /// Retrieves context from a memory/retrieval connector.
89    MemoryRetrieval {
90        connector_ref: String,
91        query_expr: String,
92        output_schema: String,
93    },
94
95    /// Evaluates policy rules; can block or branch on violation.
96    Policy {
97        policy_ref: String,
98        on_violation: ViolationAction,
99    },
100
101    /// Side-effect node (notifications, writes).
102    Finalizer {
103        tool_ref: String,
104        run_on: FinalizerTrigger,
105    },
106
107    // ── Protocol nodes ──────────────────────────────────────────────────
108    /// Delegates to a local JamJet agent.
109    Agent {
110        agent_ref: String,
111        input_mapping: HashMap<String, String>,
112        output_schema: String,
113    },
114
115    /// Invokes a tool from an external MCP server.
116    McpTool {
117        server: String,
118        tool: String,
119        input_mapping: HashMap<String, String>,
120        output_schema: String,
121    },
122
123    /// Delegates a task to an external A2A agent.
124    A2aTask {
125        remote_agent: String,
126        skill: String,
127        input_mapping: HashMap<String, String>,
128        output_schema: String,
129        stream: bool,
130        on_input_required: Option<NodeId>,
131        timeout_secs: Option<u64>,
132    },
133
134    #[deprecated(note = "Use Coordinator node instead")]
135    /// Dynamically discovers and selects an agent at runtime.
136    AgentDiscovery {
137        skill: String,
138        protocol: Option<String>,
139        output_binding: String,
140    },
141
142    /// Dynamic agent routing with structured scoring + LLM tiebreaker.
143    /// Supersedes AgentDiscovery.
144    Coordinator {
145        task: String,
146        required_skills: Vec<String>,
147        #[serde(default)]
148        preferred_skills: Vec<String>,
149        trust_domain: Option<String>,
150        budget: Option<crate::coordinator::CoordinatorBudget>,
151        tiebreaker: Option<crate::coordinator::TiebreakerConfig>,
152        #[serde(default = "default_strategy")]
153        strategy: String,
154        #[serde(default)]
155        weights: crate::coordinator::DimensionWeights,
156        #[serde(default)]
157        input_mapping: HashMap<String, String>,
158        output_key: String,
159    },
160
161    /// Invoke a registered agent as a callable tool.
162    AgentTool {
163        agent: crate::agent_tool::AgentTarget,
164        #[serde(default)]
165        mode: crate::agent_tool::AgentToolMode,
166        #[serde(default)]
167        input_mapping: HashMap<String, String>,
168        output_key: String,
169        timeout_ms: Option<u64>,
170        budget: Option<crate::agent_tool::AgentToolBudget>,
171    },
172
173    /// Evaluates the preceding node's output using configurable scorers.
174    ///
175    /// Supports LLM-judge, deterministic assertions, latency/cost thresholds,
176    /// and custom Python scorer plugins.
177    Eval {
178        /// Ordered list of scorer configurations.
179        scorers: Vec<EvalScorer>,
180        /// Action on overall failure (any scorer below threshold).
181        on_fail: EvalOnFail,
182        /// Maximum retry attempts before propagating failure.
183        #[serde(default)]
184        max_retries: u32,
185        /// Input expression — which state field to evaluate (default: last node output).
186        input_expr: Option<String>,
187    },
188
189    /// Terminal node emitted by strategy compilers when an iteration or cost
190    /// limit is reached. The runtime records the workflow as `LimitExceeded`
191    /// and stops further execution. No fields are required — it is purely a
192    /// marker that carries optional descriptive metadata in the node's
193    /// `description` / `labels`.
194    LimitExceeded,
195}
196
197impl NodeKind {
198    /// Returns the queue type this node should be dispatched to.
199    pub fn queue_type(&self) -> QueueType {
200        match self {
201            Self::Model { .. } => QueueType::Model,
202            Self::Tool { .. } | Self::Finalizer { .. } => QueueType::Tool,
203            Self::PythonFn { .. } => QueueType::PythonTool,
204            Self::MemoryRetrieval { .. } => QueueType::Retrieval,
205            Self::McpTool { .. } | Self::A2aTask { .. } => QueueType::Tool,
206            Self::Agent { .. } => QueueType::General,
207            Self::HumanApproval { .. } | Self::Wait { .. } => QueueType::General,
208            Self::Eval { .. } => QueueType::General,
209            Self::Coordinator { .. } => QueueType::General,
210            Self::AgentTool { .. } => QueueType::General,
211            _ => QueueType::General,
212        }
213    }
214
215    /// Returns true if this node requires durable tracking across crashes.
216    pub fn is_durable(&self) -> bool {
217        #[allow(deprecated)]
218        let is_agent_discovery = matches!(self, Self::AgentDiscovery { .. });
219        !matches!(self, Self::Condition { .. }) && !is_agent_discovery
220    }
221}
222
223/// Which queue a node's work item is dispatched to.
224#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "snake_case")]
226pub enum QueueType {
227    Model,
228    Tool,
229    PythonTool,
230    Retrieval,
231    Privileged,
232    General,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct ConditionalBranch {
237    pub condition: Option<String>, // None = default/else branch
238    pub target: NodeId,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242#[serde(rename_all = "snake_case")]
243pub enum MergeStrategy {
244    /// Merge all branch outputs into a list.
245    Collect,
246    /// Take the first completed branch output.
247    First,
248    /// Custom merge function.
249    Custom { function_ref: String },
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum WaitCondition {
255    Timer,
256    ExternalEvent,
257    Either,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261#[serde(rename_all = "snake_case")]
262pub enum ViolationAction {
263    Fail,
264    Branch { target: NodeId },
265    Warn,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
269#[serde(rename_all = "snake_case")]
270pub enum FinalizerTrigger {
271    Success,
272    Failure,
273    Always,
274}
275
276// ── Eval node types ──────────────────────────────────────────────────────────
277
278/// A scorer within an `Eval` node.
279#[derive(Debug, Clone, Serialize, Deserialize)]
280#[serde(tag = "type", rename_all = "snake_case")]
281pub enum EvalScorer {
282    /// LLM-as-judge: sends output to a model with a rubric, expects a score 1-5.
283    LlmJudge {
284        model: String,
285        rubric: String,
286        /// Minimum acceptable score (1-5). Scores below this fail.
287        #[serde(default = "default_min_score")]
288        min_score: u8,
289    },
290    /// Deterministic Python expressions evaluated against the output.
291    Assertion {
292        /// Each check is a Python expression that must evaluate to truthy.
293        checks: Vec<String>,
294    },
295    /// Ensures node execution completed within a latency threshold.
296    Latency {
297        /// Maximum allowed duration in milliseconds.
298        threshold_ms: u64,
299    },
300    /// Ensures the execution cost is within budget.
301    Cost {
302        /// Maximum allowed cost in USD.
303        threshold_usd: f64,
304    },
305    /// Custom Python scorer loaded via entry point or module path.
306    Custom {
307        /// Python dotted path: "my_package.scorers:MyScorer"
308        module: String,
309        /// Optional keyword arguments passed to the scorer.
310        #[serde(default)]
311        kwargs: serde_json::Value,
312    },
313}
314
315fn default_min_score() -> u8 {
316    3
317}
318
319fn default_strategy() -> String {
320    "default".to_string()
321}
322
323/// What the eval node does when one or more scorers fail.
324#[derive(Debug, Clone, Serialize, Deserialize, Default)]
325#[serde(rename_all = "snake_case")]
326pub enum EvalOnFail {
327    /// Feed scorer feedback back to the previous node and retry.
328    RetryWithFeedback,
329    /// Escalate to human (triggers HumanApproval fallback node).
330    Escalate,
331    /// Fail the workflow immediately.
332    #[default]
333    Halt,
334    /// Record the failure but continue the workflow.
335    LogAndContinue,
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn model_node_dispatches_to_model_queue() {
344        let node = NodeKind::Model {
345            model_ref: "openai.gpt4".into(),
346            prompt_ref: "prompts/summarize.md".into(),
347            output_schema: "schemas.Summary".into(),
348            system_prompt: None,
349        };
350        assert_eq!(node.queue_type(), QueueType::Model);
351        assert!(node.is_durable());
352    }
353
354    #[test]
355    fn condition_node_is_not_durable() {
356        let node = NodeKind::Condition { branches: vec![] };
357        assert!(!node.is_durable());
358    }
359
360    #[test]
361    fn coordinator_node_round_trip() {
362        let node = NodeKind::Coordinator {
363            task: "Analyze data".into(),
364            required_skills: vec!["data-analysis".into()],
365            preferred_skills: vec![],
366            trust_domain: Some("internal".into()),
367            budget: None,
368            tiebreaker: None,
369            strategy: "default".into(),
370            weights: Default::default(),
371            input_mapping: Default::default(),
372            output_key: "result".into(),
373        };
374        let json = serde_json::to_string(&node).unwrap();
375        let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
376        assert!(matches!(deserialized, NodeKind::Coordinator { .. }));
377        assert_eq!(node.queue_type(), QueueType::General);
378        assert!(node.is_durable());
379    }
380
381    #[test]
382    fn agent_tool_node_round_trip() {
383        let node = NodeKind::AgentTool {
384            agent: crate::agent_tool::AgentTarget::Explicit("jamjet://org/test".into()),
385            mode: crate::agent_tool::AgentToolMode::Sync,
386            input_mapping: Default::default(),
387            output_key: "result".into(),
388            timeout_ms: Some(5000),
389            budget: None,
390        };
391        let json = serde_json::to_string(&node).unwrap();
392        let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
393        assert!(matches!(deserialized, NodeKind::AgentTool { .. }));
394        assert_eq!(node.queue_type(), QueueType::General);
395        assert!(node.is_durable());
396    }
397
398    #[test]
399    fn agent_discovery_is_deprecated_but_functional() {
400        #[allow(deprecated)]
401        let node = NodeKind::AgentDiscovery {
402            skill: "data-analysis".into(),
403            protocol: None,
404            output_binding: "selected_agent".into(),
405        };
406        #[allow(deprecated)]
407        let _ = node.queue_type();
408    }
409}