matrixcode_core/agent/streaming.rs
1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13 while !token.is_cancelled() {
14 sleep(Duration::from_millis(100)).await;
15 }
16}
17
18impl Agent {
19 /// Drain any pending input messages from the channel.
20 /// Called during streaming to collect real-time appended messages.
21 pub(crate) fn drain_pending_inputs(&mut self) {
22 let inputs = self.session.drain_pending_inputs();
23 for msg in inputs {
24 log::info!(
25 "Agent received pending input: {}",
26 msg.chars().take(50).collect::<String>()
27 );
28 self.state.add_pending_input(msg);
29 }
30 }
31
32 /// Check if there are pending inputs waiting to be processed.
33 pub fn has_pending_inputs(&self) -> bool {
34 self.state.has_pending_inputs()
35 }
36
37 /// Get and clear all pending inputs.
38 pub fn take_pending_inputs(&mut self) -> Vec<String> {
39 self.state.take_pending_inputs()
40 }
41
42 /// Call provider with streaming and emit events in real-time.
43 /// Also monitors pending_input_rx for real-time message appending.
44 pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
45 const MAX_RETRIES: u32 = 5;
46 const RETRY_DELAY_MS: u64 = 1000;
47
48 let mut attempt = 0;
49
50 loop {
51 attempt += 1;
52 log::info!(
53 "Agent: API call attempt {} with {} messages",
54 attempt,
55 request.messages.len()
56 );
57
58 if self.session.is_cancelled() {
59 return Err(anyhow::anyhow!("Operation cancelled"));
60 }
61
62 log::info!("Agent: calling provider.chat_stream");
63 let rx_result = self.provider.chat_stream(request.clone()).await;
64 log::info!("Agent: provider.chat_stream returned");
65
66 match rx_result {
67 Ok(mut rx) => {
68 let mut response_content: Vec<ContentBlock> = Vec::new();
69 let mut current_text = String::new();
70 let mut current_thinking = String::new();
71 let mut usage = Usage {
72 input_tokens: 0,
73 output_tokens: 0,
74 cache_creation_input_tokens: 0,
75 cache_read_input_tokens: 0,
76 };
77 let mut should_retry = false;
78
79 loop {
80 // Use biased select! to prioritize stream events over pending input checks
81 // This prevents losing stream events when sleep completes first
82 let event = if let Some(token) = self.session.cancel_token() {
83 tokio::select! {
84 biased;
85
86 // Primary: receive stream event (highest priority)
87 event = rx.recv() => event,
88
89 // Cancellation signal (second priority)
90 _ = wait_for_cancel_stream(token) => {
91 return Err(anyhow::anyhow!("Operation cancelled"));
92 }
93
94 // Check for pending inputs periodically (lowest priority)
95 // Increased interval to reduce competition with stream events
96 _ = sleep(Duration::from_millis(200)) => {
97 self.drain_pending_inputs();
98 continue;
99 }
100 }
101 } else {
102 // No cancellation token, but still check pending inputs
103 tokio::select! {
104 biased;
105
106 // Primary: receive stream event (highest priority)
107 event = rx.recv() => event,
108
109 // Check for pending inputs periodically (lower priority)
110 _ = sleep(Duration::from_millis(200)) => {
111 self.drain_pending_inputs();
112 continue;
113 }
114 }
115 };
116
117 match event {
118 None => break,
119 Some(StreamEvent::FirstByte) => {}
120 Some(StreamEvent::ThinkingDelta(delta)) => {
121 // Check cancellation before emitting
122 if self.session.is_cancelled() {
123 return Err(anyhow::anyhow!("Operation cancelled"));
124 }
125 if current_thinking.is_empty() {
126 self.emit(AgentEvent::thinking_start())?;
127 }
128 current_thinking.push_str(&delta);
129 self.emit(AgentEvent::thinking_delta(delta, None))?;
130 }
131 Some(StreamEvent::TextDelta(delta)) => {
132 // Check cancellation before emitting
133 if self.session.is_cancelled() {
134 return Err(anyhow::anyhow!("Operation cancelled"));
135 }
136 if current_text.is_empty() {
137 self.emit(AgentEvent::text_start())?;
138 }
139 current_text.push_str(&delta);
140 self.emit(AgentEvent::text_delta(delta))?;
141 }
142 Some(StreamEvent::ToolUseStart { id, name }) => {
143 // Emit events for UI but don't push content blocks
144 // Content will be added from Done event's resp.content
145 if !current_thinking.is_empty() {
146 self.emit(AgentEvent::thinking_end())?;
147 current_thinking.clear();
148 }
149 if !current_text.is_empty() {
150 self.emit(AgentEvent::text_end())?;
151 current_text.clear();
152 }
153 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
154 }
155 Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
156 Some(StreamEvent::ToolInputComplete { id, name, input }) => {
157 self.state.mark_tool_input_previewed(id.clone());
158 self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
159 }
160 Some(StreamEvent::Usage { output_tokens }) => {
161 self.emit(AgentEvent::usage_with_cache(
162 0,
163 output_tokens as u64,
164 0,
165 0,
166 ))?;
167 usage.output_tokens = output_tokens;
168 }
169 Some(StreamEvent::Done(resp)) => {
170 // Check cancellation before processing final response
171 if self.session.is_cancelled() {
172 return Err(anyhow::anyhow!("Operation cancelled"));
173 }
174
175 // Final drain of pending inputs before completing
176 self.drain_pending_inputs();
177
178 // IMPORTANT: Add current_thinking/current_text to response_content FIRST
179 // before checking for duplicates from resp.content
180 // This ensures all streamed content is preserved
181 if !current_thinking.is_empty() {
182 self.emit(AgentEvent::thinking_end())?;
183 // Add to response_content with signature from resp if available
184 let signature = resp.content.iter()
185 .find_map(|b| {
186 if let ContentBlock::Thinking { thinking, signature } = b {
187 if thinking == ¤t_thinking {
188 signature.clone()
189 } else {
190 None
191 }
192 } else {
193 None
194 }
195 });
196 response_content.push(ContentBlock::Thinking {
197 thinking: current_thinking.clone(),
198 signature,
199 });
200 current_thinking.clear();
201 }
202 if !current_text.is_empty() {
203 self.emit(AgentEvent::text_end())?;
204 // Add to response_content
205 response_content.push(ContentBlock::Text {
206 text: current_text.clone(),
207 });
208 current_text.clear();
209 }
210
211 // Then add any additional blocks from final response that are NOT duplicates
212 for block in &resp.content {
213 // Smart deduplication: compare content, not entire block
214 let is_duplicate = response_content.iter().any(|b| {
215 match (b, block) {
216 // For Thinking blocks, compare thinking content only (signature may differ)
217 (
218 ContentBlock::Thinking { thinking: t1, .. },
219 ContentBlock::Thinking { thinking: t2, .. },
220 ) => t1 == t2,
221 // For Text blocks, compare text content
222 (
223 ContentBlock::Text { text: t1 },
224 ContentBlock::Text { text: t2 },
225 ) => t1 == t2,
226 // For ToolUse, compare id
227 (
228 ContentBlock::ToolUse { id: id1, .. },
229 ContentBlock::ToolUse { id: id2, .. },
230 ) => id1 == id2,
231 // For ToolResult, compare tool_use_id
232 (
233 ContentBlock::ToolResult {
234 tool_use_id: id1, ..
235 },
236 ContentBlock::ToolResult {
237 tool_use_id: id2, ..
238 },
239 ) => id1 == id2,
240 // Default: exact comparison
241 _ => b == block,
242 }
243 });
244 if !is_duplicate {
245 response_content.push(block.clone());
246 }
247 }
248 usage = resp.usage;
249 }
250 Some(StreamEvent::Error(msg)) => {
251 if attempt < MAX_RETRIES {
252 self.emit(AgentEvent::progress(
253 format!(
254 "⚠️ Stream error, retrying ({}/{}): {}",
255 attempt, MAX_RETRIES, &msg
256 ),
257 None,
258 ))?;
259 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
260 tokio::time::sleep(tokio::time::Duration::from_millis(delay))
261 .await;
262 should_retry = true;
263 break;
264 } else {
265 self.emit(AgentEvent::error(msg.clone(), None, None))?;
266 return Err(anyhow::anyhow!(
267 "Stream error after {} retries: {}",
268 MAX_RETRIES,
269 msg
270 ));
271 }
272 }
273 }
274 }
275
276 if should_retry {
277 continue;
278 }
279
280 return Ok(ChatResponse {
281 content: response_content,
282 stop_reason: StopReason::EndTurn,
283 usage,
284 });
285 }
286 Err(e) => {
287 if attempt < MAX_RETRIES {
288 let error_msg = e.to_string();
289 self.emit(AgentEvent::progress(
290 format!(
291 "⚠️ API error, retrying ({}/{}): {}",
292 attempt, MAX_RETRIES, &error_msg
293 ),
294 None,
295 ))?;
296 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
297 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
298 } else {
299 return Err(anyhow::anyhow!(
300 "API error after {} retries: {}",
301 MAX_RETRIES,
302 e
303 ));
304 }
305 }
306 }
307 }
308 }
309}