Skip to main content

imp_core/agent/
recovery.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use super::{RecoveryCheckpoint, RecoveryCheckpointKind};
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8pub struct RecoveryReconciliation {
9    pub turn: u32,
10    pub unsafe_incomplete_tools: Vec<IncompleteToolRecovery>,
11    pub retryable_incomplete_tools: Vec<IncompleteToolRecovery>,
12}
13
14impl RecoveryReconciliation {
15    pub fn is_safe_to_continue(&self) -> bool {
16        self.unsafe_incomplete_tools.is_empty()
17    }
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct IncompleteToolRecovery {
22    pub tool_call_id: String,
23    pub tool_name: Option<String>,
24    pub args_hash: Option<String>,
25    pub retry_safe: bool,
26    pub state: IncompleteToolState,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum IncompleteToolState {
32    PlannedNotStarted,
33    StartedNotCompleted,
34    CompletedNotAppended,
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct RecoveryLedger {
39    checkpoints: Vec<RecoveryCheckpoint>,
40}
41
42impl RecoveryLedger {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn from_checkpoints(checkpoints: Vec<RecoveryCheckpoint>) -> Self {
48        Self { checkpoints }
49    }
50
51    pub fn record(&mut self, checkpoint: RecoveryCheckpoint) {
52        self.checkpoints.push(checkpoint);
53    }
54
55    pub fn checkpoints(&self) -> &[RecoveryCheckpoint] {
56        &self.checkpoints
57    }
58
59    pub fn reconcile_latest_finished_turn(&self) -> Option<RecoveryReconciliation> {
60        let latest_finished_turn = self
61            .checkpoints
62            .iter()
63            .filter_map(|checkpoint| {
64                let turn = checkpoint.turn;
65                let turn_has_tool_checkpoint = self.checkpoints.iter().any(|candidate| {
66                    candidate.turn == turn
67                        && matches!(
68                            candidate.kind,
69                            RecoveryCheckpointKind::AssistantToolCallObserved
70                                | RecoveryCheckpointKind::ToolPlanCreated
71                                | RecoveryCheckpointKind::ToolExecutionStart
72                                | RecoveryCheckpointKind::ToolExecutionEnd
73                                | RecoveryCheckpointKind::ToolResultAddedToContext
74                        )
75                });
76
77                match checkpoint.kind {
78                    RecoveryCheckpointKind::ToolResultAddedToContext => Some(turn),
79                    RecoveryCheckpointKind::AssistantMessageFinalized
80                        if !turn_has_tool_checkpoint =>
81                    {
82                        Some(turn)
83                    }
84                    _ => None,
85                }
86            })
87            .max()?;
88        Some(self.reconcile_turn(latest_finished_turn))
89    }
90
91    pub fn reconcile_turn(&self, turn: u32) -> RecoveryReconciliation {
92        let mut tools: HashMap<String, ToolRecoveryState> = HashMap::new();
93
94        for checkpoint in self
95            .checkpoints
96            .iter()
97            .filter(|checkpoint| checkpoint.turn == turn)
98        {
99            let Some(tool_call_id) = checkpoint.tool_call_id.as_ref() else {
100                continue;
101            };
102            let state = tools.entry(tool_call_id.clone()).or_default();
103            state.tool_name = checkpoint.tool_name.clone().or(state.tool_name.clone());
104            state.args_hash = checkpoint.args_hash.clone().or(state.args_hash.clone());
105
106            match checkpoint.kind {
107                RecoveryCheckpointKind::ToolPlanCreated => {
108                    state.planned = true;
109                    state.retry_safe = checkpoint.success.unwrap_or(false);
110                }
111                RecoveryCheckpointKind::AssistantToolCallObserved => {
112                    state.planned = true;
113                }
114                RecoveryCheckpointKind::ToolExecutionStart => {
115                    state.started = true;
116                }
117                RecoveryCheckpointKind::ToolExecutionEnd => {
118                    state.completed = checkpoint.success.unwrap_or(false);
119                    if checkpoint.success == Some(false) {
120                        state.retry_safe = false;
121                    }
122                }
123                RecoveryCheckpointKind::ToolResultAddedToContext => {
124                    state.appended = true;
125                }
126                RecoveryCheckpointKind::ProviderRequestStart
127                | RecoveryCheckpointKind::AssistantMessageFinalized
128                | RecoveryCheckpointKind::ProviderRequestCompleted => {}
129            }
130        }
131
132        let mut retryable_incomplete_tools = Vec::new();
133        let mut unsafe_incomplete_tools = Vec::new();
134
135        for (tool_call_id, state) in tools {
136            let incomplete_state = if state.appended {
137                None
138            } else if state.completed {
139                Some(IncompleteToolState::CompletedNotAppended)
140            } else if state.started {
141                Some(IncompleteToolState::StartedNotCompleted)
142            } else if state.planned {
143                Some(IncompleteToolState::PlannedNotStarted)
144            } else {
145                None
146            };
147
148            if let Some(incomplete_state) = incomplete_state {
149                let recovery = IncompleteToolRecovery {
150                    tool_call_id,
151                    tool_name: state.tool_name,
152                    args_hash: state.args_hash,
153                    retry_safe: state.retry_safe,
154                    state: incomplete_state,
155                };
156                if recovery.retry_safe {
157                    retryable_incomplete_tools.push(recovery);
158                } else {
159                    unsafe_incomplete_tools.push(recovery);
160                }
161            }
162        }
163
164        retryable_incomplete_tools
165            .sort_by(|left, right| left.tool_call_id.cmp(&right.tool_call_id));
166        unsafe_incomplete_tools.sort_by(|left, right| left.tool_call_id.cmp(&right.tool_call_id));
167
168        RecoveryReconciliation {
169            turn,
170            unsafe_incomplete_tools,
171            retryable_incomplete_tools,
172        }
173    }
174}
175
176#[derive(Debug, Clone, Default)]
177struct ToolRecoveryState {
178    tool_name: Option<String>,
179    args_hash: Option<String>,
180    planned: bool,
181    retry_safe: bool,
182    started: bool,
183    completed: bool,
184    appended: bool,
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    fn checkpoint(
192        kind: RecoveryCheckpointKind,
193        tool_call_id: &str,
194        success: Option<bool>,
195    ) -> RecoveryCheckpoint {
196        RecoveryCheckpoint {
197            version: 1,
198            turn: 3,
199            kind,
200            tool_call_id: Some(tool_call_id.into()),
201            tool_name: Some("tool".into()),
202            args_hash: Some("abc".into()),
203            success,
204            error_class: None,
205            timestamp: 0,
206        }
207    }
208
209    #[test]
210    fn latest_finished_turn_ignores_in_progress_next_turn() {
211        let mut checkpoints = vec![
212            checkpoint(
213                RecoveryCheckpointKind::ToolPlanCreated,
214                "finished",
215                Some(false),
216            ),
217            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "finished", None),
218            checkpoint(
219                RecoveryCheckpointKind::ToolExecutionEnd,
220                "finished",
221                Some(true),
222            ),
223            checkpoint(
224                RecoveryCheckpointKind::ToolResultAddedToContext,
225                "finished",
226                Some(true),
227            ),
228        ];
229        checkpoints.push(RecoveryCheckpoint {
230            version: 1,
231            turn: 3,
232            kind: RecoveryCheckpointKind::AssistantMessageFinalized,
233            tool_call_id: None,
234            tool_name: None,
235            args_hash: None,
236            success: Some(true),
237            error_class: None,
238            timestamp: 0,
239        });
240        checkpoints.push(RecoveryCheckpoint {
241            version: 1,
242            turn: 4,
243            kind: RecoveryCheckpointKind::ToolPlanCreated,
244            tool_call_id: Some("in_progress".into()),
245            tool_name: Some("edit".into()),
246            args_hash: Some("def".into()),
247            success: Some(false),
248            error_class: None,
249            timestamp: 0,
250        });
251        checkpoints.push(RecoveryCheckpoint {
252            version: 1,
253            turn: 4,
254            kind: RecoveryCheckpointKind::ToolExecutionStart,
255            tool_call_id: Some("in_progress".into()),
256            tool_name: Some("edit".into()),
257            args_hash: Some("def".into()),
258            success: None,
259            error_class: None,
260            timestamp: 0,
261        });
262        let ledger = RecoveryLedger::from_checkpoints(checkpoints);
263
264        let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
265        assert_eq!(reconciliation.turn, 3);
266        assert!(reconciliation.is_safe_to_continue());
267    }
268
269    #[test]
270    fn later_tool_result_marks_tool_turn_finished() {
271        let ledger = RecoveryLedger::from_checkpoints(vec![
272            checkpoint(
273                RecoveryCheckpointKind::AssistantMessageFinalized,
274                "",
275                Some(true),
276            ),
277            checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
278            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
279            checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
280            checkpoint(
281                RecoveryCheckpointKind::ToolResultAddedToContext,
282                "call",
283                Some(true),
284            ),
285        ]);
286
287        let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
288        assert_eq!(reconciliation.turn, 3);
289        assert!(reconciliation.is_safe_to_continue());
290        assert!(reconciliation.retryable_incomplete_tools.is_empty());
291        assert!(reconciliation.unsafe_incomplete_tools.is_empty());
292    }
293
294    #[test]
295    fn assistant_finalized_without_tool_result_does_not_mark_tool_turn_finished() {
296        let mut checkpoints = vec![
297            checkpoint(
298                RecoveryCheckpointKind::AssistantMessageFinalized,
299                "previous",
300                Some(true),
301            ),
302            checkpoint(
303                RecoveryCheckpointKind::ToolPlanCreated,
304                "previous",
305                Some(false),
306            ),
307            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "previous", None),
308            checkpoint(
309                RecoveryCheckpointKind::ToolExecutionEnd,
310                "previous",
311                Some(true),
312            ),
313            checkpoint(
314                RecoveryCheckpointKind::ToolResultAddedToContext,
315                "previous",
316                Some(true),
317            ),
318        ];
319        checkpoints.push(RecoveryCheckpoint {
320            version: 1,
321            turn: 4,
322            kind: RecoveryCheckpointKind::AssistantMessageFinalized,
323            tool_call_id: None,
324            tool_name: None,
325            args_hash: None,
326            success: Some(true),
327            error_class: None,
328            timestamp: 0,
329        });
330        checkpoints.push(RecoveryCheckpoint {
331            version: 1,
332            turn: 4,
333            kind: RecoveryCheckpointKind::ToolPlanCreated,
334            tool_call_id: Some("interrupted".into()),
335            tool_name: Some("edit".into()),
336            args_hash: Some("def".into()),
337            success: Some(false),
338            error_class: None,
339            timestamp: 0,
340        });
341        checkpoints.push(RecoveryCheckpoint {
342            version: 1,
343            turn: 4,
344            kind: RecoveryCheckpointKind::ToolExecutionStart,
345            tool_call_id: Some("interrupted".into()),
346            tool_name: Some("edit".into()),
347            args_hash: Some("def".into()),
348            success: None,
349            error_class: None,
350            timestamp: 0,
351        });
352
353        let ledger = RecoveryLedger::from_checkpoints(checkpoints);
354        let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
355        assert_eq!(reconciliation.turn, 3);
356        assert!(reconciliation.is_safe_to_continue());
357    }
358
359    #[test]
360    fn appended_tool_is_not_incomplete() {
361        let ledger = RecoveryLedger::from_checkpoints(vec![
362            checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(true)),
363            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
364            checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
365            checkpoint(
366                RecoveryCheckpointKind::ToolResultAddedToContext,
367                "call",
368                Some(true),
369            ),
370        ]);
371
372        let reconciliation = ledger.reconcile_turn(3);
373        assert!(reconciliation.is_safe_to_continue());
374        assert!(reconciliation.retryable_incomplete_tools.is_empty());
375        assert!(reconciliation.unsafe_incomplete_tools.is_empty());
376    }
377
378    #[test]
379    fn read_only_planned_not_started_is_retryable() {
380        let ledger = RecoveryLedger::from_checkpoints(vec![checkpoint(
381            RecoveryCheckpointKind::ToolPlanCreated,
382            "call",
383            Some(true),
384        )]);
385
386        let reconciliation = ledger.reconcile_turn(3);
387        assert!(reconciliation.is_safe_to_continue());
388        assert_eq!(reconciliation.retryable_incomplete_tools.len(), 1);
389        assert_eq!(
390            reconciliation.retryable_incomplete_tools[0].state,
391            IncompleteToolState::PlannedNotStarted
392        );
393    }
394
395    #[test]
396    fn mutable_started_not_completed_is_unsafe() {
397        let ledger = RecoveryLedger::from_checkpoints(vec![
398            checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
399            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
400        ]);
401
402        let reconciliation = ledger.reconcile_turn(3);
403        assert!(!reconciliation.is_safe_to_continue());
404        assert_eq!(reconciliation.unsafe_incomplete_tools.len(), 1);
405        assert_eq!(
406            reconciliation.unsafe_incomplete_tools[0].state,
407            IncompleteToolState::StartedNotCompleted
408        );
409    }
410
411    #[test]
412    fn completed_not_appended_is_incomplete() {
413        let ledger = RecoveryLedger::from_checkpoints(vec![
414            checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
415            checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
416            checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
417        ]);
418
419        let reconciliation = ledger.reconcile_turn(3);
420        assert_eq!(reconciliation.unsafe_incomplete_tools.len(), 1);
421        assert_eq!(
422            reconciliation.unsafe_incomplete_tools[0].state,
423            IncompleteToolState::CompletedNotAppended
424        );
425    }
426}