Skip to main content

jamjet_ir/
workflow.rs

1use jamjet_core::node::NodeKind;
2use jamjet_core::retry::RetryPolicy;
3use jamjet_core::timeout::TimeoutConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// The canonical Intermediate Representation for a JamJet workflow.
8///
9/// Both YAML and Python workflow definitions compile to this struct before
10/// being submitted to the runtime. The IR is serializable to JSON and YAML.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WorkflowIr {
13    /// Unique workflow definition identifier.
14    pub workflow_id: String,
15    /// Semantic version string (e.g., "1.0.0").
16    pub version: String,
17    /// Human-readable name.
18    pub name: Option<String>,
19    /// Optional description.
20    pub description: Option<String>,
21    /// Reference to the Pydantic model or JSON Schema for workflow state.
22    pub state_schema: String,
23    /// The first node to execute.
24    pub start_node: String,
25    /// All nodes in the workflow graph, keyed by node id.
26    pub nodes: HashMap<String, NodeDef>,
27    /// All edges (transitions) between nodes.
28    pub edges: Vec<EdgeDef>,
29    /// Named retry policies referenced by nodes.
30    pub retry_policies: HashMap<String, RetryPolicy>,
31    /// Timeout configuration for this workflow.
32    #[serde(default)]
33    pub timeouts: TimeoutConfig,
34    /// Named model configurations referenced by model nodes.
35    pub models: HashMap<String, ModelConfig>,
36    /// Named tool configurations referenced by tool nodes.
37    pub tools: HashMap<String, ToolConfig>,
38    /// Named MCP server configurations.
39    pub mcp_servers: HashMap<String, McpServerConfig>,
40    /// Named remote A2A agents.
41    pub remote_agents: HashMap<String, RemoteAgentConfig>,
42    /// Observability labels attached to all spans from this workflow.
43    #[serde(default)]
44    pub labels: HashMap<String, String>,
45    /// Workflow-level policy set (overrides global; node-level overrides this).
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub policy: Option<PolicySetIr>,
48    /// Per-execution token budget enforcement.
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub token_budget: Option<TokenBudgetIr>,
51    /// Per-execution cost cap in USD. Execution fails / branches when exceeded.
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub cost_budget_usd: Option<f64>,
54    /// Node to route to when any budget is exceeded (optional).
55    /// If absent, the execution fails with `WorkflowFailed`.
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub on_budget_exceeded: Option<String>,
58    /// Data handling policy (PII redaction + retention controls).
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub data_policy: Option<DataPolicyIr>,
61}
62
63impl WorkflowIr {
64    /// Parse a WorkflowIr from a JSON string.
65    pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
66        serde_json::from_str(s)
67    }
68
69    /// Parse a WorkflowIr from a YAML string.
70    pub fn from_yaml(s: &str) -> Result<Self, serde_yaml::Error> {
71        serde_yaml::from_str(s)
72    }
73
74    /// Serialize to JSON (pretty-printed).
75    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
76        serde_json::to_string_pretty(self)
77    }
78
79    /// Look up a node by id.
80    pub fn node(&self, id: &str) -> Option<&NodeDef> {
81        self.nodes.get(id)
82    }
83
84    /// Return all edges originating from a given node.
85    pub fn edges_from(&self, node_id: &str) -> Vec<&EdgeDef> {
86        self.edges.iter().filter(|e| e.from == node_id).collect()
87    }
88
89    /// Return the successors of a node (all nodes it can transition to).
90    pub fn successors(&self, node_id: &str) -> Vec<&str> {
91        self.edges_from(node_id)
92            .into_iter()
93            .map(|e| e.to.as_str())
94            .collect()
95    }
96}
97
98/// A single node definition in the workflow IR.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct NodeDef {
101    pub id: String,
102    pub kind: NodeKind,
103    /// Reference to a named retry policy in `WorkflowIr::retry_policies`.
104    pub retry_policy: Option<String>,
105    /// Node-level timeout override (overrides workflow-level timeout).
106    pub node_timeout_secs: Option<u64>,
107    /// Human-readable description for display in traces and UI.
108    pub description: Option<String>,
109    /// Extra observability labels for this node's spans.
110    #[serde(default)]
111    pub labels: HashMap<String, String>,
112    /// Node-level policy override (most specific — overrides workflow + global).
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub policy: Option<PolicySetIr>,
115    /// Node-level data policy override.
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub data_policy: Option<DataPolicyIr>,
118}
119
120/// A directed edge between two nodes.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EdgeDef {
123    pub from: String,
124    pub to: String,
125    /// Optional condition expression. If None, this is an unconditional edge.
126    /// Expressions are evaluated against the current workflow state +
127    /// the last node's output.
128    pub condition: Option<String>,
129}
130
131/// Configuration for a model provider.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ModelConfig {
134    pub provider: String,
135    pub model: String,
136    pub timeout_secs: Option<u64>,
137    pub retry_policy: Option<String>,
138    pub temperature: Option<f32>,
139    pub max_tokens: Option<u32>,
140}
141
142/// Configuration for a tool.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ToolConfig {
145    pub kind: ToolKind,
146    /// Python: "module.submodule:function_name"
147    pub reference: String,
148    pub input_schema: Option<String>,
149    pub output_schema: Option<String>,
150    #[serde(default)]
151    pub permissions: Vec<String>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
155#[serde(rename_all = "snake_case")]
156pub enum ToolKind {
157    Python,
158    Http,
159    Grpc,
160    Mcp,
161}
162
163/// Configuration for an MCP server connection.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct McpServerConfig {
166    pub transport: McpTransport,
167    /// For stdio transport.
168    pub command: Option<String>,
169    pub args: Vec<String>,
170    /// For http_sse transport.
171    pub url: Option<String>,
172    pub auth: Option<AuthConfig>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176#[serde(rename_all = "snake_case")]
177pub enum McpTransport {
178    Stdio,
179    HttpSse,
180    WebSocket,
181}
182
183/// Configuration for a remote A2A agent.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RemoteAgentConfig {
186    pub url: String,
187    pub agent_card_path: Option<String>,
188    pub auth: Option<AuthConfig>,
189}
190
191/// Policy rules embedded in the workflow IR.
192///
193/// This is the serializable form of `PolicySet` — it lives in the IR so that
194/// workflow YAML/JSON can declare policy inline. The policy engine converts
195/// `PolicySetIr` → internal `PolicySet` at evaluation time.
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
197pub struct PolicySetIr {
198    /// Exact tool names or glob patterns to block (e.g. `"payments.*"`).
199    #[serde(default)]
200    pub blocked_tools: Vec<String>,
201    /// Tool names/patterns that require human approval before execution.
202    #[serde(default)]
203    pub require_approval_for: Vec<String>,
204    /// If non-empty, only models in this list are allowed for model nodes.
205    #[serde(default)]
206    pub model_allowlist: Vec<String>,
207}
208
209/// Per-execution token budget configuration.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct TokenBudgetIr {
212    /// Maximum input tokens allowed for the full execution.
213    #[serde(default, skip_serializing_if = "Option::is_none")]
214    pub input_tokens: Option<u32>,
215    /// Maximum output tokens allowed for the full execution.
216    #[serde(default, skip_serializing_if = "Option::is_none")]
217    pub output_tokens: Option<u32>,
218    /// Maximum combined input + output tokens for the full execution.
219    #[serde(default, skip_serializing_if = "Option::is_none")]
220    pub total_tokens: Option<u32>,
221}
222
223/// Data handling policy — controls PII redaction and prompt/output retention.
224///
225/// Applied at workflow level and overridable per-node (same layering as `PolicySetIr`).
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct DataPolicyIr {
228    /// JSON path patterns that contain PII and must be redacted.
229    /// Examples: `"$.patient.ssn"`, `"$.user.email"`
230    #[serde(default)]
231    pub pii_fields: Vec<String>,
232
233    /// Built-in PII pattern detectors to enable.
234    /// Values: `"email"`, `"ssn"`, `"credit_card"`, `"phone"`, `"ip_address"`
235    #[serde(default)]
236    pub pii_detectors: Vec<String>,
237
238    /// How to redact PII values. Default: `"mask"`.
239    /// Options: `"mask"` (`[REDACTED]`), `"hash"` (SHA-256), `"remove"` (delete key).
240    #[serde(default = "default_redaction_mode")]
241    pub redaction_mode: String,
242
243    /// Whether to store prompts in the event/audit log. Default: false.
244    #[serde(default)]
245    pub retain_prompts: bool,
246
247    /// Whether to store model outputs/completions. Default: true.
248    #[serde(default = "default_true")]
249    pub retain_outputs: bool,
250
251    /// Retention duration for audit entries in days. None = indefinite.
252    #[serde(default, skip_serializing_if = "Option::is_none")]
253    pub retention_days: Option<u32>,
254}
255
256fn default_redaction_mode() -> String {
257    "mask".to_string()
258}
259
260fn default_true() -> bool {
261    true
262}
263
264impl Default for DataPolicyIr {
265    fn default() -> Self {
266        Self {
267            pii_fields: Vec::new(),
268            pii_detectors: Vec::new(),
269            redaction_mode: default_redaction_mode(),
270            retain_prompts: false,
271            retain_outputs: true,
272            retention_days: None,
273        }
274    }
275}
276
277/// Authentication configuration for external connections.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "type", rename_all = "snake_case")]
280pub enum AuthConfig {
281    Bearer {
282        token_env: String,
283    },
284    ApiKey {
285        header: String,
286        key_env: String,
287    },
288    Oauth2 {
289        client_id_env: String,
290        client_secret_env: String,
291        token_url: String,
292    },
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    fn minimal_ir() -> WorkflowIr {
300        WorkflowIr {
301            workflow_id: "test_workflow".into(),
302            version: "0.1.0".into(),
303            name: None,
304            description: None,
305            policy: None,
306            token_budget: None,
307            cost_budget_usd: None,
308            on_budget_exceeded: None,
309            data_policy: None,
310            state_schema: "schemas.TestState".into(),
311            start_node: "start".into(),
312            nodes: HashMap::new(),
313            edges: vec![],
314            retry_policies: HashMap::new(),
315            timeouts: TimeoutConfig::default(),
316            models: HashMap::new(),
317            tools: HashMap::new(),
318            mcp_servers: HashMap::new(),
319            remote_agents: HashMap::new(),
320            labels: HashMap::new(),
321        }
322    }
323
324    #[test]
325    fn ir_roundtrip_json() {
326        let ir = minimal_ir();
327        let json = ir.to_json_pretty().unwrap();
328        let parsed = WorkflowIr::from_json(&json).unwrap();
329        assert_eq!(parsed.workflow_id, ir.workflow_id);
330    }
331}