Skip to main content

punch_runtime/
session_repair.rs

1//! Session message history repair.
2//!
3//! Fixes common issues in message histories that can cause LLM API errors:
4//! - Orphaned tool results (tool_result with no matching tool_use)
5//! - Empty messages
6//! - Consecutive same-role messages that should be merged
7//! - Tool uses with no corresponding result
8//! - Duplicate tool results for the same tool_use_id
9
10use std::collections::HashSet;
11
12use tracing::{debug, info, warn};
13
14use punch_types::{Message, Role, ToolCallResult};
15
16/// Statistics from a repair pass.
17#[derive(Debug, Clone, Default)]
18pub struct RepairStats {
19    /// Number of empty messages removed.
20    pub empty_removed: usize,
21    /// Number of orphaned tool results removed.
22    pub orphaned_results_removed: usize,
23    /// Number of synthetic error results inserted for tool_uses without results.
24    pub synthetic_results_inserted: usize,
25    /// Number of duplicate tool results removed.
26    pub duplicate_results_removed: usize,
27    /// Number of consecutive same-role message merges performed.
28    pub messages_merged: usize,
29}
30
31impl RepairStats {
32    /// Whether any repairs were made.
33    pub fn any_repairs(&self) -> bool {
34        self.empty_removed > 0
35            || self.orphaned_results_removed > 0
36            || self.synthetic_results_inserted > 0
37            || self.duplicate_results_removed > 0
38            || self.messages_merged > 0
39    }
40}
41
42impl std::fmt::Display for RepairStats {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(
45            f,
46            "empty_removed={}, orphaned_results={}, synthetic_inserts={}, duplicates={}, merges={}",
47            self.empty_removed,
48            self.orphaned_results_removed,
49            self.synthetic_results_inserted,
50            self.duplicate_results_removed,
51            self.messages_merged,
52        )
53    }
54}
55
56/// Run all repair passes on a message history.
57///
58/// This is idempotent: running repair twice produces the same result.
59/// The passes run in a specific order to handle dependencies correctly.
60pub fn repair_session(messages: &mut Vec<Message>) -> RepairStats {
61    let mut stats = RepairStats::default();
62
63    // Pass 1: Remove empty messages.
64    remove_empty_messages(messages, &mut stats);
65
66    // Pass 2: Remove duplicate tool results.
67    remove_duplicate_tool_results(messages, &mut stats);
68
69    // Pass 3: Fix orphaned tool results (results with no matching tool_use).
70    remove_orphaned_tool_results(messages, &mut stats);
71
72    // Pass 4: Insert synthetic error results for tool_uses with no result.
73    insert_synthetic_results(messages, &mut stats);
74
75    // Pass 5: Merge consecutive same-role messages.
76    merge_consecutive_same_role(messages, &mut stats);
77
78    if stats.any_repairs() {
79        info!(repairs = %stats, "session repair completed");
80    } else {
81        debug!("session repair: no repairs needed");
82    }
83
84    stats
85}
86
87/// Remove messages that have no content, no tool calls, and no tool results.
88fn remove_empty_messages(messages: &mut Vec<Message>, stats: &mut RepairStats) {
89    let before = messages.len();
90
91    messages.retain(|msg| {
92        let is_empty =
93            msg.content.is_empty() && msg.tool_calls.is_empty() && msg.tool_results.is_empty();
94        !is_empty
95    });
96
97    let removed = before - messages.len();
98    if removed > 0 {
99        debug!(count = removed, "removed empty messages");
100        stats.empty_removed = removed;
101    }
102}
103
104/// Remove tool results whose tool_use_id does not match any tool_use in the history.
105fn remove_orphaned_tool_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
106    // Collect all tool_use IDs from assistant messages.
107    let tool_use_ids: HashSet<String> = messages
108        .iter()
109        .filter(|m| m.role == Role::Assistant)
110        .flat_map(|m| &m.tool_calls)
111        .map(|tc| tc.id.clone())
112        .collect();
113
114    let mut removed = 0;
115
116    for msg in messages.iter_mut() {
117        if msg.role == Role::Tool && !msg.tool_results.is_empty() {
118            let before = msg.tool_results.len();
119            msg.tool_results.retain(|tr| tool_use_ids.contains(&tr.id));
120            let delta = before - msg.tool_results.len();
121            if delta > 0 {
122                warn!(
123                    count = delta,
124                    "removed orphaned tool results (no matching tool_use)"
125                );
126                removed += delta;
127            }
128        }
129    }
130
131    stats.orphaned_results_removed = removed;
132
133    // Also remove any Tool messages that now have zero results and no content.
134    messages.retain(|msg| {
135        if msg.role == Role::Tool {
136            !msg.tool_results.is_empty() || !msg.content.is_empty()
137        } else {
138            true
139        }
140    });
141}
142
143/// Insert synthetic error results for tool_uses that have no corresponding result.
144fn insert_synthetic_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
145    // Collect all tool_result IDs.
146    let result_ids: HashSet<String> = messages
147        .iter()
148        .filter(|m| m.role == Role::Tool)
149        .flat_map(|m| &m.tool_results)
150        .map(|tr| tr.id.clone())
151        .collect();
152
153    // Find tool_uses that have no result, grouped by their position.
154    // We need to insert results AFTER the assistant message that made the call.
155    let mut insertions: Vec<(usize, Vec<ToolCallResult>)> = Vec::new();
156
157    for (idx, msg) in messages.iter().enumerate() {
158        if msg.role == Role::Assistant && !msg.tool_calls.is_empty() {
159            let missing: Vec<ToolCallResult> = msg
160                .tool_calls
161                .iter()
162                .filter(|tc| !result_ids.contains(&tc.id))
163                .map(|tc| {
164                    warn!(
165                        tool_use_id = %tc.id,
166                        tool_name = %tc.name,
167                        "inserting synthetic error result for orphaned tool_use"
168                    );
169                    ToolCallResult {
170                        id: tc.id.clone(),
171                        content: format!(
172                            "Error: tool execution was interrupted or result was lost (tool: {})",
173                            tc.name
174                        ),
175                        is_error: true,
176                        image: None,
177                    }
178                })
179                .collect();
180
181            if !missing.is_empty() {
182                insertions.push((idx, missing));
183            }
184        }
185    }
186
187    // Insert in reverse order to preserve indices.
188    let mut inserted = 0;
189    for (idx, results) in insertions.into_iter().rev() {
190        let count = results.len();
191        inserted += count;
192
193        let tool_msg = Message {
194            role: Role::Tool,
195            content: String::new(),
196            tool_calls: Vec::new(),
197            tool_results: results,
198            timestamp: chrono::Utc::now(),
199            content_parts: Vec::new(),
200        };
201
202        // Insert right after the assistant message.
203        let insert_pos = idx + 1;
204        if insert_pos <= messages.len() {
205            messages.insert(insert_pos, tool_msg);
206        } else {
207            messages.push(tool_msg);
208        }
209    }
210
211    stats.synthetic_results_inserted = inserted;
212}
213
214/// Remove duplicate tool results (same tool_use_id appearing more than once).
215fn remove_duplicate_tool_results(messages: &mut [Message], stats: &mut RepairStats) {
216    let mut seen_ids: HashSet<String> = HashSet::new();
217    let mut removed = 0;
218
219    for msg in messages.iter_mut() {
220        if msg.role == Role::Tool && !msg.tool_results.is_empty() {
221            let before = msg.tool_results.len();
222            msg.tool_results.retain(|tr| seen_ids.insert(tr.id.clone()));
223            let delta = before - msg.tool_results.len();
224            if delta > 0 {
225                debug!(count = delta, "removed duplicate tool results");
226                removed += delta;
227            }
228        }
229    }
230
231    stats.duplicate_results_removed = removed;
232}
233
234/// Merge consecutive messages with the same role.
235///
236/// This handles cases where, e.g., two user messages end up adjacent
237/// (which some LLM APIs reject).
238fn merge_consecutive_same_role(messages: &mut Vec<Message>, stats: &mut RepairStats) {
239    if messages.len() < 2 {
240        return;
241    }
242
243    let mut merged = 0;
244    let mut result: Vec<Message> = Vec::with_capacity(messages.len());
245
246    for msg in messages.drain(..) {
247        if let Some(last) = result.last_mut() {
248            // Only merge User-User or Assistant-Assistant.
249            // Tool messages have special structure and should not be merged.
250            if last.role == msg.role && (msg.role == Role::User || msg.role == Role::Assistant) {
251                // Merge content.
252                if !msg.content.is_empty() {
253                    if !last.content.is_empty() {
254                        last.content.push('\n');
255                    }
256                    last.content.push_str(&msg.content);
257                }
258                // Merge tool calls and results.
259                last.tool_calls.extend(msg.tool_calls);
260                last.tool_results.extend(msg.tool_results);
261                // Keep the later timestamp.
262                last.timestamp = msg.timestamp;
263                merged += 1;
264                continue;
265            }
266        }
267        result.push(msg);
268    }
269
270    *messages = result;
271    stats.messages_merged = merged;
272}
273
274// ---------------------------------------------------------------------------
275// Tests
276// ---------------------------------------------------------------------------
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use punch_types::{Message, Role, ToolCall, ToolCallResult};
282
283    fn user_msg(content: &str) -> Message {
284        Message::new(Role::User, content)
285    }
286
287    fn assistant_msg(content: &str) -> Message {
288        Message::new(Role::Assistant, content)
289    }
290
291    fn assistant_with_tool_call(tool_id: &str, tool_name: &str) -> Message {
292        Message {
293            role: Role::Assistant,
294            content: String::new(),
295            tool_calls: vec![ToolCall {
296                id: tool_id.to_string(),
297                name: tool_name.to_string(),
298                input: serde_json::json!({}),
299            }],
300            tool_results: Vec::new(),
301            timestamp: chrono::Utc::now(),
302            content_parts: Vec::new(),
303        }
304    }
305
306    fn tool_result_msg(id: &str, content: &str) -> Message {
307        Message {
308            role: Role::Tool,
309            content: String::new(),
310            tool_calls: Vec::new(),
311            tool_results: vec![ToolCallResult {
312                id: id.to_string(),
313                content: content.to_string(),
314                is_error: false,
315                image: None,
316            }],
317            timestamp: chrono::Utc::now(),
318            content_parts: Vec::new(),
319        }
320    }
321
322    fn empty_msg(role: Role) -> Message {
323        Message {
324            role,
325            content: String::new(),
326            tool_calls: Vec::new(),
327            tool_results: Vec::new(),
328            timestamp: chrono::Utc::now(),
329            content_parts: Vec::new(),
330        }
331    }
332
333    #[test]
334    fn test_remove_empty_messages() {
335        let mut msgs = vec![
336            user_msg("hello"),
337            empty_msg(Role::Assistant),
338            assistant_msg("world"),
339        ];
340
341        let stats = repair_session(&mut msgs);
342        assert_eq!(stats.empty_removed, 1);
343        assert_eq!(msgs.len(), 2);
344    }
345
346    #[test]
347    fn test_remove_orphaned_tool_results() {
348        let mut msgs = vec![
349            user_msg("hello"),
350            assistant_with_tool_call("call_1", "file_read"),
351            tool_result_msg("call_1", "file contents"),
352            // This tool result has no matching tool_use:
353            tool_result_msg("call_999", "orphaned result"),
354        ];
355
356        let stats = repair_session(&mut msgs);
357        assert_eq!(stats.orphaned_results_removed, 1);
358        // The orphaned tool message should be fully removed since it has no results left.
359        assert_eq!(msgs.len(), 3);
360    }
361
362    #[test]
363    fn test_insert_synthetic_results() {
364        let mut msgs = vec![
365            user_msg("do something"),
366            assistant_with_tool_call("call_1", "shell_exec"),
367            // No tool result for call_1!
368            assistant_msg("I ran the command"),
369        ];
370
371        let stats = repair_session(&mut msgs);
372        assert_eq!(stats.synthetic_results_inserted, 1);
373
374        // Should now have: user, assistant(tool_use), tool(synthetic), assistant
375        assert_eq!(msgs.len(), 4);
376        assert_eq!(msgs[2].role, Role::Tool);
377        assert!(msgs[2].tool_results[0].is_error);
378        assert!(msgs[2].tool_results[0].content.contains("interrupted"));
379    }
380
381    #[test]
382    fn test_remove_duplicate_tool_results() {
383        let mut msgs = vec![
384            user_msg("hello"),
385            assistant_with_tool_call("call_1", "file_read"),
386            tool_result_msg("call_1", "first result"),
387            tool_result_msg("call_1", "duplicate result"),
388        ];
389
390        let stats = repair_session(&mut msgs);
391        assert_eq!(stats.duplicate_results_removed, 1);
392    }
393
394    #[test]
395    fn test_merge_consecutive_user_messages() {
396        let mut msgs = vec![
397            user_msg("hello"),
398            user_msg("world"),
399            assistant_msg("hi there"),
400        ];
401
402        let stats = repair_session(&mut msgs);
403        assert_eq!(stats.messages_merged, 1);
404        assert_eq!(msgs.len(), 2);
405        assert!(msgs[0].content.contains("hello"));
406        assert!(msgs[0].content.contains("world"));
407    }
408
409    #[test]
410    fn test_merge_consecutive_assistant_messages() {
411        let mut msgs = vec![
412            user_msg("hello"),
413            assistant_msg("part 1"),
414            assistant_msg("part 2"),
415        ];
416
417        let stats = repair_session(&mut msgs);
418        assert_eq!(stats.messages_merged, 1);
419        assert_eq!(msgs.len(), 2);
420        assert!(msgs[1].content.contains("part 1"));
421        assert!(msgs[1].content.contains("part 2"));
422    }
423
424    #[test]
425    fn test_no_merge_tool_messages() {
426        let mut msgs = vec![
427            user_msg("hello"),
428            assistant_with_tool_call("call_1", "file_read"),
429            tool_result_msg("call_1", "result 1"),
430            assistant_with_tool_call("call_2", "file_read"),
431            tool_result_msg("call_2", "result 2"),
432            assistant_msg("done"),
433        ];
434
435        let stats = repair_session(&mut msgs);
436        // Tool messages should not be merged, and the assistant messages with
437        // tool calls should not be merged with each other.
438        assert_eq!(stats.messages_merged, 0);
439        assert_eq!(msgs.len(), 6);
440    }
441
442    #[test]
443    fn test_clean_session_no_repairs() {
444        let mut msgs = vec![
445            user_msg("hello"),
446            assistant_with_tool_call("call_1", "file_read"),
447            tool_result_msg("call_1", "result"),
448            assistant_msg("done"),
449        ];
450
451        let stats = repair_session(&mut msgs);
452        assert!(!stats.any_repairs());
453        assert_eq!(msgs.len(), 4);
454    }
455
456    #[test]
457    fn test_idempotent() {
458        let mut msgs = vec![
459            user_msg("hello"),
460            empty_msg(Role::Assistant),
461            assistant_with_tool_call("call_1", "file_read"),
462            tool_result_msg("call_1", "result"),
463            tool_result_msg("call_999", "orphaned"),
464            user_msg("follow up"),
465            user_msg("more"),
466        ];
467
468        let stats1 = repair_session(&mut msgs);
469        assert!(stats1.any_repairs());
470
471        let snapshot = msgs.clone();
472        let stats2 = repair_session(&mut msgs);
473        assert!(!stats2.any_repairs());
474        assert_eq!(msgs.len(), snapshot.len());
475    }
476}