claude_agent/client/
recovery.rs

1//! Stream recovery for resumable streaming responses.
2
3use std::time::Instant;
4
5use crate::types::{ContentBlock, Message, Role, ThinkingBlock};
6
7#[derive(Debug, Clone)]
8struct ThinkingBuffer {
9    thinking: String,
10    signature: Option<String>,
11}
12
13#[derive(Debug, Clone)]
14struct ToolUseBuffer {
15    id: String,
16    name: String,
17    partial_json: String,
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct StreamRecoveryState {
22    completed_blocks: Vec<ContentBlock>,
23    pending_text: Option<String>,
24    pending_thinking: Option<ThinkingBuffer>,
25    pending_tool_use: Option<ToolUseBuffer>,
26    started_at: Option<Instant>,
27}
28
29impl StreamRecoveryState {
30    pub fn new() -> Self {
31        Self {
32            started_at: Some(Instant::now()),
33            ..Default::default()
34        }
35    }
36
37    pub fn append_text(&mut self, text: &str) {
38        self.pending_text
39            .get_or_insert_with(String::new)
40            .push_str(text);
41    }
42
43    pub fn append_thinking(&mut self, thinking: &str) {
44        match &mut self.pending_thinking {
45            Some(buf) => buf.thinking.push_str(thinking),
46            None => {
47                self.pending_thinking = Some(ThinkingBuffer {
48                    thinking: thinking.to_string(),
49                    signature: None,
50                });
51            }
52        }
53    }
54
55    pub fn append_signature(&mut self, signature: &str) {
56        if let Some(buf) = &mut self.pending_thinking {
57            buf.signature
58                .get_or_insert_with(String::new)
59                .push_str(signature);
60        }
61    }
62
63    pub fn start_tool_use(&mut self, id: String, name: String) {
64        self.pending_tool_use = Some(ToolUseBuffer {
65            id,
66            name,
67            partial_json: String::new(),
68        });
69    }
70
71    pub fn append_tool_json(&mut self, json: &str) {
72        if let Some(buf) = &mut self.pending_tool_use {
73            buf.partial_json.push_str(json);
74        }
75    }
76
77    pub fn complete_text_block(&mut self) {
78        if let Some(text) = self.pending_text.take()
79            && !text.is_empty()
80        {
81            self.completed_blocks.push(ContentBlock::Text {
82                text,
83                citations: None,
84                cache_control: None,
85            });
86        }
87    }
88
89    pub fn complete_thinking_block(&mut self) {
90        if let Some(buf) = self.pending_thinking.take()
91            && !buf.thinking.is_empty()
92        {
93            self.completed_blocks
94                .push(ContentBlock::Thinking(ThinkingBlock {
95                    thinking: buf.thinking,
96                    signature: buf.signature.unwrap_or_default(),
97                }));
98        }
99    }
100
101    pub fn complete_tool_use_block(&mut self) -> Option<crate::types::ToolUseBlock> {
102        let buf = self.pending_tool_use.take()?;
103        let input = match serde_json::from_str(&buf.partial_json) {
104            Ok(v) => v,
105            Err(e) => {
106                tracing::warn!(
107                    tool_name = %buf.name,
108                    tool_id = %buf.id,
109                    partial_json_len = buf.partial_json.len(),
110                    error = %e,
111                    "Stream recovery: failed to parse tool JSON, using empty object"
112                );
113                serde_json::Value::Object(serde_json::Map::new())
114            }
115        };
116        let tool_use = crate::types::ToolUseBlock {
117            id: buf.id,
118            name: buf.name,
119            input,
120        };
121        self.completed_blocks
122            .push(ContentBlock::ToolUse(tool_use.clone()));
123        Some(tool_use)
124    }
125
126    pub fn build_continuation_messages(&self, original: &[Message]) -> Vec<Message> {
127        let mut messages = original.to_vec();
128        let mut content = self.completed_blocks.clone();
129
130        if let Some(text) = &self.pending_text
131            && !text.is_empty()
132        {
133            content.push(ContentBlock::Text {
134                text: text.clone(),
135                citations: None,
136                cache_control: None,
137            });
138        }
139
140        if let Some(buf) = &self.pending_thinking
141            && !buf.thinking.is_empty()
142        {
143            content.push(ContentBlock::Thinking(ThinkingBlock {
144                thinking: buf.thinking.clone(),
145                signature: buf.signature.clone().unwrap_or_default(),
146            }));
147        }
148
149        if let Some(buf) = &self.pending_tool_use {
150            let input = match serde_json::from_str(&buf.partial_json) {
151                Ok(v) => v,
152                Err(e) => {
153                    tracing::warn!(
154                        tool_name = %buf.name,
155                        tool_id = %buf.id,
156                        partial_json_len = buf.partial_json.len(),
157                        error = %e,
158                        "Stream continuation: failed to parse partial tool JSON, using empty object"
159                    );
160                    serde_json::Value::Object(serde_json::Map::new())
161                }
162            };
163            content.push(ContentBlock::ToolUse(crate::types::ToolUseBlock {
164                id: buf.id.clone(),
165                name: buf.name.clone(),
166                input,
167            }));
168        }
169
170        if !content.is_empty() {
171            messages.push(Message {
172                role: Role::Assistant,
173                content,
174            });
175        }
176
177        messages
178    }
179
180    pub fn is_recoverable(&self) -> bool {
181        !self.completed_blocks.is_empty()
182            || self.pending_text.is_some()
183            || self.pending_thinking.is_some()
184            || self.pending_tool_use.is_some()
185    }
186
187    pub fn elapsed(&self) -> Option<std::time::Duration> {
188        self.started_at.map(|t| t.elapsed())
189    }
190
191    pub fn completed_blocks(&self) -> &[ContentBlock] {
192        &self.completed_blocks
193    }
194
195    pub fn has_pending(&self) -> bool {
196        self.pending_text.is_some()
197            || self.pending_thinking.is_some()
198            || self.pending_tool_use.is_some()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_empty_state() {
208        let state = StreamRecoveryState::new();
209        assert!(!state.is_recoverable());
210        assert!(state.completed_blocks().is_empty());
211    }
212
213    #[test]
214    fn test_text_accumulation() {
215        let mut state = StreamRecoveryState::new();
216        state.append_text("Hello");
217        state.append_text(" World");
218        state.complete_text_block();
219
220        assert!(state.is_recoverable());
221        assert_eq!(state.completed_blocks().len(), 1);
222    }
223
224    #[test]
225    fn test_thinking_accumulation() {
226        let mut state = StreamRecoveryState::new();
227        state.append_thinking("Let me think");
228        state.append_signature("sig123");
229        state.complete_thinking_block();
230
231        assert!(state.is_recoverable());
232        assert_eq!(state.completed_blocks().len(), 1);
233    }
234
235    #[test]
236    fn test_continuation_messages() {
237        let mut state = StreamRecoveryState::new();
238        state.append_text("Partial response");
239
240        let original = vec![Message::user("Hello")];
241        let continued = state.build_continuation_messages(&original);
242
243        assert_eq!(continued.len(), 2);
244        assert_eq!(continued[1].role, Role::Assistant);
245    }
246
247    #[test]
248    fn test_tool_use_accumulation() {
249        let mut state = StreamRecoveryState::new();
250        state.start_tool_use("tool_1".into(), "search".into());
251        state.append_tool_json(r#"{"query":"#);
252        state.append_tool_json(r#"test"}"#);
253        state.complete_tool_use_block();
254
255        assert!(state.is_recoverable());
256        assert_eq!(state.completed_blocks().len(), 1);
257    }
258}