Skip to main content

punch_runtime/
session_repair.rs

1//! Session message history repair.
2//!
3//! Fixes common issues in message histories that can cause LLM API errors:
4//! - Orphaned tool results (tool_result with no matching tool_use)
5//! - Empty messages
6//! - Consecutive same-role messages that should be merged
7//! - Tool uses with no corresponding result
8//! - Duplicate tool results for the same tool_use_id
9
10use std::collections::HashSet;
11
12use tracing::{debug, info, warn};
13
14use punch_types::{Message, Role, ToolCallResult};
15
16/// Statistics from a repair pass.
17#[derive(Debug, Clone, Default)]
18pub struct RepairStats {
19    /// Number of empty messages removed.
20    pub empty_removed: usize,
21    /// Number of orphaned tool results removed.
22    pub orphaned_results_removed: usize,
23    /// Number of synthetic error results inserted for tool_uses without results.
24    pub synthetic_results_inserted: usize,
25    /// Number of duplicate tool results removed.
26    pub duplicate_results_removed: usize,
27    /// Number of consecutive same-role message merges performed.
28    pub messages_merged: usize,
29}
30
31impl RepairStats {
32    /// Whether any repairs were made.
33    pub fn any_repairs(&self) -> bool {
34        self.empty_removed > 0
35            || self.orphaned_results_removed > 0
36            || self.synthetic_results_inserted > 0
37            || self.duplicate_results_removed > 0
38            || self.messages_merged > 0
39    }
40}
41
42impl std::fmt::Display for RepairStats {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(
45            f,
46            "empty_removed={}, orphaned_results={}, synthetic_inserts={}, duplicates={}, merges={}",
47            self.empty_removed,
48            self.orphaned_results_removed,
49            self.synthetic_results_inserted,
50            self.duplicate_results_removed,
51            self.messages_merged,
52        )
53    }
54}
55
56/// Run all repair passes on a message history.
57///
58/// This is idempotent: running repair twice produces the same result.
59/// The passes run in a specific order to handle dependencies correctly.
60pub fn repair_session(messages: &mut Vec<Message>) -> RepairStats {
61    let mut stats = RepairStats::default();
62
63    // Pass 1: Remove empty messages.
64    remove_empty_messages(messages, &mut stats);
65
66    // Pass 2: Remove duplicate tool results.
67    remove_duplicate_tool_results(messages, &mut stats);
68
69    // Pass 3: Fix orphaned tool results (results with no matching tool_use).
70    remove_orphaned_tool_results(messages, &mut stats);
71
72    // Pass 4: Insert synthetic error results for tool_uses with no result.
73    insert_synthetic_results(messages, &mut stats);
74
75    // Pass 5: Merge consecutive same-role messages.
76    merge_consecutive_same_role(messages, &mut stats);
77
78    if stats.any_repairs() {
79        info!(repairs = %stats, "session repair completed");
80    } else {
81        debug!("session repair: no repairs needed");
82    }
83
84    stats
85}
86
87/// Remove messages that have no content, no tool calls, and no tool results.
88fn remove_empty_messages(messages: &mut Vec<Message>, stats: &mut RepairStats) {
89    let before = messages.len();
90
91    messages.retain(|msg| {
92        let is_empty =
93            msg.content.is_empty() && msg.tool_calls.is_empty() && msg.tool_results.is_empty();
94        !is_empty
95    });
96
97    let removed = before - messages.len();
98    if removed > 0 {
99        debug!(count = removed, "removed empty messages");
100        stats.empty_removed = removed;
101    }
102}
103
104/// Remove tool results whose tool_use_id does not match any tool_use in the history.
105fn remove_orphaned_tool_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
106    // Collect all tool_use IDs from assistant messages.
107    let tool_use_ids: HashSet<String> = messages
108        .iter()
109        .filter(|m| m.role == Role::Assistant)
110        .flat_map(|m| &m.tool_calls)
111        .map(|tc| tc.id.clone())
112        .collect();
113
114    let mut removed = 0;
115
116    for msg in messages.iter_mut() {
117        if msg.role == Role::Tool && !msg.tool_results.is_empty() {
118            let before = msg.tool_results.len();
119            msg.tool_results.retain(|tr| tool_use_ids.contains(&tr.id));
120            let delta = before - msg.tool_results.len();
121            if delta > 0 {
122                warn!(
123                    count = delta,
124                    "removed orphaned tool results (no matching tool_use)"
125                );
126                removed += delta;
127            }
128        }
129    }
130
131    stats.orphaned_results_removed = removed;
132
133    // Also remove any Tool messages that now have zero results and no content.
134    messages.retain(|msg| {
135        if msg.role == Role::Tool {
136            !msg.tool_results.is_empty() || !msg.content.is_empty()
137        } else {
138            true
139        }
140    });
141}
142
143/// Insert synthetic error results for tool_uses that have no corresponding result.
144fn insert_synthetic_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
145    // Collect all tool_result IDs.
146    let result_ids: HashSet<String> = messages
147        .iter()
148        .filter(|m| m.role == Role::Tool)
149        .flat_map(|m| &m.tool_results)
150        .map(|tr| tr.id.clone())
151        .collect();
152
153    // Find tool_uses that have no result, grouped by their position.
154    // We need to insert results AFTER the assistant message that made the call.
155    let mut insertions: Vec<(usize, Vec<ToolCallResult>)> = Vec::new();
156
157    for (idx, msg) in messages.iter().enumerate() {
158        if msg.role == Role::Assistant && !msg.tool_calls.is_empty() {
159            let missing: Vec<ToolCallResult> = msg
160                .tool_calls
161                .iter()
162                .filter(|tc| !result_ids.contains(&tc.id))
163                .map(|tc| {
164                    warn!(
165                        tool_use_id = %tc.id,
166                        tool_name = %tc.name,
167                        "inserting synthetic error result for orphaned tool_use"
168                    );
169                    ToolCallResult {
170                        id: tc.id.clone(),
171                        content: format!(
172                            "Error: tool execution was interrupted or result was lost (tool: {})",
173                            tc.name
174                        ),
175                        is_error: true,
176                    }
177                })
178                .collect();
179
180            if !missing.is_empty() {
181                insertions.push((idx, missing));
182            }
183        }
184    }
185
186    // Insert in reverse order to preserve indices.
187    let mut inserted = 0;
188    for (idx, results) in insertions.into_iter().rev() {
189        let count = results.len();
190        inserted += count;
191
192        let tool_msg = Message {
193            role: Role::Tool,
194            content: String::new(),
195            tool_calls: Vec::new(),
196            tool_results: results,
197            timestamp: chrono::Utc::now(),
198        };
199
200        // Insert right after the assistant message.
201        let insert_pos = idx + 1;
202        if insert_pos <= messages.len() {
203            messages.insert(insert_pos, tool_msg);
204        } else {
205            messages.push(tool_msg);
206        }
207    }
208
209    stats.synthetic_results_inserted = inserted;
210}
211
212/// Remove duplicate tool results (same tool_use_id appearing more than once).
213fn remove_duplicate_tool_results(messages: &mut [Message], stats: &mut RepairStats) {
214    let mut seen_ids: HashSet<String> = HashSet::new();
215    let mut removed = 0;
216
217    for msg in messages.iter_mut() {
218        if msg.role == Role::Tool && !msg.tool_results.is_empty() {
219            let before = msg.tool_results.len();
220            msg.tool_results.retain(|tr| seen_ids.insert(tr.id.clone()));
221            let delta = before - msg.tool_results.len();
222            if delta > 0 {
223                debug!(count = delta, "removed duplicate tool results");
224                removed += delta;
225            }
226        }
227    }
228
229    stats.duplicate_results_removed = removed;
230}
231
232/// Merge consecutive messages with the same role.
233///
234/// This handles cases where, e.g., two user messages end up adjacent
235/// (which some LLM APIs reject).
236fn merge_consecutive_same_role(messages: &mut Vec<Message>, stats: &mut RepairStats) {
237    if messages.len() < 2 {
238        return;
239    }
240
241    let mut merged = 0;
242    let mut result: Vec<Message> = Vec::with_capacity(messages.len());
243
244    for msg in messages.drain(..) {
245        if let Some(last) = result.last_mut() {
246            // Only merge User-User or Assistant-Assistant.
247            // Tool messages have special structure and should not be merged.
248            if last.role == msg.role && (msg.role == Role::User || msg.role == Role::Assistant) {
249                // Merge content.
250                if !msg.content.is_empty() {
251                    if !last.content.is_empty() {
252                        last.content.push('\n');
253                    }
254                    last.content.push_str(&msg.content);
255                }
256                // Merge tool calls and results.
257                last.tool_calls.extend(msg.tool_calls);
258                last.tool_results.extend(msg.tool_results);
259                // Keep the later timestamp.
260                last.timestamp = msg.timestamp;
261                merged += 1;
262                continue;
263            }
264        }
265        result.push(msg);
266    }
267
268    *messages = result;
269    stats.messages_merged = merged;
270}
271
272// ---------------------------------------------------------------------------
273// Tests
274// ---------------------------------------------------------------------------
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use punch_types::{Message, Role, ToolCall, ToolCallResult};
280
281    fn user_msg(content: &str) -> Message {
282        Message::new(Role::User, content)
283    }
284
285    fn assistant_msg(content: &str) -> Message {
286        Message::new(Role::Assistant, content)
287    }
288
289    fn assistant_with_tool_call(tool_id: &str, tool_name: &str) -> Message {
290        Message {
291            role: Role::Assistant,
292            content: String::new(),
293            tool_calls: vec![ToolCall {
294                id: tool_id.to_string(),
295                name: tool_name.to_string(),
296                input: serde_json::json!({}),
297            }],
298            tool_results: Vec::new(),
299            timestamp: chrono::Utc::now(),
300        }
301    }
302
303    fn tool_result_msg(id: &str, content: &str) -> Message {
304        Message {
305            role: Role::Tool,
306            content: String::new(),
307            tool_calls: Vec::new(),
308            tool_results: vec![ToolCallResult {
309                id: id.to_string(),
310                content: content.to_string(),
311                is_error: false,
312            }],
313            timestamp: chrono::Utc::now(),
314        }
315    }
316
317    fn empty_msg(role: Role) -> Message {
318        Message {
319            role,
320            content: String::new(),
321            tool_calls: Vec::new(),
322            tool_results: Vec::new(),
323            timestamp: chrono::Utc::now(),
324        }
325    }
326
327    #[test]
328    fn test_remove_empty_messages() {
329        let mut msgs = vec![
330            user_msg("hello"),
331            empty_msg(Role::Assistant),
332            assistant_msg("world"),
333        ];
334
335        let stats = repair_session(&mut msgs);
336        assert_eq!(stats.empty_removed, 1);
337        assert_eq!(msgs.len(), 2);
338    }
339
340    #[test]
341    fn test_remove_orphaned_tool_results() {
342        let mut msgs = vec![
343            user_msg("hello"),
344            assistant_with_tool_call("call_1", "file_read"),
345            tool_result_msg("call_1", "file contents"),
346            // This tool result has no matching tool_use:
347            tool_result_msg("call_999", "orphaned result"),
348        ];
349
350        let stats = repair_session(&mut msgs);
351        assert_eq!(stats.orphaned_results_removed, 1);
352        // The orphaned tool message should be fully removed since it has no results left.
353        assert_eq!(msgs.len(), 3);
354    }
355
356    #[test]
357    fn test_insert_synthetic_results() {
358        let mut msgs = vec![
359            user_msg("do something"),
360            assistant_with_tool_call("call_1", "shell_exec"),
361            // No tool result for call_1!
362            assistant_msg("I ran the command"),
363        ];
364
365        let stats = repair_session(&mut msgs);
366        assert_eq!(stats.synthetic_results_inserted, 1);
367
368        // Should now have: user, assistant(tool_use), tool(synthetic), assistant
369        assert_eq!(msgs.len(), 4);
370        assert_eq!(msgs[2].role, Role::Tool);
371        assert!(msgs[2].tool_results[0].is_error);
372        assert!(msgs[2].tool_results[0].content.contains("interrupted"));
373    }
374
375    #[test]
376    fn test_remove_duplicate_tool_results() {
377        let mut msgs = vec![
378            user_msg("hello"),
379            assistant_with_tool_call("call_1", "file_read"),
380            tool_result_msg("call_1", "first result"),
381            tool_result_msg("call_1", "duplicate result"),
382        ];
383
384        let stats = repair_session(&mut msgs);
385        assert_eq!(stats.duplicate_results_removed, 1);
386    }
387
388    #[test]
389    fn test_merge_consecutive_user_messages() {
390        let mut msgs = vec![
391            user_msg("hello"),
392            user_msg("world"),
393            assistant_msg("hi there"),
394        ];
395
396        let stats = repair_session(&mut msgs);
397        assert_eq!(stats.messages_merged, 1);
398        assert_eq!(msgs.len(), 2);
399        assert!(msgs[0].content.contains("hello"));
400        assert!(msgs[0].content.contains("world"));
401    }
402
403    #[test]
404    fn test_merge_consecutive_assistant_messages() {
405        let mut msgs = vec![
406            user_msg("hello"),
407            assistant_msg("part 1"),
408            assistant_msg("part 2"),
409        ];
410
411        let stats = repair_session(&mut msgs);
412        assert_eq!(stats.messages_merged, 1);
413        assert_eq!(msgs.len(), 2);
414        assert!(msgs[1].content.contains("part 1"));
415        assert!(msgs[1].content.contains("part 2"));
416    }
417
418    #[test]
419    fn test_no_merge_tool_messages() {
420        let mut msgs = vec![
421            user_msg("hello"),
422            assistant_with_tool_call("call_1", "file_read"),
423            tool_result_msg("call_1", "result 1"),
424            assistant_with_tool_call("call_2", "file_read"),
425            tool_result_msg("call_2", "result 2"),
426            assistant_msg("done"),
427        ];
428
429        let stats = repair_session(&mut msgs);
430        // Tool messages should not be merged, and the assistant messages with
431        // tool calls should not be merged with each other.
432        assert_eq!(stats.messages_merged, 0);
433        assert_eq!(msgs.len(), 6);
434    }
435
436    #[test]
437    fn test_clean_session_no_repairs() {
438        let mut msgs = vec![
439            user_msg("hello"),
440            assistant_with_tool_call("call_1", "file_read"),
441            tool_result_msg("call_1", "result"),
442            assistant_msg("done"),
443        ];
444
445        let stats = repair_session(&mut msgs);
446        assert!(!stats.any_repairs());
447        assert_eq!(msgs.len(), 4);
448    }
449
450    #[test]
451    fn test_idempotent() {
452        let mut msgs = vec![
453            user_msg("hello"),
454            empty_msg(Role::Assistant),
455            assistant_with_tool_call("call_1", "file_read"),
456            tool_result_msg("call_1", "result"),
457            tool_result_msg("call_999", "orphaned"),
458            user_msg("follow up"),
459            user_msg("more"),
460        ];
461
462        let stats1 = repair_session(&mut msgs);
463        assert!(stats1.any_repairs());
464
465        let snapshot = msgs.clone();
466        let stats2 = repair_session(&mut msgs);
467        assert!(!stats2.any_repairs());
468        assert_eq!(msgs.len(), snapshot.len());
469    }
470}