Skip to main content

jamjet_core/
node.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4pub type NodeId = String;
5
6/// The lifecycle status of a single node within an execution.
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8#[serde(rename_all = "snake_case")]
9pub enum NodeStatus {
10    Pending,
11    Scheduled,
12    Running,
13    Completed,
14    Failed,
15    Skipped,
16    Cancelled,
17}
18
19impl NodeStatus {
20    pub fn is_terminal(&self) -> bool {
21        matches!(
22            self,
23            Self::Completed | Self::Failed | Self::Skipped | Self::Cancelled
24        )
25    }
26}
27
28/// All node kinds supported by the JamJet runtime.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum NodeKind {
32    /// LLM call with a prompt and structured output.
33    Model {
34        model_ref: String,
35        prompt_ref: String,
36        output_schema: String,
37        system_prompt: Option<String>,
38    },
39
40    /// Python function, HTTP endpoint, or gRPC tool.
41    Tool {
42        tool_ref: String,
43        input_mapping: HashMap<String, String>,
44        output_schema: String,
45    },
46
47    /// Arbitrary Python function executed by a Python worker.
48    PythonFn {
49        module: String,
50        function: String,
51        output_schema: String,
52    },
53
54    /// Router — evaluates expressions and branches.
55    Condition { branches: Vec<ConditionalBranch> },
56
57    /// Fan-out to multiple branches concurrently.
58    Parallel { branches: Vec<NodeId> },
59
60    /// Waits for all parallel branches to complete.
61    Join {
62        wait_for: Vec<NodeId>,
63        merge_strategy: MergeStrategy,
64    },
65
66    /// Pauses workflow for human decision.
67    HumanApproval {
68        description: String,
69        timeout_secs: Option<u64>,
70        fallback_node: Option<NodeId>,
71    },
72
73    /// Suspends until a timer fires or external event arrives.
74    Wait {
75        condition: WaitCondition,
76        correlation_key: Option<String>,
77        timeout_secs: Option<u64>,
78    },
79
80    /// Executes a child workflow.
81    Subgraph {
82        workflow_ref: String,
83        workflow_version: Option<String>,
84        input_mapping: HashMap<String, String>,
85        output_mapping: HashMap<String, String>,
86    },
87
88    /// Retrieves context from a memory/retrieval connector.
89    MemoryRetrieval {
90        connector_ref: String,
91        query_expr: String,
92        output_schema: String,
93    },
94
95    /// Evaluates policy rules; can block or branch on violation.
96    Policy {
97        policy_ref: String,
98        on_violation: ViolationAction,
99    },
100
101    /// Side-effect node (notifications, writes).
102    Finalizer {
103        tool_ref: String,
104        run_on: FinalizerTrigger,
105    },
106
107    // ── Protocol nodes ──────────────────────────────────────────────────
108    /// Delegates to a local JamJet agent.
109    Agent {
110        agent_ref: String,
111        input_mapping: HashMap<String, String>,
112        output_schema: String,
113    },
114
115    /// Invokes a tool from an external MCP server.
116    McpTool {
117        server: String,
118        tool: String,
119        input_mapping: HashMap<String, String>,
120        output_schema: String,
121    },
122
123    /// Delegates a task to an external A2A agent.
124    A2aTask {
125        remote_agent: String,
126        skill: String,
127        input_mapping: HashMap<String, String>,
128        output_schema: String,
129        stream: bool,
130        on_input_required: Option<NodeId>,
131        timeout_secs: Option<u64>,
132    },
133
134    #[deprecated(note = "Use Coordinator node instead")]
135    /// Dynamically discovers and selects an agent at runtime.
136    AgentDiscovery {
137        skill: String,
138        protocol: Option<String>,
139        output_binding: String,
140    },
141
142    /// Dynamic agent routing with structured scoring + LLM tiebreaker.
143    /// Supersedes AgentDiscovery.
144    Coordinator {
145        task: String,
146        required_skills: Vec<String>,
147        #[serde(default)]
148        preferred_skills: Vec<String>,
149        trust_domain: Option<String>,
150        budget: Option<crate::coordinator::CoordinatorBudget>,
151        tiebreaker: Option<crate::coordinator::TiebreakerConfig>,
152        #[serde(default = "default_strategy")]
153        strategy: String,
154        #[serde(default)]
155        weights: crate::coordinator::DimensionWeights,
156        #[serde(default)]
157        input_mapping: HashMap<String, String>,
158        output_key: String,
159    },
160
161    /// Invoke a registered agent as a callable tool.
162    AgentTool {
163        agent: crate::agent_tool::AgentTarget,
164        #[serde(default)]
165        mode: crate::agent_tool::AgentToolMode,
166        #[serde(default)]
167        input_mapping: HashMap<String, String>,
168        output_key: String,
169        timeout_ms: Option<u64>,
170        budget: Option<crate::agent_tool::AgentToolBudget>,
171    },
172
173    /// Evaluates the preceding node's output using configurable scorers.
174    ///
175    /// Supports LLM-judge, deterministic assertions, latency/cost thresholds,
176    /// and custom Python scorer plugins.
177    Eval {
178        /// Ordered list of scorer configurations.
179        scorers: Vec<EvalScorer>,
180        /// Action on overall failure (any scorer below threshold).
181        on_fail: EvalOnFail,
182        /// Maximum retry attempts before propagating failure.
183        #[serde(default)]
184        max_retries: u32,
185        /// Input expression — which state field to evaluate (default: last node output).
186        input_expr: Option<String>,
187    },
188}
189
190impl NodeKind {
191    /// Returns the queue type this node should be dispatched to.
192    pub fn queue_type(&self) -> QueueType {
193        match self {
194            Self::Model { .. } => QueueType::Model,
195            Self::Tool { .. } | Self::Finalizer { .. } => QueueType::Tool,
196            Self::PythonFn { .. } => QueueType::PythonTool,
197            Self::MemoryRetrieval { .. } => QueueType::Retrieval,
198            Self::McpTool { .. } | Self::A2aTask { .. } => QueueType::Tool,
199            Self::Agent { .. } => QueueType::General,
200            Self::HumanApproval { .. } | Self::Wait { .. } => QueueType::General,
201            Self::Eval { .. } => QueueType::General,
202            Self::Coordinator { .. } => QueueType::General,
203            Self::AgentTool { .. } => QueueType::General,
204            _ => QueueType::General,
205        }
206    }
207
208    /// Returns true if this node requires durable tracking across crashes.
209    pub fn is_durable(&self) -> bool {
210        #[allow(deprecated)]
211        let is_agent_discovery = matches!(self, Self::AgentDiscovery { .. });
212        !matches!(self, Self::Condition { .. }) && !is_agent_discovery
213    }
214}
215
216/// Which queue a node's work item is dispatched to.
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218#[serde(rename_all = "snake_case")]
219pub enum QueueType {
220    Model,
221    Tool,
222    PythonTool,
223    Retrieval,
224    Privileged,
225    General,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct ConditionalBranch {
230    pub condition: Option<String>, // None = default/else branch
231    pub target: NodeId,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235#[serde(rename_all = "snake_case")]
236pub enum MergeStrategy {
237    /// Merge all branch outputs into a list.
238    Collect,
239    /// Take the first completed branch output.
240    First,
241    /// Custom merge function.
242    Custom { function_ref: String },
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246#[serde(rename_all = "snake_case")]
247pub enum WaitCondition {
248    Timer,
249    ExternalEvent,
250    Either,
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254#[serde(rename_all = "snake_case")]
255pub enum ViolationAction {
256    Fail,
257    Branch { target: NodeId },
258    Warn,
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(rename_all = "snake_case")]
263pub enum FinalizerTrigger {
264    Success,
265    Failure,
266    Always,
267}
268
269// ── Eval node types ──────────────────────────────────────────────────────────
270
271/// A scorer within an `Eval` node.
272#[derive(Debug, Clone, Serialize, Deserialize)]
273#[serde(tag = "type", rename_all = "snake_case")]
274pub enum EvalScorer {
275    /// LLM-as-judge: sends output to a model with a rubric, expects a score 1-5.
276    LlmJudge {
277        model: String,
278        rubric: String,
279        /// Minimum acceptable score (1-5). Scores below this fail.
280        #[serde(default = "default_min_score")]
281        min_score: u8,
282    },
283    /// Deterministic Python expressions evaluated against the output.
284    Assertion {
285        /// Each check is a Python expression that must evaluate to truthy.
286        checks: Vec<String>,
287    },
288    /// Ensures node execution completed within a latency threshold.
289    Latency {
290        /// Maximum allowed duration in milliseconds.
291        threshold_ms: u64,
292    },
293    /// Ensures the execution cost is within budget.
294    Cost {
295        /// Maximum allowed cost in USD.
296        threshold_usd: f64,
297    },
298    /// Custom Python scorer loaded via entry point or module path.
299    Custom {
300        /// Python dotted path: "my_package.scorers:MyScorer"
301        module: String,
302        /// Optional keyword arguments passed to the scorer.
303        #[serde(default)]
304        kwargs: serde_json::Value,
305    },
306}
307
308fn default_min_score() -> u8 {
309    3
310}
311
312fn default_strategy() -> String {
313    "default".to_string()
314}
315
316/// What the eval node does when one or more scorers fail.
317#[derive(Debug, Clone, Serialize, Deserialize, Default)]
318#[serde(rename_all = "snake_case")]
319pub enum EvalOnFail {
320    /// Feed scorer feedback back to the previous node and retry.
321    RetryWithFeedback,
322    /// Escalate to human (triggers HumanApproval fallback node).
323    Escalate,
324    /// Fail the workflow immediately.
325    #[default]
326    Halt,
327    /// Record the failure but continue the workflow.
328    LogAndContinue,
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn model_node_dispatches_to_model_queue() {
337        let node = NodeKind::Model {
338            model_ref: "openai.gpt4".into(),
339            prompt_ref: "prompts/summarize.md".into(),
340            output_schema: "schemas.Summary".into(),
341            system_prompt: None,
342        };
343        assert_eq!(node.queue_type(), QueueType::Model);
344        assert!(node.is_durable());
345    }
346
347    #[test]
348    fn condition_node_is_not_durable() {
349        let node = NodeKind::Condition { branches: vec![] };
350        assert!(!node.is_durable());
351    }
352
353    #[test]
354    fn coordinator_node_round_trip() {
355        let node = NodeKind::Coordinator {
356            task: "Analyze data".into(),
357            required_skills: vec!["data-analysis".into()],
358            preferred_skills: vec![],
359            trust_domain: Some("internal".into()),
360            budget: None,
361            tiebreaker: None,
362            strategy: "default".into(),
363            weights: Default::default(),
364            input_mapping: Default::default(),
365            output_key: "result".into(),
366        };
367        let json = serde_json::to_string(&node).unwrap();
368        let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
369        assert!(matches!(deserialized, NodeKind::Coordinator { .. }));
370        assert_eq!(node.queue_type(), QueueType::General);
371        assert!(node.is_durable());
372    }
373
374    #[test]
375    fn agent_tool_node_round_trip() {
376        let node = NodeKind::AgentTool {
377            agent: crate::agent_tool::AgentTarget::Explicit("jamjet://org/test".into()),
378            mode: crate::agent_tool::AgentToolMode::Sync,
379            input_mapping: Default::default(),
380            output_key: "result".into(),
381            timeout_ms: Some(5000),
382            budget: None,
383        };
384        let json = serde_json::to_string(&node).unwrap();
385        let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
386        assert!(matches!(deserialized, NodeKind::AgentTool { .. }));
387        assert_eq!(node.queue_type(), QueueType::General);
388        assert!(node.is_durable());
389    }
390
391    #[test]
392    fn agent_discovery_is_deprecated_but_functional() {
393        #[allow(deprecated)]
394        let node = NodeKind::AgentDiscovery {
395            skill: "data-analysis".into(),
396            protocol: None,
397            output_binding: "selected_agent".into(),
398        };
399        #[allow(deprecated)]
400        let _ = node.queue_type();
401    }
402}