1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use super::{RecoveryCheckpoint, RecoveryCheckpointKind};
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8pub struct RecoveryReconciliation {
9 pub turn: u32,
10 pub unsafe_incomplete_tools: Vec<IncompleteToolRecovery>,
11 pub retryable_incomplete_tools: Vec<IncompleteToolRecovery>,
12}
13
14impl RecoveryReconciliation {
15 pub fn is_safe_to_continue(&self) -> bool {
16 self.unsafe_incomplete_tools.is_empty()
17 }
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct IncompleteToolRecovery {
22 pub tool_call_id: String,
23 pub tool_name: Option<String>,
24 pub args_hash: Option<String>,
25 pub retry_safe: bool,
26 pub state: IncompleteToolState,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum IncompleteToolState {
32 PlannedNotStarted,
33 StartedNotCompleted,
34 CompletedNotAppended,
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct RecoveryLedger {
39 checkpoints: Vec<RecoveryCheckpoint>,
40}
41
42impl RecoveryLedger {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn from_checkpoints(checkpoints: Vec<RecoveryCheckpoint>) -> Self {
48 Self { checkpoints }
49 }
50
51 pub fn record(&mut self, checkpoint: RecoveryCheckpoint) {
52 self.checkpoints.push(checkpoint);
53 }
54
55 pub fn checkpoints(&self) -> &[RecoveryCheckpoint] {
56 &self.checkpoints
57 }
58
59 pub fn reconcile_latest_finished_turn(&self) -> Option<RecoveryReconciliation> {
60 let latest_finished_turn = self
61 .checkpoints
62 .iter()
63 .filter_map(|checkpoint| {
64 let turn = checkpoint.turn;
65 let turn_has_tool_checkpoint = self.checkpoints.iter().any(|candidate| {
66 candidate.turn == turn
67 && matches!(
68 candidate.kind,
69 RecoveryCheckpointKind::AssistantToolCallObserved
70 | RecoveryCheckpointKind::ToolPlanCreated
71 | RecoveryCheckpointKind::ToolExecutionStart
72 | RecoveryCheckpointKind::ToolExecutionEnd
73 | RecoveryCheckpointKind::ToolResultAddedToContext
74 )
75 });
76
77 match checkpoint.kind {
78 RecoveryCheckpointKind::ToolResultAddedToContext => Some(turn),
79 RecoveryCheckpointKind::AssistantMessageFinalized
80 if !turn_has_tool_checkpoint =>
81 {
82 Some(turn)
83 }
84 _ => None,
85 }
86 })
87 .max()?;
88 Some(self.reconcile_turn(latest_finished_turn))
89 }
90
91 pub fn reconcile_turn(&self, turn: u32) -> RecoveryReconciliation {
92 let mut tools: HashMap<String, ToolRecoveryState> = HashMap::new();
93
94 for checkpoint in self
95 .checkpoints
96 .iter()
97 .filter(|checkpoint| checkpoint.turn == turn)
98 {
99 let Some(tool_call_id) = checkpoint.tool_call_id.as_ref() else {
100 continue;
101 };
102 let state = tools.entry(tool_call_id.clone()).or_default();
103 state.tool_name = checkpoint.tool_name.clone().or(state.tool_name.clone());
104 state.args_hash = checkpoint.args_hash.clone().or(state.args_hash.clone());
105
106 match checkpoint.kind {
107 RecoveryCheckpointKind::ToolPlanCreated => {
108 state.planned = true;
109 state.retry_safe = checkpoint.success.unwrap_or(false);
110 }
111 RecoveryCheckpointKind::AssistantToolCallObserved => {
112 state.planned = true;
113 }
114 RecoveryCheckpointKind::ToolExecutionStart => {
115 state.started = true;
116 }
117 RecoveryCheckpointKind::ToolExecutionEnd => {
118 state.completed = checkpoint.success.unwrap_or(false);
119 if checkpoint.success == Some(false) {
120 state.retry_safe = false;
121 }
122 }
123 RecoveryCheckpointKind::ToolResultAddedToContext => {
124 state.appended = true;
125 }
126 RecoveryCheckpointKind::ProviderRequestStart
127 | RecoveryCheckpointKind::AssistantMessageFinalized
128 | RecoveryCheckpointKind::ProviderRequestCompleted => {}
129 }
130 }
131
132 let mut retryable_incomplete_tools = Vec::new();
133 let mut unsafe_incomplete_tools = Vec::new();
134
135 for (tool_call_id, state) in tools {
136 let incomplete_state = if state.appended {
137 None
138 } else if state.completed {
139 Some(IncompleteToolState::CompletedNotAppended)
140 } else if state.started {
141 Some(IncompleteToolState::StartedNotCompleted)
142 } else if state.planned {
143 Some(IncompleteToolState::PlannedNotStarted)
144 } else {
145 None
146 };
147
148 if let Some(incomplete_state) = incomplete_state {
149 let recovery = IncompleteToolRecovery {
150 tool_call_id,
151 tool_name: state.tool_name,
152 args_hash: state.args_hash,
153 retry_safe: state.retry_safe,
154 state: incomplete_state,
155 };
156 if recovery.retry_safe {
157 retryable_incomplete_tools.push(recovery);
158 } else {
159 unsafe_incomplete_tools.push(recovery);
160 }
161 }
162 }
163
164 retryable_incomplete_tools
165 .sort_by(|left, right| left.tool_call_id.cmp(&right.tool_call_id));
166 unsafe_incomplete_tools.sort_by(|left, right| left.tool_call_id.cmp(&right.tool_call_id));
167
168 RecoveryReconciliation {
169 turn,
170 unsafe_incomplete_tools,
171 retryable_incomplete_tools,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Default)]
177struct ToolRecoveryState {
178 tool_name: Option<String>,
179 args_hash: Option<String>,
180 planned: bool,
181 retry_safe: bool,
182 started: bool,
183 completed: bool,
184 appended: bool,
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 fn checkpoint(
192 kind: RecoveryCheckpointKind,
193 tool_call_id: &str,
194 success: Option<bool>,
195 ) -> RecoveryCheckpoint {
196 RecoveryCheckpoint {
197 version: 1,
198 turn: 3,
199 kind,
200 tool_call_id: Some(tool_call_id.into()),
201 tool_name: Some("tool".into()),
202 args_hash: Some("abc".into()),
203 success,
204 error_class: None,
205 timestamp: 0,
206 }
207 }
208
209 #[test]
210 fn latest_finished_turn_ignores_in_progress_next_turn() {
211 let mut checkpoints = vec![
212 checkpoint(
213 RecoveryCheckpointKind::ToolPlanCreated,
214 "finished",
215 Some(false),
216 ),
217 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "finished", None),
218 checkpoint(
219 RecoveryCheckpointKind::ToolExecutionEnd,
220 "finished",
221 Some(true),
222 ),
223 checkpoint(
224 RecoveryCheckpointKind::ToolResultAddedToContext,
225 "finished",
226 Some(true),
227 ),
228 ];
229 checkpoints.push(RecoveryCheckpoint {
230 version: 1,
231 turn: 3,
232 kind: RecoveryCheckpointKind::AssistantMessageFinalized,
233 tool_call_id: None,
234 tool_name: None,
235 args_hash: None,
236 success: Some(true),
237 error_class: None,
238 timestamp: 0,
239 });
240 checkpoints.push(RecoveryCheckpoint {
241 version: 1,
242 turn: 4,
243 kind: RecoveryCheckpointKind::ToolPlanCreated,
244 tool_call_id: Some("in_progress".into()),
245 tool_name: Some("edit".into()),
246 args_hash: Some("def".into()),
247 success: Some(false),
248 error_class: None,
249 timestamp: 0,
250 });
251 checkpoints.push(RecoveryCheckpoint {
252 version: 1,
253 turn: 4,
254 kind: RecoveryCheckpointKind::ToolExecutionStart,
255 tool_call_id: Some("in_progress".into()),
256 tool_name: Some("edit".into()),
257 args_hash: Some("def".into()),
258 success: None,
259 error_class: None,
260 timestamp: 0,
261 });
262 let ledger = RecoveryLedger::from_checkpoints(checkpoints);
263
264 let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
265 assert_eq!(reconciliation.turn, 3);
266 assert!(reconciliation.is_safe_to_continue());
267 }
268
269 #[test]
270 fn later_tool_result_marks_tool_turn_finished() {
271 let ledger = RecoveryLedger::from_checkpoints(vec![
272 checkpoint(
273 RecoveryCheckpointKind::AssistantMessageFinalized,
274 "",
275 Some(true),
276 ),
277 checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
278 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
279 checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
280 checkpoint(
281 RecoveryCheckpointKind::ToolResultAddedToContext,
282 "call",
283 Some(true),
284 ),
285 ]);
286
287 let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
288 assert_eq!(reconciliation.turn, 3);
289 assert!(reconciliation.is_safe_to_continue());
290 assert!(reconciliation.retryable_incomplete_tools.is_empty());
291 assert!(reconciliation.unsafe_incomplete_tools.is_empty());
292 }
293
294 #[test]
295 fn assistant_finalized_without_tool_result_does_not_mark_tool_turn_finished() {
296 let mut checkpoints = vec![
297 checkpoint(
298 RecoveryCheckpointKind::AssistantMessageFinalized,
299 "previous",
300 Some(true),
301 ),
302 checkpoint(
303 RecoveryCheckpointKind::ToolPlanCreated,
304 "previous",
305 Some(false),
306 ),
307 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "previous", None),
308 checkpoint(
309 RecoveryCheckpointKind::ToolExecutionEnd,
310 "previous",
311 Some(true),
312 ),
313 checkpoint(
314 RecoveryCheckpointKind::ToolResultAddedToContext,
315 "previous",
316 Some(true),
317 ),
318 ];
319 checkpoints.push(RecoveryCheckpoint {
320 version: 1,
321 turn: 4,
322 kind: RecoveryCheckpointKind::AssistantMessageFinalized,
323 tool_call_id: None,
324 tool_name: None,
325 args_hash: None,
326 success: Some(true),
327 error_class: None,
328 timestamp: 0,
329 });
330 checkpoints.push(RecoveryCheckpoint {
331 version: 1,
332 turn: 4,
333 kind: RecoveryCheckpointKind::ToolPlanCreated,
334 tool_call_id: Some("interrupted".into()),
335 tool_name: Some("edit".into()),
336 args_hash: Some("def".into()),
337 success: Some(false),
338 error_class: None,
339 timestamp: 0,
340 });
341 checkpoints.push(RecoveryCheckpoint {
342 version: 1,
343 turn: 4,
344 kind: RecoveryCheckpointKind::ToolExecutionStart,
345 tool_call_id: Some("interrupted".into()),
346 tool_name: Some("edit".into()),
347 args_hash: Some("def".into()),
348 success: None,
349 error_class: None,
350 timestamp: 0,
351 });
352
353 let ledger = RecoveryLedger::from_checkpoints(checkpoints);
354 let reconciliation = ledger.reconcile_latest_finished_turn().unwrap();
355 assert_eq!(reconciliation.turn, 3);
356 assert!(reconciliation.is_safe_to_continue());
357 }
358
359 #[test]
360 fn appended_tool_is_not_incomplete() {
361 let ledger = RecoveryLedger::from_checkpoints(vec![
362 checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(true)),
363 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
364 checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
365 checkpoint(
366 RecoveryCheckpointKind::ToolResultAddedToContext,
367 "call",
368 Some(true),
369 ),
370 ]);
371
372 let reconciliation = ledger.reconcile_turn(3);
373 assert!(reconciliation.is_safe_to_continue());
374 assert!(reconciliation.retryable_incomplete_tools.is_empty());
375 assert!(reconciliation.unsafe_incomplete_tools.is_empty());
376 }
377
378 #[test]
379 fn read_only_planned_not_started_is_retryable() {
380 let ledger = RecoveryLedger::from_checkpoints(vec![checkpoint(
381 RecoveryCheckpointKind::ToolPlanCreated,
382 "call",
383 Some(true),
384 )]);
385
386 let reconciliation = ledger.reconcile_turn(3);
387 assert!(reconciliation.is_safe_to_continue());
388 assert_eq!(reconciliation.retryable_incomplete_tools.len(), 1);
389 assert_eq!(
390 reconciliation.retryable_incomplete_tools[0].state,
391 IncompleteToolState::PlannedNotStarted
392 );
393 }
394
395 #[test]
396 fn mutable_started_not_completed_is_unsafe() {
397 let ledger = RecoveryLedger::from_checkpoints(vec![
398 checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
399 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
400 ]);
401
402 let reconciliation = ledger.reconcile_turn(3);
403 assert!(!reconciliation.is_safe_to_continue());
404 assert_eq!(reconciliation.unsafe_incomplete_tools.len(), 1);
405 assert_eq!(
406 reconciliation.unsafe_incomplete_tools[0].state,
407 IncompleteToolState::StartedNotCompleted
408 );
409 }
410
411 #[test]
412 fn completed_not_appended_is_incomplete() {
413 let ledger = RecoveryLedger::from_checkpoints(vec![
414 checkpoint(RecoveryCheckpointKind::ToolPlanCreated, "call", Some(false)),
415 checkpoint(RecoveryCheckpointKind::ToolExecutionStart, "call", None),
416 checkpoint(RecoveryCheckpointKind::ToolExecutionEnd, "call", Some(true)),
417 ]);
418
419 let reconciliation = ledger.reconcile_turn(3);
420 assert_eq!(reconciliation.unsafe_incomplete_tools.len(), 1);
421 assert_eq!(
422 reconciliation.unsafe_incomplete_tools[0].state,
423 IncompleteToolState::CompletedNotAppended
424 );
425 }
426}