Skip to main content

matrixcode_core/agent/
streaming.rs

1//! Agent streaming implementation.
2
3use anyhow::Result;
4use tokio::time::{Duration, sleep};
5
6use crate::event::AgentEvent;
7use crate::providers::{ChatRequest, ChatResponse, ContentBlock, StopReason, StreamEvent, Usage};
8
9use super::types::Agent;
10
11/// Wait for cancellation signal, checking periodically.
12async fn wait_for_cancel_stream(token: &crate::cancel::CancellationToken) {
13    while !token.is_cancelled() {
14        sleep(Duration::from_millis(100)).await;
15    }
16}
17
18impl Agent {
19    /// Drain any pending input messages from the channel.
20    /// Called during streaming to collect real-time appended messages.
21    pub(crate) fn drain_pending_inputs(&mut self) {
22        let inputs = self.session.drain_pending_inputs();
23        for msg in inputs {
24            log::info!(
25                "Agent received pending input: {}",
26                msg.chars().take(50).collect::<String>()
27            );
28            self.state.add_pending_input(msg);
29        }
30    }
31
32    /// Check if there are pending inputs waiting to be processed.
33    pub fn has_pending_inputs(&self) -> bool {
34        self.state.has_pending_inputs()
35    }
36
37    /// Get and clear all pending inputs.
38    pub fn take_pending_inputs(&mut self) -> Vec<String> {
39        self.state.take_pending_inputs()
40    }
41
42    /// Call provider with streaming and emit events in real-time.
43    /// Also monitors pending_input_rx for real-time message appending.
44    pub(crate) async fn call_streaming(&mut self, request: &ChatRequest) -> Result<ChatResponse> {
45        const MAX_RETRIES: u32 = 5;
46        const RETRY_DELAY_MS: u64 = 1000;
47
48        let mut attempt = 0;
49
50        loop {
51            attempt += 1;
52            log::info!(
53                "Agent: API call attempt {} with {} messages",
54                attempt,
55                request.messages.len()
56            );
57
58            if self.session.is_cancelled() {
59                return Err(anyhow::anyhow!("Operation cancelled"));
60            }
61
62            log::info!("Agent: calling provider.chat_stream");
63            let rx_result = self.provider.chat_stream(request.clone()).await;
64            log::info!("Agent: provider.chat_stream returned");
65
66            match rx_result {
67                Ok(mut rx) => {
68                    let mut response_content: Vec<ContentBlock> = Vec::new();
69                    let mut current_text = String::new();
70                    let mut current_thinking = String::new();
71                    let mut usage = Usage {
72                        input_tokens: 0,
73                        output_tokens: 0,
74                        cache_creation_input_tokens: 0,
75                        cache_read_input_tokens: 0,
76                    };
77                    let mut should_retry = false;
78
79                    loop {
80                        // Use biased select! to prioritize stream events over pending input checks
81                        // This prevents losing stream events when sleep completes first
82                        let event = if let Some(token) = self.session.cancel_token() {
83                            tokio::select! {
84                                biased;
85
86                                // Primary: receive stream event (highest priority)
87                                event = rx.recv() => event,
88
89                                // Cancellation signal (second priority)
90                                _ = wait_for_cancel_stream(token) => {
91                                    return Err(anyhow::anyhow!("Operation cancelled"));
92                                }
93
94                                // Check for pending inputs periodically (lowest priority)
95                                // Increased interval to reduce competition with stream events
96                                _ = sleep(Duration::from_millis(200)) => {
97                                    self.drain_pending_inputs();
98                                    continue;
99                                }
100                            }
101                        } else {
102                            // No cancellation token, but still check pending inputs
103                            tokio::select! {
104                                biased;
105
106                                // Primary: receive stream event (highest priority)
107                                event = rx.recv() => event,
108
109                                // Check for pending inputs periodically (lower priority)
110                                _ = sleep(Duration::from_millis(200)) => {
111                                    self.drain_pending_inputs();
112                                    continue;
113                                }
114                            }
115                        };
116
117                        match event {
118                            None => break,
119                            Some(StreamEvent::FirstByte) => {}
120                            Some(StreamEvent::ThinkingDelta(delta)) => {
121                                // Check cancellation before emitting
122                                if self.session.is_cancelled() {
123                                    return Err(anyhow::anyhow!("Operation cancelled"));
124                                }
125                                if current_thinking.is_empty() {
126                                    self.emit(AgentEvent::thinking_start())?;
127                                }
128                                current_thinking.push_str(&delta);
129                                self.emit(AgentEvent::thinking_delta(delta, None))?;
130                            }
131                            Some(StreamEvent::TextDelta(delta)) => {
132                                // Check cancellation before emitting
133                                if self.session.is_cancelled() {
134                                    return Err(anyhow::anyhow!("Operation cancelled"));
135                                }
136                                if current_text.is_empty() {
137                                    self.emit(AgentEvent::text_start())?;
138                                }
139                                current_text.push_str(&delta);
140                                self.emit(AgentEvent::text_delta(delta))?;
141                            }
142                            Some(StreamEvent::ToolUseStart { id, name }) => {
143                                // Emit events for UI but don't push content blocks
144                                // Content will be added from Done event's resp.content
145                                if !current_thinking.is_empty() {
146                                    self.emit(AgentEvent::thinking_end())?;
147                                    current_thinking.clear();
148                                }
149                                if !current_text.is_empty() {
150                                    self.emit(AgentEvent::text_end())?;
151                                    current_text.clear();
152                                }
153                                self.emit(AgentEvent::tool_use_start(&id, &name, None))?;
154                            }
155                            Some(StreamEvent::ToolInputDelta { bytes_so_far: _ }) => {}
156                            Some(StreamEvent::ToolInputComplete { id, name, input }) => {
157                                self.state.mark_tool_input_previewed(id.clone());
158                                self.emit(AgentEvent::tool_use_start(&id, &name, Some(input)))?;
159                            }
160                            Some(StreamEvent::Usage { output_tokens }) => {
161                                self.emit(AgentEvent::usage_with_cache(
162                                    0,
163                                    output_tokens as u64,
164                                    0,
165                                    0,
166                                ))?;
167                                usage.output_tokens = output_tokens;
168                            }
169                            Some(StreamEvent::Done(resp)) => {
170                                // Check cancellation before processing final response
171                                if self.session.is_cancelled() {
172                                    return Err(anyhow::anyhow!("Operation cancelled"));
173                                }
174
175                                // Final drain of pending inputs before completing
176                                self.drain_pending_inputs();
177
178                                // IMPORTANT: Add current_thinking/current_text to response_content FIRST
179                                // before checking for duplicates from resp.content
180                                // This ensures all streamed content is preserved
181                                if !current_thinking.is_empty() {
182                                    self.emit(AgentEvent::thinking_end())?;
183                                    // Add to response_content with signature from resp if available
184                                    let signature = resp.content.iter()
185                                        .find_map(|b| {
186                                            if let ContentBlock::Thinking { thinking, signature } = b {
187                                                if thinking == &current_thinking {
188                                                    signature.clone()
189                                                } else {
190                                                    None
191                                                }
192                                            } else {
193                                                None
194                                            }
195                                        });
196                                    response_content.push(ContentBlock::Thinking {
197                                        thinking: current_thinking.clone(),
198                                        signature,
199                                    });
200                                    current_thinking.clear();
201                                }
202                                if !current_text.is_empty() {
203                                    self.emit(AgentEvent::text_end())?;
204                                    // Add to response_content
205                                    response_content.push(ContentBlock::Text {
206                                        text: current_text.clone(),
207                                    });
208                                    current_text.clear();
209                                }
210
211                                // Then add any additional blocks from final response that are NOT duplicates
212                                for block in &resp.content {
213                                    // Smart deduplication: compare content, not entire block
214                                    let is_duplicate = response_content.iter().any(|b| {
215                                        match (b, block) {
216                                            // For Thinking blocks, compare thinking content only (signature may differ)
217                                            (
218                                                ContentBlock::Thinking { thinking: t1, .. },
219                                                ContentBlock::Thinking { thinking: t2, .. },
220                                            ) => t1 == t2,
221                                            // For Text blocks, compare text content
222                                            (
223                                                ContentBlock::Text { text: t1 },
224                                                ContentBlock::Text { text: t2 },
225                                            ) => t1 == t2,
226                                            // For ToolUse, compare id
227                                            (
228                                                ContentBlock::ToolUse { id: id1, .. },
229                                                ContentBlock::ToolUse { id: id2, .. },
230                                            ) => id1 == id2,
231                                            // For ToolResult, compare tool_use_id
232                                            (
233                                                ContentBlock::ToolResult {
234                                                    tool_use_id: id1, ..
235                                                },
236                                                ContentBlock::ToolResult {
237                                                    tool_use_id: id2, ..
238                                                },
239                                            ) => id1 == id2,
240                                            // Default: exact comparison
241                                            _ => b == block,
242                                        }
243                                    });
244                                    if !is_duplicate {
245                                        response_content.push(block.clone());
246                                    }
247                                }
248                                usage = resp.usage;
249                            }
250                            Some(StreamEvent::Error(msg)) => {
251                                if attempt < MAX_RETRIES {
252                                    self.emit(AgentEvent::progress(
253                                        format!(
254                                            "⚠️ Stream error, retrying ({}/{}): {}",
255                                            attempt, MAX_RETRIES, &msg
256                                        ),
257                                        None,
258                                    ))?;
259                                    let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
260                                    tokio::time::sleep(tokio::time::Duration::from_millis(delay))
261                                        .await;
262                                    should_retry = true;
263                                    break;
264                                } else {
265                                    self.emit(AgentEvent::error(msg.clone(), None, None))?;
266                                    return Err(anyhow::anyhow!(
267                                        "Stream error after {} retries: {}",
268                                        MAX_RETRIES,
269                                        msg
270                                    ));
271                                }
272                            }
273                        }
274                    }
275
276                    if should_retry {
277                        continue;
278                    }
279
280                    return Ok(ChatResponse {
281                        content: response_content,
282                        stop_reason: StopReason::EndTurn,
283                        usage,
284                    });
285                }
286                Err(e) => {
287                    if attempt < MAX_RETRIES {
288                        let error_msg = e.to_string();
289                        self.emit(AgentEvent::progress(
290                            format!(
291                                "⚠️ API error, retrying ({}/{}): {}",
292                                attempt, MAX_RETRIES, &error_msg
293                            ),
294                            None,
295                        ))?;
296                        let delay = RETRY_DELAY_MS * (1 << (attempt - 1));
297                        tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
298                    } else {
299                        return Err(anyhow::anyhow!(
300                            "API error after {} retries: {}",
301                            MAX_RETRIES,
302                            e
303                        ));
304                    }
305                }
306            }
307        }
308    }
309}