use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type NodeId = String;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NodeStatus {
Pending,
Scheduled,
Running,
Completed,
Failed,
Skipped,
Cancelled,
}
impl NodeStatus {
pub fn is_terminal(&self) -> bool {
matches!(
self,
Self::Completed | Self::Failed | Self::Skipped | Self::Cancelled
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum NodeKind {
Model {
model_ref: String,
prompt_ref: String,
output_schema: String,
system_prompt: Option<String>,
},
Tool {
tool_ref: String,
input_mapping: HashMap<String, String>,
output_schema: String,
},
PythonFn {
module: String,
function: String,
output_schema: String,
},
Condition { branches: Vec<ConditionalBranch> },
Parallel { branches: Vec<NodeId> },
Join {
wait_for: Vec<NodeId>,
merge_strategy: MergeStrategy,
},
HumanApproval {
description: String,
timeout_secs: Option<u64>,
fallback_node: Option<NodeId>,
},
Wait {
condition: WaitCondition,
correlation_key: Option<String>,
timeout_secs: Option<u64>,
},
Subgraph {
workflow_ref: String,
workflow_version: Option<String>,
input_mapping: HashMap<String, String>,
output_mapping: HashMap<String, String>,
},
MemoryRetrieval {
connector_ref: String,
query_expr: String,
output_schema: String,
},
Policy {
policy_ref: String,
on_violation: ViolationAction,
},
Finalizer {
tool_ref: String,
run_on: FinalizerTrigger,
},
Agent {
agent_ref: String,
input_mapping: HashMap<String, String>,
output_schema: String,
},
McpTool {
server: String,
tool: String,
input_mapping: HashMap<String, String>,
output_schema: String,
},
A2aTask {
remote_agent: String,
skill: String,
input_mapping: HashMap<String, String>,
output_schema: String,
stream: bool,
on_input_required: Option<NodeId>,
timeout_secs: Option<u64>,
},
#[deprecated(note = "Use Coordinator node instead")]
AgentDiscovery {
skill: String,
protocol: Option<String>,
output_binding: String,
},
Coordinator {
task: String,
required_skills: Vec<String>,
#[serde(default)]
preferred_skills: Vec<String>,
trust_domain: Option<String>,
budget: Option<crate::coordinator::CoordinatorBudget>,
tiebreaker: Option<crate::coordinator::TiebreakerConfig>,
#[serde(default = "default_strategy")]
strategy: String,
#[serde(default)]
weights: crate::coordinator::DimensionWeights,
#[serde(default)]
input_mapping: HashMap<String, String>,
output_key: String,
},
AgentTool {
agent: crate::agent_tool::AgentTarget,
#[serde(default)]
mode: crate::agent_tool::AgentToolMode,
#[serde(default)]
input_mapping: HashMap<String, String>,
output_key: String,
timeout_ms: Option<u64>,
budget: Option<crate::agent_tool::AgentToolBudget>,
},
Eval {
scorers: Vec<EvalScorer>,
on_fail: EvalOnFail,
#[serde(default)]
max_retries: u32,
input_expr: Option<String>,
},
LimitExceeded,
}
impl NodeKind {
pub fn queue_type(&self) -> QueueType {
match self {
Self::Model { .. } => QueueType::Model,
Self::Tool { .. } | Self::Finalizer { .. } => QueueType::Tool,
Self::PythonFn { .. } => QueueType::PythonTool,
Self::MemoryRetrieval { .. } => QueueType::Retrieval,
Self::McpTool { .. } | Self::A2aTask { .. } => QueueType::Tool,
Self::Agent { .. } => QueueType::General,
Self::HumanApproval { .. } | Self::Wait { .. } => QueueType::General,
Self::Eval { .. } => QueueType::General,
Self::Coordinator { .. } => QueueType::General,
Self::AgentTool { .. } => QueueType::General,
_ => QueueType::General,
}
}
pub fn is_durable(&self) -> bool {
#[allow(deprecated)]
let is_agent_discovery = matches!(self, Self::AgentDiscovery { .. });
!matches!(self, Self::Condition { .. }) && !is_agent_discovery
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QueueType {
Model,
Tool,
PythonTool,
Retrieval,
Privileged,
General,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConditionalBranch {
pub condition: Option<String>, pub target: NodeId,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MergeStrategy {
Collect,
First,
Custom { function_ref: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WaitCondition {
Timer,
ExternalEvent,
Either,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ViolationAction {
Fail,
Branch { target: NodeId },
Warn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinalizerTrigger {
Success,
Failure,
Always,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EvalScorer {
LlmJudge {
model: String,
rubric: String,
#[serde(default = "default_min_score")]
min_score: u8,
},
Assertion {
checks: Vec<String>,
},
Latency {
threshold_ms: u64,
},
Cost {
threshold_usd: f64,
},
Custom {
module: String,
#[serde(default)]
kwargs: serde_json::Value,
},
}
fn default_min_score() -> u8 {
3
}
fn default_strategy() -> String {
"default".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum EvalOnFail {
RetryWithFeedback,
Escalate,
#[default]
Halt,
LogAndContinue,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_node_dispatches_to_model_queue() {
let node = NodeKind::Model {
model_ref: "openai.gpt4".into(),
prompt_ref: "prompts/summarize.md".into(),
output_schema: "schemas.Summary".into(),
system_prompt: None,
};
assert_eq!(node.queue_type(), QueueType::Model);
assert!(node.is_durable());
}
#[test]
fn condition_node_is_not_durable() {
let node = NodeKind::Condition { branches: vec![] };
assert!(!node.is_durable());
}
#[test]
fn coordinator_node_round_trip() {
let node = NodeKind::Coordinator {
task: "Analyze data".into(),
required_skills: vec!["data-analysis".into()],
preferred_skills: vec![],
trust_domain: Some("internal".into()),
budget: None,
tiebreaker: None,
strategy: "default".into(),
weights: Default::default(),
input_mapping: Default::default(),
output_key: "result".into(),
};
let json = serde_json::to_string(&node).unwrap();
let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, NodeKind::Coordinator { .. }));
assert_eq!(node.queue_type(), QueueType::General);
assert!(node.is_durable());
}
#[test]
fn agent_tool_node_round_trip() {
let node = NodeKind::AgentTool {
agent: crate::agent_tool::AgentTarget::Explicit("jamjet://org/test".into()),
mode: crate::agent_tool::AgentToolMode::Sync,
input_mapping: Default::default(),
output_key: "result".into(),
timeout_ms: Some(5000),
budget: None,
};
let json = serde_json::to_string(&node).unwrap();
let deserialized: NodeKind = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, NodeKind::AgentTool { .. }));
assert_eq!(node.queue_type(), QueueType::General);
assert!(node.is_durable());
}
#[test]
fn agent_discovery_is_deprecated_but_functional() {
#[allow(deprecated)]
let node = NodeKind::AgentDiscovery {
skill: "data-analysis".into(),
protocol: None,
output_binding: "selected_agent".into(),
};
#[allow(deprecated)]
let _ = node.queue_type();
}
}