1use std::collections::HashSet;
11
12use tracing::{debug, info, warn};
13
14use punch_types::{Message, Role, ToolCallResult};
15
16#[derive(Debug, Clone, Default)]
18pub struct RepairStats {
19 pub empty_removed: usize,
21 pub orphaned_results_removed: usize,
23 pub synthetic_results_inserted: usize,
25 pub duplicate_results_removed: usize,
27 pub messages_merged: usize,
29}
30
31impl RepairStats {
32 pub fn any_repairs(&self) -> bool {
34 self.empty_removed > 0
35 || self.orphaned_results_removed > 0
36 || self.synthetic_results_inserted > 0
37 || self.duplicate_results_removed > 0
38 || self.messages_merged > 0
39 }
40}
41
42impl std::fmt::Display for RepairStats {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(
45 f,
46 "empty_removed={}, orphaned_results={}, synthetic_inserts={}, duplicates={}, merges={}",
47 self.empty_removed,
48 self.orphaned_results_removed,
49 self.synthetic_results_inserted,
50 self.duplicate_results_removed,
51 self.messages_merged,
52 )
53 }
54}
55
56pub fn repair_session(messages: &mut Vec<Message>) -> RepairStats {
61 let mut stats = RepairStats::default();
62
63 remove_empty_messages(messages, &mut stats);
65
66 remove_duplicate_tool_results(messages, &mut stats);
68
69 remove_orphaned_tool_results(messages, &mut stats);
71
72 insert_synthetic_results(messages, &mut stats);
74
75 merge_consecutive_same_role(messages, &mut stats);
77
78 if stats.any_repairs() {
79 info!(repairs = %stats, "session repair completed");
80 } else {
81 debug!("session repair: no repairs needed");
82 }
83
84 stats
85}
86
87fn remove_empty_messages(messages: &mut Vec<Message>, stats: &mut RepairStats) {
89 let before = messages.len();
90
91 messages.retain(|msg| {
92 let is_empty =
93 msg.content.is_empty() && msg.tool_calls.is_empty() && msg.tool_results.is_empty();
94 !is_empty
95 });
96
97 let removed = before - messages.len();
98 if removed > 0 {
99 debug!(count = removed, "removed empty messages");
100 stats.empty_removed = removed;
101 }
102}
103
104fn remove_orphaned_tool_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
106 let tool_use_ids: HashSet<String> = messages
108 .iter()
109 .filter(|m| m.role == Role::Assistant)
110 .flat_map(|m| &m.tool_calls)
111 .map(|tc| tc.id.clone())
112 .collect();
113
114 let mut removed = 0;
115
116 for msg in messages.iter_mut() {
117 if msg.role == Role::Tool && !msg.tool_results.is_empty() {
118 let before = msg.tool_results.len();
119 msg.tool_results.retain(|tr| tool_use_ids.contains(&tr.id));
120 let delta = before - msg.tool_results.len();
121 if delta > 0 {
122 warn!(
123 count = delta,
124 "removed orphaned tool results (no matching tool_use)"
125 );
126 removed += delta;
127 }
128 }
129 }
130
131 stats.orphaned_results_removed = removed;
132
133 messages.retain(|msg| {
135 if msg.role == Role::Tool {
136 !msg.tool_results.is_empty() || !msg.content.is_empty()
137 } else {
138 true
139 }
140 });
141}
142
143fn insert_synthetic_results(messages: &mut Vec<Message>, stats: &mut RepairStats) {
145 let result_ids: HashSet<String> = messages
147 .iter()
148 .filter(|m| m.role == Role::Tool)
149 .flat_map(|m| &m.tool_results)
150 .map(|tr| tr.id.clone())
151 .collect();
152
153 let mut insertions: Vec<(usize, Vec<ToolCallResult>)> = Vec::new();
156
157 for (idx, msg) in messages.iter().enumerate() {
158 if msg.role == Role::Assistant && !msg.tool_calls.is_empty() {
159 let missing: Vec<ToolCallResult> = msg
160 .tool_calls
161 .iter()
162 .filter(|tc| !result_ids.contains(&tc.id))
163 .map(|tc| {
164 warn!(
165 tool_use_id = %tc.id,
166 tool_name = %tc.name,
167 "inserting synthetic error result for orphaned tool_use"
168 );
169 ToolCallResult {
170 id: tc.id.clone(),
171 content: format!(
172 "Error: tool execution was interrupted or result was lost (tool: {})",
173 tc.name
174 ),
175 is_error: true,
176 }
177 })
178 .collect();
179
180 if !missing.is_empty() {
181 insertions.push((idx, missing));
182 }
183 }
184 }
185
186 let mut inserted = 0;
188 for (idx, results) in insertions.into_iter().rev() {
189 let count = results.len();
190 inserted += count;
191
192 let tool_msg = Message {
193 role: Role::Tool,
194 content: String::new(),
195 tool_calls: Vec::new(),
196 tool_results: results,
197 timestamp: chrono::Utc::now(),
198 };
199
200 let insert_pos = idx + 1;
202 if insert_pos <= messages.len() {
203 messages.insert(insert_pos, tool_msg);
204 } else {
205 messages.push(tool_msg);
206 }
207 }
208
209 stats.synthetic_results_inserted = inserted;
210}
211
212fn remove_duplicate_tool_results(messages: &mut [Message], stats: &mut RepairStats) {
214 let mut seen_ids: HashSet<String> = HashSet::new();
215 let mut removed = 0;
216
217 for msg in messages.iter_mut() {
218 if msg.role == Role::Tool && !msg.tool_results.is_empty() {
219 let before = msg.tool_results.len();
220 msg.tool_results.retain(|tr| seen_ids.insert(tr.id.clone()));
221 let delta = before - msg.tool_results.len();
222 if delta > 0 {
223 debug!(count = delta, "removed duplicate tool results");
224 removed += delta;
225 }
226 }
227 }
228
229 stats.duplicate_results_removed = removed;
230}
231
232fn merge_consecutive_same_role(messages: &mut Vec<Message>, stats: &mut RepairStats) {
237 if messages.len() < 2 {
238 return;
239 }
240
241 let mut merged = 0;
242 let mut result: Vec<Message> = Vec::with_capacity(messages.len());
243
244 for msg in messages.drain(..) {
245 if let Some(last) = result.last_mut() {
246 if last.role == msg.role && (msg.role == Role::User || msg.role == Role::Assistant) {
249 if !msg.content.is_empty() {
251 if !last.content.is_empty() {
252 last.content.push('\n');
253 }
254 last.content.push_str(&msg.content);
255 }
256 last.tool_calls.extend(msg.tool_calls);
258 last.tool_results.extend(msg.tool_results);
259 last.timestamp = msg.timestamp;
261 merged += 1;
262 continue;
263 }
264 }
265 result.push(msg);
266 }
267
268 *messages = result;
269 stats.messages_merged = merged;
270}
271
272#[cfg(test)]
277mod tests {
278 use super::*;
279 use punch_types::{Message, Role, ToolCall, ToolCallResult};
280
281 fn user_msg(content: &str) -> Message {
282 Message::new(Role::User, content)
283 }
284
285 fn assistant_msg(content: &str) -> Message {
286 Message::new(Role::Assistant, content)
287 }
288
289 fn assistant_with_tool_call(tool_id: &str, tool_name: &str) -> Message {
290 Message {
291 role: Role::Assistant,
292 content: String::new(),
293 tool_calls: vec![ToolCall {
294 id: tool_id.to_string(),
295 name: tool_name.to_string(),
296 input: serde_json::json!({}),
297 }],
298 tool_results: Vec::new(),
299 timestamp: chrono::Utc::now(),
300 }
301 }
302
303 fn tool_result_msg(id: &str, content: &str) -> Message {
304 Message {
305 role: Role::Tool,
306 content: String::new(),
307 tool_calls: Vec::new(),
308 tool_results: vec![ToolCallResult {
309 id: id.to_string(),
310 content: content.to_string(),
311 is_error: false,
312 }],
313 timestamp: chrono::Utc::now(),
314 }
315 }
316
317 fn empty_msg(role: Role) -> Message {
318 Message {
319 role,
320 content: String::new(),
321 tool_calls: Vec::new(),
322 tool_results: Vec::new(),
323 timestamp: chrono::Utc::now(),
324 }
325 }
326
327 #[test]
328 fn test_remove_empty_messages() {
329 let mut msgs = vec![
330 user_msg("hello"),
331 empty_msg(Role::Assistant),
332 assistant_msg("world"),
333 ];
334
335 let stats = repair_session(&mut msgs);
336 assert_eq!(stats.empty_removed, 1);
337 assert_eq!(msgs.len(), 2);
338 }
339
340 #[test]
341 fn test_remove_orphaned_tool_results() {
342 let mut msgs = vec![
343 user_msg("hello"),
344 assistant_with_tool_call("call_1", "file_read"),
345 tool_result_msg("call_1", "file contents"),
346 tool_result_msg("call_999", "orphaned result"),
348 ];
349
350 let stats = repair_session(&mut msgs);
351 assert_eq!(stats.orphaned_results_removed, 1);
352 assert_eq!(msgs.len(), 3);
354 }
355
356 #[test]
357 fn test_insert_synthetic_results() {
358 let mut msgs = vec![
359 user_msg("do something"),
360 assistant_with_tool_call("call_1", "shell_exec"),
361 assistant_msg("I ran the command"),
363 ];
364
365 let stats = repair_session(&mut msgs);
366 assert_eq!(stats.synthetic_results_inserted, 1);
367
368 assert_eq!(msgs.len(), 4);
370 assert_eq!(msgs[2].role, Role::Tool);
371 assert!(msgs[2].tool_results[0].is_error);
372 assert!(msgs[2].tool_results[0].content.contains("interrupted"));
373 }
374
375 #[test]
376 fn test_remove_duplicate_tool_results() {
377 let mut msgs = vec![
378 user_msg("hello"),
379 assistant_with_tool_call("call_1", "file_read"),
380 tool_result_msg("call_1", "first result"),
381 tool_result_msg("call_1", "duplicate result"),
382 ];
383
384 let stats = repair_session(&mut msgs);
385 assert_eq!(stats.duplicate_results_removed, 1);
386 }
387
388 #[test]
389 fn test_merge_consecutive_user_messages() {
390 let mut msgs = vec![
391 user_msg("hello"),
392 user_msg("world"),
393 assistant_msg("hi there"),
394 ];
395
396 let stats = repair_session(&mut msgs);
397 assert_eq!(stats.messages_merged, 1);
398 assert_eq!(msgs.len(), 2);
399 assert!(msgs[0].content.contains("hello"));
400 assert!(msgs[0].content.contains("world"));
401 }
402
403 #[test]
404 fn test_merge_consecutive_assistant_messages() {
405 let mut msgs = vec![
406 user_msg("hello"),
407 assistant_msg("part 1"),
408 assistant_msg("part 2"),
409 ];
410
411 let stats = repair_session(&mut msgs);
412 assert_eq!(stats.messages_merged, 1);
413 assert_eq!(msgs.len(), 2);
414 assert!(msgs[1].content.contains("part 1"));
415 assert!(msgs[1].content.contains("part 2"));
416 }
417
418 #[test]
419 fn test_no_merge_tool_messages() {
420 let mut msgs = vec![
421 user_msg("hello"),
422 assistant_with_tool_call("call_1", "file_read"),
423 tool_result_msg("call_1", "result 1"),
424 assistant_with_tool_call("call_2", "file_read"),
425 tool_result_msg("call_2", "result 2"),
426 assistant_msg("done"),
427 ];
428
429 let stats = repair_session(&mut msgs);
430 assert_eq!(stats.messages_merged, 0);
433 assert_eq!(msgs.len(), 6);
434 }
435
436 #[test]
437 fn test_clean_session_no_repairs() {
438 let mut msgs = vec![
439 user_msg("hello"),
440 assistant_with_tool_call("call_1", "file_read"),
441 tool_result_msg("call_1", "result"),
442 assistant_msg("done"),
443 ];
444
445 let stats = repair_session(&mut msgs);
446 assert!(!stats.any_repairs());
447 assert_eq!(msgs.len(), 4);
448 }
449
450 #[test]
451 fn test_idempotent() {
452 let mut msgs = vec![
453 user_msg("hello"),
454 empty_msg(Role::Assistant),
455 assistant_with_tool_call("call_1", "file_read"),
456 tool_result_msg("call_1", "result"),
457 tool_result_msg("call_999", "orphaned"),
458 user_msg("follow up"),
459 user_msg("more"),
460 ];
461
462 let stats1 = repair_session(&mut msgs);
463 assert!(stats1.any_repairs());
464
465 let snapshot = msgs.clone();
466 let stats2 = repair_session(&mut msgs);
467 assert!(!stats2.any_repairs());
468 assert_eq!(msgs.len(), snapshot.len());
469 }
470}