matrixcode_core/agent/streaming.rs
1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13 while !token.is_cancelled() {
14 sleep(Duration::from_millis(100)).await;
15 }
16}
17
18impl Agent {
19 /// Drain any pending input messages from the channel.
20 /// Called during streaming to collect real-time appended messages.
21 pub(crate) fn drain_pending_inputs(&mut self) {
22 if let Some(rx) = &mut self.pending_input_rx {
23 while let Ok(msg) = rx.try_recv() {
24 log::info!(
25 "Agent received pending input: {}",
26 msg.chars().take(50).collect::<String>()
27 );
28 self.pending_inputs.push(msg);
29 }
30 }
31 }
32
33 /// Check if there are pending inputs waiting to be processed.
34 pub fn has_pending_inputs(&self) -> bool {
35 !self.pending_inputs.is_empty()
36 }
37
38 /// Get and clear all pending inputs.
39 pub fn take_pending_inputs(&mut self) -> Vec<String> {
40 let inputs = self.pending_inputs.clone();
41 self.pending_inputs.clear();
42 inputs
43 }
44
45 /// Call provider with streaming and emit events in real-time.
46 /// Also monitors pending_input_rx for real-time message appending.
47 pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
48 const MAX_RETRIES: u32 = 5;
49 const RETRY_DELAY_MS: u64 = 1000;
50
51 let mut attempt = 0;
52
53 loop {
54 attempt += 1;
55 log::info!(
56 "Agent: API call attempt {} with {} messages",
57 attempt,
58 request.messages.len()
59 );
60
61 if let Some(token) = &self.cancel_token
62 && token.is_cancelled()
63 {
64 return Err(anyhow::anyhow!("Operation cancelled"));
65 }
66
67 log::info!("Agent: calling provider.chat_stream");
68 let rx_result = self.provider.chat_stream(request.clone()).await;
69 log::info!("Agent: provider.chat_stream returned");
70
71 match rx_result {
72 Ok(mut rx) => {
73 let mut response_content: Vec<ContentBlock> = Vec::new();
74 let mut current_text = String::new();
75 let mut current_thinking = String::new();
76 let mut usage = Usage {
77 input_tokens: 0,
78 output_tokens: 0,
79 cache_creation_input_tokens: 0,
80 cache_read_input_tokens: 0,
81 };
82 let mut should_retry = false;
83
84 loop {
85 // Use select! with cancellation and pending input checks
86 let event = if let Some(token) = &self.cancel_token {
87 tokio::select! {
88 // Primary: receive stream event
89 event = rx.recv() => event,
90 // Check for pending inputs periodically
91 _ = sleep(Duration::from_millis(50)) => {
92 self.drain_pending_inputs();
93 continue;
94 }
95 // Cancellation signal
96 _ = wait_for_cancel_stream(token) => {
97 return Err(anyhow::anyhow!("Operation cancelled"));
98 }
99 }
100 } else {
101 // No cancellation token, but still check pending inputs
102 tokio::select! {
103 event = rx.recv() => event,
104 _ = sleep(Duration::from_millis(50)) => {
105 self.drain_pending_inputs();
106 continue;
107 }
108 }
109 };
110
111 match event {
112 None => break,
113 Some(StreamEvent::FirstByte) => {}
114 Some(StreamEvent::ThinkingDelta(delta)) => {
115 // Check cancellation before emitting
116 if let Some(token) = &self.cancel_token
117 && token.is_cancelled()
118 {
119 return Err(anyhow::anyhow!("Operation cancelled"));
120 }
121 if current_thinking.is_empty() {
122 self.emit(AgentEvent::thinking_start())?;
123 }
124 current_thinking.push_str(&delta);
125 self.emit(AgentEvent::thinking_delta(delta, None))?;
126 }
127 Some(StreamEvent::TextDelta(delta)) => {
128 // Check cancellation before emitting
129 if let Some(token) = &self.cancel_token
130 && token.is_cancelled()
131 {
132 return Err(anyhow::anyhow!("Operation cancelled"));
133 }
134 if current_text.is_empty() {
135 self.emit(AgentEvent::text_start())?;
136 }
137 current_text.push_str(&delta);
138 self.emit(AgentEvent::text_delta(delta))?;
139 }
140 Some(StreamEvent::ToolUseStart { id, name }) => {
141 // Emit events for UI but don't push content blocks
142 // Content will be added from Done event's resp.content
143 if !current_thinking.is_empty() {
144 self.emit(AgentEvent::thinking_end())?;
145 // Don't push - will be added from resp.content
146 }
147 if !current_text.is_empty() {
148 self.emit(AgentEvent::text_end())?;
149 // Don't push - will be added from resp.content
150 }
151 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
152 }
153 Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
154 Some(StreamEvent::ToolInputComplete { id, name, input }) => {
155 self.previewed_tool_inputs.insert(id.clone());
156 self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
157 }
158 Some(StreamEvent::Usage { output_tokens }) => {
159 self.emit(AgentEvent::usage_with_cache(
160 0,
161 output_tokens as u64,
162 0,
163 0,
164 ))?;
165 usage.output_tokens = output_tokens;
166 }
167 Some(StreamEvent::Done(resp)) => {
168 // Check cancellation before processing final response
169 if let Some(token) = &self.cancel_token
170 && token.is_cancelled()
171 {
172 return Err(anyhow::anyhow!("Operation cancelled"));
173 }
174
175 // Final drain of pending inputs before completing
176 self.drain_pending_inputs();
177
178 // Don't add current_thinking/current_text here - use resp.content directly
179 // This avoids duplicates since resp.content contains everything
180 // Just emit events for UI updates if we have pending content
181 if !current_thinking.is_empty() {
182 self.emit(AgentEvent::thinking_end())?;
183 // Don't push to response_content - will be added from resp.content
184 }
185 if !current_text.is_empty() {
186 self.emit(AgentEvent::text_end())?;
187 // Don't push to response_content - will be added from resp.content
188 }
189
190 // Add all blocks from final response with smart deduplication
191 for block in &resp.content {
192 // Smart deduplication: compare content, not entire block
193 let is_duplicate = response_content.iter().any(|b| {
194 match (b, block) {
195 // For Thinking blocks, compare thinking content only (signature may differ)
196 (
197 ContentBlock::Thinking { thinking: t1, .. },
198 ContentBlock::Thinking { thinking: t2, .. },
199 ) => t1 == t2,
200 // For Text blocks, compare text content
201 (
202 ContentBlock::Text { text: t1 },
203 ContentBlock::Text { text: t2 },
204 ) => t1 == t2,
205 // For ToolUse, compare id
206 (
207 ContentBlock::ToolUse { id: id1, .. },
208 ContentBlock::ToolUse { id: id2, .. },
209 ) => id1 == id2,
210 // For ToolResult, compare tool_use_id
211 (
212 ContentBlock::ToolResult {
213 tool_use_id: id1, ..
214 },
215 ContentBlock::ToolResult {
216 tool_use_id: id2, ..
217 },
218 ) => id1 == id2,
219 // Default: exact comparison
220 _ => b == block,
221 }
222 });
223 if !is_duplicate {
224 response_content.push(block.clone());
225 }
226 }
227 usage = resp.usage;
228 }
229 Some(StreamEvent::Error(msg)) => {
230 if attempt < MAX_RETRIES {
231 self.emit(AgentEvent::progress(
232 format!(
233 "⚠️ Stream error, retrying ({}/{}): {}",
234 attempt, MAX_RETRIES, &msg
235 ),
236 None,
237 ))?;
238 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
239 tokio::time::sleep(tokio::time::Duration::from_millis(delay))
240 .await;
241 should_retry = true;
242 break;
243 } else {
244 self.emit(AgentEvent::error(msg.clone(), None, None))?;
245 return Err(anyhow::anyhow!(
246 "Stream error after {} retries: {}",
247 MAX_RETRIES,
248 msg
249 ));
250 }
251 }
252 }
253 }
254
255 if should_retry {
256 continue;
257 }
258
259 return Ok(ChatResponse {
260 content: response_content,
261 stop_reason: StopReason::EndTurn,
262 usage,
263 });
264 }
265 Err(e) => {
266 if attempt < MAX_RETRIES {
267 let error_msg = e.to_string();
268 self.emit(AgentEvent::progress(
269 format!(
270 "⚠️ API error, retrying ({}/{}): {}",
271 attempt, MAX_RETRIES, &error_msg
272 ),
273 None,
274 ))?;
275 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
276 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
277 } else {
278 return Err(anyhow::anyhow!(
279 "API error after {} retries: {}",
280 MAX_RETRIES,
281 e
282 ));
283 }
284 }
285 }
286 }
287 }
288}