claude_agent/client/
recovery.rs1use std::time::Instant;
4
5use crate::types::{ContentBlock, Message, Role, ThinkingBlock};
6
7#[derive(Debug, Clone)]
8struct ThinkingBuffer {
9 thinking: String,
10 signature: Option<String>,
11}
12
13#[derive(Debug, Clone)]
14struct ToolUseBuffer {
15 id: String,
16 name: String,
17 partial_json: String,
18}
19
20#[derive(Debug, Clone, Default)]
21pub struct StreamRecoveryState {
22 completed_blocks: Vec<ContentBlock>,
23 pending_text: Option<String>,
24 pending_thinking: Option<ThinkingBuffer>,
25 pending_tool_use: Option<ToolUseBuffer>,
26 started_at: Option<Instant>,
27}
28
29impl StreamRecoveryState {
30 pub fn new() -> Self {
31 Self {
32 started_at: Some(Instant::now()),
33 ..Default::default()
34 }
35 }
36
37 pub fn append_text(&mut self, text: &str) {
38 self.pending_text
39 .get_or_insert_with(String::new)
40 .push_str(text);
41 }
42
43 pub fn append_thinking(&mut self, thinking: &str) {
44 match &mut self.pending_thinking {
45 Some(buf) => buf.thinking.push_str(thinking),
46 None => {
47 self.pending_thinking = Some(ThinkingBuffer {
48 thinking: thinking.to_string(),
49 signature: None,
50 });
51 }
52 }
53 }
54
55 pub fn append_signature(&mut self, signature: &str) {
56 if let Some(buf) = &mut self.pending_thinking {
57 buf.signature
58 .get_or_insert_with(String::new)
59 .push_str(signature);
60 }
61 }
62
63 pub fn start_tool_use(&mut self, id: String, name: String) {
64 self.pending_tool_use = Some(ToolUseBuffer {
65 id,
66 name,
67 partial_json: String::new(),
68 });
69 }
70
71 pub fn append_tool_json(&mut self, json: &str) {
72 if let Some(buf) = &mut self.pending_tool_use {
73 buf.partial_json.push_str(json);
74 }
75 }
76
77 pub fn complete_text_block(&mut self) {
78 if let Some(text) = self.pending_text.take()
79 && !text.is_empty()
80 {
81 self.completed_blocks.push(ContentBlock::Text {
82 text,
83 citations: None,
84 cache_control: None,
85 });
86 }
87 }
88
89 pub fn complete_thinking_block(&mut self) {
90 if let Some(buf) = self.pending_thinking.take()
91 && !buf.thinking.is_empty()
92 {
93 self.completed_blocks
94 .push(ContentBlock::Thinking(ThinkingBlock {
95 thinking: buf.thinking,
96 signature: buf.signature.unwrap_or_default(),
97 }));
98 }
99 }
100
101 pub fn complete_tool_use_block(&mut self) -> Option<crate::types::ToolUseBlock> {
102 let buf = self.pending_tool_use.take()?;
103 let input = match serde_json::from_str(&buf.partial_json) {
104 Ok(v) => v,
105 Err(e) => {
106 tracing::warn!(
107 tool_name = %buf.name,
108 tool_id = %buf.id,
109 partial_json_len = buf.partial_json.len(),
110 error = %e,
111 "Stream recovery: failed to parse tool JSON, using empty object"
112 );
113 serde_json::Value::Object(serde_json::Map::new())
114 }
115 };
116 let tool_use = crate::types::ToolUseBlock {
117 id: buf.id,
118 name: buf.name,
119 input,
120 };
121 self.completed_blocks
122 .push(ContentBlock::ToolUse(tool_use.clone()));
123 Some(tool_use)
124 }
125
126 pub fn build_continuation_messages(&self, original: &[Message]) -> Vec<Message> {
127 let mut messages = original.to_vec();
128 let mut content = self.completed_blocks.clone();
129
130 if let Some(text) = &self.pending_text
131 && !text.is_empty()
132 {
133 content.push(ContentBlock::Text {
134 text: text.clone(),
135 citations: None,
136 cache_control: None,
137 });
138 }
139
140 if let Some(buf) = &self.pending_thinking
141 && !buf.thinking.is_empty()
142 {
143 content.push(ContentBlock::Thinking(ThinkingBlock {
144 thinking: buf.thinking.clone(),
145 signature: buf.signature.clone().unwrap_or_default(),
146 }));
147 }
148
149 if let Some(buf) = &self.pending_tool_use {
150 let input = match serde_json::from_str(&buf.partial_json) {
151 Ok(v) => v,
152 Err(e) => {
153 tracing::warn!(
154 tool_name = %buf.name,
155 tool_id = %buf.id,
156 partial_json_len = buf.partial_json.len(),
157 error = %e,
158 "Stream continuation: failed to parse partial tool JSON, using empty object"
159 );
160 serde_json::Value::Object(serde_json::Map::new())
161 }
162 };
163 content.push(ContentBlock::ToolUse(crate::types::ToolUseBlock {
164 id: buf.id.clone(),
165 name: buf.name.clone(),
166 input,
167 }));
168 }
169
170 if !content.is_empty() {
171 messages.push(Message {
172 role: Role::Assistant,
173 content,
174 });
175 }
176
177 messages
178 }
179
180 pub fn is_recoverable(&self) -> bool {
181 !self.completed_blocks.is_empty()
182 || self.pending_text.is_some()
183 || self.pending_thinking.is_some()
184 || self.pending_tool_use.is_some()
185 }
186
187 pub fn elapsed(&self) -> Option<std::time::Duration> {
188 self.started_at.map(|t| t.elapsed())
189 }
190
191 pub fn completed_blocks(&self) -> &[ContentBlock] {
192 &self.completed_blocks
193 }
194
195 pub fn has_pending(&self) -> bool {
196 self.pending_text.is_some()
197 || self.pending_thinking.is_some()
198 || self.pending_tool_use.is_some()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_empty_state() {
208 let state = StreamRecoveryState::new();
209 assert!(!state.is_recoverable());
210 assert!(state.completed_blocks().is_empty());
211 }
212
213 #[test]
214 fn test_text_accumulation() {
215 let mut state = StreamRecoveryState::new();
216 state.append_text("Hello");
217 state.append_text(" World");
218 state.complete_text_block();
219
220 assert!(state.is_recoverable());
221 assert_eq!(state.completed_blocks().len(), 1);
222 }
223
224 #[test]
225 fn test_thinking_accumulation() {
226 let mut state = StreamRecoveryState::new();
227 state.append_thinking("Let me think");
228 state.append_signature("sig123");
229 state.complete_thinking_block();
230
231 assert!(state.is_recoverable());
232 assert_eq!(state.completed_blocks().len(), 1);
233 }
234
235 #[test]
236 fn test_continuation_messages() {
237 let mut state = StreamRecoveryState::new();
238 state.append_text("Partial response");
239
240 let original = vec![Message::user("Hello")];
241 let continued = state.build_continuation_messages(&original);
242
243 assert_eq!(continued.len(), 2);
244 assert_eq!(continued[1].role, Role::Assistant);
245 }
246
247 #[test]
248 fn test_tool_use_accumulation() {
249 let mut state = StreamRecoveryState::new();
250 state.start_tool_use("tool_1".into(), "search".into());
251 state.append_tool_json(r#"{"query":"#);
252 state.append_tool_json(r#"test"}"#);
253 state.complete_tool_use_block();
254
255 assert!(state.is_recoverable());
256 assert_eq!(state.completed_blocks().len(), 1);
257 }
258}