Skip to main content

apm_core/
recovery.rs

1use crate::config::{CompletionStrategy, WorkflowConfig};
2use std::collections::HashSet;
3
4#[derive(Debug, Clone, PartialEq)]
5pub enum RecoveryKind {
6    RetryMerge,
7    ReturnToWorker,
8    Abandon,
9    Other,
10}
11
12#[derive(Debug, Clone)]
13pub struct RecoveryOption {
14    pub to: String,
15    pub label: String,
16    pub kind: RecoveryKind,
17}
18
19/// Returns true iff `state_id` is the `on_failure` target of at least one
20/// merging-completion transition (Pr, Merge, or PrOrEpicMerge) anywhere in the
21/// workflow.  Transitions with a missing or empty `on_failure` are skipped.
22pub fn is_merge_failure_state(state_id: &str, workflow: &WorkflowConfig) -> bool {
23    for state in &workflow.states {
24        for t in &state.transitions {
25            if !matches!(
26                t.completion,
27                CompletionStrategy::Pr | CompletionStrategy::Merge | CompletionStrategy::PrOrEpicMerge
28            ) {
29                continue;
30            }
31            if let Some(on_failure) = &t.on_failure {
32                if !on_failure.is_empty() && on_failure == state_id {
33                    return true;
34                }
35            }
36        }
37    }
38    false
39}
40
41/// Classify the outgoing transitions of `state_id` as recovery options.
42///
43/// Each transition is labelled by its kind:
44/// - `RetryMerge`: the to-state is the target of at least one merging-completion
45///   transition anywhere in the workflow (Pr, Merge, or PrOrEpicMerge).
46/// - `ReturnToWorker`: the to-state is the target of at least one non-spec-writer
47///   `command:start` transition anywhere in the workflow.
48/// - `Abandon`: the to-state has `terminal: true`.
49/// - `Other`: none of the above apply.
50///
51/// Results are in declaration order.  Returns an empty vec if `state_id` is not
52/// found in the workflow.
53pub fn classify_recovery_options(state_id: &str, workflow: &WorkflowConfig) -> Vec<RecoveryOption> {
54    let merge_target_ids: HashSet<String> = workflow.states.iter()
55        .flat_map(|s| s.transitions.iter())
56        .filter(|t| matches!(
57            t.completion,
58            CompletionStrategy::Pr | CompletionStrategy::Merge | CompletionStrategy::PrOrEpicMerge
59        ))
60        .map(|t| t.to.clone())
61        .collect();
62
63    let coder_start_ids: HashSet<String> = workflow.states.iter()
64        .flat_map(|s| s.transitions.iter().map(move |t| (s, t)))
65        .filter(|(_, t)| t.trigger == "command:start")
66        .filter(|(_, t)| {
67            let dest_is_spec_writer = workflow.states.iter()
68                .find(|s| s.id == t.to)
69                .and_then(|s| s.worker_profile.as_deref())
70                .map(|wp| wp.ends_with("/spec-writer"))
71                .unwrap_or(false);
72            !dest_is_spec_writer
73        })
74        .map(|(_, t)| t.to.clone())
75        .collect();
76
77    let terminal_ids: HashSet<&str> = workflow.states.iter()
78        .filter(|s| s.terminal)
79        .map(|s| s.id.as_str())
80        .collect();
81
82    let Some(state) = workflow.states.iter().find(|s| s.id == state_id) else {
83        return Vec::new();
84    };
85
86    state.transitions.iter().map(|t| {
87        let kind = if merge_target_ids.contains(&t.to) {
88            RecoveryKind::RetryMerge
89        } else if coder_start_ids.contains(&t.to) {
90            RecoveryKind::ReturnToWorker
91        } else if terminal_ids.contains(t.to.as_str()) {
92            RecoveryKind::Abandon
93        } else {
94            RecoveryKind::Other
95        };
96        let label = if t.label.is_empty() { t.to.clone() } else { t.label.clone() };
97        RecoveryOption { to: t.to.clone(), label, kind }
98    }).collect()
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    fn parse_workflow(toml: &str) -> WorkflowConfig {
106        #[derive(serde::Deserialize)]
107        struct W { states: Vec<crate::config::StateConfig> }
108        let w: W = toml::from_str(toml).unwrap();
109        WorkflowConfig { states: w.states, ..Default::default() }
110    }
111
112    const DEFAULT_WF: &str = r#"[[states]]
113id    = "ready"
114label = "Ready"
115
116  [[states.transitions]]
117  to      = "in_progress"
118  trigger = "command:start"
119
120[[states]]
121id             = "in_progress"
122label          = "In Progress"
123worker_profile = "claude/coder"
124
125  [[states.transitions]]
126  to         = "implemented"
127  trigger    = "manual"
128  completion = "pr_or_epic_merge"
129  on_failure = "merge_failed"
130
131[[states]]
132id    = "implemented"
133label = "Implemented"
134
135[[states]]
136id    = "merge_failed"
137label = "Merge failed"
138
139  [[states.transitions]]
140  to      = "implemented"
141  trigger = "manual"
142
143  [[states.transitions]]
144  to      = "in_progress"
145  trigger = "manual"
146
147[[states]]
148id       = "closed"
149label    = "Closed"
150terminal = true
151"#;
152
153    #[test]
154    fn test_default_workflow_merge_failed() {
155        let wf = parse_workflow(DEFAULT_WF);
156        let opts = classify_recovery_options("merge_failed", &wf);
157        assert_eq!(opts.len(), 2);
158        assert_eq!(opts[0].to, "implemented");
159        assert_eq!(opts[0].kind, RecoveryKind::RetryMerge);
160        assert_eq!(opts[1].to, "in_progress");
161        assert_eq!(opts[1].kind, RecoveryKind::ReturnToWorker);
162    }
163
164    #[test]
165    fn test_shuffled_order_same_classification() {
166        let shuffled = r#"[[states]]
167id       = "closed"
168label    = "Closed"
169terminal = true
170
171[[states]]
172id         = "merge_failed"
173label      = "Merge failed"
174
175  [[states.transitions]]
176  to      = "implemented"
177  trigger = "manual"
178
179  [[states.transitions]]
180  to      = "in_progress"
181  trigger = "manual"
182
183[[states]]
184id    = "implemented"
185label = "Implemented"
186
187[[states]]
188id             = "in_progress"
189label          = "In Progress"
190worker_profile = "claude/coder"
191
192  [[states.transitions]]
193  to         = "implemented"
194  trigger    = "manual"
195  completion = "pr_or_epic_merge"
196  on_failure = "merge_failed"
197
198[[states]]
199id    = "ready"
200label = "Ready"
201
202  [[states.transitions]]
203  to      = "in_progress"
204  trigger = "command:start"
205"#;
206        let wf = parse_workflow(shuffled);
207        let opts = classify_recovery_options("merge_failed", &wf);
208        assert_eq!(opts.len(), 2);
209        assert_eq!(opts[0].to, "implemented");
210        assert_eq!(opts[0].kind, RecoveryKind::RetryMerge);
211        assert_eq!(opts[1].to, "in_progress");
212        assert_eq!(opts[1].kind, RecoveryKind::ReturnToWorker);
213    }
214
215    #[test]
216    fn test_renamed_merge_target() {
217        let renamed = r#"[[states]]
218id    = "ready"
219label = "Ready"
220
221  [[states.transitions]]
222  to      = "in_progress"
223  trigger = "command:start"
224
225[[states]]
226id             = "in_progress"
227label          = "In Progress"
228worker_profile = "claude/coder"
229
230  [[states.transitions]]
231  to         = "shipped"
232  trigger    = "manual"
233  completion = "pr_or_epic_merge"
234  on_failure = "merge_failed"
235
236[[states]]
237id    = "shipped"
238label = "Shipped"
239
240[[states]]
241id         = "merge_failed"
242label      = "Merge failed"
243
244  [[states.transitions]]
245  to      = "shipped"
246  trigger = "manual"
247
248  [[states.transitions]]
249  to      = "in_progress"
250  trigger = "manual"
251"#;
252        let wf = parse_workflow(renamed);
253        let opts = classify_recovery_options("merge_failed", &wf);
254        assert_eq!(opts.len(), 2);
255        assert_eq!(opts[0].to, "shipped");
256        assert_eq!(opts[0].kind, RecoveryKind::RetryMerge);
257        assert_eq!(opts[1].to, "in_progress");
258        assert_eq!(opts[1].kind, RecoveryKind::ReturnToWorker);
259    }
260
261    #[test]
262    fn test_no_merge_transitions() {
263        let no_merge = r#"[[states]]
264id    = "some_state"
265label = "Some State"
266
267  [[states.transitions]]
268  to      = "other"
269  trigger = "manual"
270
271[[states]]
272id    = "other"
273label = "Other"
274"#;
275        let wf = parse_workflow(no_merge);
276        let opts = classify_recovery_options("some_state", &wf);
277        assert!(!opts.iter().any(|o| o.kind == RecoveryKind::RetryMerge));
278    }
279
280    #[test]
281    fn test_is_merge_failure_state_default_workflow() {
282        let wf = parse_workflow(DEFAULT_WF);
283        assert!(is_merge_failure_state("merge_failed", &wf));
284        for state in &["new", "groomed", "specd", "ready", "in_progress", "implemented", "closed"] {
285            assert!(
286                !is_merge_failure_state(state, &wf),
287                "expected false for state: {state}"
288            );
289        }
290    }
291
292    #[test]
293    fn test_is_merge_failure_state_renamed() {
294        let renamed = r#"[[states]]
295id    = "in_progress"
296label = "In Progress"
297
298  [[states.transitions]]
299  to         = "implemented"
300  trigger    = "manual"
301  completion = "merge"
302  on_failure = "pr_failed"
303
304[[states]]
305id    = "implemented"
306label = "Implemented"
307
308[[states]]
309id    = "pr_failed"
310label = "Pr Failed"
311"#;
312        let wf = parse_workflow(renamed);
313        assert!(is_merge_failure_state("pr_failed", &wf));
314        assert!(!is_merge_failure_state("merge_failed", &wf));
315    }
316}