Skip to main content

deepstrike_core/runtime/
repair.rs

1use crate::context::text::truncate_with_suffix;
2use crate::runtime::session::{ProviderReplay, SessionEvent};
3use crate::types::message::{Content, ContentPart, Message, Role, ToolCall};
4
5/// Sanitize text for recovery paths: ensure valid UTF-8 and apply an optional
6/// byte cap derived from the caller's context config. When `max_bytes` is 0
7/// no cap is applied.
8pub fn sanitize_recovery_text(text: &str) -> String {
9    sanitize_recovery_text_bounded(text, 0)
10}
11
12pub fn sanitize_recovery_text_bounded(text: &str, max_bytes: usize) -> String {
13    if text.is_empty() {
14        return String::new();
15    }
16    if max_bytes > 0 && text.len() > max_bytes {
17        return truncate_with_suffix(text, max_bytes, "… [replay truncated]");
18    }
19    text.to_owned()
20}
21
22fn estimate_token_count(text: &str) -> u32 {
23    // Char count / 4 approximation — more accurate than byte count for CJK.
24    (text.chars().count() as u32 / 4).max(1)
25}
26
27fn normalize_assistant_message_with_cap(message: &mut Message, max_bytes: usize) {
28    if message.token_count.is_none() {
29        message.token_count = Some(estimate_token_count(
30            message.content.as_text().unwrap_or(""),
31        ));
32    }
33    if let Content::Text(text) = &mut message.content {
34        *text = sanitize_recovery_text_bounded(text, max_bytes);
35    }
36}
37
38/// Normalize a single `LlmCompleted` for recovery (message fields only).
39///
40/// Provider-neutral: the stored `provider_replay` envelope is left untouched.
41/// The core never synthesizes a protocol-specific replay shape — legacy
42/// reconstruction is the responsibility of the target provider in the SDK.
43pub fn repair_llm_completed(message: &mut Message, provider_replay: &mut Option<ProviderReplay>) {
44    repair_llm_completed_with_cap(message, provider_replay, 0);
45}
46
47pub fn repair_llm_completed_with_cap(
48    message: &mut Message,
49    _provider_replay: &mut Option<ProviderReplay>,
50    max_bytes: usize,
51) {
52    normalize_assistant_message_with_cap(message, max_bytes);
53}
54
55/// Repair event log entries in place for recovery minimum set completeness.
56pub fn repair_events(events: Vec<SessionEvent>) -> Vec<SessionEvent> {
57    repair_events_with_cap(events, 0)
58}
59
60pub fn repair_events_with_cap(events: Vec<SessionEvent>, max_bytes: usize) -> Vec<SessionEvent> {
61    events
62        .into_iter()
63        .map(|mut event| {
64            if let SessionEvent::LlmCompleted {
65                ref mut message,
66                ref mut provider_replay,
67                ..
68            } = event
69            {
70                repair_llm_completed_with_cap(message, provider_replay, max_bytes);
71            }
72            event
73        })
74        .collect()
75}
76
77/// Pending tool calls after the last assistant turn in preloaded history.
78pub fn pending_tool_calls_from_messages(messages: &[Message]) -> Vec<ToolCall> {
79    let Some(assistant_idx) = messages
80        .iter()
81        .rposition(|m| m.role == Role::Assistant && !m.tool_calls.is_empty())
82    else {
83        return Vec::new();
84    };
85
86    let assistant = &messages[assistant_idx];
87    let mut completed = std::collections::HashSet::new();
88    for msg in &messages[assistant_idx + 1..] {
89        if msg.role != Role::Tool {
90            continue;
91        }
92        if let Content::Parts(parts) = &msg.content {
93            for part in parts {
94                if let ContentPart::ToolResult { call_id, .. } = part {
95                    completed.insert(call_id.clone());
96                }
97            }
98        }
99    }
100
101    assistant
102        .tool_calls
103        .iter()
104        .filter(|tc| !completed.contains(&tc.id))
105        .cloned()
106        .collect()
107}
108
109/// Reconstructs full messages from a sequence of events.
110/// For `SessionEvent::Compressed` events:
111/// 1. If `archive_ref` is present, it attempts to load the messages using `load_archive`.
112/// 2. If loading succeeds, the reconstructed messages are appended to the history.
113/// 3. If loading fails (returns a `ContextFault::MissingArchive` or another error),
114///    or if `archive_ref` is `None`, it falls back to the embedded `summary` in the `Compressed` event (if present)
115///    as a system message `[Compressed context: turn {turn}]\n{summary}`.
116pub fn reconstruct_messages_with_fallback<F>(
117    events: &[SessionEvent],
118    _session_id: &str,
119    max_bytes: usize,
120    mut load_archive: F,
121) -> Vec<Message>
122where
123    F: FnMut(&str) -> Result<Vec<Message>, crate::context::snapshot::ContextFault>,
124{
125    let mut messages = Vec::new();
126    for event in events {
127        match event {
128            SessionEvent::RunStarted { goal, criteria, .. } => {
129                let user_text = if criteria.is_empty() {
130                    goal.clone()
131                } else {
132                    format!(
133                        "{}\n\nCriteria:\n{}",
134                        goal,
135                        criteria
136                            .iter()
137                            .enumerate()
138                            .map(|(i, c)| format!("{}. {}", i + 1, c))
139                            .collect::<Vec<_>>()
140                            .join("\n")
141                    )
142                };
143                messages.push(Message {
144                    role: Role::User,
145                    content: Content::Text(user_text),
146                    tool_calls: vec![],
147                    token_count: None,
148                });
149            }
150            SessionEvent::LlmCompleted { message, .. } => {
151                let mut msg = message.clone();
152                if let Content::Text(text) = &mut msg.content {
153                    *text = sanitize_recovery_text_bounded(text, max_bytes);
154                }
155                messages.push(msg);
156            }
157            SessionEvent::ToolCompleted { results, .. } => {
158                for r in results {
159                    let output = match &r.output {
160                        Content::Text(t) => sanitize_recovery_text_bounded(t, max_bytes),
161                        Content::Parts(_) => String::new(),
162                    };
163                    messages.push(Message {
164                        role: Role::Tool,
165                        content: Content::Parts(vec![ContentPart::ToolResult {
166                            call_id: r.call_id.clone(),
167                            output,
168                            is_error: r.is_error,
169                        }]),
170                        tool_calls: vec![],
171                        token_count: r.token_count,
172                    });
173                }
174            }
175            SessionEvent::Compressed {
176                turn,
177                summary,
178                archive_ref,
179                ..
180            } => {
181                let mut loaded_successfully = false;
182                if let Some(ref_str) = archive_ref {
183                    if !ref_str.is_empty() {
184                        match load_archive(ref_str) {
185                            Ok(archived_msgs) => {
186                                for mut msg in archived_msgs {
187                                    if let Content::Text(text) = &mut msg.content {
188                                        *text = sanitize_recovery_text_bounded(text, max_bytes);
189                                    }
190                                    messages.push(msg);
191                                }
192                                loaded_successfully = true;
193                            }
194                            Err(_) => {
195                                // Loader failed (e.g. MissingArchive). We degrade and fallback.
196                            }
197                        }
198                    }
199                }
200
201                if !loaded_successfully {
202                    if let Some(sum) = summary {
203                        let system_text = format!("[Compressed context: turn {}]\n{}", turn, sum);
204                        messages.push(Message {
205                            role: Role::System,
206                            content: Content::Text(system_text),
207                            tool_calls: vec![],
208                            token_count: None,
209                        });
210                    }
211                }
212            }
213            SessionEvent::Rollbacked { checkpoint_history_len, .. } => {
214                messages.truncate(*checkpoint_history_len as usize);
215            }
216            _ => {}
217        }
218    }
219    messages
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use compact_str::CompactString;
226
227    #[test]
228    fn repair_does_not_synthesize_provider_replay_for_tool_turns() {
229        let mut message = Message {
230            role: Role::Assistant,
231            content: Content::Text("checking".into()),
232            tool_calls: vec![ToolCall {
233                id: CompactString::new("c1"),
234                name: CompactString::new("ping"),
235                arguments: serde_json::json!({}),
236            }],
237            token_count: None,
238        };
239        let mut replay: Option<ProviderReplay> = None;
240        repair_llm_completed(&mut message, &mut replay);
241        // Provider-neutral: no fabricated native_blocks.
242        assert!(replay.is_none());
243        // Message is still normalized (token count backfilled).
244        assert!(message.token_count.is_some());
245    }
246
247    #[test]
248    fn repair_passes_stored_replay_through() {
249        let mut message = Message {
250            role: Role::Assistant,
251            content: Content::Text("x".into()),
252            tool_calls: vec![],
253            token_count: Some(1),
254        };
255        let mut replay = Some(ProviderReplay {
256            native_blocks: None,
257            reasoning_content: Some("trace".into()),
258            extra: serde_json::Map::new(),
259        });
260        repair_llm_completed(&mut message, &mut replay);
261        assert_eq!(
262            replay.as_ref().and_then(|r| r.reasoning_content.as_deref()),
263            Some("trace")
264        );
265    }
266
267    #[test]
268    fn provider_replay_round_trips_unknown_envelope_fields() {
269        let json = serde_json::json!({
270            "schema_version": 2,
271            "provider": "deepseek",
272            "protocol": "openai-chat",
273            "model": "deepseek-v4-flash",
274            "reasoning_content": "trace",
275            "reasoning_details": [{"type": "reasoning.text", "text": "trace"}],
276            "tool_calls": [{"id": "c1"}]
277        });
278        let replay: ProviderReplay = serde_json::from_value(json.clone()).expect("parse");
279        assert_eq!(replay.reasoning_content.as_deref(), Some("trace"));
280        assert_eq!(replay.extra["provider"], "deepseek");
281        assert_eq!(replay.extra["protocol"], "openai-chat");
282        // Re-serialize: the envelope is preserved verbatim.
283        assert_eq!(serde_json::to_value(&replay).expect("serialize"), json);
284    }
285
286    #[test]
287    fn reconstruct_ignores_categorized_kernel_os_events() {
288        use crate::runtime::event_log::KernelEventCategory;
289        use crate::runtime::session::SessionEvent;
290
291        let events = vec![
292            SessionEvent::RunStarted {
293                run_id: "r1".into(),
294                goal: "g".into(),
295                criteria: vec![],
296                agent_id: None,
297                system_prompt: None,
298            },
299            SessionEvent::PageOut {
300                turn: 1,
301                category: Some(KernelEventCategory::Mm),
302                primitive: None,
303                action: Some("auto_compact".into()),
304                summary: Some("sum".into()),
305                tier_hint: Some("durable".into()),
306                message_count: 3,
307            },
308            SessionEvent::SignalDisposed {
309                turn: 1,
310                category: Some(KernelEventCategory::Ipc),
311                primitive: None,
312                signal_id: "sig-1".into(),
313                disposition: "queue".into(),
314                queue_depth: 1,
315            },
316        ];
317        let messages = reconstruct_messages_with_fallback(&events, "s1", 0, |_| {
318            Err(crate::context::snapshot::ContextFault::MissingArchive {
319                session_id: "s1".into(),
320                seq: 0,
321            })
322        });
323        assert_eq!(messages.len(), 1);
324        assert_eq!(messages[0].role, Role::User);
325    }
326
327    #[test]
328    fn sanitize_recovery_text_bounded_respects_cjk_boundary() {
329        let text = "你".repeat(20_000);
330        // Pass an explicit byte cap: 300 bytes
331        let out = sanitize_recovery_text_bounded(&text, 300);
332        assert!(out.ends_with("… [replay truncated]"));
333        assert!(std::str::from_utf8(out.as_bytes()).is_ok());
334    }
335}