1use std::collections::HashSet;
11
12use tracing::{debug, info, warn};
13
14use punch_types::{Message, Role, ToolCallResult};
15
16#[derive(Debug, Clone, Default)]
18pub struct RepairStats {
19 pub empty_removed: usize,
21 pub orphaned_results_removed: usize,
23 pub synthetic_results_inserted: usize,
25 pub duplicate_results_removed: usize,
27 pub messages_merged: usize,
29}
30
31impl RepairStats {
32 pub fn any_repairs(&self) -> bool {
34 self.empty_removed > 0
35 || self.orphaned_results_removed > 0
36 || self.synthetic_results_inserted > 0
37 || self.duplicate_results_removed > 0
38 || self.messages_merged > 0
39 }
40}
41
42impl std::fmt::Display for RepairStats {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(
45 f,
46 "empty_removed={}, orphaned_results={}, synthetic_inserts={}, duplicates={}, merges={}",
47 self.empty_removed,
48 self.orphaned_results_removed,
49 self.synthetic_results_inserted,
50 self.duplicate_results_removed,
51 self.messages_merged,
52 )
53 }
54}
55
56pub fn repair_session(messages: &mut Vec<Message>) -> RepairStats {
61 let mut stats = RepairStats::default();
62
63 remove_empty_messages(messages, &mut stats);
65
66 remove_duplicate_tool_results(messages, &mut stats);
68
69 remove_orphaned_tool_results(messages, &mut stats);
71
72 insert_synthetic_results(messages, &mut stats);
74
75 merge_consecutive_same_role(messages, &mut stats);
77
78 if stats.any_repairs() {
79 info!(repairs = %stats, "session repair completed");
80 } else {
81 debug!("session repair: no repairs needed");
82 }
83
84 stats
85}
86
87fn remove_empty_messages(messages: &mut Vec<Message>, stats: &mut RepairStats) {
89 let before = messages.len();
90
91 messages.retain(|msg| {
92 let is_empty =
93 msg.content.is_empty() && msg.tool_calls.is_empty() && msg.tool_results.is_empty();
94 !is_empty
95 });
96
97 let removed = before - messages.len();
98 if removed > 0 {
99 debug!(count = removed, "removed empty messages");
100 stats.empty_removed = removed;
101 }
102}
103
104fn remove_orphaned_tool_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
106 let tool_use_ids: HashSet<String> = messages
108 .iter()
109 .filter(|m| m.role == Role::Assistant)
110 .flat_map(|m| &m.tool_calls)
111 .map(|tc| tc.id.clone())
112 .collect();
113
114 let mut removed = 0;
115
116 for msg in messages.iter_mut() {
117 if msg.role == Role::Tool && !msg.tool_results.is_empty() {
118 let before = msg.tool_results.len();
119 msg.tool_results.retain(|tr| tool_use_ids.contains(&tr.id));
120 let delta = before - msg.tool_results.len();
121 if delta > 0 {
122 warn!(
123 count = delta,
124 "removed orphaned tool results (no matching tool_use)"
125 );
126 removed += delta;
127 }
128 }
129 }
130
131 stats.orphaned_results_removed = removed;
132
133 messages.retain(|msg| {
135 if msg.role == Role::Tool {
136 !msg.tool_results.is_empty() || !msg.content.is_empty()
137 } else {
138 true
139 }
140 });
141}
142
143fn insert_synthetic_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
145 let result_ids: HashSet<String> = messages
147 .iter()
148 .filter(|m| m.role == Role::Tool)
149 .flat_map(|m| &m.tool_results)
150 .map(|tr| tr.id.clone())
151 .collect();
152
153 let mut insertions: Vec<(usize, Vec<ToolCallResult>)> = Vec::new();
156
157 for (idx, msg) in messages.iter().enumerate() {
158 if msg.role == Role::Assistant && !msg.tool_calls.is_empty() {
159 let missing: Vec<ToolCallResult> = msg
160 .tool_calls
161 .iter()
162 .filter(|tc| !result_ids.contains(&tc.id))
163 .map(|tc| {
164 warn!(
165 tool_use_id = %tc.id,
166 tool_name = %tc.name,
167 "inserting synthetic error result for orphaned tool_use"
168 );
169 ToolCallResult {
170 id: tc.id.clone(),
171 content: format!(
172 "Error: tool execution was interrupted or result was lost (tool: {})",
173 tc.name
174 ),
175 is_error: true,
176 image: None,
177 }
178 })
179 .collect();
180
181 if !missing.is_empty() {
182 insertions.push((idx, missing));
183 }
184 }
185 }
186
187 let mut inserted = 0;
189 for (idx, results) in insertions.into_iter().rev() {
190 let count = results.len();
191 inserted += count;
192
193 let tool_msg = Message {
194 role: Role::Tool,
195 content: String::new(),
196 tool_calls: Vec::new(),
197 tool_results: results,
198 timestamp: chrono::Utc::now(),
199 content_parts: Vec::new(),
200 };
201
202 let insert_pos = idx + 1;
204 if insert_pos <= messages.len() {
205 messages.insert(insert_pos, tool_msg);
206 } else {
207 messages.push(tool_msg);
208 }
209 }
210
211 stats.synthetic_results_inserted = inserted;
212}
213
214fn remove_duplicate_tool_results(messages: &mut [Message], stats: &mut RepairStats) {
216 let mut seen_ids: HashSet<String> = HashSet::new();
217 let mut removed = 0;
218
219 for msg in messages.iter_mut() {
220 if msg.role == Role::Tool && !msg.tool_results.is_empty() {
221 let before = msg.tool_results.len();
222 msg.tool_results.retain(|tr| seen_ids.insert(tr.id.clone()));
223 let delta = before - msg.tool_results.len();
224 if delta > 0 {
225 debug!(count = delta, "removed duplicate tool results");
226 removed += delta;
227 }
228 }
229 }
230
231 stats.duplicate_results_removed = removed;
232}
233
234fn merge_consecutive_same_role(messages: &mut Vec<Message>, stats: &mut RepairStats) {
239 if messages.len() < 2 {
240 return;
241 }
242
243 let mut merged = 0;
244 let mut result: Vec<Message> = Vec::with_capacity(messages.len());
245
246 for msg in messages.drain(..) {
247 if let Some(last) = result.last_mut() {
248 if last.role == msg.role && (msg.role == Role::User || msg.role == Role::Assistant) {
251 if !msg.content.is_empty() {
253 if !last.content.is_empty() {
254 last.content.push('\n');
255 }
256 last.content.push_str(&msg.content);
257 }
258 last.tool_calls.extend(msg.tool_calls);
260 last.tool_results.extend(msg.tool_results);
261 last.timestamp = msg.timestamp;
263 merged += 1;
264 continue;
265 }
266 }
267 result.push(msg);
268 }
269
270 *messages = result;
271 stats.messages_merged = merged;
272}
273
274#[cfg(test)]
279mod tests {
280 use super::*;
281 use punch_types::{Message, Role, ToolCall, ToolCallResult};
282
283 fn user_msg(content: &str) -> Message {
284 Message::new(Role::User, content)
285 }
286
287 fn assistant_msg(content: &str) -> Message {
288 Message::new(Role::Assistant, content)
289 }
290
291 fn assistant_with_tool_call(tool_id: &str, tool_name: &str) -> Message {
292 Message {
293 role: Role::Assistant,
294 content: String::new(),
295 tool_calls: vec![ToolCall {
296 id: tool_id.to_string(),
297 name: tool_name.to_string(),
298 input: serde_json::json!({}),
299 }],
300 tool_results: Vec::new(),
301 timestamp: chrono::Utc::now(),
302 content_parts: Vec::new(),
303 }
304 }
305
306 fn tool_result_msg(id: &str, content: &str) -> Message {
307 Message {
308 role: Role::Tool,
309 content: String::new(),
310 tool_calls: Vec::new(),
311 tool_results: vec![ToolCallResult {
312 id: id.to_string(),
313 content: content.to_string(),
314 is_error: false,
315 image: None,
316 }],
317 timestamp: chrono::Utc::now(),
318 content_parts: Vec::new(),
319 }
320 }
321
322 fn empty_msg(role: Role) -> Message {
323 Message {
324 role,
325 content: String::new(),
326 tool_calls: Vec::new(),
327 tool_results: Vec::new(),
328 timestamp: chrono::Utc::now(),
329 content_parts: Vec::new(),
330 }
331 }
332
333 #[test]
334 fn test_remove_empty_messages() {
335 let mut msgs = vec![
336 user_msg("hello"),
337 empty_msg(Role::Assistant),
338 assistant_msg("world"),
339 ];
340
341 let stats = repair_session(&mut msgs);
342 assert_eq!(stats.empty_removed, 1);
343 assert_eq!(msgs.len(), 2);
344 }
345
346 #[test]
347 fn test_remove_orphaned_tool_results() {
348 let mut msgs = vec![
349 user_msg("hello"),
350 assistant_with_tool_call("call_1", "file_read"),
351 tool_result_msg("call_1", "file contents"),
352 tool_result_msg("call_999", "orphaned result"),
354 ];
355
356 let stats = repair_session(&mut msgs);
357 assert_eq!(stats.orphaned_results_removed, 1);
358 assert_eq!(msgs.len(), 3);
360 }
361
362 #[test]
363 fn test_insert_synthetic_results() {
364 let mut msgs = vec![
365 user_msg("do something"),
366 assistant_with_tool_call("call_1", "shell_exec"),
367 assistant_msg("I ran the command"),
369 ];
370
371 let stats = repair_session(&mut msgs);
372 assert_eq!(stats.synthetic_results_inserted, 1);
373
374 assert_eq!(msgs.len(), 4);
376 assert_eq!(msgs[2].role, Role::Tool);
377 assert!(msgs[2].tool_results[0].is_error);
378 assert!(msgs[2].tool_results[0].content.contains("interrupted"));
379 }
380
381 #[test]
382 fn test_remove_duplicate_tool_results() {
383 let mut msgs = vec![
384 user_msg("hello"),
385 assistant_with_tool_call("call_1", "file_read"),
386 tool_result_msg("call_1", "first result"),
387 tool_result_msg("call_1", "duplicate result"),
388 ];
389
390 let stats = repair_session(&mut msgs);
391 assert_eq!(stats.duplicate_results_removed, 1);
392 }
393
394 #[test]
395 fn test_merge_consecutive_user_messages() {
396 let mut msgs = vec![
397 user_msg("hello"),
398 user_msg("world"),
399 assistant_msg("hi there"),
400 ];
401
402 let stats = repair_session(&mut msgs);
403 assert_eq!(stats.messages_merged, 1);
404 assert_eq!(msgs.len(), 2);
405 assert!(msgs[0].content.contains("hello"));
406 assert!(msgs[0].content.contains("world"));
407 }
408
409 #[test]
410 fn test_merge_consecutive_assistant_messages() {
411 let mut msgs = vec![
412 user_msg("hello"),
413 assistant_msg("part 1"),
414 assistant_msg("part 2"),
415 ];
416
417 let stats = repair_session(&mut msgs);
418 assert_eq!(stats.messages_merged, 1);
419 assert_eq!(msgs.len(), 2);
420 assert!(msgs[1].content.contains("part 1"));
421 assert!(msgs[1].content.contains("part 2"));
422 }
423
424 #[test]
425 fn test_no_merge_tool_messages() {
426 let mut msgs = vec![
427 user_msg("hello"),
428 assistant_with_tool_call("call_1", "file_read"),
429 tool_result_msg("call_1", "result 1"),
430 assistant_with_tool_call("call_2", "file_read"),
431 tool_result_msg("call_2", "result 2"),
432 assistant_msg("done"),
433 ];
434
435 let stats = repair_session(&mut msgs);
436 assert_eq!(stats.messages_merged, 0);
439 assert_eq!(msgs.len(), 6);
440 }
441
442 #[test]
443 fn test_clean_session_no_repairs() {
444 let mut msgs = vec![
445 user_msg("hello"),
446 assistant_with_tool_call("call_1", "file_read"),
447 tool_result_msg("call_1", "result"),
448 assistant_msg("done"),
449 ];
450
451 let stats = repair_session(&mut msgs);
452 assert!(!stats.any_repairs());
453 assert_eq!(msgs.len(), 4);
454 }
455
456 #[test]
457 fn test_idempotent() {
458 let mut msgs = vec![
459 user_msg("hello"),
460 empty_msg(Role::Assistant),
461 assistant_with_tool_call("call_1", "file_read"),
462 tool_result_msg("call_1", "result"),
463 tool_result_msg("call_999", "orphaned"),
464 user_msg("follow up"),
465 user_msg("more"),
466 ];
467
468 let stats1 = repair_session(&mut msgs);
469 assert!(stats1.any_repairs());
470
471 let snapshot = msgs.clone();
472 let stats2 = repair_session(&mut msgs);
473 assert!(!stats2.any_repairs());
474 assert_eq!(msgs.len(), snapshot.len());
475 }
476}