matrixcode_core/agent/streaming.rs
1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13 while !token.is_cancelled() {
14 sleep(Duration::from_millis(100)).await;
15 }
16}
17
18impl Agent {
19 /// Drain any pending input messages from the channel.
20 /// Called during streaming to collect real-time appended messages.
21 fn drain_pending_inputs(&mut self) {
22 if let Some(rx) = &mut self.pending_input_rx {
23 while let Ok(msg) = rx.try_recv() {
24 log::info!("Agent received pending input: {}", msg.chars().take(50).collect::<String>());
25 self.pending_inputs.push(msg);
26 }
27 }
28 }
29
30 /// Check if there are pending inputs waiting to be processed.
31 pub fn has_pending_inputs(&self) -> bool {
32 !self.pending_inputs.is_empty()
33 }
34
35 /// Get and clear all pending inputs.
36 pub fn take_pending_inputs(&mut self) -> Vec<String> {
37 let inputs = self.pending_inputs.clone();
38 self.pending_inputs.clear();
39 inputs
40 }
41
42 /// Call provider with streaming and emit events in real-time.
43 /// Also monitors pending_input_rx for real-time message appending.
44 pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
45 const MAX_RETRIES: u32 = 5;
46 const RETRY_DELAY_MS: u64 = 1000;
47
48 let mut attempt = 0;
49
50 loop {
51 attempt += 1;
52 log::info!("Agent: API call attempt {} with {} messages", attempt, request.messages.len());
53
54 if let Some(token) = &self.cancel_token
55 && token.is_cancelled()
56 {
57 return Err(anyhow::anyhow!("Operation cancelled"));
58 }
59
60 log::info!("Agent: calling provider.chat_stream");
61 let rx_result = self.provider.chat_stream(request.clone()).await;
62 log::info!("Agent: provider.chat_stream returned");
63
64 match rx_result {
65 Ok(mut rx) => {
66 let mut response_content: Vec<ContentBlock> = Vec::new();
67 let mut current_text = String::new();
68 let mut current_thinking = String::new();
69 let mut usage = Usage {
70 input_tokens: 0,
71 output_tokens: 0,
72 cache_creation_input_tokens: 0,
73 cache_read_input_tokens: 0,
74 };
75 let mut should_retry = false;
76
77 loop {
78 // Use select! with cancellation and pending input checks
79 let event = if let Some(token) = &self.cancel_token {
80 tokio::select! {
81 // Primary: receive stream event
82 event = rx.recv() => event,
83 // Check for pending inputs periodically
84 _ = sleep(Duration::from_millis(50)) => {
85 self.drain_pending_inputs();
86 continue;
87 }
88 // Cancellation signal
89 _ = wait_for_cancel_stream(token) => {
90 return Err(anyhow::anyhow!("Operation cancelled"));
91 }
92 }
93 } else {
94 // No cancellation token, but still check pending inputs
95 tokio::select! {
96 event = rx.recv() => event,
97 _ = sleep(Duration::from_millis(50)) => {
98 self.drain_pending_inputs();
99 continue;
100 }
101 }
102 };
103
104 match event {
105 None => break,
106 Some(StreamEvent::FirstByte) => {}
107 Some(StreamEvent::ThinkingDelta(delta)) => {
108 // Check cancellation before emitting
109 if let Some(token) = &self.cancel_token
110 && token.is_cancelled()
111 {
112 return Err(anyhow::anyhow!("Operation cancelled"));
113 }
114 if current_thinking.is_empty() {
115 self.emit(AgentEvent::thinking_start())?;
116 }
117 current_thinking.push_str(&delta);
118 self.emit(AgentEvent::thinking_delta(delta, None))?;
119 }
120 Some(StreamEvent::TextDelta(delta)) => {
121 // Check cancellation before emitting
122 if let Some(token) = &self.cancel_token
123 && token.is_cancelled()
124 {
125 return Err(anyhow::anyhow!("Operation cancelled"));
126 }
127 if current_text.is_empty() {
128 self.emit(AgentEvent::text_start())?;
129 }
130 current_text.push_str(&delta);
131 self.emit(AgentEvent::text_delta(delta))?;
132 }
133 Some(StreamEvent::ToolUseStart { id, name }) => {
134 // Emit events for UI but don't push content blocks
135 // Content will be added from Done event's resp.content
136 if !current_thinking.is_empty() {
137 self.emit(AgentEvent::thinking_end())?;
138 // Don't push - will be added from resp.content
139 }
140 if !current_text.is_empty() {
141 self.emit(AgentEvent::text_end())?;
142 // Don't push - will be added from resp.content
143 }
144 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
145 }
146 Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
147 Some(StreamEvent::Usage { output_tokens }) => {
148 self.emit(AgentEvent::usage_with_cache(
149 0,
150 output_tokens as u64,
151 0,
152 0,
153 ))?;
154 usage.output_tokens = output_tokens;
155 }
156 Some(StreamEvent::Done(resp)) => {
157 // Check cancellation before processing final response
158 if let Some(token) = &self.cancel_token
159 && token.is_cancelled()
160 {
161 return Err(anyhow::anyhow!("Operation cancelled"));
162 }
163
164 // Final drain of pending inputs before completing
165 self.drain_pending_inputs();
166
167 // Don't add current_thinking/current_text here - use resp.content directly
168 // This avoids duplicates since resp.content contains everything
169 // Just emit events for UI updates if we have pending content
170 if !current_thinking.is_empty() {
171 self.emit(AgentEvent::thinking_end())?;
172 // Don't push to response_content - will be added from resp.content
173 }
174 if !current_text.is_empty() {
175 self.emit(AgentEvent::text_end())?;
176 // Don't push to response_content - will be added from resp.content
177 }
178
179 // Add all blocks from final response with smart deduplication
180 for block in &resp.content {
181 // Smart deduplication: compare content, not entire block
182 let is_duplicate = response_content.iter().any(|b| {
183 match (b, block) {
184 // For Thinking blocks, compare thinking content only (signature may differ)
185 (ContentBlock::Thinking { thinking: t1, .. }, ContentBlock::Thinking { thinking: t2, .. }) => {
186 t1 == t2
187 }
188 // For Text blocks, compare text content
189 (ContentBlock::Text { text: t1 }, ContentBlock::Text { text: t2 }) => {
190 t1 == t2
191 }
192 // For ToolUse, compare id
193 (ContentBlock::ToolUse { id: id1, .. }, ContentBlock::ToolUse { id: id2, .. }) => {
194 id1 == id2
195 }
196 // For ToolResult, compare tool_use_id
197 (ContentBlock::ToolResult { tool_use_id: id1, .. }, ContentBlock::ToolResult { tool_use_id: id2, .. }) => {
198 id1 == id2
199 }
200 // Default: exact comparison
201 _ => b == block
202 }
203 });
204 if !is_duplicate {
205 response_content.push(block.clone());
206 }
207 }
208 usage = resp.usage;
209 }
210 Some(StreamEvent::Error(msg)) => {
211 if attempt < MAX_RETRIES {
212 self.emit(AgentEvent::progress(
213 format!(
214 "⚠️ Stream error, retrying ({}/{}): {}",
215 attempt, MAX_RETRIES, &msg
216 ),
217 None,
218 ))?;
219 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
220 tokio::time::sleep(tokio::time::Duration::from_millis(delay))
221 .await;
222 should_retry = true;
223 break;
224 } else {
225 self.emit(AgentEvent::error(msg.clone(), None, None))?;
226 return Err(anyhow::anyhow!(
227 "Stream error after {} retries: {}",
228 MAX_RETRIES,
229 msg
230 ));
231 }
232 }
233 }
234 }
235
236 if should_retry {
237 continue;
238 }
239
240 return Ok(ChatResponse {
241 content: response_content,
242 stop_reason: StopReason::EndTurn,
243 usage,
244 });
245 }
246 Err(e) => {
247 if attempt < MAX_RETRIES {
248 let error_msg = e.to_string();
249 self.emit(AgentEvent::progress(
250 format!(
251 "⚠️ API error, retrying ({}/{}): {}",
252 attempt, MAX_RETRIES, &error_msg
253 ),
254 None,
255 ))?;
256 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
257 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
258 } else {
259 return Err(anyhow::anyhow!(
260 "API error after {} retries: {}",
261 MAX_RETRIES,
262 e
263 ));
264 }
265 }
266 }
267 }
268 }
269}