Skip to main content

matrixcode_core/agent/
streaming.rs

1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep, Instant};
5
6use crate::constants::STREAM_DELTA_BUFFER_SIZE;
7use crate::event::AgentEvent;
8use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
9
10use super::types::Agent;
11
12/// Buffered delta for efficient event emission
13#[derive(Debug)]
14struct BufferedDelta {
15    text: String,
16    thinking: String,
17    last_emit: Instant,
18}
19
20impl Default for BufferedDelta {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl BufferedDelta {
27    fn new() -> Self {
28        Self {
29            text: String::new(),
30            thinking: String::new(),
31            last_emit: Instant::now(),
32        }
33    }
34
35    /// Add text delta to buffer, returns true if should flush
36    fn add_text(&mut self, delta: &str) -> bool {
37        self.text.push_str(delta);
38        self.should_flush_text()
39    }
40
41    /// Add thinking delta to buffer, returns true if should flush
42    fn add_thinking(&mut self, delta: &str) -> bool {
43        self.thinking.push_str(delta);
44        self.should_flush_thinking()
45    }
46
47    fn should_flush_text(&self) -> bool {
48        self.text.len() >= STREAM_DELTA_BUFFER_SIZE
49    }
50
51    fn should_flush_thinking(&self) -> bool {
52        self.thinking.len() >= STREAM_DELTA_BUFFER_SIZE
53    }
54
55    /// Check if buffer needs flush due to time interval
56    fn should_flush_by_time(&self, interval_ms: u64) -> bool {
57        self.last_emit.elapsed().as_millis() >= interval_ms as u128
58            && (!self.text.is_empty() || !self.thinking.is_empty())
59    }
60
61    /// Flush text buffer, returns content if non-empty
62    fn flush_text(&mut self) -> Option<String> {
63        if self.text.is_empty() {
64            return None;
65        }
66        let content = self.text.clone();
67        self.text.clear();
68        self.last_emit = Instant::now();
69        Some(content)
70    }
71
72    /// Flush thinking buffer, returns content if non-empty
73    fn flush_thinking(&mut self) -> Option<String> {
74        if self.thinking.is_empty() {
75            return None;
76        }
77        let content = self.thinking.clone();
78        self.thinking.clear();
79        self.last_emit = Instant::now();
80        Some(content)
81    }
82
83    /// Flush all buffers
84    fn flush_all(&mut self) -> (Option<String>, Option<String>) {
85        let text = self.flush_text();
86        let thinking = self.flush_thinking();
87        (text, thinking)
88    }
89}
90
91/// Wait for cancellation signal, checking periodically.
92async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
93    while !token.is_cancelled() {
94        sleep(Duration::from_millis(100)).await;
95    }
96}
97
98impl Agent {
99    /// Drain any pending input messages from the channel.
100    /// Called during streaming to collect real-time appended messages.
101    pub(crate) fn drain_pending_inputs(&mut self) {
102        let inputs = self.session.drain_pending_inputs();
103        for msg in inputs {
104            log::info!(
105                "Agent received pending input: {}",
106                msg.chars().take(50).collect::<String>()
107            );
108            self.state.add_pending_input(msg);
109        }
110    }
111
112    /// Check if there are pending inputs waiting to be processed.
113    pub fn has_pending_inputs(&self) -> bool {
114        self.state.has_pending_inputs()
115    }
116
117    /// Get and clear all pending inputs.
118    pub fn take_pending_inputs(&mut self) -> Vec<String> {
119        self.state.take_pending_inputs()
120    }
121
122    /// Call provider with streaming and emit events in real-time.
123    /// Also monitors pending_input_rx for real-time message appending.
124    /// Uses buffered delta emission to reduce event frequency.
125    pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
126        const MAX_RETRIES: u32 = 5;
127        const RETRY_DELAY_MS: u64 = 1000;
128        const FLUSH_INTERVAL_MS: u64 = crate::constants::STREAM_DELTA_FLUSH_INTERVAL_MS;
129
130        let mut attempt = 0;
131
132        loop {
133            attempt += 1;
134            log::info!(
135                "Agent: API call attempt {} with {} messages",
136                attempt,
137                request.messages.len()
138            );
139
140            if self.session.is_cancelled() {
141                return Err(anyhow::anyhow!("Operation cancelled"));
142            }
143
144            log::info!("Agent: calling provider.chat_stream");
145            let rx_result = self.provider.chat_stream(request.clone()).await;
146            log::info!("Agent: provider.chat_stream returned");
147
148            match rx_result {
149                Ok(mut rx) => {
150                    let mut response_content: Vec<ContentBlock> = Vec::new();
151                    let mut current_text = String::new();
152                    let mut current_thinking = String::new();
153                    let mut usage = Usage {
154                        input_tokens: 0,
155                        output_tokens: 0,
156                        cache_creation_input_tokens: 0,
157                        cache_read_input_tokens: 0,
158                    };
159                    let mut should_retry = false;
160
161                    // Buffered delta for efficient emission
162                    let mut buffer = BufferedDelta::new();
163                    let mut thinking_started = false;
164                    let mut text_started = false;
165
166                    loop {
167                        // Use biased select! to prioritize stream events over pending input checks
168                        // This prevents losing stream events when sleep completes first
169                        let event = if let Some(token) = self.session.cancel_token() {
170                            tokio::select! {
171                                biased;
172
173                                // Primary: receive stream event (highest priority)
174                                event = rx.recv() => event,
175
176                                // Cancellation signal (second priority)
177                                _ = wait_for_cancel_stream(token) => {
178                                    // Flush any pending buffers before cancelling
179                                    let (text, thinking) = buffer.flush_all();
180                                    if let Some(t) = thinking {
181                                        self.emit(AgentEvent::thinking_delta(&t, None))?;
182                                    }
183                                    if let Some(t) = text {
184                                        self.emit(AgentEvent::text_delta(&t))?;
185                                    }
186                                    return Err(anyhow::anyhow!("Operation cancelled"));
187                                }
188
189                                // Check for pending inputs periodically (lowest priority)
190                                // Also check for buffer flush by time interval
191                                _ = sleep(Duration::from_millis(FLUSH_INTERVAL_MS)) => {
192                                    self.drain_pending_inputs();
193                                    // Flush buffers if interval elapsed
194                                    if buffer.should_flush_by_time(FLUSH_INTERVAL_MS) {
195                                        if let Some(t) = buffer.flush_thinking() {
196                                            self.emit(AgentEvent::thinking_delta(&t, None))?;
197                                        }
198                                        if let Some(t) = buffer.flush_text() {
199                                            self.emit(AgentEvent::text_delta(&t))?;
200                                        }
201                                    }
202                                    continue;
203                                }
204                            }
205                        } else {
206                            // No cancellation token, but still check pending inputs
207                            tokio::select! {
208                                biased;
209
210                                // Primary: receive stream event (highest priority)
211                                event = rx.recv() => event,
212
213                                // Check for pending inputs periodically (lower priority)
214                                // Also check for buffer flush by time interval
215                                _ = sleep(Duration::from_millis(FLUSH_INTERVAL_MS)) => {
216                                    self.drain_pending_inputs();
217                                    // Flush buffers if interval elapsed
218                                    if buffer.should_flush_by_time(FLUSH_INTERVAL_MS) {
219                                        if let Some(t) = buffer.flush_thinking() {
220                                            self.emit(AgentEvent::thinking_delta(&t, None))?;
221                                        }
222                                        if let Some(t) = buffer.flush_text() {
223                                            self.emit(AgentEvent::text_delta(&t))?;
224                                        }
225                                    }
226                                    continue;
227                                }
228                            }
229                        };
230
231                        match event {
232                            None => break,
233                            Some(StreamEvent::FirstByte) => {}
234                            Some(StreamEvent::ThinkingDelta(delta)) => {
235                                // Check cancellation before emitting
236                                if self.session.is_cancelled() {
237                                    return Err(anyhow::anyhow!("Operation cancelled"));
238                                }
239                                if !thinking_started {
240                                    self.emit(AgentEvent::thinking_start())?;
241                                    thinking_started = true;
242                                }
243                                current_thinking.push_str(&delta);
244                                // Buffer the delta and emit if threshold reached
245                                if buffer.add_thinking(&delta) {
246                                    if let Some(t) = buffer.flush_thinking() {
247                                        self.emit(AgentEvent::thinking_delta(&t, None))?;
248                                    }
249                                }
250                            }
251                            Some(StreamEvent::TextDelta(delta)) => {
252                                // Check cancellation before emitting
253                                if self.session.is_cancelled() {
254                                    return Err(anyhow::anyhow!("Operation cancelled"));
255                                }
256                                if !text_started {
257                                    self.emit(AgentEvent::text_start())?;
258                                    text_started = true;
259                                }
260                                current_text.push_str(&delta);
261                                // Buffer the delta and emit if threshold reached
262                                if buffer.add_text(&delta) {
263                                    if let Some(t) = buffer.flush_text() {
264                                        self.emit(AgentEvent::text_delta(&t))?;
265                                    }
266                                }
267                            }
268                            Some(StreamEvent::ToolUseStart { id, name }) => {
269                                // Flush any pending buffers before tool use
270                                if let Some(t) = buffer.flush_thinking() {
271                                    self.emit(AgentEvent::thinking_delta(&t, None))?;
272                                }
273                                if let Some(t) = buffer.flush_text() {
274                                    self.emit(AgentEvent::text_delta(&t))?;
275                                }
276                                // Emit events for UI but don't push content blocks
277                                // Content will be added from Done event's resp.content
278                                if !current_thinking.is_empty() {
279                                    self.emit(AgentEvent::thinking_end())?;
280                                    current_thinking.clear();
281                                }
282                                if !current_text.is_empty() {
283                                    self.emit(AgentEvent::text_end())?;
284                                    current_text.clear();
285                                }
286                                thinking_started = false;
287                                text_started = false;
288                                self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
289                            }
290                            Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
291                            Some(StreamEvent::ToolInputComplete { id, name, input }) => {
292                                self.state.mark_tool_input_previewed(id.clone());
293                                self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
294                            }
295                            Some(StreamEvent::Usage { output_tokens }) => {
296                                self.emit(AgentEvent::usage_with_cache(
297                                    0,
298                                    output_tokens as u64,
299                                    0,
300                                    0,
301                                ))?;
302                                usage.output_tokens = output_tokens;
303                            }
304                            Some(StreamEvent::Done(resp)) => {
305                                // Check cancellation before processing final response
306                                if self.session.is_cancelled() {
307                                    return Err(anyhow::anyhow!("Operation cancelled"));
308                                }
309
310                                // Final drain of pending inputs before completing
311                                self.drain_pending_inputs();
312
313                                // Flush any remaining buffered deltas
314                                if let Some(t) = buffer.flush_thinking() {
315                                    self.emit(AgentEvent::thinking_delta(&t, None))?;
316                                }
317                                if let Some(t) = buffer.flush_text() {
318                                    self.emit(AgentEvent::text_delta(&t))?;
319                                }
320
321                                // IMPORTANT: Add current_thinking/current_text to response_content FIRST
322                                // before checking for duplicates from resp.content
323                                // This ensures all streamed content is preserved
324                                if !current_thinking.is_empty() {
325                                    self.emit(AgentEvent::thinking_end())?;
326                                    // Add to response_content with signature from resp if available
327                                    let signature = resp.content.iter()
328                                        .find_map(|b| {
329                                            if let ContentBlock::Thinking { thinking, signature } = b {
330                                                if thinking == &current_thinking {
331                                                    signature.clone()
332                                                } else {
333                                                    None
334                                                }
335                                            } else {
336                                                None
337                                            }
338                                        });
339                                    response_content.push(ContentBlock::Thinking {
340                                        thinking: current_thinking.clone(),
341                                        signature,
342                                    });
343                                    current_thinking.clear();
344                                }
345                                if !current_text.is_empty() {
346                                    self.emit(AgentEvent::text_end())?;
347                                    // Add to response_content
348                                    response_content.push(ContentBlock::Text {
349                                        text: current_text.clone(),
350                                    });
351                                    current_text.clear();
352                                }
353
354                                // Then add any additional blocks from final response that are NOT duplicates
355                                for block in &resp.content {
356                                    // Smart deduplication: compare content, not entire block
357                                    let is_duplicate = response_content.iter().any(|b| {
358                                        match (b, block) {
359                                            // For Thinking blocks, compare thinking content only (signature may differ)
360                                            (
361                                                ContentBlock::Thinking { thinking: t1, .. },
362                                                ContentBlock::Thinking { thinking: t2, .. },
363                                            ) => t1 == t2,
364                                            // For Text blocks, compare text content
365                                            (
366                                                ContentBlock::Text { text: t1 },
367                                                ContentBlock::Text { text: t2 },
368                                            ) => t1 == t2,
369                                            // For ToolUse, compare id
370                                            (
371                                                ContentBlock::ToolUse { id: id1, .. },
372                                                ContentBlock::ToolUse { id: id2, .. },
373                                            ) => id1 == id2,
374                                            // For ToolResult, compare tool_use_id
375                                            (
376                                                ContentBlock::ToolResult {
377                                                    tool_use_id: id1, ..
378                                                },
379                                                ContentBlock::ToolResult {
380                                                    tool_use_id: id2, ..
381                                                },
382                                            ) => id1 == id2,
383                                            // Default: exact comparison
384                                            _ => b == block,
385                                        }
386                                    });
387                                    if !is_duplicate {
388                                        response_content.push(block.clone());
389                                    }
390                                }
391                                usage = resp.usage;
392                            }
393                            Some(StreamEvent::Error(msg)) => {
394                                if attempt < MAX_RETRIES {
395                                    self.emit(AgentEvent::progress(
396                                        format!(
397                                            "⚠️ Stream error, retrying ({}/{}): {}",
398                                            attempt, MAX_RETRIES, &msg
399                                        ),
400                                        None,
401                                    ))?;
402                                    let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
403                                    tokio::time::sleep(tokio::time::Duration::from_millis(delay))
404                                        .await;
405                                    should_retry = true;
406                                    break;
407                                } else {
408                                    self.emit(AgentEvent::error(msg.clone(), None, None))?;
409                                    return Err(anyhow::anyhow!(
410                                        "Stream error after {} retries: {}",
411                                        MAX_RETRIES,
412                                        msg
413                                    ));
414                                }
415                            }
416                        }
417                    }
418
419                    if should_retry {
420                        continue;
421                    }
422
423                    return Ok(ChatResponse {
424                        content: response_content,
425                        stop_reason: StopReason::EndTurn,
426                        usage,
427                    });
428                }
429                Err(e) => {
430                    if attempt < MAX_RETRIES {
431                        let error_msg = e.to_string();
432                        self.emit(AgentEvent::progress(
433                            format!(
434                                "⚠️ API error, retrying ({}/{}): {}",
435                                attempt, MAX_RETRIES, &error_msg
436                            ),
437                            None,
438                        ))?;
439                        let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
440                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
441                    } else {
442                        return Err(anyhow::anyhow!(
443                            "API error after {} retries: {}",
444                            MAX_RETRIES,
445                            e
446                        ));
447                    }
448                }
449            }
450        }
451    }
452}