1use jamjet_core::node::NodeKind;
2use jamjet_core::retry::RetryPolicy;
3use jamjet_core::timeout::TimeoutConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WorkflowIr {
13 pub workflow_id: String,
15 pub version: String,
17 pub name: Option<String>,
19 pub description: Option<String>,
21 pub state_schema: String,
23 pub start_node: String,
25 pub nodes: HashMap<String, NodeDef>,
27 pub edges: Vec<EdgeDef>,
29 pub retry_policies: HashMap<String, RetryPolicy>,
31 #[serde(default)]
33 pub timeouts: TimeoutConfig,
34 pub models: HashMap<String, ModelConfig>,
36 pub tools: HashMap<String, ToolConfig>,
38 pub mcp_servers: HashMap<String, McpServerConfig>,
40 pub remote_agents: HashMap<String, RemoteAgentConfig>,
42 #[serde(default)]
44 pub labels: HashMap<String, String>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub policy: Option<PolicySetIr>,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub token_budget: Option<TokenBudgetIr>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub cost_budget_usd: Option<f64>,
54 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub on_budget_exceeded: Option<String>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub data_policy: Option<DataPolicyIr>,
61}
62
63impl WorkflowIr {
64 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
66 serde_json::from_str(s)
67 }
68
69 pub fn from_yaml(s: &str) -> Result<Self, serde_yaml::Error> {
71 serde_yaml::from_str(s)
72 }
73
74 pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
76 serde_json::to_string_pretty(self)
77 }
78
79 pub fn node(&self, id: &str) -> Option<&NodeDef> {
81 self.nodes.get(id)
82 }
83
84 pub fn edges_from(&self, node_id: &str) -> Vec<&EdgeDef> {
86 self.edges.iter().filter(|e| e.from == node_id).collect()
87 }
88
89 pub fn successors(&self, node_id: &str) -> Vec<&str> {
91 self.edges_from(node_id)
92 .into_iter()
93 .map(|e| e.to.as_str())
94 .collect()
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct NodeDef {
101 pub id: String,
102 pub kind: NodeKind,
103 pub retry_policy: Option<String>,
105 pub node_timeout_secs: Option<u64>,
107 pub description: Option<String>,
109 #[serde(default)]
111 pub labels: HashMap<String, String>,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub policy: Option<PolicySetIr>,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
117 pub data_policy: Option<DataPolicyIr>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EdgeDef {
123 pub from: String,
124 pub to: String,
125 pub condition: Option<String>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ModelConfig {
134 pub provider: String,
135 pub model: String,
136 pub timeout_secs: Option<u64>,
137 pub retry_policy: Option<String>,
138 pub temperature: Option<f32>,
139 pub max_tokens: Option<u32>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ToolConfig {
145 pub kind: ToolKind,
146 pub reference: String,
148 pub input_schema: Option<String>,
149 pub output_schema: Option<String>,
150 #[serde(default)]
151 pub permissions: Vec<String>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
155#[serde(rename_all = "snake_case")]
156pub enum ToolKind {
157 Python,
158 Http,
159 Grpc,
160 Mcp,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct McpServerConfig {
166 pub transport: McpTransport,
167 pub command: Option<String>,
169 pub args: Vec<String>,
170 pub url: Option<String>,
172 pub auth: Option<AuthConfig>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176#[serde(rename_all = "snake_case")]
177pub enum McpTransport {
178 Stdio,
179 HttpSse,
180 WebSocket,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RemoteAgentConfig {
186 pub url: String,
187 pub agent_card_path: Option<String>,
188 pub auth: Option<AuthConfig>,
189}
190
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
197pub struct PolicySetIr {
198 #[serde(default)]
200 pub blocked_tools: Vec<String>,
201 #[serde(default)]
203 pub require_approval_for: Vec<String>,
204 #[serde(default)]
206 pub model_allowlist: Vec<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct TokenBudgetIr {
212 #[serde(default, skip_serializing_if = "Option::is_none")]
214 pub input_tokens: Option<u32>,
215 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub output_tokens: Option<u32>,
218 #[serde(default, skip_serializing_if = "Option::is_none")]
220 pub total_tokens: Option<u32>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct DataPolicyIr {
228 #[serde(default)]
231 pub pii_fields: Vec<String>,
232
233 #[serde(default)]
236 pub pii_detectors: Vec<String>,
237
238 #[serde(default = "default_redaction_mode")]
241 pub redaction_mode: String,
242
243 #[serde(default)]
245 pub retain_prompts: bool,
246
247 #[serde(default = "default_true")]
249 pub retain_outputs: bool,
250
251 #[serde(default, skip_serializing_if = "Option::is_none")]
253 pub retention_days: Option<u32>,
254}
255
256fn default_redaction_mode() -> String {
257 "mask".to_string()
258}
259
260fn default_true() -> bool {
261 true
262}
263
264impl Default for DataPolicyIr {
265 fn default() -> Self {
266 Self {
267 pii_fields: Vec::new(),
268 pii_detectors: Vec::new(),
269 redaction_mode: default_redaction_mode(),
270 retain_prompts: false,
271 retain_outputs: true,
272 retention_days: None,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "type", rename_all = "snake_case")]
280pub enum AuthConfig {
281 Bearer {
282 token_env: String,
283 },
284 ApiKey {
285 header: String,
286 key_env: String,
287 },
288 Oauth2 {
289 client_id_env: String,
290 client_secret_env: String,
291 token_url: String,
292 },
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 fn minimal_ir() -> WorkflowIr {
300 WorkflowIr {
301 workflow_id: "test_workflow".into(),
302 version: "0.1.0".into(),
303 name: None,
304 description: None,
305 policy: None,
306 token_budget: None,
307 cost_budget_usd: None,
308 on_budget_exceeded: None,
309 data_policy: None,
310 state_schema: "schemas.TestState".into(),
311 start_node: "start".into(),
312 nodes: HashMap::new(),
313 edges: vec![],
314 retry_policies: HashMap::new(),
315 timeouts: TimeoutConfig::default(),
316 models: HashMap::new(),
317 tools: HashMap::new(),
318 mcp_servers: HashMap::new(),
319 remote_agents: HashMap::new(),
320 labels: HashMap::new(),
321 }
322 }
323
324 #[test]
325 fn ir_roundtrip_json() {
326 let ir = minimal_ir();
327 let json = ir.to_json_pretty().unwrap();
328 let parsed = WorkflowIr::from_json(&json).unwrap();
329 assert_eq!(parsed.workflow_id, ir.workflow_id);
330 }
331}