Skip to main content

hydra_compiler/
ast.rs

1//! AST nodes for compiled action sequences.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7/// A node in the action AST
8#[derive(Debug, Clone, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ActionNode {
11    /// Execute a single tool/action
12    Action {
13        tool: String,
14        params: HashMap<String, ParamExpr>,
15    },
16    /// Execute a sequence of actions in order
17    Sequence(Vec<ActionNode>),
18    /// Conditional execution
19    If {
20        condition: ConditionExpr,
21        then: Box<ActionNode>,
22        #[serde(rename = "else")]
23        else_: Option<Box<ActionNode>>,
24    },
25    /// Iterate over a collection
26    ForEach {
27        variable: String,
28        collection: CollectionExpr,
29        body: Box<ActionNode>,
30    },
31    /// Store result of an action for later use
32    StoreResult {
33        key: String,
34        action: Box<ActionNode>,
35    },
36}
37
38impl ActionNode {
39    /// Count the number of leaf actions in this AST
40    pub fn action_count(&self) -> usize {
41        match self {
42            Self::Action { .. } => 1,
43            Self::Sequence(nodes) => nodes.iter().map(|n| n.action_count()).sum(),
44            Self::If { then, else_, .. } => {
45                then.action_count() + else_.as_ref().map(|e| e.action_count()).unwrap_or(0)
46            }
47            Self::ForEach { body, .. } => body.action_count(),
48            Self::StoreResult { action, .. } => action.action_count(),
49        }
50    }
51
52    /// Get all tool names referenced in this AST
53    pub fn tool_names(&self) -> Vec<&str> {
54        let mut names = Vec::new();
55        self.collect_tool_names(&mut names);
56        names
57    }
58
59    fn collect_tool_names<'a>(&'a self, names: &mut Vec<&'a str>) {
60        match self {
61            Self::Action { tool, .. } => names.push(tool),
62            Self::Sequence(nodes) => {
63                for node in nodes {
64                    node.collect_tool_names(names);
65                }
66            }
67            Self::If { then, else_, .. } => {
68                then.collect_tool_names(names);
69                if let Some(e) = else_ {
70                    e.collect_tool_names(names);
71                }
72            }
73            Self::ForEach { body, .. } => body.collect_tool_names(names),
74            Self::StoreResult { action, .. } => action.collect_tool_names(names),
75        }
76    }
77}
78
79/// Parameter expression — how to compute a parameter value
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum ParamExpr {
83    /// A fixed literal value
84    Literal(serde_json::Value),
85    /// A variable extracted from user input
86    Variable(String),
87    /// Result from a previous step
88    PreviousResult(String),
89    /// A computed transformation
90    Computed(ComputeRule),
91}
92
93/// Condition for If nodes
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(rename_all = "snake_case")]
96pub enum ConditionExpr {
97    /// Check if a variable/result equals a value
98    Equals {
99        left: String,
100        right: serde_json::Value,
101    },
102    /// Check if a result is not null/empty
103    Exists(String),
104    /// Check if a result indicates success
105    Success(String),
106    /// Boolean AND of conditions
107    And(Vec<ConditionExpr>),
108    /// Boolean OR of conditions
109    Or(Vec<ConditionExpr>),
110    /// Negate a condition
111    Not(Box<ConditionExpr>),
112}
113
114/// Collection expression for ForEach
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(rename_all = "snake_case")]
117pub enum CollectionExpr {
118    /// Literal array
119    Literal(Vec<serde_json::Value>),
120    /// From a previous result (expects array)
121    FromResult(String),
122    /// From a variable (expects array)
123    FromVariable(String),
124}
125
126/// A transformation rule for computed parameters
127#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(rename_all = "snake_case")]
129pub enum ComputeRule {
130    /// Concatenate strings
131    Concat(Vec<ParamExpr>),
132    /// Format a template string
133    Format {
134        template: String,
135        args: Vec<ParamExpr>,
136    },
137    /// Extract a field from a JSON value
138    Extract { source: String, field: String },
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_action_node() {
147        let node = ActionNode::Action {
148            tool: "git_commit".into(),
149            params: HashMap::from([("message".into(), ParamExpr::Variable("commit_msg".into()))]),
150        };
151        assert_eq!(node.action_count(), 1);
152        assert_eq!(node.tool_names(), vec!["git_commit"]);
153    }
154
155    #[test]
156    fn test_sequence() {
157        let node = ActionNode::Sequence(vec![
158            ActionNode::Action {
159                tool: "git_add".into(),
160                params: HashMap::from([(
161                    "path".into(),
162                    ParamExpr::Literal(serde_json::json!(".")),
163                )]),
164            },
165            ActionNode::Action {
166                tool: "git_commit".into(),
167                params: HashMap::from([("message".into(), ParamExpr::Variable("msg".into()))]),
168            },
169            ActionNode::Action {
170                tool: "git_push".into(),
171                params: HashMap::new(),
172            },
173        ]);
174        assert_eq!(node.action_count(), 3);
175        assert_eq!(node.tool_names(), vec!["git_add", "git_commit", "git_push"]);
176    }
177
178    #[test]
179    fn test_if_condition() {
180        let node = ActionNode::If {
181            condition: ConditionExpr::Success("step_1".into()),
182            then: Box::new(ActionNode::Action {
183                tool: "deploy".into(),
184                params: HashMap::new(),
185            }),
186            else_: Some(Box::new(ActionNode::Action {
187                tool: "rollback".into(),
188                params: HashMap::new(),
189            })),
190        };
191        assert_eq!(node.action_count(), 2);
192        assert_eq!(node.tool_names(), vec!["deploy", "rollback"]);
193    }
194
195    #[test]
196    fn test_foreach() {
197        let node = ActionNode::ForEach {
198            variable: "file".into(),
199            collection: CollectionExpr::Literal(vec![
200                serde_json::json!("a.rs"),
201                serde_json::json!("b.rs"),
202            ]),
203            body: Box::new(ActionNode::Action {
204                tool: "lint".into(),
205                params: HashMap::from([("path".into(), ParamExpr::Variable("file".into()))]),
206            }),
207        };
208        assert_eq!(node.action_count(), 1); // body template counted once
209        assert_eq!(node.tool_names(), vec!["lint"]);
210    }
211
212    #[test]
213    fn test_store_result() {
214        let node = ActionNode::StoreResult {
215            key: "branch".into(),
216            action: Box::new(ActionNode::Action {
217                tool: "git_branch".into(),
218                params: HashMap::new(),
219            }),
220        };
221        assert_eq!(node.action_count(), 1);
222    }
223
224    #[test]
225    fn test_serialization_roundtrip() {
226        let node = ActionNode::Sequence(vec![ActionNode::Action {
227            tool: "test".into(),
228            params: HashMap::from([
229                ("a".into(), ParamExpr::Literal(serde_json::json!(42))),
230                ("b".into(), ParamExpr::Variable("input".into())),
231            ]),
232        }]);
233        let json = serde_json::to_string(&node).unwrap();
234        let restored: ActionNode = serde_json::from_str(&json).unwrap();
235        assert_eq!(restored.action_count(), 1);
236    }
237
238    #[test]
239    fn test_empty_sequence() {
240        let node = ActionNode::Sequence(vec![]);
241        assert_eq!(node.action_count(), 0);
242        assert!(node.tool_names().is_empty());
243    }
244
245    #[test]
246    fn test_if_without_else() {
247        let node = ActionNode::If {
248            condition: ConditionExpr::Exists("x".into()),
249            then: Box::new(ActionNode::Action { tool: "a".into(), params: HashMap::new() }),
250            else_: None,
251        };
252        assert_eq!(node.action_count(), 1);
253        assert_eq!(node.tool_names(), vec!["a"]);
254    }
255
256    #[test]
257    fn test_nested_sequence() {
258        let node = ActionNode::Sequence(vec![
259            ActionNode::Sequence(vec![
260                ActionNode::Action { tool: "a".into(), params: HashMap::new() },
261                ActionNode::Action { tool: "b".into(), params: HashMap::new() },
262            ]),
263            ActionNode::Action { tool: "c".into(), params: HashMap::new() },
264        ]);
265        assert_eq!(node.action_count(), 3);
266        assert_eq!(node.tool_names(), vec!["a", "b", "c"]);
267    }
268
269    #[test]
270    fn test_store_result_tool_names() {
271        let node = ActionNode::StoreResult {
272            key: "result".into(),
273            action: Box::new(ActionNode::Sequence(vec![
274                ActionNode::Action { tool: "x".into(), params: HashMap::new() },
275                ActionNode::Action { tool: "y".into(), params: HashMap::new() },
276            ])),
277        };
278        assert_eq!(node.action_count(), 2);
279        assert_eq!(node.tool_names(), vec!["x", "y"]);
280    }
281
282    #[test]
283    fn test_condition_expr_serde() {
284        let cond = ConditionExpr::And(vec![
285            ConditionExpr::Exists("x".into()),
286            ConditionExpr::Not(Box::new(ConditionExpr::Success("y".into()))),
287        ]);
288        let json = serde_json::to_string(&cond).unwrap();
289        let _: ConditionExpr = serde_json::from_str(&json).unwrap();
290    }
291
292    #[test]
293    fn test_collection_expr_serde() {
294        let c = CollectionExpr::FromResult("step_1".into());
295        let json = serde_json::to_string(&c).unwrap();
296        let _: CollectionExpr = serde_json::from_str(&json).unwrap();
297    }
298
299    #[test]
300    fn test_compute_rule_serde() {
301        let rule = ComputeRule::Extract { source: "result".into(), field: "id".into() };
302        let json = serde_json::to_string(&rule).unwrap();
303        let _: ComputeRule = serde_json::from_str(&json).unwrap();
304    }
305
306    #[test]
307    fn test_param_expr_previous_result() {
308        let p = ParamExpr::PreviousResult("step_1".into());
309        let json = serde_json::to_string(&p).unwrap();
310        let restored: ParamExpr = serde_json::from_str(&json).unwrap();
311        assert!(matches!(restored, ParamExpr::PreviousResult(_)));
312    }
313
314    #[test]
315    fn test_param_expr_computed() {
316        let p = ParamExpr::Computed(ComputeRule::Concat(vec![
317            ParamExpr::Literal(serde_json::json!("hello ")),
318            ParamExpr::Variable("name".into()),
319        ]));
320        let json = serde_json::to_string(&p).unwrap();
321        let _: ParamExpr = serde_json::from_str(&json).unwrap();
322    }
323
324    #[test]
325    fn test_condition_or_serde() {
326        let cond = ConditionExpr::Or(vec![
327            ConditionExpr::Exists("a".into()),
328            ConditionExpr::Exists("b".into()),
329        ]);
330        let json = serde_json::to_string(&cond).unwrap();
331        let _: ConditionExpr = serde_json::from_str(&json).unwrap();
332    }
333
334    #[test]
335    fn test_condition_equals_serde() {
336        let cond = ConditionExpr::Equals { left: "x".into(), right: serde_json::json!(42) };
337        let json = serde_json::to_string(&cond).unwrap();
338        let _: ConditionExpr = serde_json::from_str(&json).unwrap();
339    }
340}