use jamjet_core::node::NodeKind;
use jamjet_core::retry::RetryPolicy;
use jamjet_core::timeout::TimeoutConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowIr {
pub workflow_id: String,
pub version: String,
pub name: Option<String>,
pub description: Option<String>,
pub state_schema: String,
pub start_node: String,
pub nodes: HashMap<String, NodeDef>,
pub edges: Vec<EdgeDef>,
pub retry_policies: HashMap<String, RetryPolicy>,
#[serde(default)]
pub timeouts: TimeoutConfig,
pub models: HashMap<String, ModelConfig>,
pub tools: HashMap<String, ToolConfig>,
pub mcp_servers: HashMap<String, McpServerConfig>,
pub remote_agents: HashMap<String, RemoteAgentConfig>,
#[serde(default)]
pub labels: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub policy: Option<PolicySetIr>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_budget: Option<TokenBudgetIr>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cost_budget_usd: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub on_budget_exceeded: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data_policy: Option<DataPolicyIr>,
}
impl WorkflowIr {
pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(s)
}
pub fn from_yaml(s: &str) -> Result<Self, serde_yaml::Error> {
serde_yaml::from_str(s)
}
pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
pub fn node(&self, id: &str) -> Option<&NodeDef> {
self.nodes.get(id)
}
pub fn edges_from(&self, node_id: &str) -> Vec<&EdgeDef> {
self.edges.iter().filter(|e| e.from == node_id).collect()
}
pub fn successors(&self, node_id: &str) -> Vec<&str> {
self.edges_from(node_id)
.into_iter()
.map(|e| e.to.as_str())
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeDef {
pub id: String,
pub kind: NodeKind,
pub retry_policy: Option<String>,
pub node_timeout_secs: Option<u64>,
pub description: Option<String>,
#[serde(default)]
pub labels: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub policy: Option<PolicySetIr>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data_policy: Option<DataPolicyIr>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeDef {
pub from: String,
pub to: String,
pub condition: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub provider: String,
pub model: String,
pub timeout_secs: Option<u64>,
pub retry_policy: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolConfig {
pub kind: ToolKind,
pub reference: String,
pub input_schema: Option<String>,
pub output_schema: Option<String>,
#[serde(default)]
pub permissions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolKind {
Python,
Http,
Grpc,
Mcp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub transport: McpTransport,
pub command: Option<String>,
pub args: Vec<String>,
pub url: Option<String>,
pub auth: Option<AuthConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum McpTransport {
Stdio,
HttpSse,
WebSocket,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoteAgentConfig {
pub url: String,
pub agent_card_path: Option<String>,
pub auth: Option<AuthConfig>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PolicySetIr {
#[serde(default)]
pub blocked_tools: Vec<String>,
#[serde(default)]
pub require_approval_for: Vec<String>,
#[serde(default)]
pub model_allowlist: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenBudgetIr {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataPolicyIr {
#[serde(default)]
pub pii_fields: Vec<String>,
#[serde(default)]
pub pii_detectors: Vec<String>,
#[serde(default = "default_redaction_mode")]
pub redaction_mode: String,
#[serde(default)]
pub retain_prompts: bool,
#[serde(default = "default_true")]
pub retain_outputs: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retention_days: Option<u32>,
}
fn default_redaction_mode() -> String {
"mask".to_string()
}
fn default_true() -> bool {
true
}
impl Default for DataPolicyIr {
fn default() -> Self {
Self {
pii_fields: Vec::new(),
pii_detectors: Vec::new(),
redaction_mode: default_redaction_mode(),
retain_prompts: false,
retain_outputs: true,
retention_days: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthConfig {
Bearer {
token_env: String,
},
ApiKey {
header: String,
key_env: String,
},
Oauth2 {
client_id_env: String,
client_secret_env: String,
token_url: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_ir() -> WorkflowIr {
WorkflowIr {
workflow_id: "test_workflow".into(),
version: "0.1.0".into(),
name: None,
description: None,
policy: None,
token_budget: None,
cost_budget_usd: None,
on_budget_exceeded: None,
data_policy: None,
state_schema: "schemas.TestState".into(),
start_node: "start".into(),
nodes: HashMap::new(),
edges: vec![],
retry_policies: HashMap::new(),
timeouts: TimeoutConfig::default(),
models: HashMap::new(),
tools: HashMap::new(),
mcp_servers: HashMap::new(),
remote_agents: HashMap::new(),
labels: HashMap::new(),
}
}
#[test]
fn ir_roundtrip_json() {
let ir = minimal_ir();
let json = ir.to_json_pretty().unwrap();
let parsed = WorkflowIr::from_json(&json).unwrap();
assert_eq!(parsed.workflow_id, ir.workflow_id);
}
}