matrixcode_core/agent/streaming.rs
1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep, Instant};
5
6use crate::constants::STREAM_DELTA_BUFFER_SIZE;
7use crate::event::AgentEvent;
8use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
9
10use super::types::Agent;
11
12/// Buffered delta for efficient event emission
13#[derive(Debug)]
14struct BufferedDelta {
15 text: String,
16 thinking: String,
17 last_emit: Instant,
18}
19
20impl Default for BufferedDelta {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl BufferedDelta {
27 fn new() -> Self {
28 Self {
29 text: String::new(),
30 thinking: String::new(),
31 last_emit: Instant::now(),
32 }
33 }
34
35 /// Add text delta to buffer, returns true if should flush
36 fn add_text(&mut self, delta: &str) -> bool {
37 self.text.push_str(delta);
38 self.should_flush_text()
39 }
40
41 /// Add thinking delta to buffer, returns true if should flush
42 fn add_thinking(&mut self, delta: &str) -> bool {
43 self.thinking.push_str(delta);
44 self.should_flush_thinking()
45 }
46
47 fn should_flush_text(&self) -> bool {
48 self.text.len() >= STREAM_DELTA_BUFFER_SIZE
49 }
50
51 fn should_flush_thinking(&self) -> bool {
52 self.thinking.len() >= STREAM_DELTA_BUFFER_SIZE
53 }
54
55 /// Check if buffer needs flush due to time interval
56 fn should_flush_by_time(&self, interval_ms: u64) -> bool {
57 self.last_emit.elapsed().as_millis() >= interval_ms as u128
58 && (!self.text.is_empty() || !self.thinking.is_empty())
59 }
60
61 /// Flush text buffer, returns content if non-empty
62 fn flush_text(&mut self) -> Option<String> {
63 if self.text.is_empty() {
64 return None;
65 }
66 let content = self.text.clone();
67 self.text.clear();
68 self.last_emit = Instant::now();
69 Some(content)
70 }
71
72 /// Flush thinking buffer, returns content if non-empty
73 fn flush_thinking(&mut self) -> Option<String> {
74 if self.thinking.is_empty() {
75 return None;
76 }
77 let content = self.thinking.clone();
78 self.thinking.clear();
79 self.last_emit = Instant::now();
80 Some(content)
81 }
82
83 /// Flush all buffers
84 fn flush_all(&mut self) -> (Option<String>, Option<String>) {
85 let text = self.flush_text();
86 let thinking = self.flush_thinking();
87 (text, thinking)
88 }
89}
90
91/// Wait for cancellation signal, checking periodically.
92async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
93 while !token.is_cancelled() {
94 sleep(Duration::from_millis(100)).await;
95 }
96}
97
98impl Agent {
99 /// Drain any pending input messages from the channel.
100 /// Called during streaming to collect real-time appended messages.
101 pub(crate) fn drain_pending_inputs(&mut self) {
102 let inputs = self.session.drain_pending_inputs();
103 for msg in inputs {
104 log::info!(
105 "Agent received pending input: {}",
106 msg.chars().take(50).collect::<String>()
107 );
108 self.state.add_pending_input(msg);
109 }
110 }
111
112 /// Check if there are pending inputs waiting to be processed.
113 pub fn has_pending_inputs(&self) -> bool {
114 self.state.has_pending_inputs()
115 }
116
117 /// Get and clear all pending inputs.
118 pub fn take_pending_inputs(&mut self) -> Vec<String> {
119 self.state.take_pending_inputs()
120 }
121
122 /// Call provider with streaming and emit events in real-time.
123 /// Also monitors pending_input_rx for real-time message appending.
124 /// Uses buffered delta emission to reduce event frequency.
125 pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
126 const MAX_RETRIES: u32 = 5;
127 const RETRY_DELAY_MS: u64 = 1000;
128 const FLUSH_INTERVAL_MS: u64 = crate::constants::STREAM_DELTA_FLUSH_INTERVAL_MS;
129
130 let mut attempt = 0;
131
132 loop {
133 attempt += 1;
134 log::info!(
135 "Agent: API call attempt {} with {} messages",
136 attempt,
137 request.messages.len()
138 );
139
140 if self.session.is_cancelled() {
141 return Err(anyhow::anyhow!("Operation cancelled"));
142 }
143
144 log::info!("Agent: calling provider.chat_stream");
145 let rx_result = self.provider.chat_stream(request.clone()).await;
146 log::info!("Agent: provider.chat_stream returned");
147
148 match rx_result {
149 Ok(mut rx) => {
150 let mut response_content: Vec<ContentBlock> = Vec::new();
151 let mut current_text = String::new();
152 let mut current_thinking = String::new();
153 let mut usage = Usage {
154 input_tokens: 0,
155 output_tokens: 0,
156 cache_creation_input_tokens: 0,
157 cache_read_input_tokens: 0,
158 };
159 let mut should_retry = false;
160
161 // Buffered delta for efficient emission
162 let mut buffer = BufferedDelta::new();
163 let mut thinking_started = false;
164 let mut text_started = false;
165
166 loop {
167 // Use biased select! to prioritize stream events over pending input checks
168 // This prevents losing stream events when sleep completes first
169 let event = if let Some(token) = self.session.cancel_token() {
170 tokio::select! {
171 biased;
172
173 // Primary: receive stream event (highest priority)
174 event = rx.recv() => event,
175
176 // Cancellation signal (second priority)
177 _ = wait_for_cancel_stream(token) => {
178 // Flush any pending buffers before cancelling
179 let (text, thinking) = buffer.flush_all();
180 if let Some(t) = thinking {
181 self.emit(AgentEvent::thinking_delta(&t, None))?;
182 }
183 if let Some(t) = text {
184 self.emit(AgentEvent::text_delta(&t))?;
185 }
186 return Err(anyhow::anyhow!("Operation cancelled"));
187 }
188
189 // Check for pending inputs periodically (lowest priority)
190 // Also check for buffer flush by time interval
191 _ = sleep(Duration::from_millis(FLUSH_INTERVAL_MS)) => {
192 self.drain_pending_inputs();
193 // Flush buffers if interval elapsed
194 if buffer.should_flush_by_time(FLUSH_INTERVAL_MS) {
195 if let Some(t) = buffer.flush_thinking() {
196 self.emit(AgentEvent::thinking_delta(&t, None))?;
197 }
198 if let Some(t) = buffer.flush_text() {
199 self.emit(AgentEvent::text_delta(&t))?;
200 }
201 }
202 continue;
203 }
204 }
205 } else {
206 // No cancellation token, but still check pending inputs
207 tokio::select! {
208 biased;
209
210 // Primary: receive stream event (highest priority)
211 event = rx.recv() => event,
212
213 // Check for pending inputs periodically (lower priority)
214 // Also check for buffer flush by time interval
215 _ = sleep(Duration::from_millis(FLUSH_INTERVAL_MS)) => {
216 self.drain_pending_inputs();
217 // Flush buffers if interval elapsed
218 if buffer.should_flush_by_time(FLUSH_INTERVAL_MS) {
219 if let Some(t) = buffer.flush_thinking() {
220 self.emit(AgentEvent::thinking_delta(&t, None))?;
221 }
222 if let Some(t) = buffer.flush_text() {
223 self.emit(AgentEvent::text_delta(&t))?;
224 }
225 }
226 continue;
227 }
228 }
229 };
230
231 match event {
232 None => break,
233 Some(StreamEvent::FirstByte) => {}
234 Some(StreamEvent::ThinkingDelta(delta)) => {
235 // Check cancellation before emitting
236 if self.session.is_cancelled() {
237 return Err(anyhow::anyhow!("Operation cancelled"));
238 }
239 if !thinking_started {
240 self.emit(AgentEvent::thinking_start())?;
241 thinking_started = true;
242 }
243 current_thinking.push_str(&delta);
244 // Buffer the delta and emit if threshold reached
245 if buffer.add_thinking(&delta) {
246 if let Some(t) = buffer.flush_thinking() {
247 self.emit(AgentEvent::thinking_delta(&t, None))?;
248 }
249 }
250 }
251 Some(StreamEvent::TextDelta(delta)) => {
252 // Check cancellation before emitting
253 if self.session.is_cancelled() {
254 return Err(anyhow::anyhow!("Operation cancelled"));
255 }
256 if !text_started {
257 self.emit(AgentEvent::text_start())?;
258 text_started = true;
259 }
260 current_text.push_str(&delta);
261 // Buffer the delta and emit if threshold reached
262 if buffer.add_text(&delta) {
263 if let Some(t) = buffer.flush_text() {
264 self.emit(AgentEvent::text_delta(&t))?;
265 }
266 }
267 }
268 Some(StreamEvent::ToolUseStart { id, name }) => {
269 // Flush any pending buffers before tool use
270 if let Some(t) = buffer.flush_thinking() {
271 self.emit(AgentEvent::thinking_delta(&t, None))?;
272 }
273 if let Some(t) = buffer.flush_text() {
274 self.emit(AgentEvent::text_delta(&t))?;
275 }
276 // Emit events for UI but don't push content blocks
277 // Content will be added from Done event's resp.content
278 if !current_thinking.is_empty() {
279 self.emit(AgentEvent::thinking_end())?;
280 current_thinking.clear();
281 }
282 if !current_text.is_empty() {
283 self.emit(AgentEvent::text_end())?;
284 current_text.clear();
285 }
286 thinking_started = false;
287 text_started = false;
288 self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
289 }
290 Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
291 Some(StreamEvent::ToolInputComplete { id, name, input }) => {
292 self.state.mark_tool_input_previewed(id.clone());
293 self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
294 }
295 Some(StreamEvent::Usage { output_tokens }) => {
296 self.emit(AgentEvent::usage_with_cache(
297 0,
298 output_tokens as u64,
299 0,
300 0,
301 ))?;
302 usage.output_tokens = output_tokens;
303 }
304 Some(StreamEvent::Done(resp)) => {
305 // Check cancellation before processing final response
306 if self.session.is_cancelled() {
307 return Err(anyhow::anyhow!("Operation cancelled"));
308 }
309
310 // Final drain of pending inputs before completing
311 self.drain_pending_inputs();
312
313 // Flush any remaining buffered deltas
314 if let Some(t) = buffer.flush_thinking() {
315 self.emit(AgentEvent::thinking_delta(&t, None))?;
316 }
317 if let Some(t) = buffer.flush_text() {
318 self.emit(AgentEvent::text_delta(&t))?;
319 }
320
321 // IMPORTANT: Add current_thinking/current_text to response_content FIRST
322 // before checking for duplicates from resp.content
323 // This ensures all streamed content is preserved
324 if !current_thinking.is_empty() {
325 self.emit(AgentEvent::thinking_end())?;
326 // Add to response_content with signature from resp if available
327 let signature = resp.content.iter()
328 .find_map(|b| {
329 if let ContentBlock::Thinking { thinking, signature } = b {
330 if thinking == ¤t_thinking {
331 signature.clone()
332 } else {
333 None
334 }
335 } else {
336 None
337 }
338 });
339 response_content.push(ContentBlock::Thinking {
340 thinking: current_thinking.clone(),
341 signature,
342 });
343 current_thinking.clear();
344 }
345 if !current_text.is_empty() {
346 self.emit(AgentEvent::text_end())?;
347 // Add to response_content
348 response_content.push(ContentBlock::Text {
349 text: current_text.clone(),
350 });
351 current_text.clear();
352 }
353
354 // Then add any additional blocks from final response that are NOT duplicates
355 for block in &resp.content {
356 // Smart deduplication: compare content, not entire block
357 let is_duplicate = response_content.iter().any(|b| {
358 match (b, block) {
359 // For Thinking blocks, compare thinking content only (signature may differ)
360 (
361 ContentBlock::Thinking { thinking: t1, .. },
362 ContentBlock::Thinking { thinking: t2, .. },
363 ) => t1 == t2,
364 // For Text blocks, compare text content
365 (
366 ContentBlock::Text { text: t1 },
367 ContentBlock::Text { text: t2 },
368 ) => t1 == t2,
369 // For ToolUse, compare id
370 (
371 ContentBlock::ToolUse { id: id1, .. },
372 ContentBlock::ToolUse { id: id2, .. },
373 ) => id1 == id2,
374 // For ToolResult, compare tool_use_id
375 (
376 ContentBlock::ToolResult {
377 tool_use_id: id1, ..
378 },
379 ContentBlock::ToolResult {
380 tool_use_id: id2, ..
381 },
382 ) => id1 == id2,
383 // Default: exact comparison
384 _ => b == block,
385 }
386 });
387 if !is_duplicate {
388 response_content.push(block.clone());
389 }
390 }
391 usage = resp.usage;
392 }
393 Some(StreamEvent::Error(msg)) => {
394 if attempt < MAX_RETRIES {
395 self.emit(AgentEvent::progress(
396 format!(
397 "⚠️ Stream error, retrying ({}/{}): {}",
398 attempt, MAX_RETRIES, &msg
399 ),
400 None,
401 ))?;
402 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
403 tokio::time::sleep(tokio::time::Duration::from_millis(delay))
404 .await;
405 should_retry = true;
406 break;
407 } else {
408 self.emit(AgentEvent::error(msg.clone(), None, None))?;
409 return Err(anyhow::anyhow!(
410 "Stream error after {} retries: {}",
411 MAX_RETRIES,
412 msg
413 ));
414 }
415 }
416 }
417 }
418
419 if should_retry {
420 continue;
421 }
422
423 return Ok(ChatResponse {
424 content: response_content,
425 stop_reason: StopReason::EndTurn,
426 usage,
427 });
428 }
429 Err(e) => {
430 if attempt < MAX_RETRIES {
431 let error_msg = e.to_string();
432 self.emit(AgentEvent::progress(
433 format!(
434 "⚠️ API error, retrying ({}/{}): {}",
435 attempt, MAX_RETRIES, &error_msg
436 ),
437 None,
438 ))?;
439 let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
440 tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
441 } else {
442 return Err(anyhow::anyhow!(
443 "API error after {} retries: {}",
444 MAX_RETRIES,
445 e
446 ));
447 }
448 }
449 }
450 }
451 }
452}